Problems regarding training 3D Vision transformer : model does not converge
Uljibuh opened this issue · 0 comments
Hi, this is my first time working on a transformer model, in this case, a 3D vision transformer model,
I am working on a 3d medical image classification task, and the training set is around 300 3D images; here is what image input looks like (1, 224, 224, 32); here, 1 is the number of channels, and 32 is the z dim size. I trained my data set on 3D efficientnet, and the accuracy was around 80%. I tried a 3D vision transformer, but the model does not converge. Can you please review the code below? Why does the model not learn? Do you know if I am doing something wrong? Do you have any help or suggestions? Thank you in advance.
This is the forward path:
`
def forward(self, img):
print("img,input shape before patch embedding", img.shape)
x = self.to_patch_embedding(img)
print("after patch embedding", x.shape)
b, n, _ = x.shape
#cls_tokens = self.cls_token.expand(b, -1, -1)
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
print("cls token shape", cls_tokens.shape)
x = torch.cat((cls_tokens, x), dim=1)
print("after cls_token", x.shape)
x += self.pos_embedding[:, :(n + 1)]
print("after position embedding", x.shape)
x = self.dropout(x)
x = self.transformer(x)
print("after transformer", x.shape)
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
x = self.to_latent(x)
print("after latent", x.shape)
return self.mlp_head(x) `
3D vision transformer model configuration:
`
ViTmodel = ViT(
image_size = 224,
image_patch_size = 16,
frames = 32,
frame_patch_size= 4,
num_classes = 2,
dim = 1024,
depth = 6,
heads = 2,
mlp_dim = 1024,
pool = 'cls',
channels = 1,
dim_head = 64,
dropout = 0.2,
emb_dropout = 0.1
)`
`
optimizer = optim.Adam(ViTmodel.parameters(), lr=0.002)
criterion = nn.CrossEntropyLoss().to(device)
`
Input shapes from the forward path:
`
img,input shape before patch embedding torch.Size([4, 1, 224, 224, 32])
after patch embedding torch.Size([4, 1568, 1024])
cls token shape torch.Size([4, 1, 1024])
after cls_token torch.Size([4, 1569, 1024])
after position embedding torch.Size([4, 1569, 1024])
after transformer torch.Size([4, 1569, 1024])
after latent torch.Size([4, 1024])
Input shape: torch.Size([4, 1, 224, 224, 32])
Output shape: torch.Size([4, 2])`
model training results:
`
Epoch 1/10 (Training): 100%|██████████| 56/56 [00:59<00:00, 1.07s/it]
Epoch 1/10, Training Loss: 0.5908904586519513, Training Accuracy: 0.7142857142857143
Epoch 1/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 1/10, Validation Loss: 0.5275474616459438, Validation Accuracy: 0.7798165137614679
Best model saved at epoch 1
Epoch 2/10 (Training): 100%|██████████| 56/56 [00:58<00:00, 1.04s/it]
Epoch 2/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 2/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 2/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 3/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 3/10, Training Loss: 0.5878153315612248, Training Accuracy: 0.7210884353741497
Epoch 3/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.40it/s]
Epoch 3/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679
Epoch 4/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 4/10, Training Loss: 0.5878153358186994, Training Accuracy: 0.7210884353741497
Epoch 4/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.41it/s]
Epoch 4/10, Validation Loss: 0.5329046036515918, Validation Accuracy: 0.7798165137614679
Epoch 5/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 5/10, Training Loss: 0.6034403315612248, Training Accuracy: 0.7210884353741497
Epoch 5/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 5/10, Validation Loss: 0.532904612166541, Validation Accuracy: 0.7798165137614679
Epoch 6/10 (Training): 100%|██████████| 56/56 [00:57<00:00, 1.03s/it]
Epoch 6/10, Training Loss: 0.5878153379474368, Training Accuracy: 0.7210884353741497
Epoch 6/10 (Validation): 100%|██████████| 14/14 [00:09<00:00, 1.44it/s]
Epoch 6/10, Validation Loss: 0.527547470160893, Validation Accuracy: 0.7798165137614679`