/Implicit-Competitive-Regularization

Code for: Implicit Competitive Regularization in GANs

Primary LanguagePythonApache License 2.0Apache-2.0

Implicit competitive regularization in GANs

This code contains experiments for our ICML paper: Implicit competitive regularization in GANs.

For competitive optimization

Optimizers in this package are for competitive optimization problems, given by $$ \min_{x}\max_{y} f(x,y) $$

Install through pip

pip install CGDs

See details at CGDs package: CGDs · PyPI.

You can also directly copy the folder 'optims' to your workspace.

How to use

The package contains the original Compeititive Gradient Descent (BCGD), and the Adaptive Competitive Gradient Descent (ACGD).

Quickstart with notebook: Examples of using ACGD.

It's important to force cudnn to benchmark and pick the best algo.

Check more details at cgds-package: Package for CGD and ACGD optimizers .

import torch
torch.backends.cudnn.benchmark = True
from CGDs import ACGD
device = torch.device('cuda:0')
lr = 0.0001
G = Generator()
D = Discriminator()
optimizer = ACGD(max_params=G.parameters(), min_params=D.parameters(), lr_max=lr, lr_min=lr, device=device)
# max_parems is maximizing the objective function while the min_params is trying to minimizing it. 
# BCGD(max_params=G.parameters(), min_params=D.parameters(), lr_max=lr, lr_min=lr, device=device)
# ACGD: Adaptive CGD;
for img in dataloader:
    d_real = D(img)
    z = torch.randn((batch_size, z_dim), device=device)
    d_fake = D(G(z))
    loss = criterion(d_real, d_fake)
    optimizer.zero_grad()
    optimizer.step(loss=loss)

==Warning==:

  1. zero sum game setting only. This implementation uses conjugate gradient method to solve matrix inversion efficiently, which requires the matrix to be positive definite. If you are using competitive gradient descent (CGD) algorithm for non-zero sum games, please check more details in CGD paper https://arxiv.org/abs/1905.12103. For example, GMRES (the generalized minimal residual) algorithm can be a solver for non-zero sum setting.
  2. This implementation doesn't work with torch.nn.parallel.DistributedDataParallel module because we need autograd.grad() to compute Hessian vector product. See details at [DDP doc](DistributedDataParallel — PyTorch 1.7.0 documentation) .

Citation

Please cite the following paper if you find this code useful. Thanks!

@misc{schfer2019implicit,
    title={Implicit competitive regularization in GANs},
    author={Florian Schäfer and Hongkai Zheng and Anima Anandkumar},
    year={2019},
    eprint={1910.05852},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}