lucidrains/magvit2-pytorch

Question about Imagenet Parameters

jpfeil opened this issue · 4 comments

jpfeil commented

Hi @lucidrains ,

Thanks again for this great resource. I'm trying to get the training up and running on ImageNet, but I get a strange error midway through training. I was hoping you could take a quick look to see if I'm doing something that doesn't make sense. Thank you!

Traceback (most recent call last):
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/run/test-fashion-mnist.py", line 39, in <module>
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/trainer.py", line 431, in train
    self.train_step(dl_iter)
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/trainer.py", line 290, in train_step
    loss, loss_breakdown = self.model(
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x2ae9f33abb80>", line 53, in forward
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 1561, in forward
    x = self.encode(padded_video, cond = cond)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.encode) at 0x2ae9f33ab5e0>", line 53, in encode
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 1442, in encode
    x = self.conv_in(video)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 867, in forward
    return self.conv(x)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 610, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
    return F.conv3d(
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7, 7], expected input[1, 1, 10, 230, 230] to have 3 channels, but got 1 channels instead
srun: error: mg092: task 0: Exited with exit code 1

Here is the code I'm running:

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    init_dim = 64,
    max_dim = 512,
    channels=3,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
        'compress_time',
        ('consecutive_residual', 2),
        'compress_time',
        ('consecutive_residual', 2),
        'attend_time',
    )
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='imagenet/ILSVRC/Data/CLS-LOC/train/n01440764',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 1,
    grad_accum_every = 4,
    num_train_steps = 1_000
)

trainer.train()

@jpfeil i think there's a greyscale image in there (1 channel)

@jpfeil want to try 0.1.16?

jpfeil commented

Thanks, @lucidrains! It is working now.