A general distillation pipeline which could easily plugin a teacher that has to be distilled to a student. Currently only supports models from the Huggingface transformer library for the Language Modelling Problem.
train.py
: entry file called from training the studentutils.py
: contains utility functions used by the trainerdistiller.py
: main file holding the distillation model structurepreprocess_data.py
,grouped_batch_sampler.py
,lm_seqs_dataset
: files for preprocessing the raw input into binarized and tokenized text called by thetrain.py
file.
A specific example of using the code for training a distillation for the masked language modelling task.
pip3 install -r requirements.txt
python3 train.py \
--student_name distilroberta-base \
--teacher_name roberta-base \
--teacher_pretrained trained_Roberta_checkpoint \
--alpha_ce 5.0 --alpha_mlm 2.0 --alpha_cos 1.0 --alpha_clm 0.0 --mlm \
--dump_path output/train1 \
--data_file data/dump.txt \
--force --n_gpu 0
The dump path argument holds the following files after training:
checkpoint.pth
: The latest saved checkpoint for the modelmodel_epoch_<epoch_no>.pth
: Checkpoint saved after the completion of the respective epochconfig.json
: Configuration file for your studentparmeters.json
: Parameters used for distillationlogs
: Directory holding the tensorboard logs<teacher_name>.pickle
* : Binarized data file created using the input file<teacher_name>.token_counts.pickle
* : Token Counts for the input file used in the MLM smoothing task
*These files take time to be created during the first run, hence it is advisable to save them in a different directory and use them directly in the following runs by using arguments preprocessed_data_file
and preprocessed_token_counts
.
For any help in understanding the arguments needed by the train.py file use:
python3 train.py -h
In the dump path location config.json
and checkpoint.pth
will be used to load then student. Run the following commands to get the student loaded:
from transformers inport AutoConfig, AutoModelForMaskedLM
stu_config = AutoConfig.from_pretrained('config.json')
stu_config.output_hidden_states = True
student_model = AutoModelForMaskedLM.from_config(stu_config)
student_model.load_state_dict(torch.load("checkpoint.pth", map_location = device))
Here device
refers to the device either cpu or gpu where this model needs to get loaded.
(https://drive.google.com/drive/folders/12PmvrGTB6WWjridWUsDbI6JV7Tj3SDSs?usp=sharing)
This repository was used to distill the knowledge of RoBERTa and XLM-RoBERTa fine tuned on twitter datasets to the DistilRoBERTa student model. The different models distilled are:
- RoBERTa finetuned on English Tweets
- XLM-RoBERTa finetuned on Hindi Tweets
- XLM-RoBERTa finetuned on Latin Tweets
Some Empirical results where the distilled model is able to capture twitter specific lingo better than the base RoBERTa model is given as follows, where the words define the closest cosine distanced words in the respective models:
Word name | fine-tuned-model | base-model |
---|---|---|
ma | my (0.99755), anna (0.9975) | son (0.99207), ji (0.99207) |
af | f**k (0.99625), ma (0.99607), | if (0.99241), ash (0.99198) |
lmao | Lmao (0.99693), Lmaoooo (0.99401) | Lmao (0.94968), lady (0.94861) |
Yeah | Okay (0.99899), Yo (0.99889) | Yes (0.98318), yeah (0.97725) |
This project was part of Sprinklr Inc. ML internship 2021.
Adapted in part from HuggingFace DistilBert training model (https://github.com/huggingface/transformers/tree/master/examples/research_projects/distillation)
Author : Mayank Musaddi (mayankmusaddi1997@gmail.com)
Mentor : Ratnesh Jamidar (ratnesh.jamidar@sprinklr.com)