yaoppeng/U-Net_v2

[Question] Training on custom dataset

Closed this issue · 7 comments

Dear Author,
Thank you for interesting repo.

I have a question about training on custom dataset.
I am using binary image for segmentation. (1 class)
And I am using dice loss (1-Dice) as the loss function.
But when I train the model with the code below, loss returns a negative number from 1 epoch.

model = UNetV2(n_classes=1, deep_supervision=True, pretrained_path='./pretrained/pvt_v2_b2.pth')
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
    model.train()  
    loss_batch, val_loss_batch = [], []
    
    # train
    for inputs, labels in tqdm.tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}'):
        inputs, labels = inputs.to(device).float(), labels.to(device).float()
        # opt zerograd -> model output -> loss -> backward -> opt step
        optimizer.zero_grad()
        outputs = model(inputs)[::-1][-1] # shape
        outputs = torch.sigmoid(outputs) 
        loss = Dice_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        loss_batch.append(loss.detach())
        
    loss_batch = (torch.stack(loss_batch)).mean()
    loss_history.append(loss_batch.cpu())

    print(f"Epoch [{epoch + 1}/{num_epochs}], Training Loss: {loss_batch}")
    #val
labels.shape: (8,1,256,256) #8 is batch size
outputs[::-1][-1].shape:  (8,1,256,256)
outputs[::-1][-2].shape:  (8,1,128,128)

Like original U-Net, I used the model output as model(inputs)[::-1][-1] to match the shape to the labels.
Could you please check if there is a problem in the code?

Could you please show the code of Dice_loss class?

Sure.
Here is DiceLoss class.

class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1e-5):
        
        # inputs = F.sigmoid(inputs) # if sigmoid outputs: comment
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice 
    
Dice_loss = DiceLoss()

I have an additional question.
When defining a model, let's assume we use pvt_v2_b2.pth as a pretrained model.
If so, is the pyramid vision transformer(backbone model), trainable or non-trainable?
If it is non-trainable, can it be changed to trainable?

Hi, I made a small demo using your code:

model = UNetV2(n_classes=1, deep_supervision=True, pretrained_path='./pretrained/pvt_v2_b2.pth')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

x = torch.rand((4, 3, 256, 256)).float().cuda()
labels = torch.randint(0, 2, (4, 256, 256)).cuda()
# print(labels.max())
for epoch in range(20):
    optimizer.zero_grad()
    outputs = model(x)[::-1][-1]  # shape
    outputs = torch.sigmoid(outputs)
    loss = Dice_loss(outputs, labels)
    print("epoch", epoch, "loss:", loss.detach())
    loss.backward()
    optimizer.step()

The following is the loss:

epoch 0 loss: tensor(0.4868, device='cuda:0')
epoch 1 loss: tensor(0.4873, device='cuda:0')
epoch 2 loss: tensor(0.4827, device='cuda:0')
epoch 3 loss: tensor(0.4767, device='cuda:0')
epoch 4 loss: tensor(0.4703, device='cuda:0')
epoch 5 loss: tensor(0.4646, device='cuda:0')
epoch 6 loss: tensor(0.4604, device='cuda:0')
epoch 7 loss: tensor(0.4579, device='cuda:0')
epoch 8 loss: tensor(0.4538, device='cuda:0')
epoch 9 loss: tensor(0.4686, device='cuda:0')
epoch 10 loss: tensor(0.4423, device='cuda:0')
epoch 11 loss: tensor(0.4363, device='cuda:0')
epoch 12 loss: tensor(0.4245, device='cuda:0')
epoch 13 loss: tensor(0.4132, device='cuda:0')
epoch 14 loss: tensor(0.3998, device='cuda:0')
epoch 15 loss: tensor(0.3861, device='cuda:0')
epoch 16 loss: tensor(0.3743, device='cuda:0')
epoch 17 loss: tensor(0.3641, device='cuda:0')
epoch 18 loss: tensor(0.3586, device='cuda:0')
epoch 19 loss: tensor(0.3505, device='cuda:0')

You may check your input data, especially the label.

The parameters of the backbone are trainable by default.

@yaoppeng

Thank you for your answering.
There was an error in the label in normalization. I modified it, But the loss is still not falling.

Following is shape of input and label.

for inputs, labels in train_dataloader:
    print("Input shape:", inputs.shape)
    print("Label shape:", labels.shape)
    print("Input max value:", inputs.max())
    print("Label max value:", labels.max())
    print("Input min value:", inputs.min())
    print("Label min value:", labels.min())
    break  
Input shape: torch.Size([8, 3, 256, 256])
Label shape: torch.Size([8, 1, 256, 256])
Input max value: tensor(1., dtype=torch.float64)
Label max value: tensor(1., dtype=torch.float64)
Input min value: tensor(0., dtype=torch.float64)
Label min value: tensor(0., dtype=torch.float64)

Following is result in our custom dataset.

