To achieve lossless dataset distillation, an intuitive idea is to increase the size of the synthetic dataset. However, previous dataset distillation methods tend to perform worse than random selection as IPC (i.e., data keep ratio) increases.
To address this issue, we find the difficulty of the generated patterns should be aligned with the size of the synthetic dataset (avoid generating patterns that are too easy or too difficult).
By doing so, our method remains effective in high IPC cases and achieves lossless dataset distillation for the very first time. What do easy patterns and hard patterns look like?
16 May. The implementation of DATM_with_TESLA is merged. Thanks for the PR from Yue XU!
- Create environment as follows
conda env create -f environment.yaml
conda activate distillation
- Generate expert trajectories
cd buffer
python buffer_FTD.py --dataset=CIFAR10 --model=ConvNet --train_epochs=100 --num_experts=100 --zca --buffer_path=../buffer_storage/ --data_path=../dataset/ --rho_max=0.01 --rho_min=0.01 --alpha=0.3 --lr_teacher=0.01 --mom=0. --batch_train=256
- Perform the distillation
cd distill
python DATM.py --cfg ../configs/xxxx.yaml
DATM_tesla.py
is a TESLA implementation of DATM, which could greatly reduce the VRAM usage, e.g. ~12G for CIFAR10 and IPC=1000.
We provide a simple script for evaluating the distilled datasets.
cd distill
python evaluation.py --lr_dir=path_to_lr --data_dir=path_to_images --label_dir=path_to_labels --zca
Our code is built upon MTT, FTD and TESLA.
If you find our code useful for your research, please cite our paper.
@inproceedings{guo2024lossless,
title={Towards Lossless Dataset Distillation via Difficulty-Aligned Trajectory Matching},
author={Ziyao Guo and Kai Wang and George Cazenavette and Hui Li and Kaipeng Zhang and Yang You},
year={2024},
booktitle={The Twelfth International Conference on Learning Representations}
}