
Unofficial implementation of Linear Recurrent Units, by Deepmind, in Pytorch

Primary LanguagePythonMIT LicenseMIT


An implementation of Linear Recurrent Units, by Deepmind, in Pytorch. LRUs are inspired by Deep State-Space Machines, particularly S4 and S5.


  • Since Pytorch does not have associative scans as of now, the Pytorch implementation will very likely be slower than a JAX implementation.
  • Complex tensors are still in beta in Pytorch and do not fully support .half(), so using torch.float16 is not advised.
  • Certain tensors are created on every forward pass. This is necessary only during training, and these tensors could be frozen to speed up inference.


$ pip install LRU-pytorch


import torch

from LRU_pytorch import LRU

# Create a single Linear Recurrent Unit, that takes in inputs of size (batch_size, seq_length, 30) (or (seq_length, 30)), 
# with internal state-space variable of size 10, and returns outputs of (batch_size, seq_length, 50) (or (seq_length, 50)).

lru= LRU(

preds= lru(torch.randn([2,50,30])) # Get predictions


in_features: int. The size of each timestep of the input sequence.

out_features: int. The size of each timestep of the output sequence.

state_features:int. The size of the internal state variable.


Resurrecting Recurrent Neural Networks for Long Sequences