/pytorch-seq2seq-example

Fully batched seq2seq example based on practical-pytorch, and more extra features.

Primary LanguageJupyter Notebook

Batched Seq2Seq Example Based on the seq2seq-translation-batched.ipynb from practical-pytorch, but more extra features.

This example runs grammatical error correction task where the source sequence is a grammatically erroneuous English sentence and the target sequence is an grammatically correct English sentence. The corpus and evaluation script can be download at: https://github.com/keisks/jfleg.

Extra features

  • Cleaner codebase
  • Very detailed comments for learners
  • Implement Pytorch native dataset and dataloader for batching
  • Correctly handle the hidden state from bidirectional encoder and past to the decoder as initial hidden state.
  • Fully batched attention mechanism computation (only implement general attention but it's sufficient). Note: The original code still uses for-loop to compute, which is very slow.
  • Support LSTM instead of only GRU
  • Shared embeddings (encoder's input embedding and decoder's input embedding)
  • Pretrained Glove embedding
  • Fixed embedding
  • Tie embeddings (decoder's input embedding and decoder's output embedding)
  • Tensorboard visualization
  • Load and save checkpoint
  • Replace unknown words by selecting the source token with the highest attention score. (Translation)

Cons

Comparing to the state-of-the-art seq2seq library, OpenNMT-py, there are some stuffs that aren't optimized in this codebase:

  • Use CuDNN when possible (always on encoder, on decoder when input_feed=0)
  • Always avoid indexing / loops and use torch primitives.
  • When possible, batch softmax operations across time. (this is the second complicated part of the code)
  • Batch inference and beam search for translation (this is the most complicated part of the code)

How to speed up RNN training?

Several ways to speed up RNN training:

  • Batching
  • Static padding
  • Dynamic padding
  • Bucketing
  • Truncated BPTT

See "Sequence Models and the RNN API (TensorFlow Dev Summit 2017)" for understanding those techniques.

You can use torchtext or OpenNMT's data iterator for speeding up the training. It can be 7x faster! (ex: 7 hours for an epoch -> 1 hour!)

Acknowledgement

Thanks to the author of OpenNMT-py @srush for answering the questions for me! See OpenNMT/OpenNMT-py#552