POSTECH-CVLab/PyTorch-StudioGAN

how to instantiate a model, load a checkpoint and visualize some generated images?

coloedrainbow opened this issue · 5 comments

Could you help with this please? I'm trying to run a DCGAN model form the StudioGAN repo and using the checkpoint at Hugging face. No matter what I try it shows different error. Can any one please show me how to instantiate one of the GAN model like DCGAN, WGAN,... and then load a checkpoints and just generate some data. No training whatsoever.

As my system is quite weak I need to run over Colab. Here's what I did:

I've downloaded checkpoints from Hugging face here for Cifar10-DCGAN. then made a folder "checkpoints" and uploaded this checkpoints. Also made a folder with name "results". Then:

!pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
!pip install tqdm ninja h5py kornia matplotlib pandas sklearn scipy seaborn wandb PyYaml click requests pyspng imageio-ffmpeg timm
! pip install wandb
!wandb login  xxx # my login info....

Then cloned the repo

!git clone https://github.com/POSTECH-CVLab/PyTorch-StudioGAN.git

And then once I did

import os
os.chdir("PyTorch-StudioGAN")

!CUDA_VISIBLE_DEVICES=0 python3  \
/content/PyTorch-StudioGAN/src/main.py \
-sf -metrics is fid prdc -cfg /content/PyTorch-StudioGAN/src/configs/CIFAR10/DCGAN.yaml \
-data data/cifar10 -ckpt /content/checkpoints -save /content/results  

The tag -t is for training and -sf is for saving figures I guess. So, it still shows an error! What's the issue and how can I fix it??
And it give the error message of :

Traceback (most recent call last):
  File "/content/PyTorch-StudioGAN/src/main.py", line 192, in <module>
    loader.load_worker(local_rank=rank,
  File "/content/PyTorch-StudioGAN/src/loader.py", line 237, in load_worker
    ckpt.load_StudioGAN_ckpts(ckpt_dir=cfgs.RUN.ckpt_dir,
  File "/content/PyTorch-StudioGAN/src/utils/ckpt.py", line 81, in load_StudioGAN_ckpts
    Gen_ckpt_path = glob.glob(glob.escape(x) + '*.pth')[0]
IndexError: list index out of range

It seems like there is a problem with loading ckpts. Did you change the file name of the checkpoint? If so, please retry with original file name in huggingface as StudioGAN expects specific format of file name when loading. (i.e. "model=G-{when}-weights-step=")

Hi @alex4727 thank you again very much. No, I didn't change anything. Is there any documentation that I can refer and read. I really want to work with this repo's awesom GAN collections.
I've just doenloaded one of cifar10-DCGANs (my screenshot) and uploaded on colab.

Have a great day joonghyuk. and thank for your help.

image

Problem seems to be here.

import os
os.chdir("PyTorch-StudioGAN")
!CUDA_VISIBLE_DEVICES=0 python3  \
/content/PyTorch-StudioGAN/src/main.py \
-sf -metrics is fid prdc -cfg /content/PyTorch-StudioGAN/src/configs/CIFAR10/DCGAN.yaml \
-data data/cifar10 -ckpt /content/checkpoints -save /content/results  

Once you use os.chdir("PyTorch-StudioGAN") you should run codes inside the directory like below.

!CUDA_VISIBLE_DEVICES=0 python src/main.py \
-sf -metrics is fid prdc -cfg src/configs/CIFAR10/DCGAN.yaml \
-data data/cifar10 -ckpt /content/DCGAN -save /content/results 

Besides, if you don't need metrics and just need to visualize generated images, use below command.

!CUDA_VISIBLE_DEVICES=0 python src/main.py \
-sf -sf_num NUMBER_OF_IMAGES_TO_GENERATE -metrics none -cfg src/configs/CIFAR10/DCGAN.yaml \
-data data/cifar10 -ckpt /content/DCGAN -save /content/results 

I've checked it on colab and it worked fine! Let me know if you have other issues.

Hi @alex4727 thank you very very very much! I've got IndexError: list index out of range first. But after cloning again and repeating your instruction once again, Yes, it's now working. if possible can you share your colab too? It'd be informative I'm sure.

Thank you again so much. I'm happy that I've finally understood it :) with your help 👍

@coloedrainbow @alex4727 and for anyone interested in running it from the terminal, I have made this small bash script:

How to execute:

 bash get_gan.sh
CUDA_VISIBLE_DEVICES=0 
name=$"MHGAN"
mkdir ./content/$name
mkdir ./content/$name/checkpoints
outdir=./content/$name/checkpoints
echo $outdir
# Pass BEST! weights here - ALL 3 checkpoints are needed!
# Otherwise remove  --load_best  from the main.py!!!!
wget -P $outdir https://huggingface.co/Mingguksky/PyTorch-StudioGAN/resolve/main/studiogan_official_ckpt/CIFAR10_tailored/CIFAR10-MHGAN-train-2022_02_14_18_23_18/model%3DD-best-weights-step%3D98000.pth
wget -P $outdir https://huggingface.co/Mingguksky/PyTorch-StudioGAN/resolve/main/studiogan_official_ckpt/CIFAR10_tailored/CIFAR10-MHGAN-train-2022_02_14_18_23_18/model%3DG-best-weights-step%3D98000.pth
wget -P $outdir https://huggingface.co/Mingguksky/PyTorch-StudioGAN/resolve/main/studiogan_official_ckpt/CIFAR10_tailored/CIFAR10-MHGAN-train-2022_02_14_18_23_18/model%3DG_ema-best-weights-step%3D98000.pth

python src/main.py --load_best \
-sf -sf_num 100 -metrics none -cfg ./src/configs/CIFAR10/$name.yaml \
-data data/cifar10 -ckpt ./content/$name/checkpoints -save ./content/$name

Best,

Nikolas