Distill the knowledge of Google's BERT transformer language model into a smaller transformer. A blog post on the topic can be found here.
Using this repository for knowledge distillation is a 5-stage processes outlined as such:
- Download Pretrained Model
- Extract Wikipedia
- Prepare Text For TensorFlow
- Extract Teacher Neural Network Outputs
- Distill Knowledge
python split_text.py --read_file wikipedia.txt --split_number 20 --folder data/split_dir --name_base wiki_split
split_text.py has the following arguments:
Args:
read_file (str) : the txt file that will be split
split_number (int) : the number of smaller txt files that will be created
folder (str) : the path where the split txt files will be placed
name_base (str) : the base name of the split txt files. files will be named as such: base_name_N where N is a number
After splitting Wikipedia into smaller txt files, we can turn all of them into tfrecord files by running multifile_create_pretraining_data.py
python multifile_create_pretraining_data.py \
--input_dir data/split_dir/ \
--output_dir data/record_intermed \
--output_base_name wiki_intermed \
--vocab_file uncased_L-12_H-768_A-12/vocab.txt
multifile_create_pretraining_data.py has the following arguments:
Args:
input_dir (str) : Input directory of raw text files
output_dir (str) : Output directory for created tfrecord files
output_base_name (str) : Output base name for TF example files
vocab_file (str) : The vocabulary file that the BERT model was trained on
do_lower_case (bool) : Whether to lower case the input text. Should be True for uncased models and False for cased models
max_seq_length (int) : Maximum sequence length
max_predictions_per_seq (int) : Maximum number of masked LM predictions per sequence
random_seed (int) : Random seed for data generation
dupe_factor (int) : Number of times to duplicate the input data (with different masks)
masked_lm_prob (float) : Masked LM probability
short_seq_prob (float) : Probability of creating sequences which are shorter than the maximum length
python extract_teacher_labels_truncated.py \
--bert_config_file uncased_L-12_H-768_A-12/bert_config.json \
--data/record_intermed/wiki_intermed_0.tfrecord \
--output_file data/record_distill/wiki_distill_0.tfrecord \
--truncation_factor 10 \
--init_checkpoint uncased_L-12_H-768_A-12/bert_model.ckpt
extract_teacher_labels_truncated.py has the following arguments:
Args:
bert_config_file (str) : The config json file corresponding to the pre-trained BERT model. This specifies the model architecture
input_file (str) : Input TF example files (can be a glob or comma separated)
output_file (str) : The output file that has transformer inputs and teacher outputs
truncation_factor (int) : Number of top probable words to save from teacher network output
init_checkpoint (str) : Initial checkpoint (usually from a pre-trained BERT model)
max_seq_length (int) : The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. Must match data generation
max_predictions_per_seq (int) : Maximum number of masked LM predictions per sequence. Must match data generation
batch_size (int) : Total batch size when processing sequences
python network_distillation_single_machine_truncated.py \
--bert_config_file uncased_L-12_H-768_A-12/bert_config.json \
--input_file data/record_distill/wiki_distill_0.tfrecord \
--output_dir output_dir \
--truncation_factor 10 \
--do_train True \
--do_eval true
network_distillation_single_machine_truncated.py has the following arguments:
Args:
bert_config_file (str) : The config json file corresponding to the pre-trained BERT model. This specifies the model architecture
input_file (str) : Input TF example files (can be a glob or comma separated)
output_dir (str) : The output directory where the model checkpoints will be written
init_checkpoint (str) : Initial checkpoint (usually from a pre-trained BERT model)
truncation_factor (int) : Number of top probable words to save from teacher network output
do_train (bool) : Whether to run training
do_eval (bool) : Whether to run eval on the dev set
max_seq_length (int) : The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. Must match data generation
max_predictions_per_seq (int) : Maximum number of masked LM predictions per sequence. Must match data generation
train_batch_size (int) : Total batch size for training
eval_batch_size (int) Total batch size for eval
learning_rate (float) : The initial learning rate for Adam
num_train_steps (int) : Number of training steps
num_warmup_steps (int) Number of warmup steps
save_checkpoints_steps (int) : How often to save the model checkpoint
iterations_per_loop (int) : How many steps to make in each estimator call
max_eval_steps (int) : Maximum number of eval steps
Now suppose you have a lil cluster of 8 GPU's! If you have Horovod installed, you can perform some distributed training!!! (If you don't have horovod installed you can install it here). We shall run network_distillation_distributed_truncated.py to perform distributed training as such:
mpirun -np 8 \
-H localhost:8 \
-bind-to none -map-by slot \
-x NCCL_DEBUG=INFO -x LD_LIBRARY_PATH -x PATH \
-mca pml ob1 -mca btl ^openib \
python network_distillation_distributed_truncated.py \
--bert_config_file uncased_L-12_H-768_A-12/bert_config.json \
--input_file data/record_distill/wiki_distill_0.tfrecord \
--output_dir output_dir \
--truncation_factor 10 \
--do_train True \
--do_eval true
network_distillation_distributed_truncated.py has the following arguments:
Args:
bert_config_file (str) : The config json file corresponding to the pre-trained BERT model. This specifies the model architecture
input_file (str) : Input TF example files (can be a glob or comma separated)
output_dir (str) : The output directory where the model checkpoints will be written
init_checkpoint (str) : Initial checkpoint (usually from a pre-trained BERT model)
truncation_factor (int) : Number of top probable words to save from teacher network output
do_train (bool) : Whether to run training
do_eval (bool) : Whether to run eval on the dev set
max_seq_length (int) : The maximum total input sequence length after WordPiece tokenization. Sequences longer than this will be truncated, and sequences shorter than this will be padded. Must match data generation
max_predictions_per_seq (int) : Maximum number of masked LM predictions per sequence. Must match data generation
train_batch_size (int) : Total batch size for training
eval_batch_size (int) Total batch size for eval
learning_rate (float) : The initial learning rate for Adam
num_train_steps (int) : Number of training steps
num_warmup_steps (int) Number of warmup steps
save_checkpoints_steps (int) : How often to save the model checkpoint
iterations_per_loop (int) : How many steps to make in each estimator call
max_eval_steps (int) : Maximum number of eval steps