Diversity is Definitely Needed: Improving Model-Agnostic Zero-shot Classification via Stable Diffusion (CVPRW)
Jordan Shipard1, Arnold Wiliem1,2, Kien Nguyen Thanh1, Wei Xiang3, Clinton Fookes1
1Signal Processing, Artificial Intelligence and Vision Technologies (SAIVT), Queensland University of Technology, Australia
2Sentient Vision Systems, Australia
3School of Computing, Engineering and Mathematical Sciences, La Trobe University, Australia
Accepted to the Generative Models for Computer Vision Workshop at CVPR 2023
For training and testing
- Python 3.8+
- Pytorch 1.13.0
- Torchvision 0.14.0
For dataset generation
- 'ldm' environment from stable-diffusion
Datasets are hosted on Zenodo with the download links provided in the table below.
Dataset | Download |
---|---|
CIFAR10 Base Class | cifar10_generated_32A.tar.gz |
CIFAR10 Class Prompt | cifar10_generated_class_prompt_32A.tar.gz |
CIFAR10 Multi-Domain | cifar10_generated_multidomain_32A.tar.gz |
CIFAR10 Random Guidance | cifar10_generated_random_scale_32A.tar.gz |
CIFAR10 Merged | cifar10_generated_merged_32A.tar.gz |
CIFAR100 Base Class | cifar100_generated_32A.tar.gz |
CIFAR100 Multi-Domain | cifar100_generated_multidomain_32A.tar.gz |
CIFAR100 Random Scale | cifar100_generated_random_scale_32A.tar.gz |
CIFAR100 Merged | Cifar100_generated_merged_32A.tar.gz |
EuroSAT Base Class | EuroSat_generated_64.tar.gz |
EuroSAT Random Scale | EuroSat_generated_random_scale_64.tar.gz |
EuroSAT Merged | EuroSat_generated_merged_64.tar.gz |
These are the exact generated synthetic datasets and images used to train the networks in the paper. All datasets were generated using Stable Diffusion V1.4. '32A' refers to the image size of 32x32 pixels, which was resized from 512x512 with anti-aliasing. '64' is 64x64 resized from 512x512 without anti-aliasing. Only the datasets which improve performance above the base class (e.g. the best tricks) are currently hosted. If you would like any of the other datasets from the paper either raise an issue, or email me at jordan.shipard@hdr.qut.edu.au.
You can generate your own synthetic datasets using one of the tricks with the create_dataset.py
file. First, ensure the file is located in the same directoy as your Stable Diffusion repository as the file will attempt to run scripts/txt2img.py
.
e.g. python create_dataset.py --classes dog cat --trick class_prompt --outdir synthetic_datasets/cats_and_dogs
This file has the following arguments:
--classes
A list of the class labels used in generating images.--trick
The specific trick you wish to use when generating the dataset. Limited to"class_prompt", "multidomain", "random_scale"
.--outdir
The directory to save the generated images in.
--domains
A list of domains to use when generating the synthetic images.
--min_scale
The minimum possible value for the unconditional random guidance. Default1
.--max_scale
The maximum possible value for the unconditional random guidance. Default5
.
--n_samples
The number of images to produce in a single round of generation. Default2
.--n_iter
The number of iterations to run of producingn_samples
numbers of imgaes. Default1000
.--ddim_steps
The number of DDIM sampling steps. Default40
.--seed
The seed (for reproducible sampling). Default64
.--H
The height of the images to generate. Default512
.--W
The width of the images to generate. Default512
.
You can train a network on a synthetic dataset while testing it on a real dataset using train_network.py
.
e.g. python train_network.py --model Vit-B --epoch 10 --batch_size 32 --dataset cifar100_generated_32A --syn_data_location data/synthetic_cifar10 --real_data_location data/real_cifar10
--model
The image classification model to train. Limited toMBV3, Vit-B, Vit-S, RS50, RS101, convnext, convnext-s
--epoch
Defualt50
. The number of epochs to train for.--batch_Size
Defualt64
. The batch size to use for training and testing.--dataset
The name of the synthetic dataset to use, e.g.cifar100_generated_32A
. NOTE: The dataset needs to already be extracted from its.tar.gz
compressed version if downloaded from one of the above links.--img_size
Defualt32
. The image size of the dataset, can be used to resize the dataset.--lr
Defualt1e-4
. The initial learning rate.--wd
Default0.9
. Weight decay used in training.--model_path
Optional. Path to a pytorch model, the script will then load the weights from this path.--wandb
Default False. This flags the use of the wandb logger. NOTE: Please check the init settings for thelogger
variable inside the training script if you wish to use the wandb logger.--syn_data_location
The location of the synthetic dataset.--real_data_location
The location of the real dataset.
You can test networks using eval_network.py
.
e.g. python eval_network.py --model Vit-B --model_path saved_models/trained_model.pt --dataset cifar10 --real_data_location data/real_cifar10
--model
The image classification model to train. Limited toMBV3, Vit-B, Vit-S, RS50, RS101, convnext, convnext-s
--batch_Size
Defualt64
. The batch size to use for training and testing.--dataset
The name of the real dataset to use, e.g.cifar100
.--img_size
Defualt32
. The image size of the dataset, can be used to resize the dataset.--real_data_location
The location of the real dataset.
This work has been supported by the SmartSat CRC, whose activities are funded by the Australian Government’s CRC Program; and partly supported by Sentient Vision Systems. Sentient Vision Systems is one of the leading Australian developers of computer vision and artificial intelligence software solutions for defence and civilian applications.
@inproceedings{
shipard2023DDN,
title={Diversity is Definitely Needed: Improving Model-Agnostic Zero-shot Classification via Stable Diffusion},
author={Jordan Shipard, Arnold Wiliem, Kien Nguyen Thanh, Wei Xiang, Clinton Fookes},
booktitle={Computer Vision and Pattern Recognition Workshop on Generative Models for Computer Vision},
year={2023},
url={https://arxiv.org/pdf/1908.09791.pdf}
}