A PyTorch implementation of Transformer in "Attention is All You Need" (https://arxiv.org/abs/1706.03762)
This repo focuses on clean, readable, and modular implementation of the paper.
- Python 3.6+
- PyTorch 4.1+
- NumPy
- NLTK
- tqdm
This repo comes with example data in data/
directory. To begin, you will need to prepare datasets with given data as follows:
$ python prepare_datasets.py --train_source=data/example/raw/src-train.txt --train_target=data/example/raw/tgt-train.txt --val_source=data/example/raw/src-val.txt --val_target=data/example/raw/tgt-val.txt --save_data_dir=data/example/processed
The example data is brought from OpenNMT-py. The data consists of parallel source (src) and target (tgt) data for training and validation. A data file contains one sentence per line with tokens separated by a space. Below are the provided example data files.
src-train.txt
tgt-train.txt
src-val.txt
tgt-val.txt
To train model, provide the train script with a path to processed data and save files as follows:
$ python train.py --data_dir=data/example/processed --save_config=checkpoints/example_config.json --save_checkpoint=checkpoints/example_model.pth --save_log=logs/example.log
This saves model config and checkpoints to given files, respectively.
You can play around with hyperparameters of the model with command line arguments.
For example, add --epochs=300
to set the number of epochs to 300.
To translate a sentence in source language to target language:
$ python predict.py --source="There is an imbalance here ." --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
Candidate 0 : Hier fehlt das Gleichgewicht .
Candidate 1 : Hier fehlt das das Gleichgewicht .
Candidate 2 : Hier fehlt das das das Gleichgewicht .
It will give you translation candidates of the given source sentence. You can adjust the number of candidates with command line argument.
To calculate BLEU score of a trained model:
$ python evaluate.py --save_result=logs/example_eval.txt --config=checkpoints/example_config.json --checkpoint=checkpoints/example_model.pth
BLEU score : 0.0007947
models.py
includes Transformer's encoder, decoder, and multi-head attention.embeddings.py
contains positional encoding.losses.py
contains label smoothing loss.optimizers.py
contains Noam optimizer.metrics.py
contains accuracy metric.beam.py
contains beam search.datasets.py
has code for loading and processing data.trainer.py
has code for training model.prepare_datasets.py
processes data.train.py
trains model.predict.py
translates given source sentence with a trained model.evaluate.py
calculates BLEU score of a trained model.