This is official code of paper "Dataset Distillation with Pre-trained Models: A Contrastive Approach".
Abstract: Dataset distillation is a prominent technique that compresses knowledge from a large-scale original dataset into a small synthetic dataset for efficient training. Recent advancements in dataset distillation have demonstrated that pre-trained models can guide the distillation process by providing auxiliary information. However, most methods are limited to support in-distribution pre-trained models, where the label space of pre-trained models must match that of the dataset intended for distillation. We argue that this limitation underestimates the broader potential of out-of-distribution pre-trained models, such as foundational models. To support a flexible and wide range of pre-trained models, we introduce a plug-and-play loss term, namely Contrastive Loss of pre-trained Model (CLoM). Specifically, CLoM contrasts the original example and the synthetic example by some distance measures and treats them as positive pairs if their labels are the same. Equipped with CLoM, we conduct parallel experiments on both in-distribution and out-of-distribution pre-trained models. Extensive experimental evaluations demonstrate the effectiveness of CLoM in enhancing the performance and cross-architecture generalization of synthetic datasets. Furthermore, guided by CLoM, we elucidate the beneficial impact of foundational models on the distillation process, which unlocks the potential of foundational models in dataset distillation.
The Contrastive Loss of pre-trained Model (CLoM) is defined as:
where
We employ the Cross-Entropy(CE) distance for in-distribution models and cosine distance for out-of-distribution models respectively:
- CE distance:
where
- cosine distance:
where
Install packages in the requirements.txt
.
The following command will train 10 ConvNet models on CIFAR-10.
python train_original_models.py --normalize_data --dataset CIFAR10 --model ConvNet --num 10
- Set
--num
to point how many models needed to be trained , and the pretrained models will be saved at./pretrained_model/[dataset]/original/[model]/...
. - Use the default path of original dataset
./data/[dataset]
or set--data_path
to specify the path. - Change
--model
to train models with different architectures, such as AlexNet, VGG11, ResNet18, et.
Execute following commands to generated synthesis dataset with CLoM on in-distribution pre-trained models with CE distance.
python methods/DC_DSA_DM/main.py --method DC --dataset CIFAR10 --model ConvNet --ipc 10 --CLoM --CLoM_distance ce --models_pool ConvNet --alphas 1000 --model_num 1 --epoch 150 --source_dataset CIFAR10 --CLoM_batch_size 8196
or generated synthesis dataset with CLoM on out-distribution pre-trained models with cosine distance.
python methods/DC_DSA_DM/main.py --method DC --dataset CIFAR10 --model ConvNet --ipc 10 --CLoM --CLoM_distance cos --models_pool ConvNet --alphas 1000 --model_num 1 --epoch 150 --source_dataset CIFAR100 --CLoM_batch_size 8196
-
set
--CLoM
to enable CLoM. -
The synthesis dataset will be saved at
./condensed/[dataset]/[method]/...
. -
Use the default path of original dataset
./data/[dataset]
or set--data_path
to specify the path. -
--alphas
:the weights of each model architecture.--models_pool
: model architectures--model_num
: number of each architecture(initialization parameters)--epoch
: training epoch the models are at--source_dataset
: the dataset (domain) the models are trained on--CLoM_batch_size
: the batch size of CLoM
The following command is an example of validating a specified synthesis dataset.
python validate.py --normalize_data --dataset CIFAR10 --model ConvNet --dsa --method DC --ipc 10 --synthesis_data_path <specified_path>
- Set
--save_model
to save model.