/BicycleGAN-pytorch

Pytorch implementation of BicycleGAN with implementation details

Primary LanguagePython

BicycleGAN-pytorch

Pytorch implementation of BicycleGAN : Toward Multimodal Image-to-Image Translation.

Result

Edges2Shoes

Image size is 128 x 128 and normal discriminator is used, not conditional discriminator. You can check what the conditional discriminator is in Advanced-BicycleGAN in this repository. It generates slightly more diverse, clear and realistic images than the ones below.

  • Random sampling

  • Linear interpolated sampling

Model description

cVAE-GAN

cVAE-GAN is an image reconstruction process. From this, the encoder can extract proper latent code z containing features of given image 'B'. Through this process, the generator can generate image which has features of 'B' but the generator also needs to be able to fool the discriminator. Futhermore, cVAE-GAN uses KL-divergence to make the generator be able to generate images using randomly sampled z from normal distribution at the test phase.

cLR-GAN

This is an latent code reconstruction process. If many latent codes correspond to a same output mode, this is called mode collapse. The main purpose of cLR-GAN is to make invertible mapping between B and z. It leads to bijective consistency between latent encoding and output modes that is significant in preventing model from mode collapse.

Prerequisites

Training step  

Before getting started, suppose that we want to optmize G which can convert domain A into B.

real_B : A real image of domain B from training data set
fake_B : A fake image of domain B made by the generator
encoded_z : Latent code z made by the encoder
random_z : Latent code z sampled randomly from normal distribution

1. Optimize D

  • Optimize D in cVAE-GAN using real_B and fake_B made with encoded_z(Adversarial loss).
  • Optimize D in cLR-GAN using real_B and fake_B made with random_z(Adversarial loss).

2. Optimize G or E

  • Optimize G and E in cVAE-GAN using fake_B made with encoded_z(Adversarial loss).
  • Optimize G and E in cVAE-GAN using real_B and fake_B made with encoded_z(Image reconstruction loss).
  • Optimize E in cVAE-GAN using the encoder outputs, mu and log_variance(KL-div loss).
  • Optimize G in cLR-GAN using fake_B made with random_z(Adversarial loss).

3. Optimize ONLY G(Do not update E)

  • Optimize G in cLR-GAN using random_z and the encoder output mu(Latent code reconstruction loss).

Implementation details

  • Multi discriminator
    First, two discriminators are used for two different last output sizes(PatchGAN); 14x14 and 30x30, for the discriminator to learn from two different scales.
    Second, each discriminator from above have two discriminators because of two images each made with encoded_z(cVAE-GAN) and random_z(cLR-GAN) from N(mu, std) and N(0, 1) respectively. Two discriminators are better than just one discriminator for both distributions.
    Totally, four discriminators are used; (cVAE-GAN, 14x14), (cVAE-GAN, 30x30), (cLR-GAN, 14x14) and (cLR-GAN, 30x30).

  • Encoder
    E_ResNet is used, not E_CNN. Residual block in the encoder is slightly different with the usual one. Check ResBlock class and Encoder class in model.py.

  • How to inject the latent code z to the generator
    Inject only to the input by concatenating, not to all intermediate layers

  • Training data
    Batch size is 1 for both cVAE-GAN and cLR-GAN which means that get two images from the dataloader and distribute to cVAE-GAN and cLR-GAN.

  • How to encode with encoder
    Encoder returns mu and log_variance. Reparameterization trick is used, so encoded_z = random_z * std + mu such that std = exp(log_variance / 2).

  • How to calculate KL divergence
    Following formula is from here. Also if you want to see simple and clean VAE code, you can check here.

From N(0, 1) get KL divergence, so it leads to following formula.

  • How to reconstruct z in cLR-GAN
    mu and log_variance are derived from the encoder in cLR-GAN. Use L1 loss between mu and random_z, not encoded_z and random_z. The reasons are the followings or you can check here.

    1. cLR-GAN is for point estimation not distribution estimation.
    2. If std is too big, L1 loss between encoded_z and random_z can be unstable.

Dataset

You can download many datasets for BicycleGAN from here.

  • Training images : data/edges2shoes/train
  • Test images : data/edges2shoes/test

How to use

Train

python train.py --root=data/edges2shoes --result_dir=result --weight_dir=weight

Test

Random sample

  • Most recent
    python test.py --sample_type=random --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=5

  • Set epoch
    python test.py --sample_type=random --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=5 --epoch=55

Interpolation

  • Most recent
    python test.py --sample_type=interpolation --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=10

  • Set epoch
    python test.py --sample_type=interpolation --root=data/edges2shoes --result_dir=test --weight_dir=weight --img_num=10 --epoch=55

Future work

  • Training with other datasets.
  • New model using conditional discriminator is on the training now Check Advanced-BicycleGAN