Epoch 1/500: 100%|██████████| 473/473 [01:27<00:00,  5.40it/s]
Epoch [1/500], Training Loss: 0.9878846406936646
Epoch [1/500], Validation Loss: 0.9915483593940735
Epoch 2/500: 100%|██████████| 473/473 [01:25<00:00,  5.55it/s]
Epoch [2/500], Training Loss: 0.9878699779510498
Epoch [2/500], Validation Loss: 0.9842602610588074
Epoch 3/500: 100%|██████████| 473/473 [01:26<00:00,  5.48it/s]
Epoch [3/500], Training Loss: 0.9864965081214905
Epoch [3/500], Validation Loss: 0.9903872609138489
Epoch 4/500: 100%|██████████| 473/473 [01:27<00:00,  5.39it/s]
Epoch [4/500], Training Loss: 0.9860386848449707
Epoch [4/500], Validation Loss: 0.9844949245452881
Epoch 5/500: 100%|██████████| 473/473 [01:27<00:00,  5.39it/s]
Epoch [5/500], Training Loss: 0.9862862229347229
Epoch [5/500], Validation Loss: 0.9839692711830139
Epoch 6/500: 100%|██████████| 473/473 [01:28<00:00,  5.36it/s]
Epoch [6/500], Training Loss: 0.9860971570014954
Epoch [6/500], Validation Loss: 0.98335862159729
Epoch 7/500: 100%|██████████| 473/473 [01:28<00:00,  5.33it/s]
Epoch [7/500], Training Loss: 0.9881670475006104
Epoch [7/500], Validation Loss: 0.9861281514167786
Epoch 8/500: 100%|██████████| 473/473 [01:29<00:00,  5.30it/s]
Epoch [8/500], Training Loss: 0.9872145652770996
Epoch [8/500], Validation Loss: 0.9875427484512329
Epoch 9/500: 100%|██████████| 473/473 [01:28<00:00,  5.36it/s]
Epoch [9/500], Training Loss: 0.987337052822113
Epoch [9/500], Validation Loss: 0.9868549108505249
Epoch 10/500: 100%|██████████| 473/473 [01:25<00:00,  5.51it/s]
Epoch [10/500], Training Loss: 0.9872326850891113
Epoch [10/500], Validation Loss: 0.9850394129753113
Epoch 11/500: 100%|██████████| 473/473 [01:28<00:00,  5.34it/s]
Epoch [11/500], Training Loss: 0.986293375492096
Epoch [11/500], Validation Loss: 0.9838315844535828
Epoch 12/500: 100%|██████████| 473/473 [01:28<00:00,  5.34it/s]
Epoch [12/500], Training Loss: 0.9871325492858887
Epoch [12/500], Validation Loss: 0.9865005016326904
Epoch 13/500: 100%|██████████| 473/473 [01:25<00:00,  5.54it/s]
Epoch [13/500], Training Loss: 0.9944794178009033
Epoch [52/500], Training Loss: 0.9969722032546997

Also, I ran the code you provided. But it doesn't look like the loss is falling.

epoch 0 loss: tensor(0.5357, device='cuda:0')
epoch 1 loss: tensor(0.5233, device='cuda:0')
epoch 2 loss: tensor(0.5742, device='cuda:0')
epoch 3 loss: tensor(0.5955, device='cuda:0')
epoch 4 loss: tensor(0.6120, device='cuda:0')
epoch 5 loss: tensor(0.6107, device='cuda:0')
epoch 6 loss: tensor(0.5600, device='cuda:0')
epoch 7 loss: tensor(0.5793, device='cuda:0')
epoch 8 loss: tensor(0.5772, device='cuda:0')
epoch 9 loss: tensor(0.6173, device='cuda:0')
epoch 10 loss: tensor(0.6185, device='cuda:0')
epoch 11 loss: tensor(0.5604, device='cuda:0')
epoch 12 loss: tensor(0.6200, device='cuda:0')
epoch 13 loss: tensor(0.6161, device='cuda:0')
epoch 14 loss: tensor(0.6220, device='cuda:0')
epoch 15 loss: tensor(0.6470, device='cuda:0')
epoch 16 loss: tensor(0.5584, device='cuda:0')
epoch 17 loss: tensor(0.6066, device='cuda:0')
epoch 18 loss: tensor(0.6703, device='cuda:0')
epoch 19 loss: tensor(0.5985, device='cuda:0')
epoch 20 loss: tensor(0.6904, device='cuda:0')
epoch 21 loss: tensor(0.5614, device='cuda:0')
epoch 22 loss: tensor(0.6613, device='cuda:0')
epoch 23 loss: tensor(0.6489, device='cuda:0')
epoch 24 loss: tensor(0.6517, device='cuda:0')
epoch 25 loss: tensor(0.5074, device='cuda:0')
epoch 26 loss: tensor(0.7205, device='cuda:0')
epoch 27 loss: tensor(0.5944, device='cuda:0')
epoch 28 loss: tensor(0.5240, device='cuda:0')
epoch 29 loss: tensor(0.7012, device='cuda:0')
epoch 30 loss: tensor(0.6493, device='cuda:0')
epoch 31 loss: tensor(0.7371, device='cuda:0')
epoch 32 loss: tensor(0.5982, device='cuda:0')
epoch 33 loss: tensor(0.5792, device='cuda:0')
epoch 34 loss: tensor(0.6704, device='cuda:0')
epoch 35 loss: tensor(0.5537, device='cuda:0')
epoch 36 loss: tensor(0.7473, device='cuda:0')
epoch 37 loss: tensor(0.6773, device='cuda:0')
epoch 38 loss: tensor(0.6783, device='cuda:0')
epoch 39 loss: tensor(0.6772, device='cuda:0')
epoch 40 loss: tensor(0.5064, device='cuda:0')
epoch 41 loss: tensor(0.6455, device='cuda:0')
epoch 42 loss: tensor(0.5213, device='cuda:0')
epoch 43 loss: tensor(0.7634, device='cuda:0')
epoch 44 loss: tensor(0.7001, device='cuda:0')
epoch 45 loss: tensor(0.6983, device='cuda:0')
epoch 46 loss: tensor(0.7683, device='cuda:0')
epoch 47 loss: tensor(0.6623, device='cuda:0')
epoch 48 loss: tensor(0.5896, device='cuda:0')
epoch 49 loss: tensor(0.6309, device='cuda:0')

