lucidrains/byol-pytorch

Fine-tune VIT model with BYOL method

khawar-islam opened this issue · 0 comments

I am fine-tuning the dataset on VIT using the below line

model = timm.create_model('vit_base_resnet50_384', pretrained=True, num_classes=7)
The accuracy is not that much good so I decided to integrate BYOL paper which is very easy to integrate with VIT.
https://github.com/lucidrains/byol-pytorch

Code

self.learner = BYOL(
            model,
            image_size=224,
            hidden_layer=model.cls_token
        )
        optimizer = optim.SGD(model, weight_decay=.0005, momentum=.9, nesterov=args.nesterov, lr=args.lr) 

    def _do_epoch(self, epoch=None):
        criterion = nn.CrossEntropyLoss()

        self.model.train()
        for it, ((data, jig_l, class_l), d_idx) in enumerate(self.source_loader):
            data, jig_l, class_l, d_idx = data.to(self.device), jig_l.to(self.device), class_l.to(
                self.device), d_idx.to(self.device)
            self.optimizer.zero_grad()

            data_flip = torch.flip(data, (3,)).detach().clone()
            data = torch.cat((data, data_flip))
            class_l = torch.cat((class_l, class_l))

            class_logit = self.model(data, class_l, True, epoch)
            class_loss = criterion(class_logit, class_l)
            _, cls_pred = class_logit.max(dim=1)
            loss = class_loss

            loss.backward()
            self.optimizer.step()
            self.learner.update_moving_average() #byol code

            self.logger.log(it, len(self.source_loader),
                            {"class": class_loss.item()},
                            {"class": torch.sum(cls_pred == class_l.data).item(), }, data.shape[0])
            del loss, class_

Traceback:

Traceback (most recent call last):
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 193, in <module>
    main()
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 187, in main
    trainer = Trainer(args, device)
  File "/media/khawar/HDD_Khawar/RSC/Domain_Generalization/train.py", line 87, in __init__
    hidden_layer=model.cls_token
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 211, in __init__
    self.forward(torch.randn(2, 3, image_size, image_size, device=device))
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 239, in forward
    online_proj_one, _ = self.online_encoder(image_one)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 149, in forward
    representation = self.get_representation(x)
  File "/home/khawar/anaconda3/envs/RSC/lib/python3.7/site-packages/byol_pytorch/byol_pytorch.py", line 134, in get_representation
    if self.layer == -1:
RuntimeError: Boolean value of Tensor with more than one value is ambiguous