Discriminator loss converges to zero early in training
jpfeil opened this issue · 9 comments
I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN:
https://api.wandb.ai/links/pfeiljx/f7wdueh0
Do you have any suggestions for improving training?
I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?
I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?
After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.
I'm just brainstorming, so feel free to address or ignore any of these comments.
import torch
from datetime import datetime
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")
tokenizer = VideoTokenizer(
image_size = 32,
channels=1,
use_gan=True,
use_fsq=False,
codebook_size=2**13,
init_dim=64,
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'attend_space',
),
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
dataset_type = 'images', # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
batch_size = 10,
grad_accum_every = 5,
num_train_steps = 5_000,
num_frames=1,
max_grad_norm=1.0,
learning_rate=2e-5,
accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
random_split_seed=85,
optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
ema_kwargs={},
use_wandb_tracking=True,
checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
results_folder=f'./runs/{RUNTIME}/results',
)
with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
trainer.train()
@jpfeil can you screenshot the paper section where they propose delaying the discriminator training? (and link the paper too)
@jpfeil do you have adversarial_loss_weight
greater than 0.? also try another run where your perceptual_loss_weight
is 0.1
Thanks @lucidrains. I'll try again with those parameters. I saw it in the taming implementation here:
https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/losses/vqperceptual.py#L51
@jpfeil welp.. whatever Robin and Patrick does goes; they are the best in the world.
let me add that
@jpfeil you don't happen to have relatives in Massachusetts, do you?
@lucidrains Nice. Let me try it out again. No, I don't have any relatives in Massachusetts. Did you meet someone with the last name Pfeil?
yea, I knew someone back in high school with the Pfeil family name. Tragedy struck and they moved away though. You are the second Pfeil I've met!
That's amazing. It's not a common name. Sorry to hear about your friend.