An example (Attentional) Encoder-Decoder RNN in PyTorch, applied to a toy data set: translating back and forth between numbers-as-text ("one thousand and two") and numbers-as-digits ("1002").
This project began as an attempt to code up an implementation of attentional RNNs against a simple data set, to understand the inner-workings of the attention mechanism. It turns out that the implementation can get fairly non-trivial.
In particular, the attention portions of the code heavily referenced the implementation in https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation.
Attention plot between input (top) and output (left). The attention isn't perfect because our models can be way over-parameterized compared the the problem.
- Python 3.6
- PyTorch 0.2.0
- (For generating data)
num2words
,tqdm
pandas
,argparse
,matplotlib
,seaborn
Using num2words
, we can generate random numbers and get the "ground-truth" text versions of those numbers. Commas are removed, and the remaining tokens have been pre-computed and saved in (datafiles/word2num_tokens
)[datafiles/word2num_tokens].
Note that num2num
supports both word-level and character-level modeling, so both tokenizations have been pre-computed.
- Git clone this repository.
- (Optional) Generate training and validation datasets. A very small sample training and validation dataset is included with the project.
- E.g.
python gendata.py \ --output=datafiles/train_numbers.csv \ --size=100000 python gendata.py \ --output=datafiles/val_numbers.csv \ --size=10000
- Run
python gendata.py -h
or seegendata.py
for details and more options.
- E.g.
- Train the model:
- E.g.
python run_train.py \ --train_data_path=datafiles/train_numbers.csv \ --val_data_path=datafiles/val_numbers.csv \ --plot_attn_show=False \ --plot_attn_save_path="output/attn_plots" \ --model_save_path="output/models
- Run
python run_train.py -h
or seeorchestrate.py
andnum2num/config.py
for details and more options
- E.g.
- Sample from model:
- E.g.
python run_sample.py \ --val_data_path=datafiles/val_numbers.csv \ --model_save_path="output/models/my_favorite_model
- Run
python run_sample.py -h
or seeorchestrate.py
andnum2num/config.py
for details and more options
- E.g.
- Documentation
- Tests