This is the code for the paper "Prioritized training on points that are learnable, worth learning, and not yet learned".
The code uses PyTorch Lightning, Hydra for config file management, and Weights & Biases for logging. The codebase is adapted from this great template.
Conda: conda install --file my_environment.yaml
Poetry: poetry install
The repository also contains a singularity container definition file that can be built and used to run the experiments. See the singularity
folder.
tutorial.ipynb
contains the full training pipeline (irreducible loss model training and target model training) on CIFAR-10. This is the best place to start if you want to understand the code or reproduce our results.
The codebase contains the functionality for all the experiments in the paper (and more 😜).
Start with run_irreducible.py
(which then calls src/train_irreducible.py
). The base config file is configs/irreducible_training.yaml
.
Start with run.py
(which then calls src/train.py
). The base config file is configs/config.yaml
. A key file is src//models/MultiModels.py
---this is the LightningModule that handles the training loop incl. batch selection.
The datamodules are implemented in src/datamodules/datamodules.py
, the individual datasets in src/datamodules/dataset/sequence_datasets
. If you want to add your own dataset, note that __getitem__()
needs to return the tuple (index, input, target)
, where index
is the index of the datapoint with respect to the overall dataset (this is required so that we can match the irreducible losses to the correct datapoints).
All the selection methods mentioned in the paper (and more) are implemented in src/curricula/selection_methods.py
.
All ALBERT experiments are implemented in a separate branch, which is a bit less clean. Good luck :-)
This repo can be used to reproduce all the experiments in the paper. Check out configs/experiment
for some example experiment configs. The experiment files for the main results are:
- CIFAR-10:
cifar10_resnet18_irred.yaml
andcifar10_resnet18_main.yaml
- CINIC-10:
cinic10_resnet18_irred.yaml
andcinic10_resnet18_main.yaml
- CIFAR-100:
cifar100_resnet18_irred.yaml
andcifar100_resnet18_main.yaml
- Clothing-1M:
c1m_resnet18_irred.yaml
andc1m_resnet50_main.yaml
NLP datasets, on a separate branch:
- CoLA:
- Irreducible loss model training:
python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05
- Target model training:
python run_nlp.py +experiment=nlp datamodule.task_name=cola trainer.max_epochs=100 irreducible_loss_generator.f=\"path/to/file" selection_method_nlp=reducible_loss_selection
- Irreducible loss model training:
- SST2:
- Irreducible loss model training:
python run_irreducible_nlp.py +experiment=nlp trainer.max_epochs=10 callbacks=val_loss datamodule.task_name=sst2 trainer.val_check_interval=0.05
- Target model training:
python run_nlp.py +experiment=nlp trainer.max_epochs=15 datamodule.task_name=sst2 +trainer.val_check_interval=0.2 irreducible_loss_generator.f=\"path/to/file" selection_method_nlp=reducible_loss_selection
- Irreducible loss model training:
To run the importance sampling experiments:
Importance sampling on CINIC10
python3 run_simple.py datamodule.data_dir=$DATA_DIR +experiment=importance_sampling_baseline.yaml