IcarusWizard/MAE

Unable to reconstruct a distinguishable image

liuziqi opened this issue · 5 comments

Sorry to bother you, I tried to reconstruct the single-channel CT image (512x512), but the MSE loss did not decrease any more and remained at 0.03 when the epoch was 1000, the reconstructed image quality was very poor, what was the cause of this problem? I'm not familiar with Transformer, so I just tried adjusting the patch size (2->16), embedding dimensions(192->768) and the number of encoder/decoder heads (12).

Hi, @liuziqi. I am not familiar with CT images. But the idea of MAE should be generalizable to other images.

Could you provide more details of your data (e.g. how many images do you have as the training set) and training parameters (e.g. batch size, warmup epoch etc.), and which part of the code did you modify? So that we can try to debug what is going wrong.

训练集2200张图片,大小1x512x512,像素值在[0,1],其他参数设置成下面这样

  parser.add_argument('--seed', type=int, default=42)
  parser.add_argument('--batch_size', type=int, default=32)
  parser.add_argument('--base_learning_rate', type=float, default=1.5e-5)
  parser.add_argument('--weight_decay', type=float, default=0.05)
  parser.add_argument('--total_epoch', type=int, default=1001)
  parser.add_argument('--warmup_epoch', type=int, default=40)
  parser.add_argument('--model_path', type=str, default='mae.pt')
  parser.add_argument('--gpu', type=str, default='1')
  # model
  parser.add_argument('--image_size', type=int, default=512)
  parser.add_argument('--in_channel', type=int, default=1)
  parser.add_argument('--patch_size', type=int, default=16)
  parser.add_argument('--emb_dim', type=int, default=768)
  parser.add_argument('--encoder_layer', type=int, default=12)
  parser.add_argument('--encoder_head', type=int, default=12)
  parser.add_argument('--decoder_layer', type=int, default=8)
  parser.add_argument('--decoder_head', type=int, default=16)
  parser.add_argument('--mask_ratio', type=float, default=0.75)

损失下降过程
image
重构的图像
image
我在Encoder的forward里加了对patch的标准化,这样loss在刚开始时比较小(好像是不加的话开始loss是0.8,加了0.26),但对最后的收敛和重构结果没有影响

def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')    # h*w是patch个数,c是嵌入维度
        # 标准化
        patches = (patches - patches.mean(dim=-1, keepdim=True)) / (patches.var(dim=-1, unbiased=True, keepdim=True).sqrt() + 1e-6)
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.transformer(patches)
        features = self.layer_norm(features)
        features = rearrange(features, 'b t c -> t b c')

Your result is very interesting.

I think the main issue here is that the CT image is much more complex than the images in ImageNet and Cifar. I think you should set the mask ratio to a smaller value, maybe 0.25 or even smaller.

Besides the main issue, I think the batch size and learning rate are quite low. Keep the base learning rate as the default value. If you set a small batch size due to the memory capacity, just set the batch size to the value you want to train on (maybe 256 is ok), and max_device_batch_size to the max batch size your device can handle. The code will accumulate gradient from multiple gradient steps. Since your dataset is small, it is better to train for more epochs, say 10000 epochs and 1000 warmup.

One more thing, I think the network you are testing with is a bit heavy. For debugging purposes, maybe ViT-S and ViT-B with patch size of 32 is more suitable. You can use a larger model after you see progress.

Hope these help.

Thanks for your help, it seems to work. The loss is still decreasing after 800 epoches

    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--batch_size', type=int, default=240)
    parser.add_argument('--max_device_batch_size', type=int, default=12)
    parser.add_argument('--base_learning_rate', type=float, default=1.5e-4)
    parser.add_argument('--weight_decay', type=float, default=0.05)
    parser.add_argument('--total_epoch', type=int, default=2001)
    parser.add_argument('--warmup_epoch', type=int, default=200)
    parser.add_argument('--model_path', type=str, default='mae.pt')
    parser.add_argument('--gpu', type=str, default='1')

    # model
    parser.add_argument('--image_size', type=int, default=512)
    parser.add_argument('--in_channel', type=int, default=1)
    parser.add_argument('--patch_size', type=int, default=16)
    parser.add_argument('--encoder_emb_dim', type=int, default=192)
    parser.add_argument('--encoder_layer', type=int, default=12)
    parser.add_argument('--encoder_head', type=int, default=3)
    parser.add_argument('--decoder_emb_dim', type=int, default=512)
    parser.add_argument('--decoder_layer', type=int, default=8)
    parser.add_argument('--decoder_head', type=int, default=16)
    parser.add_argument('--mask_ratio', type=float, default=0.75)

image
image

Great to see it works!

Turns out the large batch size and learning rate play a more important role here. A bit surprising that mask_ratio=0.75 still works for such a complex structure. Amazing!