package issues with functions under C extensions
d4l3k opened this issue · 1 comments
d4l3k commented
import torch
from torch.package import PackageExporter, PackageImporter
output_path = "/tmp/model.pt"
def save_load(model):
with PackageExporter(output_path) as e:
e.extern("torch.**")
e.intern("**")
e.save_pickle("model", "model.pkl", model)
imp = PackageImporter(output_path)
return imp.load_pickle("model", "model.pkl")
print("pass")
model = torch.nn.TransformerEncoderLayer(
d_model=64,
nhead=2,
dim_feedforward=64,
dropout=1.0,
batch_first=True,
activation='gelu',
norm_first=True,
)
save_load(model)
The issue is that F.gelu
can't be loaded from package due to a nimport error
ModuleNotFoundError: No module named 'torch._C._nn'; 'torch._C' is not a package
d4l3k commented
You can work around this by avoiding adding any functional methods to the class ie. avoid self.foo = F.gelu