Fine-tune VIT model with BYOL method
khawar-islam opened this issue · 0 comments
khawar-islam commented
I am fine-tuning the dataset on VIT using the below line
model = timm.create_model('vit_base_resnet50_384', pretrained=True, num_classes=7)
The accuracy is not that much good so I decided to integrate BYOL paper which is very easy to integrate with VIT.
https://github.com/lucidrains/byol-pytorch
Code
self.learner = BYOL(
model,
image_size=224,
hidden_layer=model.cls_token
)
optimizer = optim.SGD(model, weight_decay=.0005, momentum=.9, nesterov=args.nesterov, lr=args.lr)
def _do_epoch(self, epoch=None):
criterion = nn.CrossEntropyLoss()
self.model.train()
for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
self.device), d_idx.to(self.device)
self.optimizer.zero_grad()
data_flip = torch.flip(data, (3,)).detach().clone()
data = torch.cat((data, data_flip))
class_l = torch.cat((class_l, class_l))
class_logit = self.model(data, class_l, True, epoch)
class_loss = criterion(class_logit, class_l)
_, cls_pred = class_logit.max(dim=1)
loss = class_loss
loss.backward()
self.optimizer.step()
self.learner.update_moving_average() #byol code
self.logger.log(it, len(self.source_loader),
{"class": class_loss.item()},
{"class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
del loss, class_
Traceback:
Traceback (most recent call last):
File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 193, in <module>
main()
File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 187, in main
trainer = Trainer(args, device)
File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 87, in __init__
hidden_layer=model.cls_token
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 211, in __init__
self.forward(torch.randn(2, 3, image_size, image_size, device=device))
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 239, in forward
online_proj_one, _ = self.online_encoder(image_one)
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 149, in forward
representation = self.get_representation(x)
File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 134, in get_representation
if self.layer == -1:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous