FirasGit/medicaldiffusion

Loss plateau and potential mode collapse in VAGAN training on MRnet dataset

WhenMelancholy opened this issue · 3 comments

When training VAGAN on the MRnet dataset, the loss stopped decreasing after a certain period of time and started to increase. We trained the model using the following parameters:

CUDA_VISIBLE_DEVICES=3 PL_TORCH_DISTRIBUTED_BACKEND=gloo PYTHONPATH=.:$PYTHONPATH python \
    train/train_vqgan.py \
    dataset=mrnet \
    dataset.root_dir="~/medicaldiffusion/data/MRNet-v1.0/" \
    model=vq_gan_3d \
    model.gpus=1 \
    model.default_root_dir="~/medicaldiffusion/when/checkpoints/vq_gan2" \
    model.default_root_dir_postfix="mrnet" \
    model.precision=32 \
    model.embedding_dim=8 \
    model.n_hiddens=16 \
    model.downsample=[4,4,4] \
    model.num_workers=32 \
    model.gradient_clip_val=1.0 \
    model.lr=3e-4 \
    model.discriminator_iter_start=10000 \
    model.perceptual_weight=4 \
    model.image_gan_weight=1 \
    model.video_gan_weight=1 \
    model.gan_feat_weight=4 \
    model.batch_size=2 \
    model.n_codes=16384 \
    model.accumulate_grad_batches=1 

The excerpt of abnormal loss changes during training is as follows:

......
Epoch 0:   1%|          | 4/565 [00:04<10:59,  1.18s/it, loss=2.48, v_num=0, train/perceptual_loss_step=2.650, train/recon_loss_step=2.390, train/aeloss_step=0.000, train/commitment_loss_step=0.00397, train/perplexity_step=8.32e+3, train/discloss_step=0.000]
......
Epoch 3:  44%|████▎     | 247/565 [02:13<02:52,  1.84it/s, loss=1.11, v_num=0, train/perceptual_loss_step=1.790, train/recon_loss_step=0.387, train/aeloss_step=0.000, train/commitment_loss_step=0.00353, train/perplexity_step=4.45e+3, train/discloss_step=0.000, val/recon_loss=0.478, val/perceptual_loss=1.600, val/perplexity=5.95e+3, val/commitment_loss=0.00394, train/perceptual_loss_epoch=1.690, train/recon_loss_epoch=0.481, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.00367, train/perplexity_epoch=5.88e+3, train/discloss_epoch=0.000]
......
Epoch 11:   7%|▋         | 41/565 [00:24<05:17,  1.65it/s, loss=4.18, v_num=0, train/perceptual_loss_step=1.160, train/recon_loss_step=0.319, train/aeloss_step=0.326, train/commitment_loss_step=0.00809, train/perplexity_step=5.84e+3, train/discloss_step=1.930, val/recon_loss=0.378, val/perceptual_loss=1.360, val/perplexity=6.32e+3, val/commitment_loss=0.00797, train/perceptual_loss_epoch=1.340, train/recon_loss_epoch=0.382, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.00779, train/perplexity_epoch=6.86e+3, train/discloss_epoch=0.000]
......
Epoch 14:  49%|████▉     | 278/565 [02:30<02:35,  1.85it/s, loss=5.4, v_num=0, train/perceptual_loss_step=1.440, train/recon_loss_step=0.397, train/aeloss_step=-.0933, train/commitment_loss_step=0.0129, train/perplexity_step=6.43e+3, train/discloss_step=1.940, val/recon_loss=0.422, val/perceptual_loss=1.600, val/perplexity=6.38e+3, val/commitment_loss=0.0126, train/perceptual_loss_epoch=1.640, train/recon_loss_epoch=0.429, train/aeloss_epoch=0.679, train/commitment_loss_epoch=0.012, train/perplexity_epoch=6.49e+3, train/discloss_epoch=1.660]
......

Is this caused by mode collapse in GAN? Or is it due to the training configuration? Are there any good methods to fix this? I would greatly appreciate any suggestions.

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

Hello, could you tell me what the GAN loss weights include?

@WhenMelancholy What worked for me was starting the discriminator after only 50000+ steps and also decreasing the GAN loss weights by a factor of 4-5. This will cause the discriminator to train a lot slower (so sample quality will quickly decrease for a couple thousand iterations after the discriminator starts training) but should recover after about 5-6k iterations and lead to a further improvement in sample quality. Also precision of 32 was needed for me in all cases, but you have that in your config already.

Hello, could you tell me what the GAN loss weights include?

Sorry but I can not access the environment I meet this problem before >_< I will close the issue.