POSTECH-CVLab/PyTorch-StudioGAN

Training a cGAN on MNIST

Adversarian opened this issue · 3 comments

Hi, I hope your day is going well.

I would to ask whether or not it is possible to train a cGAN on MNIST using the StudioGAN framework. I have modified the code to allow for single-channel images and I'm using a pre-resizer to get the 28x28 images up to 32x32 so that I can feed them into present model configurations on StudioGAN, I have also introduced a new ResNet18 eval backbone that was pretrained on MNIST to obtain meaningful IS/FID metrics.

However, it does not seem like I'm able to get any form cGAN (ReACGAN and ADCGAN specifically but I've also tried WGAN-DRA as a sanity check) to work on MNIST. The generator seems to fall behind extremely quickly and doesn't seem to ever catch up again. Favoring the generator in the number of updates per step against the discriminator (by setting d_updates_per_step to 1 and g_updates_per_step to a number greater than 1) doesn't seem to solve the issue.

I would really appreciate some insight on this. If any more information is required of me, I'd be happy to oblige.

ronny3 commented

Hi. Did you find a solution? I too would like to have a custom eval backbone, care to share your code?
I also wonder how did you manage the single channel situation?

Hi, unfortunately I didn't manage to get it to work yet but essentially what I did was to manually append an img_channels to the DATA field of configuration that is built at the start of a run and then pass this down to the modules that build the generators and discriminators. There, if you peruse the source code (under src/models, e.g. src/models/big_resnet.py), you will that the number of image channels is hardcoded to 3 (for instance in big_resnet.py, the Discriminator, we construct a collection d_in_dims_collection that houses the input dimensions of each convolution layer where 3 is hardcoded as the first dimension. Same thing can be observed for the output of the last generator deconvolution layer). Here instead we simply pass on the DATA portion of the config from which we can use the newly appended img_channels field to the model builders.

As for the custom eval backbone, I trained a small ResNet18 on MNIST and then edited the LoadEvalModel function under src/metrics/preparation.py to accommodate for the special case of MNIST. Most of the code there can be reused and extended for further use cases.

I didn't fork the project and did all of the edits on-the-fly on a private repository so I'm afraid I can't share any source code at this moment but I hope my explanations suffice. If you manage to get this to work some other way I'd appreciate it if you shared your method here as well.

ronny3 commented

I didn't fork the project and did all of the edits on-the-fly on a private repository so I'm afraid I can't share any source code at this moment but I hope my explanations suffice. If you manage to get this to work some other way I'd appreciate it if you shared your method here as well.

Thanks for this! And I will try to see what its about.