Currently, a common method to enhance image classification involves expanding the training set with synthetic datasets generated by T2I models. Here, we propose an inter-class data augmentation method, Diff-Mix. Diff-Mix expands the dataset by conducting image translation in an inter-class manner, significantly improving the diversity of synthetic data. We observe an improved trade-off between faithfulness and diversity with Diff-Mix, resulting in a significant performance gain across various image classification settings, including few-shot classification, conventional classification, and long-tail classification, particularly for domain-specific datasets.
For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets CUB
and Aircraft
we experimented with can be downloaded from Multimodal-Fatima/CUB_train and Multimodal-Fatima/FGVC_Aircraft_train, respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path HUG_LOCAL_IMAGE_TRAIN_DIR
should be specified in the semantic_aug/datasets/cub.py
.
We fine-tune both the textual tokens and U-Net (LoRA) (diffusers) of the pre-trained Stable Diffusion to expedite the fine-tuning process.
To simplify the usage, the concrete fine-tuning command is wrapped in the script scripts/finetune.sh
. The distributed training is performed using the accelerate
tool, and the GPU should be specified using the environmental variable CUDA_VISIBLE_DEVICES
. The simplified command for fine-tuning on the full training set of CUB
with a total of 35000
steps is:
source scripts/finetune.sh
bash finetune 'cub' 'ti_db' -1 35000
To fine-tune in a 5-shot setting, modify the shell command to
source scripts/finetune.sh
bash finetune 'cub' 'ti_db' 5 35000
The fine-tuned checkpoints will be saved under outputs/finetune_model/finetune_ti_db{_5shot}/cub/
. After that, please manually add the meta information of checkpoints into config/finetuned_ckpts.yaml
constructed with the following format:
cub:
ti_db_latest:
model_path: "runwayml/stable-diffusion-v1-5"
lora_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/checkpoint-35000/pytorch_model.bin"
embed_path: "outputs/finetune_model/finetune_ti_db/sd-cub-model-lora-rank10/learned_embeds-steps-35000.bin"
This structure allows you to locate the checkpoint paths simply by using the key set ('cub', 'ti_db_latest').
Similarly, we wrap the command details in the file scripts/sample.sh
. To expedite the inference process, we utilize the multiprocessing
tool to initiate multiple inference processes. The desired processes should be specified using the defined environmental variable GPU_IDS
, where each item in the list denotes the process running on the indexed GPU.
The simplified command for sampling a
source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample 'cub' 'ti_db_latest' 'diff-mix' 0.7
One can also attempt to construct the synthetic subset using other expansion strategies by replacing diff-mix
with diff-aug
(Diff-Aug, fine-tuned intra-class translation method), real-mix
(Real-Mix, pre-trained inter-class translation method), real-guidance
(Real-Aug, pre-trained intra-class translation method).
To sample a 5-shot setting, modify the shell command to:
source scripts/sample.sh
export GPU_IDS=(0 0 0 1 1 1)
bash sample_fewshot 5 'cub' '5shot_ti_db_latest' 'diff-mix' 0.7
The sampled subset will be cached at outputs/aug_samples{_5shot}/cub
. After that, please manually add the meta-information of the subset into synthetic_datasets.yaml
constructed with the form:
cub:
diffmix_0.7: 'outputs/aug_samples/cub/diff-mix-Multi7-ti_db35000-Strength0.7'
5shot_diffmix_0.7: 'outputs/aug_samples_5shot/cub/diff-mix-Multi7-ti_db35000-Strength0.7'
This allows you to locate the synthetic paths simply by using the key set ('cub', 'diffmix_fixed_0.7') in case there are multiple subsets.
After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the following commands:
source scripts/classification.sh
# main_cls {dataset_name} {gpu} {seed} {model} {resolution} {nepoch} {syndata_key} {gamma} {synthetic_prob}
main_cls 'cub' '0' 2020 'resnet50' '224' 120 'diffmix_0.7' 0.5 0.1
Running scripts
This project is built upon the repository Da-fusion and diffusers. Special thanks to the contributors.