When first defining UNetV2 class, The following messages are printed:
Do these messages indicate a fatal error?

./UNet_v2/unet_v2/pvtv2.py:388: UserWarning: Overwriting pvt_v2_b0 in registry with unet_v2.pvtv2.pvt_v2_b0. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
  class pvt_v2_b0(PyramidVisionTransformerImpr):
(skip)
./UNet_v2/unet_v2/pvtv2.py:431: UserWarning: Overwriting pvt_v2_b5 in registry with unet_v2.pvtv2.pvt_v2_b5. This is because the name being registered conflicts with an existing name. Please check if this is not expected.
  class pvt_v2_b5(PyramidVisionTransformerImpr):

Thanks.

Sometimes, this may be caused by the GPU, CUDA, and PyTorch versions. You can try the demo on your CPU first and see, or reinstall these software.

@yaoppeng
First, I upgraded the PyTorch version, and the problem of loss not falling even on the CPU when using random tensor was resolved. (1.11.x -> 2.)

However, in our dataset, the loss still did not fall in CPU.

using pretrained file: ./pretrained/pvt_v2_b2.pth
Epoch 1/5: 100%|██████████| 817/817 [15:58<00:00,  1.17s/it]
Epoch [1/5], Training Loss: 0.9854771494865417
Epoch [1/5], Validation Loss: 0.9855693578720093
Epoch 2/5: 100%|██████████| 817/817 [16:20<00:00,  1.20s/it]
Epoch [2/5], Training Loss: 0.9854830503463745
Epoch [2/5], Validation Loss: 0.9855701923370361
Epoch 3/5: 100%|██████████| 817/817 [16:07<00:00,  1.18s/it]
Epoch [3/5], Training Loss: 0.985472559928894
Epoch [3/5], Validation Loss: 0.9855678677558899
Epoch 4/5: 100%|██████████| 817/817 [16:35<00:00,  1.22s/it]
Epoch [4/5], Training Loss: 0.9854772090911865
Epoch [4/5], Validation Loss: 0.9855728149414062
Epoch 5/5: 100%|██████████| 817/817 [15:58<00:00,  1.17s/it]
Epoch [5/5], Training Loss: 0.9854697585105896
Epoch [5/5], Validation Loss: 0.9855706691741943

We split some of our custom multi-modality dataset into one modality dataset and compared it to a simple U-Net baseline.

Looking at the results, it is difficult to say that U-Net v2 shows good results even compared to the simple U-Net baseline.

image

Although not attached here, U-Net v2 did not show better results than baselline in other easy datasets.

Do you have any insights to improve the performance of UNetV2?

Following is code of simple U-Net baseline.

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True): 
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr

        self.enc1_1 = CBR2d(in_channels=3, out_channels=64)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)
    
    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)

        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x

I am sorry to hear that. I am not familiar with your experimental details.

Following is one of my experiments.

experiments

You'd better check the validation DSC or accuracy, whichever you may value more. Sometimes, a smaller loss gives you a lower test score. I would recommend adjusting the channels, trying other CNN backbones (probably the same backbone), choosing the proper learning rate, using learning rate decay, and training for more epochs, etc.

Another aspect is that U-Net v2 used the 4× downsampled result as the output, because I found it won't affect the result too much but reduce computation on my dataset:

for i, o in enumerate(seg_outs):
    seg_outs[i] = F.interpolate(o, scale_factor=4, mode='bilinear')

On your own dataset, this may be different. Your UNet used all the resolutions. So you may need to adjust the code of U-Net v2, so as to use all the resolutions.