lucidrains/magvit2-pytorch

Discriminator loss converges to zero early in training

jpfeil opened this issue · 9 comments

jpfeil commented

I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN:
https://api.wandb.ai/links/pfeiljx/f7wdueh0

Do you have any suggestions for improving training?

I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?

I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?

After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.

I'm just brainstorming, so feel free to address or ignore any of these comments.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()

@jpfeil can you screenshot the paper section where they propose delaying the discriminator training? (and link the paper too)

@jpfeil do you have adversarial_loss_weight greater than 0.? also try another run where your perceptual_loss_weight is 0.1

@jpfeil welp.. whatever Robin and Patrick does goes; they are the best in the world.

let me add that

@jpfeil ok, added that same functionality here. try removing the learning rate schedule in your next run too, shouldn't need it for something this easy

@jpfeil you don't happen to have relatives in Massachusetts, do you?

jpfeil commented

@lucidrains Nice. Let me try it out again. No, I don't have any relatives in Massachusetts. Did you meet someone with the last name Pfeil?

yea, I knew someone back in high school with the Pfeil family name. Tragedy struck and they moved away though. You are the second Pfeil I've met!

jpfeil commented

That's amazing. It's not a common name. Sorry to hear about your friend.