lucidrains/magvit2-pytorch

Reconstruction image is always a solid color

jpfeil opened this issue ยท 20 comments

jpfeil commented

Hello,

I've been working on training this on the imagenet data, but I'm concerned I'm doing something wrong because the reconstructions are always a solid color. I haven't trained it very long ~1500 steps (batch size 10), but I just wanted to check if this is expected.

1300 steps:
image

1200 steps:
image

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, 
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/imagenet/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 8,
    num_train_steps = 1_000_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=1e-4, # From the paper
    accelerate_kwargs={"split_batches": True, "mixed_precision": 'fp16'},
    random_split_seed=171,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={}
)

trainer.train()

@jpfeil could you retry with fp32? and train until 5000 steps? also, grad accum of 4-6 is sufficient (32-64 effective batch size)

@jpfeil also share your training curve, try out wandb's report feature for easy sharing

jpfeil commented

Thanks @lucidrains I'll let you know when the wandb report is ready.

jpfeil commented

@lucidrains This was run on 0.1.24, so I'm going to pull the latest version and retry. The loss was slowly improving, but around step 1000, the loss became nan. The only change I've made is I added a cosine schedule with warmup. I'm also still using bf16, so I'll change that in the next run.

https://api.wandb.ai/links/pfeiljx/p2x7x2x2

jpfeil commented

Hi @lucidrains

I ran it using fp32 and trained for 5000 steps, but I did not see any improvement.

https://api.wandb.ai/links/pfeiljx/8kqeyypi

Let me know if you have any suggestions.

jpfeil commented

@lucidrains I ran the fashion mnist data last night and the model was able to converge:

https://api.wandb.ai/links/pfeiljx/udspvdgu

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d %H:%M:%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, # From the paper,
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 5,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 {RUNTIME}'):
    trainer.train()

@jpfeil @jacobpfeil i think this repository should support pretraining with 2d conv layers, and then a way to convert it to 3d for video. but let me meditate on the simplest way to achieve this

jpfeil commented

Thanks @lucidrains. Let me know if I can help run some tests. I have access to a few A100 GPUs.

@jpfeil sounds good

let me think about this for a few days or the code will come out wrong

measure twice cut once kinda thing

jpfeil commented

@lucidrains After looking at the FashionMNIST results, it looks like the discriminator collapsed to zero loss. So, I think the learning stopped prematurely. I'm also not getting good reconstructions.

sampled 17

For VQ-GAN, I've read that the autoencoder needs a couple epochs to generate good images before the discriminator starts. Is there a way to do that here?

@jpfeil yea i could add that, but only if need be

what happens if you set adversarial_loss_weight to 0.

it really should converge for fashion mnist quite quickly, even without the GAN system

jpfeil commented

I get an assertion error because self.has_gan attribute gets set to False. Is it okay to override that assertion?

@jpfeil could you point to the line number?

could you also give 0.1.29 a quick try? may be a bug but not entirely sure

@jpfeil oh nvm, yes i see it. we should be able to turn off adversarial loss, let me fix

@jpfeil try 0.1.31 with use_gan = False on the VideoTokenizer

jpfeil commented

Woops. My Tokenizer change wasn't saved. Running now...

@jpfeil give the imagenet run another try

there may have been a bug with how I zeroed the gradients a few patches ago

jpfeil commented

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

Hi @jpfeil Do you mind sharing how did you end up solving it? I run into the same issue #25

jpfeil commented

Hi @coolbunnyx,

Sorry for the delay. I think you already solved it, but I was able to get good reconstruction after training for longer.