This repository contains examples trained using
the python package pro-gan-pth
. You can find the github repo for
the project at
github-repository
and the PyPI package at
pypi
There are two examples presented here for LFW dataset and MNIST dataset. Please refer to the following sections for how to train and / or load the provided trained weights for these models.
Before running any of the following training experiments, please setup
your VirtualEnv
with the required packages for this project. Importantly,
please install the progan package using $ pip install pro-gan-pth
and
your appropriate gpu / cpu version of PyTorch 0.4.0
. Once this
is done, you can proceed with the following experiments.
The configuration used for the LFW training experiment can be found in
implementation/configs/lfw.conf
in this repository. The training was
performed using the wgan-gp
loss function.
The configuration used for the MNIST training experiment can be found in
implementation/configs/mnist.conf
in this repository. The training was
performed using the lsgan
loss function.
For running the training script, simply use the following procedure:
$ cd implementation
$ python train_network.py --config=configs/mnist.conf
You can tinker with the configuration for your desired behaviour. This training script also exposes some of the use cases of the package's api.
You can generate the loss plots from the `loss-logs` by using the provided script. The logs get generated while the training is in progress.$ python generate_loss_plots --logdir=training_runs/mnist/losses/ \
--plotdir=training_runs/mnist/losses/loss_plots/
import torch as th
import pro_gan_pytorch.PRO_GAN as pg
import matplotlib.pyplot as plt
device = th.device("cuda" if th.cuda.is_available()
else "cpu")
gen = pg.Generator(depth=4, latent_size=128,
use_eql=False).to(device)
gen.load_state_dict(
th.load("training_runs/saved_models/GAN_GEN_3.pth")
)
noise = th.randn(1, 128).to(device)
sample_image = gen(noise, detph=3, alpha=1).detach()
plt.imshow(sample_image[0].permute(1, 2, 0) / 2 + 0.5)
plt.show()
The trained weights can be found in the saved_models
directory present in respective training_runs
.
This code can be run on Google Colaboratory using GPU acceleration. Colab offers a free Tesla K80 GPU with up to ~12GB of VRAM. However, the duration of the instance is limited and closes after a certain time. All installed libraries and saved files will be reset in that process. A workaround is to save training results to Google Drive. The packages need to be installed after every instance reset.
Here is a step-by-step instruction on how to run this using Google Colab. ProGAN Colaboratory Notebook
Please feel free to open PRs here if you train on other datasets
using this package.
Best regards,
@akanimax :)