Running with GAN raises RuntimeError
jpfeil opened this issue · 8 comments
jpfeil commented
v0.1.32 works without the GAN, but I get an error when using the GAN again.
import torch
from datetime import datetime
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")
tokenizer = VideoTokenizer(
image_size = 32,
channels=1,
use_gan=True,
use_fsq=False,
codebook_size=2**13,
init_dim=64,
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'attend_space',
),
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
dataset_type = 'images', # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
batch_size = 10,
grad_accum_every = 5,
num_train_steps = 5_000,
num_frames=1,
max_grad_norm=1.0,
learning_rate=2e-5,
accelerate_kwargs={"split_batches": True, "mixed_precision": "bf16"},
random_split_seed=85,
optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
ema_kwargs={},
use_wandb_tracking=True,
checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
results_folder=f'./runs/{RUNTIME}/results',
)
with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
trainer.train()
Traceback (most recent call last):
File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
trainer.train()
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
self.train_step(dl_iter)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
loss, loss_breakdown = self.model(
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
return model_forward(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1832, in forward
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
File "<@beartype(magvit2_pytorch.magvit2_pytorch.grad_layer_wrt_loss) at 0x7fff42659900>", line 50, in grad_layer_wrt_loss
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 129, in grad_layer_wrt_loss
return torch_grad(
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/autograd/__init__.py", line 394, in grad
result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
lucidrains commented
jpfeil commented
Thanks, @lucidrains! I tried 0.1.33 but I got this error
Traceback (most recent call last):
3 File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
4 trainer.train()
5 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
6 self.train_step(dl_iter)
7 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
8 loss, loss_breakdown = self.model(
9 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
10 return self._call_impl(*args, **kwargs)
11 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
12 return forward_call(*args, **kwargs)
13 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
14 return model_forward(*args, **kwargs)
15 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
16 return convert_to_fp32(self.model_forward(*args, **kwargs))
17 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
18 return func(*args, **kwargs)
19 File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
20 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1849, in forward
21 fake_logits = self.discr(recon_video_frames)
22 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
23 return self._call_impl(*args, **kwargs)
24 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
25 return forward_call(*args, **kwargs)
26 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 641, in forward
27 x = block(x)
28 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
29 return self._call_impl(*args, **kwargs)
30 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
31 return forward_call(*args, **kwargs)
32 File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 545, in forward
33 res = self.conv_res(x)
34 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
35 return self._call_impl(*args, **kwargs)
36 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
37 return forward_call(*args, **kwargs)
38 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
39 return self._conv_forward(input, self.weight, self.bias)
40 File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
41 return F.conv2d(input, weight, bias, self.stride,
42 RuntimeError: Given groups=1, weight of size [512, 3, 1, 1], expected input[10, 1, 32, 32] to have 3 channels, but got 1 channels instead
lucidrains commented
lucidrains commented
@jpfeil you are seeing reconstructions being correct without adversarial training right?
jpfeil commented
lucidrains commented
nice! thank you!
jpfeil commented
The discriminator code runs. The discriminator loss converges to zero, but I'll open this in a different issue.
lucidrains commented
have you set the adversarial loss weight to be greater than 0?