Smaller images
snarb opened this issue · 10 comments
I want to try to train it on the dataset of small 32 by 32. This should make training relatively fast and is ok for my task. I have changed the dataset code. Can you please suggest how
to estimate correct hparams and what I will need to change. Looks like it is not enough to change image_size and patch_size in enhancing.modules.stage1.vitvqgan.ViTVQ config.
If your dataset is only images then you can use attached config below. Remember that change image size to your desired size (in your case is 32) and your dataset path must follow this structure:
.
/path/to/your/dataset
- train
- images # an folder named "images" that contains all training images
- val
- images # an folder named "images" that contains all validation images
Thanks. I have changed the dataset and it is ok. The problems with configuring models parameters. If I am changing vit image_size and patch_size I get some latent dimensions mismatches. Maybe you have some hints on what I need to change to match.
@thuanz123 https://github.com/lucidrains/parti-pytorch training of the VitVQGanVAE works fine with 32 crops. The only modification I have made to make it work - changed the number of discriminator layers from 4 to 3
The discriminator in my code expect 256 resolution by default if you don't change that in the config file, it will raise an error at this line where size=256. You can add two last line in the loss config like following:
@thuanz123 thanks, it is working now. But the results look disappointing. I have tried to overfit on a super simple toy dataset: random 32 by 32 crops of a single 80 by 80 input image for train and validation. I expected that model will converge super fast for a such simple task, but not. I have trained for 12+ hours on 4 A100 GPU. In your implementation results are slightly less blurred than for lucidrains implementation, but small pixel level details are still not preserved and blurred. It is improving slow, so maybe I just need to train for a week for this :) How do you think, could this be because of crop size(32 instead of default 256)?
Hi @snarb, be aware that both StyleGAN discriminator and ViT is slow to converge. May be lucidrain's code converge faster as he use a lots of techniques for ViT and his discriminator is not StyleGAN discriminator. Also I have never tried crop as small as 32 x 32 so I'm not sure my code and hyper-parameter works for this resolution. I will try to run some small experiment to see. If possible can you upload what image you used ?
@thuanz123 no, it looks like your code is converging faster than lucidrain.
I am using this image:
https://github.com/augmentedperception/spaces_dataset/blob/master/data/800/scene_036/cam_09/image_008.JPG
I the begining I tried to overfit on the whole image with random crops, didn't work well, than
simplified task to the:
transform = T.Compose([
T.CenterCrop(80),
T.RandomCrop(32),
T.ToTensor()
])
I think it could make sense to try to train without discriminator and perceptual loss and than enable it and see how it improves. There could be some tricky points for example with perceptual loss, I remember that I when I have created a VGG loss model it in TF I need to pad the 32 input to 40 px with reflective padding to match the minimal receptive field. So will try plain l2 loss training. Can I just switch loss and disable discriminator from the config?
@thuanz123 no, it looks like your code is converging faster than lucidrain.
I am using this image: https://github.com/augmentedperception/spaces_dataset/blob/master/data/800/scene_036/cam_09/image_008.JPG
I the begining I tried to overfit on the whole image with random crops, didn't work well, than simplified task to the: transform = T.Compose([ T.CenterCrop(80), T.RandomCrop(32), T.ToTensor() ])
I think it could make sense to try to train without discriminator and perceptual loss and than enable it and see how it improves. There could be some tricky points for example with perceptual loss, I remember that I when I have created a VGG loss model it in TF I need to pad the 32 input to 40 px with reflective padding to match the minimal receptive field. So will try plain l2 loss training. Can I just switch loss and disable discriminator from the config?
To disable the discriminator, just change the line
target: enhancing.losses.vqperceptual.VQLPIPSWithDiscriminator
to
target: enhancing.losses.vqperceptual.VQLPIPS
and change the weight of each loss as you desire. For reference, please see the code in here
And thanks for the image, I will try some experiments on 32x32 resolution whenever I'm free and see if there is anything wrong. Also you can try training longer as training vit-vqgan is slow to converge
Having some experiments on small dataset like oxford flower and your image, I came to an conclusion that ViT-VQGAN and StyleGAN discriminator is not good with limited data, it really needs large and diverse dataset to work well. Since there haven't been any work on ViT-based AutoEncoder so I dont really know what is the reason behind this but there is some reasons I can come up with:
- ViT does not have spatial inductive bias like CNN so ViT-VQGAN needs to see many and diverse images
- StyleGAN discriminator is famous for performing poorly on limited dataset, you can try tricks for training on limited dataset like DiffAug or Diffusion-GAN
- The Vector Quantizer used in my code is a simple one and it may be the bottleneck, you can try more advance quantizer here
- Usually, ViT and StyleGAN discriminator requires large batch size and long training time