/AGE

A implementation of Attribute Group Editing for Reliable Few-shot Image Generation (CVPR 2022)

Primary LanguagePythonMIT LicenseMIT

Attribute Group Editing for Reliable Few-shot Image Generation (CVPR 2022)

In this work, we propose a new “editing-based” method, i.e., Attribute Group Editing (AGE), for few-shot image generation. The basic assumption is that any image is a collection of attributes and the editing direction for a specific attribute is shared across all categories. AGE examines the internal representation learned in GANs and identifies semantically meaningful directions. Specifically, the class embedding, i.e., the mean vector of the latent codes from a specific category, is used to represent the category-relevant attributes, and the category-irrelevant attributes are learned globally by Sparse Dictionary Learning on the difference between the sample embedding and the class embedding. Given a GAN well trained on seen categories, diverse images of unseen categories can be synthesized through editing category-irrelevant attributes while keeping category-relevant attributes unchanged. Without re-training the GAN, AGE is capable of not only producing more realistic and diverse images for downstream visual applications with limited data but achieving controllable image editing with interpretable category-irrelevant directions.


Comparison between images generated by MatchingGAN, LoFGAN, and AGE on Flowers, Animal Faces, and VGGFaces.

Description

Official implementation of AGE for few-shot image generation. Our code is modified from pSp.

Getting Started

Prerequisites

  • Linux
  • NVIDIA GPU + CUDA CuDNN (CPU may be possible with some modifications, but is not inherently supported)
  • Python 3

Installation

  • Clone this repo:
git clone https://github.com/UniBester/AGE.git
cd AGE
  • Dependencies:
    We recommend running this repository using Anaconda. All dependencies for defining the environment are provided in environment/environment.yaml.

Pretrained pSp

Here, we use pSp to find the latent code of real images in the latent domain of a pretrained StyleGAN generator. Follow the instructions to train a pSp model firsly. Or you can also directly download the pSp pre-trained models we provide.

Training

Preparing your Data

  • You should first download the Animal Faces / Flowers / VggFaces and organize the file structure as follows:

    └── data_root
        ├── train                      
        |   ├── cate-id_sample-id.jpg                # train-img
        |   └── ...                                  # ...
        └── valid                      
            ├── cate-id_sample-id.jpg                # valid-img
            └── ...                                  # ...
    

    Here, we provide organized Animal Faces dataset as an example:

    └── data_root
      ├── train                      
      |   ├── n02085620_25.JPEG_238_24_392_167.jpg              
      |   └── ...                                
      └── valid                      
          ├── n02093754_14.JPEG_80_18_239_163.jpg           
          └── ...                                             
    
  • Currently, we provide support for numerous datasets.

    • Refer to configs/paths_config.py to define the necessary data paths and model paths for training and evaluation.
    • Refer to configs/transforms_config.py for the transforms defined for each dataset.
    • Finally, refer to configs/data_configs.py for the data paths for the train and valid sets as well as the transforms.
  • If you wish to experiment with your own dataset, you can simply make the necessary adjustments in

    1. data_configs.py to define your data paths.
    2. transforms_configs.py to define your own data transforms.

Get Class Embedding

To train AGE, the class embedding of each category in both train and test split should be get first by using tools/get_class_embedding.py.

python tools/get_class_embedding.py \
--class_embedding_path=/path/to/save/classs/embeddings \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--train_data_path=/path/to/training/data \
--test_batch_size=4 \
--test_workers=4

Training pSp

The main training script can be found in tools/train.py.
Intermediate training results are saved to opts.exp_dir. This includes checkpoints, train outputs, and test outputs.
Additionally, if you have tensorboard installed, you can visualize tensorboard logs in opts.exp_dir/logs.

Training the pSp Encoder

#set GPUs to use.
export CUDA_VISIBLE_DEVICES=0,1,2,3

#begin training.
python -m torch.distributed.launch \
--nproc_per_node=4 \
tools/train.py \
--dataset_type=af_encode \
--exp_dir=/path/to/experiment/output \
--workers=8 \
--batch_size=8 \
--valid_batch_size=8 \
--valid_workers=8 \
--val_interval=2500 \
--save_interval=5000 \
--start_from_latent_avg \
--l2_lambda=1 \
--sparse_lambda=0.005 \
--orthogonal_lambda=0.0005 \
--A_length=100 \
--psp_checkpoint_path=/path/to/pretrained/pSp/checkpoint \
--class_embedding_path=/path/to/class/embeddings 

Testing

Inference

Having trained your model or using pre-trained models we provide, you can use tools/inference.py to apply the model on a set of images.
For example,

python tools/inference.py \
--output_path=/path/to/output \
--checkpoint_path=/path/to/checkpoint \
--test_data_path=/path/to/test/input \
--train/data_path=/path/to/training/data \
--class_embedding_path=/path/to/classs/embeddings \
--n_distribution_path=/path/to/save/n/distribution \
--test_batch_size=4 \
--test_workers=4 \
--n_images=5 \
--alpha=1 \
--beta=0.005

Repository structure

Path Description
AGE Repository root folder
├  configs Folder containing configs defining model/data paths and data transforms
├  criteria Folder containing various loss criterias for training
├  datasets Folder with various dataset objects and augmentations
├  environment Folder containing Anaconda environment used in our experiments
├ models Folder containting all the models and training objects
│  ├  encoders Folder containing our pSp encoder architecture implementation and ArcFace encoder implementation from TreB1eN
│  ├  stylegan2 StyleGAN2 model from rosinality
│  └  age.py Implementation of our AGE
├  options Folder with training and test command-line options
├  tools Folder with running scripts for training and inference
├  optimizer Folder with Ranger implementation from lessw2020
└  utils Folder with various utility functions

Citation

If you use this code for your research, please cite our paper Attribute Group Editing for Reliable Few-shot Image Generation:

@inproceedings{ding2022attribute,
  title={Attribute Group Editing for Reliable Few-shot Image Generation},
  author={Ding, Guanqi and Han, Xinzhe and Wang, Shuhui and Wu, Shuzhe and Jin, Xin and Tu, Dandan and Huang, Qingming},
  booktitle=CVPR,
  year={2022},
}