pfnet/pytorch-pfn-extras

Cannot export modules that include non-persistent buffers with pfto

Closed this issue · 0 comments

hvy commented

Including non-persistent buffers causes KeyError during pfto export.

Repro

Add the following unit test.

def test_keep_initializers_as_inputs_non_persistent():
    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.register_buffer("var", torch.rand(10, 10), persistent=False)

        def forward(self, x):
            return self.var + x

    model: onnx.ModelProto = run_model_test(Model(), (torch.rand((10,)),), keep_initializers_as_inputs=False)
    assert len(model.graph.input) == 1
    model = run_model_test(Model(), (torch.rand((1,)),), keep_initializers_as_inputs=True)
    assert len(model.graph.input) == 2

I'm not entirely sure about the expected outcomes, but at least not a KeyError.

Notice the following error.

...
        def _run_trace(self) -> None:
            # TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
            # So that we don't need to run heavy models multiple times
            self._restore_state()
            with grad.init_grad_state():
                self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace(  # type: ignore
                    self.original_model,
                    self.inputs,
                    check_trace=self.check_trace,
                    strict=self.strict_trace,
                    _force_outplace=self.force_outplace_trace,
                )

            self.graph_doc_string = f"""
    # Model: {self.traced.original_name}
    """

            self.g: torch._C.Graph = self.traced.inlined_graph
            """
            `self.trace` ignores the override of `state_dict` method in `self.original_model`.
            Thus, the key name may be different between state dict of `self.trace` and `self.original_model`.
            pfto uses the key name of `self.original_model.state_dict()` as the parameter names in ONNX.

            To implement this behavior, we have to prepare mapping from name of `self.trace` state_dict to
            the name of `self.original_model` state_dict.
            """
            self.name_from_trace: Dict[str, str] = {}
            vars_in_traced: Dict[str, torch.IValue] = {
                _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict().items()
            }
            if isinstance(self.original_model, torch.nn.Module):
                vars_tmp: Dict[str, Any] = {
                    _remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict(keep_vars=True).items()
                }
                v_to_name: Dict[Any, str] = {v: k for k, v in self.original_model.state_dict(keep_vars=True).items()}
                for name, v in vars_tmp.items():
>                   self.name_from_trace[name] = v_to_name[v]
E                   KeyError: tensor([[0.7764, 0.1132, 0.1047, 0.5981, 0.7832, 0.7895, 0.9826, 0.9846, 0.8980,
E                            0.2705],
E                           [0.9972, 0.6791, 0.2659, 0.2546, 0.7653, 0.3500, 0.4025, 0.4878, 0.8704,
E                            0.9722],
E                           [0.6881, 0.1135, 0.9128, 0.9024, 0.6175, 0.5133, 0.7324, 0.8930, 0.0665,
E                            0.4134],
E                           [0.6376, 0.5635, 0.5990, 0.8954, 0.0362, 0.5462, 0.9238, 0.3544, 0.5266,
E                            0.2794],
E                           [0.3645, 0.4708, 0.4369, 0.4414, 0.1465, 0.7628, 0.8381, 0.2780, 0.4223,
E                            0.0478],
E                           [0.3141, 0.6476, 0.0248, 0.7387, 0.5339, 0.3359, 0.1631, 0.7937, 0.8205,
E                            0.4865],
E                           [0.9292, 0.3566, 0.7268, 0.5959, 0.5767, 0.3358, 0.5907, 0.4227, 0.6223,
E                            0.1155],
E                           [0.0794, 0.1254, 0.9170, 0.4346, 0.2727, 0.3781, 0.0956, 0.3128, 0.6857,
E                            0.8727],
E                           [0.3904, 0.3542, 0.9379, 0.5776, 0.2552, 0.9822, 0.3229, 0.2070, 0.4196,
E                            0.5795],
E                           [0.7628, 0.9556, 0.1374, 0.2586, 0.9910, 0.0180, 0.5521, 0.2260, 0.5857,
E                            0.4884]])

pytorch_pfn_extras/onnx/pfto_exporter/export.py:365: KeyError

This is seemingly because the state_dict() method does not include non-persistent buffers, while the returned result from jit.trace does.

Please let me know if there is anything I can do to help. 🙇