Direct Adversarial Training of GANs.
Generative Adversarial Networks (GANs) are the most popular models for image generation by op-timizing discriminator and generator jointly and gradually. However, instability in training processis still one of the open problems for all GAN-based algorithms. In order to stabilize training, someregularization and normalization techniques have been proposed to make discriminator meet the Lip-schitz continuity constraint. In this paper, a new approach inspired by works on adversarial attack isproposed to stabilize the training process of GANs. It is found that sometimes the images generated bythe generator play a role just like adversarial examples for discriminator during the training process,which might be a part of the reason of the unstable training. With this discovery, we propose to intro-duce a adversarial training method into the training process of GANs to improve its stabilization. Weprove that this DAT can limit the Lipschitz constant of the discriminatoradaptively. The advancedperformance of the proposed method is verified on multiple baseline and SOTA networks, such asDCGAN, WGAN, Spectral Normalization GAN, Self-supervised GAN and Information MaximumGAN.
The library is forked by mimicry, a lightweight PyTorch library aimed towards the reproducibility of GAN research. We need to install the following packages before running the code.
conda install pytorch torchvision cudatoolkit -c pytorch
conda install tensorflow
pip install torch-mimicry
run SNGAN:
python sngan_example.py
You need to replace the code in the package torch_mimicry in your environment with the code from this respository.
def advtrain_step(self,
real_batch,
netG,
optD,
log_data,
device=None,
global_step=None,
**kwargs):
r"""
Takes one adv_training step for D.
Args:
real_batch (Tensor): A batch of real images of shape (N, C, H, W).
loss_type (str): Name of loss to use for GAN loss.
netG (nn.Module): Generator model for obtaining fake images.
optD (Optimizer): Optimizer for updating discriminator's parameters.
device (torch.device): Device to use for running the model.
log_data (dict): A dict mapping name to values for logging uses.
global_step (int): Variable to sync training, logging and checkpointing.
Useful for dynamic changes to model amidst training.
Returns:
MetricLog: Returns MetricLog object containing updated logging variables after 1 training step.
"""
self.zero_grad()
real_images, real_labels = real_batch
batch_size = real_images.shape[0] # Match batch sizes for last iter
# Produce logits for real images
output_real = self.forward(real_images)
# Produce fake images
fake_images = netG.generate_images(num_images=batch_size,
device=device).detach()
# Produce logits for fake images
output_fake = self.forward(fake_images)
#compute the adversarial samples of real and fake images.
t=1
real_value=torch.mean(output_real)
fake_value=torch.mean(output_fake)
fake_imgs_adv=fake_images.clone()
real_imgs_adv=real_images.clone()
real_imgs_adv=Variable(real_imgs_adv,requires_grad=True)
fake_imgs_adv=Variable(fake_imgs_adv,requires_grad=True)
fake_output= self.forward(fake_imgs_adv)
fake_output=fake_output.mean()
fake_adv_loss = torch.abs(fake_output-real_value)
fake_grad=torch.autograd.grad(fake_adv_loss,fake_imgs_adv)
fake_imgs_adv=fake_imgs_adv-fake_grad[0].clamp(-1*t,t)
fake_imgs_adv=fake_imgs_adv.clamp(-1,1)
real_output= self.forward(real_imgs_adv)
real_output=real_output.mean()
real_adv_loss = torch.abs(real_output-fake_value)
real_grad=torch.autograd.grad(real_adv_loss,real_imgs_adv)
real_imgs_adv=real_imgs_adv-real_grad[0].clamp(-1*t,t)
fake_adv_validity= self.forward(fake_imgs_adv.detach())
real_adv_validity = self.forward(real_imgs_adv)
real_imgs_adv=real_imgs_adv.clamp(-1,1)
# Compute loss for D
errD = self.compute_gan_loss(output_real=real_adv_validity,
output_fake=fake_adv_validity)
# Backprop and update gradients
errD.backward()
optD.step()
# Compute probabilities
D_x, D_Gz = self.compute_probs(output_real=real_adv_validity,
output_fake=fake_adv_validity)
# Log statistics for D once out of loop
log_data.add_metric('errD', errD.item(), group='loss')
log_data.add_metric('D(x)', D_x, group='prob')
log_data.add_metric('D(G(z))', D_Gz, group='prob')
return log_data
If you have found this work useful, please consider citing our work:
@article{li2020direct,
title={Direct Adversarial Training for GANs},
author={Li, Ziqiang},
booktitle={arXiv preprint arXiv:2008.09041},
year={2020},
}
[1] Spectral Normalization for Generative Adversarial Networks
[2] cGANs with Projection Discriminator
[3] Self-Supervised GANs via Auxiliary Rotation Loss
[4] A Large-Scale Study on Regularization and Normalization in GANs
[6] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium