/CUT_fork

"Contrastive Learning for Unpaired Image-to-Image Translation" in TensorFlow 2

Primary LanguagePython

Contrastive Unpaired Translation (CUT)

This is an implementation of Contrastive Learning for Unpaired Image-to-Image Translation in Tensorflow 2.

Contrastive Unpaired Translation(CUT) using a framework based on contrastive learning, the goal is to associate the input and output patches, "query" refers to an output patch, positive and negatives are corresponding and noncorresponding input patches. Compared to CycleGAN, CUT enables one-sided translation, while improving quality and reducing training time.

Translated examples of summer2winter

Training

Use train.py to train a CUT/FastCUT model on given dataset. Training takes 340ms(CUDA ops)/400ms(Tensorflow ops) for a singel step on GTX 1080ti.

Example usage for training on horse2zebra-dataset:

python train.py --mode cut                                    \
                --save_n_epoch 10                             \
                --train_src_dir ./datasets/horse2zebra/trainA \
                --train_tar_dir ./datasets/horse2zebra/trainB \
                --test_src_dir ./datasets/horse2zebra/testA   \
                --test_tar_dir ./datasets/horse2zebra/testB   \

Inference

Use inference.py to translate image from source domain to target domain. The pre-trained weights are located here.

Example usage:

python inference.py --mode cut                            \
                    --weights ./output/checkpoints        \
                    --input ./datasets/horse2zebra/testA  \

Qualitative comparisons between the implementation and the results from the paper.

Requirements

You will need the following to run the above:

  • TensorFlow >= 2.0
  • Python 3, Numpy 1.18, Matplotlib 3.3.1
  • If you want to use custom TensorFlow ops:

Acknowledgements