This repository contains scripts, configuration files, and examples for a Generative Adversarial Network (GAN) of Population Genetic Alignments. This paper associated with this work can be found here: https://doi.org/10.1093/genetics/iyad063
The scripts here function to train and evaluate a GAN that learns the distribution of and mimicks population genetic alignments. Within layers.py there are several generator and discriminator architectures you can use, but the best performing is a Deep-Convolutional GAN using a Wasserstein loss with gradient penalty (DCWGAN-GP). Full details of the architecture are depicted below.
PG-Alignments-GAN is implemented in python (3.9.7) using pytorch (1.9.1) libraries
All dependencies and libraries are best installed using Conda using the provided environment file. To do so:
git clone git@github.com:SchriderLab/PG-Alignments-GAN.git
cd PG-Alignments-GAN
conda env create -f PG-Alignments-GAN.yml
conda activate PGA-GAN
To train the GAN:
python3 src/train_wgan_v2.py --odir ODIR --idir IDIR --use_cuda --plot
Optional arguments:
-h, --help show this help message and exit
--latent_size LATENT_SIZE
size of latent/noise vector
--idir IDIR input directory
--odir ODIR output directory
--plot plot summaries in output
--gen GEN set what type of generator to be used. Options: sigGen tanGen tanNorm
--loss LOSS whether to use gp or div to make the loss 1-Lipschitz compatible
--gen_lr GEN_LR generator learning rate
--disc_lr DISC_LR discriminator learning rate
--num_in NUM_IN number of input alignments
--use_cuda use cuda?
--save_freq SAVE_FREQ
save model every save_freq epochs
--batch_size BATCH_SIZE
set batch size
--epochs EPOCHS total number of epochs
--critic_iter CRITIC_ITER
number of generator iterations per critic iteration
--gp_lambda GP_LAMBDA
lambda for gradient penalty
--use_buffer use a buffer for fake data sampling
--buffer_n BUFFER_N the buffer size will be this many batches large (integer)
--permute permute real data along the individual axis
--label_smooth label smooth both real and fake data
--label_noise LABEL_NOISE
upper bound of the uniform distribution used to label smooth
--mono_switch switch some input sites to monomorphic for training
--normalize normalize inputs for tanh activation
--shuffle_inds shuffle individuals in each input alignment
--verbose verbose output to log
Input into the GAN can be simluated on the fly using SimulatorGenerator as the data loader or DataGeneratorDisk to load in simulated alignments in a folder of csv files. SimulatorGenerator works with simplistic models and can be edited in the train_wgan_v2.py script to use different models, but for more complex models and those using discoal or stdpopsim it is probably easiest to simulate the data prior to training. After simulating, the output can be piped to either the convert_ms.py script if using a fixed number of sites, or the convert_relative.py script if data are simulated under a model with variable number of sites. The latter will choose the 64 sites surrounding your desired location on the chromosome and normalize the positions from 0 to 1. As an example you can pipe the output of your simulation directly:
ms 64 20000 -s 64 | python convert_ms.py outdir/
or presave your sims to a text file and do:
cat sim.txt | python convert_ms.py outdir/
The outdir here is what will be used as the input directory for training
Below are some example input and generated alignments, evaluated at the point where the 2D Sliced Wasserstein Distance (2DSWD, see Evaluation) between input and generated alignments is minimized
The GAN can be evaluated in a number of ways. One way we did so is to calculate the 2D Sliced Wasserstein Distance, as calculated from the site-frequency-spectrum (SFS), between the input and generated alignments. This measurement is essentially the difference between the input and generated data distributions in multidimensional space. This measurement is calculated at every save frequency (SAVE_FREQ) and an example is shown below. Here, the minimum is reached relatively soon and is stably maintained. In other examples where the GAN struggles this line may be more erratic or increase after reaching a minimum.
Another way to evaluate the GAN is to calculate the Adversarial Accuracy. This measurement is used to determine the level of overfitting or underfitting of the network, where an ideal value of all AA values is 0.5. Essentially this measurement looks at how how often the nearest neighbor alignment to a generated or input alignment is another generated (AAsynth) or input alignment (AAtruth) , respectively, in some multidimensional space. For a perfectly fit model, generated alignments would be next to other generated alignments 50% of the time and similarly for input alignments, resulting in an AAts score of 0.5. For more information see Yelmen et al. (2021). Below, AAts is above 0.5 indicating the model is underfitting, but it is closely tracking with the AAtruth and AAsynth values, meaning the underfitting isn't from the model focusing on some smaller part of the input alignment distribution.
Additional ways to evaluate include investigating the output of the GAN in more detail and looking at metrics relevant to population genetics, such as the SFS. Enabling plotting (--plot) will automatically generate these (and the above 2DSWD and AA plots) in your output directory.
Yelmen, Burak, et al. "Creating artificial human genomes using generative neural networks." PLoS genetics 17.2 (2021): e1009303.