The work is done during a 5-month internship at Systran under the supervision of Josep Crego
The training corpus is omitted, you can download the corpus on your own.
First train a tokenizer over the training corpus, e.g. we've done this by a BPE tokenizer from OpenNMT-Tokenizer(pyonmttok)
- baseline model: run
run.sh
and theninfer.sh
, we get a trained baseline model and baseline translation. - inference with penalization, use the same checkpoint of baseline model and run
infer_rep_penalty_only.sh
, we get a penalized translation with reduced word repetition - training with penalization, run
run_rep_penalty.sh
and theninfer_rep_penalty.sh
, we get a model trained with penalization and inference normally(without penalization like 2), we also get a penalized translation with reduced word repetition. - evaluation: compute BLEU score with sacrebleu and word repetition times with
wordrep
in this repo
The training and inference is preferred to run on a GPU.
We recommend you use conda/mamba
to manage the environment, e.g. run
micromamba env create -f mini-trans.yml
micromamba activate mini-trans
pip install -r requirements.txt
Implementation on top of the template written by Guillaume Klein Below are the original readme file.
This repository contains an example of Transformer training with PyTorch. While the code is quite minimal, the training is faster than OpenNMT-tf and models can reach a similar accuracy.
The code implements:
- pre-norm Transformer
- gradient accumulation
- mixed precision training
- multi-GPU training
- checkpoint averaging
- beam search decoding
The default parameters are mostly copied from the Scaling NMT paper.
pip install -r requirements.txt
python3 train.py --src train.en.tok --tgt train.de.tok --src_vocab vocab.en --tgt_vocab vocab.de --save_dir checkpoints/
For multi-GPU training, use --num_gpus N
.
Vocabularies that work with OpenNMT-tf also work here. If you are building your own vocabulary, make sure that it meets the following requirements:
- must have one token per line (no token frequencies or other annotations)
- must start with a padding token (use
<blank>
or<pad>
) - must contain the tokens
<s>
and</s>
- may contain the token
<unk>
(if not present, the token is automatically added in the training)
python3 average.py checkpoints/ --output averaged_checkpoint.pt
python3 beam_search.py --ckpt averaged_checkpoint.pt --src_vocab vocab.en --tgt_vocab vocab.de < test.en.tok