Adding support to images with transparency
markuschue opened this issue ยท 7 comments
Hi, I'm training a VAE and DALLE with a custom dataset which must contain .png images with transparency.
What should I change in the code so that the model can learn and then generate images in this format?
I'm a little bit lost and any tip would help, thanks!
You'll have to change this https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L110 from 3 to 4
oh looks like the train vae script doesn't take care of it, I'll get around to this next week!
@alu0101130507 Hey Markus, do you want to give 1.6.0 a try? bebc280 You'll have to train your own custom DiscreteVAE
with --transparent
flag
Thanks a lot! Unfortunately, I'm still having some problems and when running the train_vae.py
script I get the following errors:
C:\Users\Marku\AppData\Local\Programs\Python\Python38\lib\site-packages\PIL\Image.py:945: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images
warnings.warn(
Traceback (most recent call last):
File ".\train_vae.py", line 234, in <module>
loss, recons = distr_vae(
File "C:\Users\Marku\AppData\Local\Programs\Python\Python38\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Marku\Documents\DALLE-1.6.0\dalle_pytorch\dalle_pytorch.py", line 222, in forward
img = self.norm(img)
File "C:\Users\Marku\Documents\DALLE-1.6.0\dalle_pytorch\dalle_pytorch.py", line 189, in norm
images.sub_(means).div_(stds)
RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 1
Maybe it's because the images are not getting converted to RGBA, but I couldn't find out how to fix it.
@alu0101130507 ohh yes, there were more things i did not consider (like normalization and validation in the DALL-E class)
ok! try 1.6.1 a6776c8
Yes, now it works! Just a little thing more, the TRANSPARENT
constant wasn't defined in the train_dalle.py
file so I added the following line:
TRANSPARENT = True if CHANNELS == 4 else False
Also, in the generate.py
file there was a problem because the images were being saved as JPEG. I changed it to PNG and everything is now working pretty good, thank you so much! :)
@alu0101130507 got it, thank you Markus! feel free to close the issue once you find everything satisfactory :)