facebookresearch/dino

Changing number of classes

Opened this issue · 0 comments

Hi,
I am trying to retrain the linear classifier by changing the number of classes. This is what I did.
First I trained the features using the following command

out_dir=./main_dino_output
python -m torch.distributed.launch --arch vit_small --data_path </path/to/my/datadir> --output_dir $out_dir --epochs 1000

Now that I have trained my features. I will run a linear classifier as follows

python   eval_linear.py --pretrained_weights $out_dir/checkpoint.pth --num_labels 5 --data_path $data_dir --epochs 500 --arch vit_small

This part executes properly. Now I would like to evaluate the trained algorithm. The trained model from this step is saved as ./checkpoint.pth.tar. So, I execute the above command with the --evaluate flag turned on

python   eval_linear.py  --evaluate --pretrained_weights ./checkpoint.pth.tar --num_labels 5 --data_path $data_dir 

However, in this case I get the following error:
image

The model throws an error as it is expecting 1000 classes and not 5.

size mismatch for module.linear.weight: copying a param with shape torch.Size([1000, 1536]) from checkpoint, the shape in current model is torch.Size([5, 1536]).
size mismatch for module.linear.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([5]).

When I look at the code

dino/eval_linear.py

Lines 79 to 83 in 7c446df

if args.evaluate:
utils.load_pretrained_linear_weights(linear_classifier, args.arch, args.patch_size)
test_stats = validate_network(val_loader, model, linear_classifier, args.n_last_blocks, args.avgpool_patchtokens)
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return

It is downloading new weights and trying to run the evaluation on the new weights.
I think this is a bug and if I am providing the weights, it should not download the weights as it is doing in the utils.load_pretrained_linear_weights.
When I comment out Line 80 in the eval_linear.py file, the code works fine.

Is this the right thing to do. Please let me know.

P.S.: I know that positing screenshots is generally not the norm. But I wanted to show that it is downloading new weights.