pfnet-research/sngan_projection

Request for conditional CIFAR10 codes

mbsariyildiz opened this issue · 16 comments

Hello! I have been really trying hard to reproduce conditional SNGAN results on CIFAR10, recently. But I am far from the ones stated in the paper. I started coding in Pytorch, but somehow couldn't manage to achieve IS score above 7. While looking around, I saw you mentioning this link in anaother issue, which reproduces DCGAN results (with standard CNNs) on CIFAR10. So I basically replaced conv2ds in common/net/DownResBlockXs with snconv2ds, and linear layer with snlinear layer in common/net/ResnetDiscriminator. I also modified wgan_gp updater to optimize hinge loss and used the optimal hyperparameters. But neither dcgan nor the updated wgan_gp (with lam=0.) achieved relatively competitive results. Could you please provide the scripts that reproduce conditional SNGAN results?

What does "conditional SNGAN results" refer to? Is it the scores shown in Table 3 of https://arxiv.org/pdf/1802.05637.pdf?

@takerum Oh, yes. I should have indicated that table in the first place.

The hyper-parameters are the same as the ones used in the ImageNet experiments, except that:

  • The number of iterations is 50K
  • We apply linear decay for the learning rate after the beginning of the training so that the rate would be 0 at the end

The followings are the models used in that paper.

Discriminator:

import sys
import os

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import cuda
import numpy as np
from source.links.sn_embed_id import SNEmbedID
from source.links.sn_linear import SNLinear
from source.links.sn_convolution_2d import SNConvolution2D
from chainer import functions as F
from dis_models.resblocks import Block, OptimizedBlock


class SNResNetDiscriminator32(chainer.Chain):
    def __init__(self, ch=128, n_classes=10, activation=F.relu):
        super(SNResNetDiscriminator32, self).__init__()
        self.activation = activation
        with self.init_scope():
            self.block1 = OptimizedBlock(3, ch)
            self.block2 = Block(ch, ch, activation=activation, downsample=True)
            self.block3 = Block(ch, ch, activation=activation, downsample=False)
            self.block4 = Block(ch, ch, activation=activation, downsample=False)
            self.l5 = SNLinear(ch, 1, initialW=chainer.initializers.GlorotUniform(), nobias=True)
            if n_classes > 0:
                self.l_y = SNEmbedID(n_classes, ch, initialW=chainer.initializers.GlorotUniform())

    def __call__(self, x, y=None):
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.activation(h)
        # Global average pooling
        h = F.sum(h, axis=(2, 3))
        output = self.l5(h)
        if y is not None:
            w_y = self.l_y(y)
            output += F.sum(w_y * h, axis=1, keepdims=True)
        return output

Generator:

import chainer
import chainer.links as L
from chainer import functions as F
from gen_models.resblocks import Block
from source.miscs.random_samples import sample_categorical, sample_continuous


class ResNetGenerator32(chainer.Chain):
    # Please set n_classes to 10 if train conditional GANs.
    def __init__(self, ch=256, dim_z=128, bottom_width=4, activation=F.relu, n_classes=10,
                 distribution="normal"):
        super(ResNetGenerator32, self).__init__()
        print(ch)
        self.bottom_width = bottom_width
        self.activation = activation
        self.distribution = distribution
        self.dim_z = dim_z
        self.n_classes = n_classes
        with self.init_scope():
            self.l1 = L.Linear(dim_z, (bottom_width ** 2) * ch, initialW=chainer.initializers.GlorotUniform())
            self.block2 = Block(ch, ch, activation=activation, upsample=True, n_classes=n_classes)
            self.block3 = Block(ch, ch, activation=activation, upsample=True, n_classes=n_classes)
            self.block4 = Block(ch, ch, activation=activation, upsample=True, n_classes=n_classes)
            self.b5 = L.BatchNormalization(ch)
            self.c5 = L.Convolution2D(ch, 3, ksize=3, stride=1, pad=1, initialW=chainer.initializers.GlorotUniform())

    def sample_z(self, batchsize=64):
        return sample_continuous(self.dim_z, batchsize, distribution=self.distribution, xp=self.xp)

    def sample_y(self, batchsize=64):
        return sample_categorical(self.n_classes, batchsize, distribution="uniform", xp=self.xp)

    def __call__(self, batchsize=64, z=None, y=None):
        if z is None:
            z = self.sample_z(batchsize)
        if y is None:
            y = self.sample_y(batchsize) if self.n_classes > 0 else None
        if (y is not None) and z.shape[0] != y.shape[0]:
            raise ValueError('z.shape[0] != y.shape[0]')
        h = z
        h = self.l1(h)
        h = F.reshape(h, (h.shape[0], -1, self.bottom_width, self.bottom_width))
        h = self.block2(h, y)
        h = self.block3(h, y)
        h = self.block4(h, y)
        h = self.b5(h)
        h = self.activation(h)
        h = F.tanh(self.c5(h))
        return h

