Cannot export modules that include non-persistent buffers with pfto
Closed this issue · 0 comments
hvy commented
Including non-persistent buffers causes KeyError
during pfto export.
Repro
Add the following unit test.
def test_keep_initializers_as_inputs_non_persistent():
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("var", torch.rand(10, 10), persistent=False)
def forward(self, x):
return self.var + x
model: onnx.ModelProto = run_model_test(Model(), (torch.rand((10,)),), keep_initializers_as_inputs=False)
assert len(model.graph.input) == 1
model = run_model_test(Model(), (torch.rand((1,)),), keep_initializers_as_inputs=True)
assert len(model.graph.input) == 2
I'm not entirely sure about the expected outcomes, but at least not a KeyError
.
Notice the following error.
...
def _run_trace(self) -> None:
# TODO(twata): Use `torch._C._craete_graph_by_tracing` instead.
# So that we don't need to run heavy models multiple times
self._restore_state()
with grad.init_grad_state():
self.traced: torch.jit.RecursiveScriptModule = torch.jit.trace( # type: ignore
self.original_model,
self.inputs,
check_trace=self.check_trace,
strict=self.strict_trace,
_force_outplace=self.force_outplace_trace,
)
self.graph_doc_string = f"""
# Model: {self.traced.original_name}
"""
self.g: torch._C.Graph = self.traced.inlined_graph
"""
`self.trace` ignores the override of `state_dict` method in `self.original_model`.
Thus, the key name may be different between state dict of `self.trace` and `self.original_model`.
pfto uses the key name of `self.original_model.state_dict()` as the parameter names in ONNX.
To implement this behavior, we have to prepare mapping from name of `self.trace` state_dict to
the name of `self.original_model` state_dict.
"""
self.name_from_trace: Dict[str, str] = {}
vars_in_traced: Dict[str, torch.IValue] = {
_remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict().items()
}
if isinstance(self.original_model, torch.nn.Module):
vars_tmp: Dict[str, Any] = {
_remove_prefix(k, f"{_ppe_ignore_scope}."): v for k, v in self.traced.state_dict(keep_vars=True).items()
}
v_to_name: Dict[Any, str] = {v: k for k, v in self.original_model.state_dict(keep_vars=True).items()}
for name, v in vars_tmp.items():
> self.name_from_trace[name] = v_to_name[v]
E KeyError: tensor([[0.7764, 0.1132, 0.1047, 0.5981, 0.7832, 0.7895, 0.9826, 0.9846, 0.8980,
E 0.2705],
E [0.9972, 0.6791, 0.2659, 0.2546, 0.7653, 0.3500, 0.4025, 0.4878, 0.8704,
E 0.9722],
E [0.6881, 0.1135, 0.9128, 0.9024, 0.6175, 0.5133, 0.7324, 0.8930, 0.0665,
E 0.4134],
E [0.6376, 0.5635, 0.5990, 0.8954, 0.0362, 0.5462, 0.9238, 0.3544, 0.5266,
E 0.2794],
E [0.3645, 0.4708, 0.4369, 0.4414, 0.1465, 0.7628, 0.8381, 0.2780, 0.4223,
E 0.0478],
E [0.3141, 0.6476, 0.0248, 0.7387, 0.5339, 0.3359, 0.1631, 0.7937, 0.8205,
E 0.4865],
E [0.9292, 0.3566, 0.7268, 0.5959, 0.5767, 0.3358, 0.5907, 0.4227, 0.6223,
E 0.1155],
E [0.0794, 0.1254, 0.9170, 0.4346, 0.2727, 0.3781, 0.0956, 0.3128, 0.6857,
E 0.8727],
E [0.3904, 0.3542, 0.9379, 0.5776, 0.2552, 0.9822, 0.3229, 0.2070, 0.4196,
E 0.5795],
E [0.7628, 0.9556, 0.1374, 0.2586, 0.9910, 0.0180, 0.5521, 0.2260, 0.5857,
E 0.4884]])
pytorch_pfn_extras/onnx/pfto_exporter/export.py:365: KeyError
This is seemingly because the state_dict()
method does not include non-persistent buffers, while the returned result from jit.trace
does.
Please let me know if there is anything I can do to help. 🙇