/Robust-Fill

PyTorch implementation of Robust Fill

Primary LanguagePythonMIT LicenseMIT

PyTorch implementation of RobustFill

Original Paper: https://arxiv.org/pdf/1703.07469.pdf

Model Checkpoint: https://huggingface.co/eddyyeo/robustfill

The RobustFill network by Devlin et al. is trained for the following task -- based on a few example input-output string pairs, generate a program in a domain-specific language that transforms the given inputs into the given outputs. This program can then be used to transform unseen inputs. For example:

Given these pairs:

Input Output
Jacob Devlin Devlin, J.
Eddy Yeo Yeo, E.
Andrej Karpathy Karpathy, A.
Anatoly Yakovenko Yakovenko, A.

The RobustFill network will generate a program that can be used to transform an unbounded number of unseen inputs:

Unseen input Transformed Output
Elon Musk Musk, E.
Joe Rogan Rogan, J.
Balaji Srinivasan Srinivasan, B.

The program generated by our trained network for the example above is as follows:

Concat(
    Compose(
        Trim(),
        GetFrom(<Type.LOWER: 6>)
    ),
    ConstStr(','),
    ConstStr(' '),
    GetUpto(<Type.CHAR: 8>),
    ConstStr('.')
)

See the demo notebook to reproduce the result with the model checkpoint:

The network was trained on Google Cloud with 4 x NVIDIA Tesla P4 using PyTorch's Distributed Data Parallel.

Instructions

Set up environment:

python3 -m venv env
source env/bin/activate
pip install -r requirements.txt

Train neural net. The script will automatically use GPU(s) if they are available.

python train.py --mode full

For testing purposes, run smaller network (on CPU) with a smaller problem size just to see that the loss goes to 0.

python train.py --mode easy

Run profiler:

python train.py --mode profile

Run unit tests:

python -m unittest

Lint:

flake8