pytorch/multipy

package issues with functions under C extensions

d4l3k opened this issue · 1 comments

d4l3k commented
import torch

from torch.package import PackageExporter, PackageImporter

output_path = "/tmp/model.pt"



def save_load(model):
    with PackageExporter(output_path) as e:
        e.extern("torch.**")
        e.intern("**")
    
        e.save_pickle("model", "model.pkl", model)
    
    imp = PackageImporter(output_path)
    return imp.load_pickle("model", "model.pkl")

    print("pass")


model = torch.nn.TransformerEncoderLayer(
        d_model=64,
        nhead=2,
    dim_feedforward=64,
    dropout=1.0,
    batch_first=True,
    activation='gelu',
    norm_first=True,
)
save_load(model)

The issue is that F.gelu can't be loaded from package due to a nimport error

ModuleNotFoundError: No module named 'torch._C._nn'; 'torch._C' is not a package
d4l3k commented

You can work around this by avoiding adding any functional methods to the class ie. avoid self.foo = F.gelu