About DinoV2+BoQ training and released checkpoint
Closed this issue · 3 comments
Hi, Amar.
Is the released DinoV2 weights selected on validation data (ie: pitts30k_val) or the last epoch? Using valiation on pitts30K,it seems the R1 converges around epoch 20. But the batch accuarcy (b_accu) is still increasing at epoch 40, not converging.
I tested on some indoor data, the released boq+dinov2 performs exceptional, better than Salad/CricaVPR. But my reproduced version is below the released model(all training params followed #6). I have the implemented in gsv-cities repo, here https://github.com/hit2sjtu/gsv-cities/pull/1/files.
Am I missing something?
Hello @hit2sjtu,
Can you share the performance you're getting?
I just checked your code, one thing I noticed is that you're using weight_decay=0.0 (I think in the first verion of the code I stated 0.0 for adam, someting I have read in a paper). But for AdamW, the best weight_decay is something between 0.0001 and 0.01 (I personnaly use 0.001 for ResNet and for DinoV2).
I checkpoint weights with the best performance on msls-val (not pitts30k-val but they most of the time converge together). I have trained dozens of models, with slightly different config (lr, weight_decay, augmentation) and they all get to ~92.5 recall@1 on msls-val and 94.5 R@1 on pitts30k-val (with 280x280 images).
The last models I trained, I was using a smaller learning rate on the DinoV2 backbone (0.2*lr).
...
elif self.optimizer.lower() == 'adamw':
optimizer_params = [
{"params": self.backbone.parameters(), "lr": self.lr*0.2, "weight_decay" : self.weight_decay},
{"params": self.aggregator.parameters(),"lr": self.lr, "weight_decay" : self.weight_decay},
]
optimizer = torch.optim.AdamW(optimizer_params)
Now that you have two parameter groups, you have to keep an eye on the warmup, and take this change into account:
def optimizer_step(self, epoch, batch_idx,
optimizer, optimizer_closure
):
# warm up lr
if self.trainer.global_step < self.warmpup_steps:
lr_scale = min(1., float(self.trainer.global_step + 1) / self.warmpup_steps)
for pg in optimizer.param_groups:
# pg['lr'] = lr_scale * self.lr # old
pg['lr'] = lr_scale * pg["initial_lr"] # new
optimizer.step(closure=optimizer_closure)
One other parameter I played with is the miner_margin
of the multisimilarity loss. It has a small performance effect on sequence datasets (especially Nordland), using a smaller miner_margin=0.05
can enhance performance for some datasets (small miner_margin
means fewer but harder pairs in the batch, which may hinder performance if too small). You can try with 0.075, 0.05 and 0.025. This is a hyperparameters that doesn't get saved so I can't say which value I used.
Here are the hparams that got saved when i trained the model in this repo:
lr: 0.0002
optimizer: adamw
weight_decay: 0.001
warmup_steps: 3900
milestones:
- 10
- 20
- 30
lr_mult: 0.1
batch_size: 160
img_per_place: 4
min_img_per_place: 4
shuffle_all: false
image_size:
- 280
- 280
random_sample_from_each_place: true
Finally, always use BICUBIC interpolation when possible:
T.Resize(image_size, interpolation = T.InterpolationMode.BICUBIC, antialias=True),
My training converge at 94.7%-95.3% for pitts30K_val, slightly higher than you mentioned (94.5%) but they generalize worse on my own indoor dataset. I will make the changes as explained and keep you updated. Thanks again for the detailed info.
Hello @hit2sjtu,
Can you share the performance you're getting?
I just checked your code, one thing I noticed is that you're using weight_decay=0.0 (I think in the first verion of the code I stated 0.0 for adam, someting I have read in a paper). But for AdamW, the best weight_decay is something between 0.0001 and 0.01 (I personnaly use 0.001 for ResNet and for DinoV2).
I checkpoint weights with the best performance on msls-val (not pitts30k-val but they most of the time converge together). I have trained dozens of models, with slightly different config (lr, weight_decay, augmentation) and they all get to ~92.5 recall@1 on msls-val and 94.5 R@1 on pitts30k-val (with 280x280 images).
The last models I trained, I was using a smaller learning rate on the DinoV2 backbone (0.2*lr).
... elif self.optimizer.lower() == 'adamw': optimizer_params = [ {"params": self.backbone.parameters(), "lr": self.lr*0.2, "weight_decay" : self.weight_decay}, {"params": self.aggregator.parameters(),"lr": self.lr, "weight_decay" : self.weight_decay}, ] optimizer = torch.optim.AdamW(optimizer_params)
Now that you have two parameter groups, you have to keep an eye on the warmup, and take this change into account:
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure ): # warm up lr if self.trainer.global_step < self.warmpup_steps: lr_scale = min(1., float(self.trainer.global_step + 1) / self.warmpup_steps) for pg in optimizer.param_groups: # pg['lr'] = lr_scale * self.lr # old pg['lr'] = lr_scale * pg["initial_lr"] # new optimizer.step(closure=optimizer_closure)
One other parameter I played with is the
miner_margin
of the multisimilarity loss. It has a small performance effect on sequence datasets (especially Nordland), using a smallerminer_margin=0.05
can enhance performance for some datasets (smallminer_margin
means fewer but harder pairs in the batch, which may hinder performance if too small). You can try with 0.075, 0.05 and 0.025. This is a hyperparameters that doesn't get saved so I can't say which value I used.Here are the hparams that got saved when i trained the model in this repo:
lr: 0.0002 optimizer: adamw weight_decay: 0.001 warmup_steps: 3900 milestones: - 10 - 20 - 30 lr_mult: 0.1 batch_size: 160 img_per_place: 4 min_img_per_place: 4 shuffle_all: false image_size: - 280 - 280 random_sample_from_each_place: true
Finally, always use BICUBIC interpolation when possible:
T.Resize(image_size, interpolation = T.InterpolationMode.BICUBIC, antialias=True),
I can largely close most of the gap with your suggestions above. I am closing the issue for now. @amaralibey