On DEIT distilled network
kspruthviraj opened this issue · 2 comments
Hi Hila,
Thanks for sharing the script to visualize the attention maps.
I am trying to run your DEIT example on custom model (DEIT-base distilled network with 19 classes), but so far have been unsuccessful. I keep getting this error "AttributeError: 'VisionTransformer' object has no attribute 'relprop'"
Here is my saved model weights and biases:
https://drive.switch.ch/index.php/s/dimybgHdzyE90gB
This is how I first load the model from Timm:
basemodel = timm.create_model('deit_base_distilled_patch16_224', pretrained=True, num_classes=19)
model = basemodel
Then I load trained weights and biases from my model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model)
model.to(device)
criterion = nn.CrossEntropyLoss()
torch.cuda.set_device(0)
model.cuda(0)
criterion = criterion.cuda(0)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5,weight_decay=3e-5)
PATH = checkpoint_path+'/trained_model.pth'. # Saved model path -- Shared in the link earlier.
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
When I try to use this model in your example code, I keep getting this error "AttributeError: 'VisionTransformer' object has no attribute 'relprop'"
I was wondering if you would have time to take a look at the uploaded model and see whether you can generate attention map ?
Thanks a lot
Hi @kspruthviraj, thanks for your interest in our work!
Please notice that I do not use the out of the box implementation of ViT. My code contains a modified implementation which adds a relevance propagation function for each layer in the network.
Thus, when you load your weights, load them to the model implemented in this repo to get LRP propagation working.
Best,
Hila
Hi @hila-chefer ,
Thanks for getting back to me. Okay now I get it.
I guess the models implemented in this repo are DEIT-smal and DEIT-base networks and it does not have DEIT-BASE-DISTILLED. Since, I am using the DEIT-base-Distilled network I might not be able to load the weights to the models implemented in to this repo directly.
Best,
Sreenath