Selective Knowledge Distillation for Neural Machine Translation

This is the PyTorch implementation of paper: Selective Knowledge Distillation for Neural Machine Translation (ACL2021).

We carry out our experiments on standard Transformer with the fairseq toolkit. If you use any source code included in this repo in your work, please cite the following paper.

@article{wang2021selective,
  title={Selective Knowledge Distillation for Neural Machine Translation},
  author={Wang, Fusheng and Yan, Jianhao and Meng, Fandong and Zhou, Jie},
  journal={arXiv preprint arXiv:2105.12967},
  year={2021}
}

Runtime Environment

  • OS: Ubuntu 16.04.1 LTS 64 bits
  • Python version >=3.6
  • Pytorch version >=1.4
  • To install fairseq and develop locally:
    cd fairseq
    pip install --estable ./
    

Training

For selective distillation: First, you need train a teacher model, the training script is the same with fairseq.

Second, train selective distillation model. The training script is the same with fairseq, except for the following arguments:

  • add --use-distillation for openning knowledge distillation method.
  • add --teacher-ckpt-path for adding the path of teacher model which has been trained in first step.
  • add --distil-strategy for selecting distillation strategy, such as batch_level, global_level .
  • add --distil-rate , the hyper-parameter $r$ control the number of words to get distillation knowledge, which is 0.5 in this paper .
  • add --difficult-queue-size, the hyper-parameter $Q_{size}$ which control the size of global queue. And it does not need to set when use batch_level strategy. In our method, the most suitable value is 30k for WMT'14 En-De and 50k for WMT'19 Zh-En.

For example, the script for global-level training on WMT'14 En-De. The script of WMT'19 is the same with WMT'14 EN-De.

output_dir=directory_of_output
teacher_ckpt=path_of_teacher_ckpt/teacher.pt
data_dir=directory_of_data_bin
distil_strategy=batch_level
disitl_rate=0.5
queue_size=30000


export CUDA_VISIBLE_DEVICES=0,1,2,3

fairseq-train $data_dir --arch transformer_wmt_en_de \
    --share-all-embeddings \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --save-dir $output_dir \
    --max-update 300000 --save-interval-updates 5000 \
    --keep-interval-updates 40 \
    --encoder-normalize-before --decoder-normalize-before \
    --lr 7e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --weight-decay 0.0 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 4, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --eval-bleu-print-samples \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --update-freq 2 --no-epoch-checkpoints \
    --use-distillation --teacher-ckpt-path $teacher_ckpt  --distil-strategy $distil_strategy --distil-rate $disitl_rate \
    --difficult-queue-size $queue_size

Note

  • We need to test every checkpoints separately on validation set, and choose the checkpoint which performs the best. Since the checkpoint_best.pt and the log generated by default may be not right.