/StableSSM

Primary LanguageJupyter Notebook

StableSSM: Alleviating the Curse of Memory in State-space Models through Stable Reparameterization

PyTorch Lightning Config: Hydra Template
Paper Conference

Description

Use stable reparameterizations to improve the long-term memory learning and optimization stability.

SSMs

The state-space models we are talking about refer to the linear RNNs with layer-wise nonlinear activations.

Discrete-time case: $$h_{k+1} = \Lambda h_k+Ux_k+b$$

$$y_k = c^\top \sigma(h_k)$$

Continuous-time case: $$\frac{dh_{t}}{dt} = \Lambda h_t+Ux_t+b$$

$$y_t = c^\top \sigma(h_t)$$

Stable reparameterization

Let $W$ be the trainable parameters. No reparameterization is unstable parameterization $$\Lambda = W.$$ Stable reparameterization: $$\Lambda = -e^W, -\log(1+e^W).$$

Installation

Pip

# clone project
git clone git@github.com:radarFudan/StableSSM.git
cd StableSSM

# [OPTIONAL] create conda environment
conda create -n StableSSM python=3.11
conda activate StableSSM

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip --default-timeout=1000 install -U -r requirements.txt # CUDA 12
pip --default-timeout=1000 install -U -r requirements_11.txt --index-url https://download.pytorch.org/whl/cu117 # CUDA11

Conda

# clone project
git clone git@github.com:radarFudan/StableSSM.git
cd StableSSM

# create conda environment and install dependencies
conda env create -f environment.yaml -n StableSSM

# activate conda environment
conda activate StableSSM

Refs

Curse of memory phenomneon / definition of memory functions / concept of stable approximation

@inproceedings{
    wang2023statespace,
    title={State-space models with layer-wise nonlinearity are universal approximators with exponential decaying memory},
    author={Shida Wang and Beichen Xue},
    booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
    year={2023},
    url={https://openreview.net/forum?id=i0OmcF14Kf}
}
@inproceedings{
    wang2024stablessm,
    title={Stable{SSM}: Alleviating the Curse of Memory in State-space Models through Stable Reparameterization},
    author={Shida Wang and Qianxiao Li},
    booktitle={Forty-first International Conference on Machine Learning},
    year={2024},
    url={https://openreview.net/forum?id=nMN5hNZMQK}
}

Survey on sequence modelling from approximation perspective

@Article{JML-2-1,
    author = {Haotian Jiang and Qianxiao Li and Zhong Li and Shida Wang},
    title = {A Brief Survey on the Approximation Theory for Sequence Modelling},
    journal = {Journal of Machine Learning},
    year = {2023},
    volume = {2},
    number = {1},
    pages = {1--30},
    abstract = {We survey current developments in the approximation theory of sequence modelling in machine learning. Particular emphasis is placed on classifying existing results for various model architectures through the lens of classical approximation paradigms, and the insights one can gain from these results. We also outline some future research directions towards building a theory of sequence modelling.},
    issn = {2790-2048},
    doi = {https://doi.org/10.4208/jml.221221},
    url = {http://global-sci.org/intro/article_detail/jml/21511.html} }