/cosformer-pytorch

Unofficial PyTorch implementation of the paper "cosFormer: Rethinking Softmax In Attention".

Primary LanguageJupyter NotebookMIT LicenseMIT

cosFormer-PyTorch

An unofficial PyTorch implementation of the model proposed in the paper cosFormer: Rethinking Softmax In Attention (Submitted to ICLR 2022).

Image stolen from the paper.

Table of contents

Paper Summary

As many others, this paper buids on recent work on linear attention that is calculated as instead of , where is a kernel function. This reduces the complexity from to . The authors propose to extend this mechanism by including relative distance information in the Q, K product as . After expanding the trigonometric identity, the full equation becomes:

where etc.

As the author of this repo possesses neither the time nor the ability, only the non-causal version of this approach is implemented.

Installation

$ git clone https://github.com/davidsvy/cosformer-pytorch
$ cd cosformer-pytorch
$ pip install -r requirements.txt

Usage

from models.kernel_transformer import Kernel_transformer
import torch

model = Kernel_transformer(
    # Linear attention args:
    use_cos=True,         # Whether to use the cosine reweighting mechanism prposed in the paper.
    kernel='relu',        # Kernel that approximates softmax. Available options are 'relu' and 'elu'.
    denom_eps=1e-5,       # Added to the denominator of linear attention for numerical stability.
    # If use_cos=True & kernel='relu' the model is equivalent to https://openreview.net/pdf?id=Bl8CQrx2Up4
    # If use_cos=False & kernel='elu' the model is equivalent to https://arxiv.org/pdf/2006.16236.pdf
    # Vanilla transformer args:
    d_model=512,
    n_heads=8, 
    n_layers=6,
    n_emb=20000, 
    ffn_ratio=4, 
    rezero=True,          # If True, use the ReZero architecture from https://arxiv.org/pdf/2003.04887.pdf, else the Pre-LN architecture from https://arxiv.org/pdf/2002.04745.pdf
    ln_eps=1e-5, 
    bias=False, 
    dropout=0.2, 
    max_len=1024, 
    xavier=True
)

input_ids = torch.randint(0, 20000, [4, 100])
lengths = torch.randint(1, 100, [4])
attention_mask = torch.arange(100)[None, :] < lengths[:, None]

output = model(
    input_ids=input_ids,
    lengths=lengths,
    attention_mask=attention_mask,
)

Citations

@inproceedings{
anonymous2022cosformer,
title={cosFormer: Rethinking Softmax In Attention},
author={Anonymous},
booktitle={Submitted to The Tenth International Conference on Learning Representations },
year={2022},
url={https://openreview.net/forum?id=Bl8CQrx2Up4},
note={under review}
}