/RTF

A State-Space Model with Rational Transfer Function Representation.

Primary LanguageAssemblyApache License 2.0Apache-2.0

RTF

This repository contains the official implementation of the rational transfer function (RTF) parametrization for state-space layers.

image

Repository Structure

Setup and Usage Guides

Experiment-specific setup and usage guides:

Setup for standalone rtf.py:

pip3 install -r requirements.txt

Example Usage

from rtf import RTF
import torch

seq_len = 1024
d_model = 32
init = "xavier" # Other options: "zeros" (default), "montel"
constraint = "l1_montel" # Other options: "no"|None (default)
batch_size = 1
input = torch.rand(batch_size, seq_len, d_model)

model = RTF(
	d_model=d_model, 
	state_size=128, 
	trunc_len=seq_len, 
	init=init, 
	constraint=constraint)

output = model(input)
print(output.shape)
>>> torch.Size([1, 1024, 32])

Tutorials

Citation

You can cite our work with:

@article{parnichkun2024statefree,
  title={State-Free Inference of State-Space Models: The Transfer Function Approach}, 
  author={Rom N. Parnichkun and Stefano Massaroli and Alessandro Moro and Jimmy T. H. Smith and Ramin Hasani and Mathias Lechner and Qi An and Christopher Ré and Hajime Asama and Stefano Ermon and Taiji Suzuki and Atsushi Yamashita and Michael Poli},
  journal={International Conference on Machine Learning},
  year={2024}
}