Confused when trying to run inference with trained model
Closed this issue · 2 comments
I trained the model with all default parameters on the mvtec single label data. Training finished and I can see that a file called snapshot.pt is saved in the results dir (not sure if this is correct since file size is only 31M). I tried loading this file as a state_dict for VGG19 base model and got the below error:
Traceback (most recent call last): File "/Users/pratikh/Desktop/anomaly_detection/fcdd/inf_fcdd.py", line 13, in <module> model.load_state_dict(state_dict["opt"]["state"]) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict load(self, state_dict) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load module._load_from_state_dict( File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2061, in _load_from_state_dict if key.startswith(prefix) and key != extra_state_key: AttributeError: 'int' object has no attribute 'startswith'
This is the code I am using:
`import sys
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image
model = models.vgg19(pretrained=False)
state_dict = torch.load(sys.argv[1])
model.load_state_dict(state_dict["opt"]["state"])
model.eval()
image = Image.open(sys.argv[2])
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)
if torch.cuda.is_available():
input_batch = input_batch.to("cuda")
model.to("cuda")
with torch.no_grad():
output = model(input_batch)
_, predicted = torch.max(output, 1)
print(predicted)`
What is the correct way to load the trained model for inference? Any help would be appreciated.
Hi. There's a script for running inference with a trained model: https://github.com/liznerski/fcdd/blob/master/python/fcdd/runners/run_prediction_with_snapshot.py. It uses the trainer class to load the model in line 79. Loading the model is defined here. As you can see, the network state dict is extracted with snapshot.pop('net', None)
. In your case that is state_dict['net']
rather than state_dict["opt"]["state"]
. I suspect that's the source of your error.
Thanks @liznerski , I was not aware there was a script to run inference. Ran the inference with the script, works as intended for me.