jacobgil/pytorch-pruning

torch.load(pruned_model)

wuzhiyang2016 opened this issue · 7 comments

when i use torch.load() to load pruned model , error happened: AttributeError: 'module' object has no attribute 'ModifiedVGG16Model', anyone meet this problem?

@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?

The code in the repo saves the entire model with pickling, instead of the state dict, which is actually a bad practice.
A better way would be to save only the state_dict, and then load it.
Going to change it now to

state_dict = model.state_dict()
save(state_dict, 'model.chkpt')

Then you can load the model like this:

model = ModifiedVGG16Model()
checkpoint = torch.load(checkpoint_path, \
        map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint)
model.eval()

@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?

like the author, when we load the saved model , ModifiedVGG16Model() should be defined

@jacobgil Hey! There is a problem with size mismatch. For example: RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).

Is there way to solve it ?

@jacobgil Hey! There is a problem with size mismatch. For example: RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).

Is there way to solve it ?

@jacobgil : Since the pruned model will be having different in and out filters in each layer. Is there any way to load the pruned model state dictionary with original model class?

@ms-krajesh Probably not a solution but I just rewrote model architecture with altered layer numbers.