@takerum Thank you! I will try ASAP.

I added some changes to the code written by you @mbsariyildiz and I hope you can now reproduce the results.
I trained the model with the latest commit, and saw that the score reaches around 8.7 at the end of the training.

@takerum, I saw the changes and run an experiment right away 😄 I got same results. Now I can focus on "why Pytorch version does not reproduce the results".

Great, and thanks again for your contribution!

@takerum, I tried to reproduce unconditional SNGAN results on the paper with the following config file:

batchsize: 64
iteration: 50000
iteration_decay_start: 0
seed: 0
display_interval: 100
progressbar_interval: 100
snapshot_interval: 10000
evaluation_interval: 1000

models:
  generator:
    fn: gen_models/resnet_32.py
    name: ResNetGenerator
    args:
      dim_z: 128
      bottom_width: 4
      ch: 256
      n_classes: 1


  discriminator:
      fn: dis_models/snresnet_32.py
      name: SNResNetProjectionDiscriminator
      args:
        ch: 128
        n_classes: 1

dataset:
  dataset_fn: datasets/cifar10.py
  dataset_name: CIFAR10Dataset
  args:
    test: False

adam:
  alpha: 0.0002
  beta1: 0.5
  beta2: 0.999

updater:
  fn: updater.py
  name: Updater
  args:
    n_dis: 1
    n_gen_samples: 128
    conditional: False
    loss_type: hinge

Everything else was same as in the conditional GAN setup. However, IS oscillates between 4.3 and 4.6. Do you see what I am doing wrong here?

Which results do you want to reproduce?

If it's the score with standard CNNs (7.42 in the paper), you should use different architectures from those specified in the config files.
Please see this repository https://github.com/pfnet-research/chainer-gan-lib and running the following command should reproduce the score:python train.py --gpu 0 --algorithm stdgan --architecture sndcgan --out result_sndcgan --n_dis 1 --adam_beta1 0.5 --adam_beta2 0.999 .

If the score with ResNet, you can reproduce the results with the same config file as https://github.com/pfnet-research/sngan_projection/blob/master/configs/sn_cifar10.yml, but please change parameters not to use label supervision.

I will make the config for the latter one, just moment

I made and uploaded the code just now, please check it.

Thank you. I'm running an experiment now.

With the code that you provided above, I do get an Inception score around 8.7 as you mentioned but I get the intra-FID score to be around 23. Is there something I am missing ? @takerum

Thank you for your help!

sorry for the late reply.
the FID on CIFAR10 and 100 shown in the paper are not the intra FID that I used for ImageNet experiments.
It's a just conventional FID, i.e., the distance between q(x) and p(x) not q(x|y) and p(x|y).

@takerum In table 8 of the paper, (where the result of 8.6 IS on CIFAR10 is obtained) it says that the model used a gradient penalty if I am reading correctly. However I can't find where this is done in the code — could you clarify whether this result indeed includes gradient penalty and if so, where this computation happens in this repo?

Thank you so much for your time!

Edit: To add to this, can anyone show how the gradient penalty computation works for the class-conditional GAN? I have not been able to find a good resource on this

@amartya18x I get the intra-FID score 23 like you, but my conventional FID is 14.5. What's your conventional FID?