lucidrains/magvit2-pytorch

Running with GAN raises RuntimeError

jpfeil opened this issue · 8 comments

jpfeil commented

v0.1.32 works without the GAN, but I get an error when using the GAN again.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "bf16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()
Traceback (most recent call last):
  File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
    trainer.train()
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
    self.train_step(dl_iter)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
    loss, loss_breakdown = self.model(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1832, in forward
    norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.grad_layer_wrt_loss) at 0x7fff42659900>", line 50, in grad_layer_wrt_loss
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 129, in grad_layer_wrt_loss
    return torch_grad(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/autograd/__init__.py", line 394, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@jpfeil nice! ok, so i forgot to handle the greyscale edge case for the perceptual loss

try 0.1.33?

jpfeil commented

Thanks, @lucidrains! I tried 0.1.33 but I got this error

Traceback (most recent call last):
 3   File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
 4     trainer.train()
 5   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
 6     self.train_step(dl_iter)
 7   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
 8     loss, loss_breakdown = self.model(
 9   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
10     return self._call_impl(*args, **kwargs)
11   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
12     return forward_call(*args, **kwargs)
13   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
14     return model_forward(*args, **kwargs)
15   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
16     return convert_to_fp32(self.model_forward(*args, **kwargs))
17   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
18     return func(*args, **kwargs)
19   File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
20   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1849, in forward
21     fake_logits = self.discr(recon_video_frames)
22   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
23     return self._call_impl(*args, **kwargs)
24   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
25     return forward_call(*args, **kwargs)
26   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 641, in forward
27     x = block(x)
28   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
29     return self._call_impl(*args, **kwargs)
30   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
31     return forward_call(*args, **kwargs)
32   File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 545, in forward
33     res = self.conv_res(x)
34   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
35     return self._call_impl(*args, **kwargs)
36   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
37     return forward_call(*args, **kwargs)
38   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
39     return self._conv_forward(input, self.weight, self.bias)
40   File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
41     return F.conv2d(input, weight, bias, self.stride,
42 RuntimeError: Given groups=1, weight of size [512, 3, 1, 1], expected input[10, 1, 32, 32] to have 3 channels, but got 1 channels instead 

@jpfeil you are seeing reconstructions being correct without adversarial training right?

jpfeil commented

Yeah, I found I needed to increase the codebook size and I can get reconstructions looking decent without adversarial training. The details aren't there yet, but the general features are encoded.

sampled 47

nice! thank you!

jpfeil commented

The discriminator code runs. The discriminator loss converges to zero, but I'll open this in a different issue.

have you set the adversarial loss weight to be greater than 0?