liznerski/fcdd

Confused when trying to run inference with trained model

Closed this issue · 2 comments

I trained the model with all default parameters on the mvtec single label data. Training finished and I can see that a file called snapshot.pt is saved in the results dir (not sure if this is correct since file size is only 31M). I tried loading this file as a state_dict for VGG19 base model and got the below error:
Traceback (most recent call last): File "/Users/pratikh/Desktop/anomaly_detection/fcdd/inf_fcdd.py", line 13, in <module> model.load_state_dict(state_dict["opt"]["state"]) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2138, in load_state_dict load(self, state_dict) File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2120, in load module._load_from_state_dict( File "/Users/pratikh/venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2061, in _load_from_state_dict if key.startswith(prefix) and key != extra_state_key: AttributeError: 'int' object has no attribute 'startswith'

This is the code I am using:
`import sys
import torch
import torchvision.models as models
from torchvision import transforms
from PIL import Image

model = models.vgg19(pretrained=False)

state_dict = torch.load(sys.argv[1])
model.load_state_dict(state_dict["opt"]["state"])
model.eval()

image = Image.open(sys.argv[2])
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = preprocess(image)
input_batch = input_tensor.unsqueeze(0)

if torch.cuda.is_available():
input_batch = input_batch.to("cuda")
model.to("cuda")

with torch.no_grad():
output = model(input_batch)

_, predicted = torch.max(output, 1)
print(predicted)`

What is the correct way to load the trained model for inference? Any help would be appreciated.

Hi. There's a script for running inference with a trained model: https://github.com/liznerski/fcdd/blob/master/python/fcdd/runners/run_prediction_with_snapshot.py. It uses the trainer class to load the model in line 79. Loading the model is defined here. As you can see, the network state dict is extracted with snapshot.pop('net', None). In your case that is state_dict['net'] rather than state_dict["opt"]["state"]. I suspect that's the source of your error.

Thanks @liznerski , I was not aware there was a script to run inference. Ran the inference with the script, works as intended for me.