D2A2-GAN

Discriminator-Driven Attention-Aware GAN (D2A2-GAN)
識別器の注視領域に着目して画像を生成するGAN.
GeneratorはDCGAN (Deep Convolutional GAN)とほぼ同様に構築した.
Discminatorは最終層をAttention branchとAdversarial branchに分割した.
Attention branchはABN (Attention Branch Network)を参考に導入しており,入力画像に対するAttention mapの生成及びクラス分類を行う.
Adversarial branchは,Attention機構を用いてAttention mapを特徴マップへ反映して敵対的な誤差を出力する. 正確なAttention mapを得るための工夫としてABNの出力するAttention mapを教師信号としてconsistency lossを計算している.

  • main.py
    D2A2-GANを動かすためのメインのソースコード.

  • utils.py
    Test時に生成した画像及びAttention mapをTesnorboardに書き込むためのソースコード.

  • nets/Generator.py
    Generatorのネットワークが記述してある.

  • nets/Discriminator.py
    Disciminatorのネットワークが記述してある.

  • ABN
    ABNのディレクトリの中にはAttention branch networkが含まれている.

生成した画像,Attention mapや誤差は,tensorboardに書き込むように作成している.

Useage

  • cifar10
python3 main.py --training_data_name cifar10 --epoch 100 --gpu 0
  • svhn
python3 main.py --training_data_name svhn --epoch 100 --gpu 0

generating_images_for_data_augmentation.pyを動かすとクラスラベルと生成画像をディレクトリに保存する. クラスラベルは,onehot及びAttention branchの出力にsoftmax関数を施したものを保存する.

Requirement

  • python3
  • pytorch ver1.0.0以上
  • tqdm