/subset_selection_SFL

Primary LanguagePythonMIT LicenseMIT

Selective Focusing Learning for Conditional GANs

This repository is the official implementation of My Paper Title.

distribution_overview1

TODO

⬜️ Release pre-trained model on ImageNet 64x64, 128x128, CIFAR-10, CIFAR-100 dataset

⬜️ Release training code for Exact Selective Focusing Learning

About Selective Focusing Learing

Conditional generative adversarial networks (cGANs) have demonstrated remarkable success due to their class-wise controllability and superior quality for complex generation tasks. Typical cGANs solve the joint distribution matching problem by decomposing two easier sub-problems: marginal matching and conditional matching. From our toy experiments, we found that it is the best to apply only conditional matching to certain samples due to the content-aware optimization of the discriminator. This paper proposes a simple (a few lines of code) but effective training methodology, selective focusing learning, which enforces the discriminator and generator to learn easy samples of each class rapidly while maintaining diversity. Our key idea is to selectively apply conditional and joint matching for the data in each mini-batch. We conducted experiments on recent cGAN variants in ImageNet (64x64 and 128x128), CIFAR-10, and CIFAR-100 datasets, and improved the performance significantly (up to 35.18% in terms of FID) without sacrificing diversity.

Requirements

To install requirements:

pip install -r requirements.txt

Training BigGAN with Selective Focusing Learing on ImageNet

To train BigGAN models we use the BigGAN-PyTorch and Instance Selection for GANs repo. We perform minimal changes to the code. The main change part is the conditional term of the projection discriminator in BigGAN.py (L391-L402, L415-L447). Further, updating the focusing rate is represented in train.py (L66-L71, L146-L155, L185-L209).

Preparing Data (Same as Instance Selection for GANs)

To train a BigGAN on ImageNet you will first need to construct an HDF5 dataset file for ImageNet (optional), compute Inception moments for calculating FID, and construct the image manifold for calculating Precision, Recall, Density, and Coverage. All can by done by modifying and running

bash scripts/utils/prepare_data_imagenet_[res].sh

where [res] is substituted with the desired resolution (options are 64, 128, or 256). These scripts will assume that ImageNet is in a folder called data in the instance_selection_for_gans directory. Replace this with the filepath to your copy of ImageNet.

64x64 ImageNet

To replicate our best 64x64 model run bash scripts/launch_SAGAN_res64_ch32_bs128_dstep_1_rr40.sh. A single GPU with at least 12GB of memory should be sufficient to train this model. Training is expected to take about 2-3 days on a high-end GPU.

We added only two configurations: Training_type and maximum_focusing_rate.

parser.add_argument(
  '--Training_type', type=str, default='without_SFL',
  choices=['without_SFL', 'SFL', 'SFL+'],
  help='Training type of SFL (default: %(default)s)')
  
parser.add_argument(
  '--maximum_focusing_rate', type=float, default=1,
  help='The percentage of maximum focusing rate (default: %(default)s)')

Pre-trained weight

SFL+ [SFL] to be

Results

Our model achieves the following performance on :

Conditional Image Generation on ImageNet 64x64

Model name IS ↑ FID ↓ P ↑ R ↑ D ↑ C ↑
SA-GAN 17.77 17.23 0.68 0.66 0.72 0.71
Approx SFL 19.11 16.20 0.69 0.67 0.76 0.76
Approx SFL+ 21.50 14.20 0.72 0.68 0.84 0.80
Exact SFL+ 21.98 13.55 0.73 0.66 0.85 0.81

Applying Selective Focusing Learning to Your Own Dataset or Any cGAN variant architectures

Selective Focusing Learing can be applied to any class labeled PyTorch dataset using the SFL and SFL_plus functions which are a few lines of code.

  def SFL(self, out_c, out_u, Focusing_rate):
    out_c, idx_c = torch.sort(out_c, dim=0, descending=True)
    out_u = out_u[idx_c[:, 0]]
    out = torch.cat([out_c[Focusing_rate:] + out_u[Focusing_rate:], out_c[:Focusing_rate]], 0)
    return out

  def SFL_plus(self, out_c, out_u, Focusing_rate, scores):
    _,idx_c = torch.sort(scores, dim=0)
    out_c = out_c[idx_c]
    out_u = out_u[idx_c]
    out = torch.cat([out_c[Focusing_rate:] + out_u[Focusing_rate:], out_c[:Focusing_rate]], 0)
    return out

Contributing

[1] Brock, Andrew, and Alex Andonian. "BigGAN-PyTorch". https://github.com/ajbrock/BigGAN-PyTorch

[2] Terrance DeVries, Michal Drozdzal, and Graham W. Taylor. "Instance Selection for GANs". https://github.com/uoguelph-mlrg/instance_selection_for_gans