/reformer_lm

a Pytorch implementation of the Reformer Network (https://openreview.net/pdf?id=rkgNKkHtvB)

Primary LanguageJupyter NotebookMIT LicenseMIT

Reformer

a Pytorch implementation of the Reformer Network (https://openreview.net/pdf?id=rkgNKkHtvB)

Much of this code base is loosely translated from the jax implementation found here from Google: https://github.com/google/trax/blob/master/trax/models/research/reformer.py

How to use

All of the hard work has been taken care of, all you need to do is instantiate the model!

from reformer_lm.reformer_lm import ReformerLM
import torch

test = torch.rand((4, 4, 64))
model = ReformerLM(
    vocab_size=300000,
    d_in=test.shape[-2],
    d_out=test.shape[-1],
    n_layers=6,
    n_heads=1,
    attn_k=test.shape[-1],
    attn_v=test.shape[-1],
)

output = model(test)
print(output)

This model is still in testing, and will therefore continue to see updates. PRs are welcomed! Feel free to take advantage of the Docker container for development. I have been working in notebooks to test code with the original paper, and then I refactor my code back into the package

paypal