/num2num

PyTorch RNNs on a toy example of translating between numbers in words and digits

Primary LanguagePython

Home

An example (Attentional) Encoder-Decoder RNN in PyTorch, applied to a toy data set: translating back and forth between numbers-as-text ("one thousand and two") and numbers-as-digits ("1002").

Introduction

This project began as an attempt to code up an implementation of attentional RNNs against a simple data set, to understand the inner-workings of the attention mechanism. It turns out that the implementation can get fairly non-trivial.

In particular, the attention portions of the code heavily referenced the implementation in https://github.com/spro/practical-pytorch/tree/master/seq2seq-translation.

Attention Plot

Attention plot between input (top) and output (left). The attention isn't perfect because our models can be way over-parameterized compared the the problem.

Requirements

  • Python 3.6
  • PyTorch 0.2.0
  • (For generating data) num2words, tqdm
  • pandas, argparse, matplotlib, seaborn

Data

Using num2words, we can generate random numbers and get the "ground-truth" text versions of those numbers. Commas are removed, and the remaining tokens have been pre-computed and saved in (datafiles/word2num_tokens)[datafiles/word2num_tokens].

Note that num2num supports both word-level and character-level modeling, so both tokenizations have been pre-computed.

Usage

  1. Git clone this repository.
  2. (Optional) Generate training and validation datasets. A very small sample training and validation dataset is included with the project.
    • E.g.
      python gendata.py \
          --output=datafiles/train_numbers.csv \
          --size=100000
      
      python gendata.py \
          --output=datafiles/val_numbers.csv \ 
          --size=10000
    • Run python gendata.py -h or see gendata.py for details and more options.
  3. Train the model:
    • E.g.
      python run_train.py \
        --train_data_path=datafiles/train_numbers.csv \
        --val_data_path=datafiles/val_numbers.csv \
        --plot_attn_show=False \
        --plot_attn_save_path="output/attn_plots" \
        --model_save_path="output/models
    • Run python run_train.py -h or see orchestrate.py and num2num/config.py for details and more options
  4. Sample from model:
    • E.g.
      python run_sample.py \
        --val_data_path=datafiles/val_numbers.csv \
        --model_save_path="output/models/my_favorite_model
    • Run python run_sample.py -h or see orchestrate.py and num2num/config.py for details and more options

Road map

  • Documentation
  • Tests