/DAT-GAN

Direct Adversarial Training for GANs Training (avoiding adversarial examples during GANs training)

Primary LanguagePythonMIT LicenseMIT

DAT-GAN

Direct Adversarial Training of GANs.

Abstract

Generative Adversarial Networks (GANs) are the most popular models for image generation by op-timizing discriminator and generator jointly and gradually. However, instability in training processis still one of the open problems for all GAN-based algorithms. In order to stabilize training, someregularization and normalization techniques have been proposed to make discriminator meet the Lip-schitz continuity constraint. In this paper, a new approach inspired by works on adversarial attack isproposed to stabilize the training process of GANs. It is found that sometimes the images generated bythe generator play a role just like adversarial examples for discriminator during the training process,which might be a part of the reason of the unstable training. With this discovery, we propose to intro-duce a adversarial training method into the training process of GANs to improve its stabilization. Weprove that this DAT can limit the Lipschitz constant of the discriminatoradaptively. The advancedperformance of the proposed method is verified on multiple baseline and SOTA networks, such asDCGAN, WGAN, Spectral Normalization GAN, Self-supervised GAN and Information MaximumGAN.


Installation

The library is forked by mimicry, a lightweight PyTorch library aimed towards the reproducibility of GAN research. We need to install the following packages before running the code.

conda install pytorch torchvision cudatoolkit -c pytorch
conda install tensorflow
pip install torch-mimicry

Training

run SNGAN:

python sngan_example.py

Tips

You need to replace the code in the package torch_mimicry in your environment with the code from this respository.

DAT for discriminator (SNGAN)

def advtrain_step(self,
                   real_batch,
                   netG,
                   optD,
                   log_data,
                   device=None,
                   global_step=None,
                   **kwargs):
        r"""
        Takes one adv_training step for D.

        Args:
            real_batch (Tensor): A batch of real images of shape (N, C, H, W).
            loss_type (str): Name of loss to use for GAN loss.
            netG (nn.Module): Generator model for obtaining fake images.
            optD (Optimizer): Optimizer for updating discriminator's parameters.
            device (torch.device): Device to use for running the model.
            log_data (dict): A dict mapping name to values for logging uses.
            global_step (int): Variable to sync training, logging and checkpointing.
                Useful for dynamic changes to model amidst training.

        Returns:
            MetricLog: Returns MetricLog object containing updated logging variables after 1 training step.
        """
        self.zero_grad()
        real_images, real_labels = real_batch
        batch_size = real_images.shape[0]  # Match batch sizes for last iter

        # Produce logits for real images
        output_real = self.forward(real_images)

        # Produce fake images
        fake_images = netG.generate_images(num_images=batch_size,
                                           device=device).detach()

        # Produce logits for fake images
        output_fake = self.forward(fake_images)

        #compute the adversarial samples of real and fake images.
        t=1
        real_value=torch.mean(output_real)
        fake_value=torch.mean(output_fake)
        fake_imgs_adv=fake_images.clone()
        real_imgs_adv=real_images.clone()
        real_imgs_adv=Variable(real_imgs_adv,requires_grad=True)
        fake_imgs_adv=Variable(fake_imgs_adv,requires_grad=True)

        fake_output= self.forward(fake_imgs_adv)
        fake_output=fake_output.mean()
        fake_adv_loss = torch.abs(fake_output-real_value)
        fake_grad=torch.autograd.grad(fake_adv_loss,fake_imgs_adv)
        fake_imgs_adv=fake_imgs_adv-fake_grad[0].clamp(-1*t,t)
        fake_imgs_adv=fake_imgs_adv.clamp(-1,1)
        real_output= self.forward(real_imgs_adv)
        real_output=real_output.mean()
        real_adv_loss = torch.abs(real_output-fake_value)
        real_grad=torch.autograd.grad(real_adv_loss,real_imgs_adv)
        real_imgs_adv=real_imgs_adv-real_grad[0].clamp(-1*t,t)
        fake_adv_validity= self.forward(fake_imgs_adv.detach())
        real_adv_validity = self.forward(real_imgs_adv)
        real_imgs_adv=real_imgs_adv.clamp(-1,1) 

        # Compute loss for D
        errD = self.compute_gan_loss(output_real=real_adv_validity,
                                     output_fake=fake_adv_validity)

        # Backprop and update gradients
        errD.backward()
        optD.step()

        # Compute probabilities
        D_x, D_Gz = self.compute_probs(output_real=real_adv_validity,
                                       output_fake=fake_adv_validity)

        # Log statistics for D once out of loop
        log_data.add_metric('errD', errD.item(), group='loss')
        log_data.add_metric('D(x)', D_x, group='prob')
        log_data.add_metric('D(G(z))', D_Gz, group='prob')

        return log_data

Citation

If you have found this work useful, please consider citing our work:

@article{li2020direct,
    title={Direct Adversarial Training for GANs},
    author={Li, Ziqiang},
    booktitle={arXiv preprint arXiv:2008.09041},
    year={2020},
}

References

[1] Spectral Normalization for Generative Adversarial Networks

[2] cGANs with Projection Discriminator

[3] Self-Supervised GANs via Auxiliary Rotation Loss

[4] A Large-Scale Study on Regularization and Normalization in GANs

[5] InfoMax-GAN: Improved Adversarial Image Generation via Information Maximization and Contrastive Learning

[6] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium