lucidrains/vit-pytorch

Problems regarding training 3D Vision transformer : model does not converge

Uljibuh opened this issue · 0 comments

Hi, this is my first time working on a transformer model, in this case, a 3D vision transformer model,

I am working on a 3d medical image classification task, and the training set is around 300 3D images; here is what image input looks like (1, 224, 224, 32); here, 1 is the number of channels, and 32 is the z dim size. I trained my data set on 3D efficientnet, and the accuracy was around 80%. I tried a 3D vision transformer, but the model does not converge. Can you please review the code below? Why does the model not learn? Do you know if I am doing something wrong? Do you have any help or suggestions? Thank you in advance.

This is the forward path:

`

    def forward(self, img):
     print("img,input shape before patch embedding", img.shape)

    x = self.to_patch_embedding(img)
    print("after patch embedding", x.shape)
    
    b, n, _ = x.shape
    #cls_tokens = self.cls_token.expand(b, -1, -1)
    cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
    
    print("cls token shape", cls_tokens.shape)
    x = torch.cat((cls_tokens, x), dim=1)
     
    print("after cls_token", x.shape)
    x += self.pos_embedding[:, :(n + 1)]
    
    print("after position embedding", x.shape)
    x = self.dropout(x)
    
    x = self.transformer(x)
    print("after transformer", x.shape)
    
    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    x = self.to_latent(x)
    
    print("after latent", x.shape)
    return self.mlp_head(x) `

3D vision transformer model configuration:

`

    ViTmodel = ViT(
    image_size = 224,
    image_patch_size = 16,
    frames = 32,
    frame_patch_size= 4,
    
    num_classes = 2,
    dim = 1024,
    depth = 6,
    heads = 2,
    mlp_dim = 1024,
    pool = 'cls',
    channels = 1,
    dim_head = 64,
    dropout = 0.2,
    emb_dropout = 0.1

)`

`
optimizer = optim.Adam(ViTmodel.parameters(), lr=0.002)

    criterion = nn.CrossEntropyLoss().to(device)

`

Input shapes from the forward path:

`

      img,input shape before patch embedding torch.Size([4, 1, 224, 224, 32])

       after patch embedding torch.Size([4, 1568, 1024])

       cls token shape torch.Size([4, 1, 1024])

       after cls_token torch.Size([4, 1569, 1024])

       after position embedding torch.Size([4, 1569, 1024])

       after transformer torch.Size([4, 1569, 1024])

       after latent torch.Size([4, 1024])

       Input shape: torch.Size([4, 1, 224, 224, 32])

       Output shape: torch.Size([4, 2])`

model training results:

`

Epoch 1/10 (Training): 100%|██████████| 56/56 [00:59<00:00, 1.07s/it]
Epoch 1/10, Training Loss: 0.5908904586519513, Training Accuracy: 0.7142857142857143
Epoch 1/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 1/10, Validation Loss: 0.5275474616459438, Validation Accuracy: 0.7798165137614679
Best model saved at epoch 1
Epoch 2/10 (Training): 100%|██████████| 56/56 [00:58<00:00, 1.04s/it]
Epoch 2/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 2/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 2/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 3/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 3/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 3/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 3/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679
Epoch 4/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 4/10, Training Loss: 0.5878153358186994, Training Accuracy: 0.7210884353741497
Epoch 4/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 4/10, Validation Loss: 0.5329046036515918, Validation Accuracy: 0.7798165137614679
Epoch 5/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 5/10, Training Loss: 0.6034403315612248, Training Accuracy: 0.7210884353741497
Epoch 5/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 5/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 6/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 6/10, Training Loss: 0.5878153379474368, Training Accuracy: 0.7210884353741497
Epoch 6/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 6/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679`