facebookresearch/swav

What are good final loss values (subset of ImageNet)

AnanyaKumar opened this issue · 3 comments

Hey! Great repository, and this has been really useful for my research.

I'm running Swav on subsets of ImageNet (that is say around 60-100 of the ImageNet classes). I ran for 200 epochs, and my loss value is around 4.1547. Does this sound reasonable?

On the other hand, when I train a linear probe on my checkpoint on this subset of ImageNet, the best train accuracy I get is around 88% (whereas if I use the swav 200 epoch checkpoint from this repository I get 99% train accuracy), so I wonder if I haven't trained enough or need to change the hyperparameters when pre-training on a smaller dataset (subset of ImageNet)?

Also, for some other subsets the loss just stays at 8.0. I did try reducing epsilon, but that didn't help. Bullet 2 says maybe tune some other hyperparameters, any suggestions on what to start with? Thanks a lot!

Hi @AnanyaKumar

Thanks a lot for your interest in this work ! Honestly, it's so motivating to know that this is actually useful for somebody !!

In my experience, there is not a strong correlation between the loss value and the transfer performance... In otherwords, it's not because the loss is lower for one run that the linear evaluation will be better.

If it's not training I would recommend trying to freeze the prototypes for more epochs. Also something that works well in my experiments is to start the training without multi-crop (only the two large, for example 224^2, are used at the beginning of training). Then once you see the loss is decreasing you can inject multi-crop. I haven't included this option in this codebase so that would require you some additional coding.

For example if n if the number of epochs to train without multi-crop you add before this line

embedding, output = model(inputs)

if epoch < n:
    inputs = inputs[:2]  # you keep only the two large views

and here

swav/main_swav.py

Lines 312 to 315 in 9a2dc80

for v in np.delete(np.arange(np.sum(args.nmb_crops)), crop_id):
x = output[bs * v: bs * (v + 1)] / args.temperature
subloss -= torch.mean(torch.sum(q * F.log_softmax(x, dim=1), dim=1))
loss += subloss / (np.sum(args.nmb_crops) - 1)
you would need to replace args.nmb_crops by len(inputs)

Hope that helps

Thank you so much for your help, we tried things based on your suggestions and they're looking much better now :)

Hi @AnanyaKumar, do you mind sharing plots of your loss per epoch?
Provide some details of your training setup, and dataset, etc.