SD2GAN: A Siamese Dual Discriminator Generative Adversarial Network for Mode Collapse Reduction.
This project implements the SD2GAN algorithm, presented in this preprint paper.
Manal Allahyani, Rahaf Alsulami, Taif Alwafi, Tarik Alafif, Heyfa Ammar, Sari Sabban, Xuewen Chen
This repo contains an example notebook with a TensorFlow implementation of SD2GAN on MNIST and Fashion-MNIST datasets.
SD2GAN combined a diversity network, which consisted of a Siamese Network and an additional discriminator with the regular GANs' architecture. SD2GAN architecture illustrated in this figure:
The Siamese network used to measure the similarity in batch of data, we used the following architecture of the Siamese network:
The Algorithm used in SD2GAN training is described as follows:
The result of this algorithm is get an efficient generator that produces diverse and realistic data. To achive this, there are other components that should be included in the algorithm steps such as the Siamese network, discriminator2 and discriminator1. Each component has its own inputs and outputs and all of them contributed to generator training.
Starting with the Siamese network trainig, the training set of Siamese should be integrated by training dataset and the noise set produced by generator before its training which is similar to that produced by the beginning of its training. This resulted in get an efficient Siamese network that can discover the similarity between the real data and the generated data during the other components training loop.
- Python 3.6+
- Tensorflow 2.0+
open SN_SD2GAN.ipynb and set the params then, run the cells the defult params are:
params = {
'dataset': "mnist", # 'mnist' , 'fashion'
'z_dim' : 100, # latent dim
'lr' : 0.0001, #lr of generator and first discrimnator
'lr_disc2' : 0.00009, # lr of the second discrimnator
'beta' : 0.5,
'bs' : 32, # for GAN training
'epochs' : 30000 # for GAN trainig
}
open SD2GAN.ipynb then,
- To use the pretrained model check on parameters and make sure the
train
otption is set to False
params = {
'dataset': "mnist", # 'mnist' , 'fashion'
'train' : False,
'pretrained': "SD2GAN",
'z_dim' : 100,
'lr' : 0.0001,
'lr_disc2' : 0.00009,
'beta' : 0.5,
'bs' : 32,
'epochs' : 30000
}
- To train SD2GAN make sure the
train
otption is set to True
params = {
'dataset': "mnist", # 'mnist' , 'fashion'
'train' : True,
'pretrained': "SD2GAN",
'z_dim' : 100,
'lr' : 0.0001,
'lr_disc2' : 0.00009,
'beta' : 0.5,
'bs' : 32,
'epochs' : 30000
}