/PoDD

Official PyTorch Implementation for the "Distilling Datasets Into Less Than One Image" paper.

Primary LanguagePython

Distilling Datasets Into Less Than One Image

Official PyTorch Implementation for the "Distilling Datasets Into Less Than One Image" paper.

PWC PWC PWC PWC

🌐 Project | 📃 Paper

Poster Dataset Distillation (PoDD): We propose PoDD, a new dataset distillation setting for a tiny, under 1 image-per-class (IPC) budget. In this example, the standard method attains an accuracy of 35.5% on CIFAR-100 with approximately 100k pixels, PoDD achieves an accuracy of 35.7% with less than half the pixels (roughly 40k)

___

Distilling Datasets Into Less Than One Image
Asaf Shul*, Eliahu Horwitz*, Yedid Hoshen
*Equal contribution
https://arxiv.org/abs/2403.12040

Abstract: Dataset distillation aims to compress a dataset into a much smaller one so that a model trained on the distilled dataset achieves high accuracy. Current methods frame this as maximizing the distilled classification accuracy for a budget of K distilled images-per-class, where K is a positive integer. In this paper, we push the boundaries of dataset distillation, compressing the dataset into less than an image-per-class. It is important to realize that the meaningful quantity is not the number of distilled images-per-class but the number of distilled pixels-per-dataset. We therefore, propose Poster Dataset Distillation (PoDD), a new approach that distills the entire original dataset into a single poster. The poster approach motivates new technical solutions for creating training images and learnable labels. Our method can achieve comparable or better performance with less than an image-per-class compared to existing methods that use one image-per-class. Specifically, our method establishes a new state-of-the-art performance on CIFAR-10, CIFAR-100, and CUB200 using as little as 0.3 images-per-class.

Poster distillation progress over time followed by a semantic visualization of the distilled classes using a poster of CIFAR-10 with 1 IPC

Project Structure

This project consists of:

  • main.py - Main entry point (handles user run arguments).
  • src/base.py - Main worker for the distillation process.
  • src/PoDD.py - PoDD implementation using RaT-BPTT as the underlying dataset distillation algorithm.
  • src/PoCO.py - PoCO class ordering strategy implementation, using CLIP text embeddings.
  • src/PoDDL.py - PoDDL soft labeling strategy implementation.
  • src/PoDD_utils.py - Utility functions for PoDD.
  • src/data_utils.py - Utility functions for data handling.
  • src/util.py - General utility functions.
  • src/convnet.py - ConvNet model for the distillation process.

Installation

  1. Clone the repo:
git clone https://github.com/AsafShul/PoDD
cd PoDD
  1. Create a new environment with needed libraries from the environment.yml file, then activate it:
conda env create -f environment.yml
conda activate podd

Dataset Preparation

This implementation supports the following 4 datasets:

CIFAR-10 and CIFAR-100

Both the CIFAR-10 and CIFAR-100 datasets are built-in and will be downloaded automatically.

CUB200

  1. Download the data from here
  2. Extract the dataset into ./datasets/CUB200

Tiny ImageNet

  1. Download the dataset by running wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
  2. Extract the dataset into ./tiny-imagenet-200/tiny-imagenet-200
  3. Preprocess the validation split of the dataset to fit torchvision's ImageFolder structure. This can be done by running the function format_tiny_imagenet_val located in ./src/data_utils.py

Running PoDD

The main.py script is the main script in this project.

Below are examples for running PoDD on CIFAR-10, CIFAR100, CUB200 and Tiny ImageNet datasets for 0.9 IPC.

CIFAR-10

python main.py --name=PoDD-CIFAR10-LT1-90 --distill_batch_size=96 --patch_num_x=16 --patch_num_y=6 --dataset=cifar10 --num_train_eval=8 --update_steps=1 --batch_size=5000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=153 --poster_height=60 --poster_class_num_x=5 --poster_class_num_y=2

CIFAR-100

python main.py --name=PoDD-CIFAR100-LT1-90 --distill_batch_size=50 --patch_num_x=20 --patch_num_y=20 --dataset=cifar100 --num_train_eval=8 --update_steps=1 --batch_size=2000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=303 --poster_height=303 --poster_class_num_x=10 --poster_class_num_y=10 --train_y

CUB200

python main.py --name=PoDD-CUB200-LT1-90 --distill_batch_size=200 --patch_num_x=60 --patch_num_y=30 --dataset=cub-200 --num_train_eval=8 --update_steps=1 --batch_size=3000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=25 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=1 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=610 --poster_height=302 --poster_class_num_x=20 --poster_class_num_y=10 --train_y

Tiny ImageNet

python main.py --name=PoDD_TinyImageNet-LT1-90 --distill_batch_size=30 --patch_num_x=40 --patch_num_y=20 --dataset=tiny-imagenet-200 --num_train_eval=8 --update_steps=1 --batch_size=500 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=5 --print_freq=1 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.0005 --lr=0.0005 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=64 --class_area_height=64 --poster_width=1211 --poster_height=608 --poster_class_num_x=20 --poster_class_num_y=10 --train_y

Important Hyper-parameters

  • --patch_num_x and --patch_num_y - The number of extracted overlapping patches in the x and y axis of the poster.
  • --poster_width and --poster_height - The width and height of the poster (controls the distillation data budget).
  • --poster_class_num_x and --poster_class_num_y - The class layout dimensions within the poster as a 2d array (e.g., 10X10 or 20X5), (the product must be equal to the number of classes).
  • --train_y - If set, the model will also optimize a set of learnable labels for the poster.

Tip

Increase the distill_batch_size and batch_size as your GPU memory limitations allow.

Using PoDD with other Dataset Distillation Algorithms

Although we use RaT-BPTT as the underlying distillation algorithm, using PoDD with other dataset distillation algorithms should be straight forward. The main change is replacing the distillation functionality in src/base.py and src/PoDD.py with the desired distillation algorithm.

Citation

If you find this useful for your research, please use the following.

@article{shul2024distilling,
  title={Distilling Datasets Into Less Than One Image},
  author={Shul, Asaf and Horwitz, Eliahu and Hoshen, Yedid},
  journal={arXiv preprint arXiv:2403.12040},
  year={2024}
}

Acknowledgments

  • This repo uses RaT-BPTT as the underlying distillation algorithm, the implementation of RaT-BPTT is based on the following code found in their supplementary materials.