torch.load(pruned_model)
wuzhiyang2016 opened this issue · 7 comments
when i use torch.load() to load pruned model , error happened: AttributeError: 'module' object has no attribute 'ModifiedVGG16Model', anyone meet this problem?
@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?
@jacobgil Any idea?
The code in the repo saves the entire model with pickling, instead of the state dict, which is actually a bad practice.
A better way would be to save only the state_dict, and then load it.
Going to change it now to
state_dict = model.state_dict()
save(state_dict, 'model.chkpt')
Then you can load the model like this:
model = ModifiedVGG16Model()
checkpoint = torch.load(checkpoint_path, \
map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint)
model.eval()
@wuzhiyang2016 I've got the same problem. I added the ModifiedVGG16Model class in my test script, with no success. Do you still have the same issue?
like the author, when we load the saved model , ModifiedVGG16Model() should be defined
@jacobgil Hey! There is a problem with size mismatch. For example: RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
Is there way to solve it ?
@jacobgil Hey! There is a problem with size mismatch. For example:
RuntimeError: Error(s) in loading state_dict for ModifiedVGG16Model: size mismatch for features.0.weight: copying a param with shape torch.Size([50, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 3, 3]). size mismatch for features.0.bias: copying a param with shape torch.Size([50]) from checkpoint, the shape in current model is torch.Size([64]). size mismatch for features.2.weight: copying a param with shape torch.Size([44, 50, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
Is there way to solve it ?
@jacobgil : Since the pruned model will be having different in and out filters in each layer. Is there any way to load the pruned model state dictionary with original model class?
@ms-krajesh Probably not a solution but I just rewrote model architecture with altered layer numbers.