/mru-lm

An LM forked from my transformer-train-script repo that replaces attention with a novel idea called "matrix recurrent units."

Primary LanguagePythonApache License 2.0Apache-2.0

mru-lm

How to Run

Use the command python main.py --device=cuda --dataset=tiny_stories.py. Set --dataset=shakespeare_char for the character-level Shakespeare dataset.

What is this?

Introduction

This is a project which replaces attention in a traditional GPT2-based transformer with my idea, the linear-complexity matrix recurrent unit (MRU). This repo is forked from my repo transformer-train-script. Based on testing on the shakespeare_char toy dataset, the MRU seems to work well as a replacement for attention. MRU-LM vs Transformer

The above loss plot is the first train attempt, using the independent-heads branch of this repo and my other repo https://github.com/mikayahlevi/transformer-train-script.

Moving Forward

I have limited compute and experience with datascience, so I haven't been able to test the LM on much other than the toy dataset. Firstly, I would like to test this on larger and more informative datasets. If anyone wants to help me with this, reach out to me at mikayahlevi@gmail.com or any other means. Secondly, the MRU is still relatively slow compared to the theoretical amount of operations it should take, so I would like to investigate writing a CUDA kernel or just trying to optimize the PyTorch code.

Explanation

General Idea

The idea of a matrix recurrent unit is dictated by the update rule $H_t = H_{t-1} X_{t-1}$, and $H_1 = X_1$ where $X$ and $H$ are $\mathbb{R}^{s \times d_o \times d_o}$ sequences of square matrices ($d_o$ will be clarified later). The primary difference between this and a traditional RNN is that no initial vector is passed through the linears, instead the first state is a matrix, leading to the output also being a matrix. My motivation for coming up with this idea are based on the following reasons:

  • Matrix multiplication is associative but not commutative. The associativity means I can compute the cumulative matrix product using an (inclusive) parallel scan. The lack of commutativity means that the order of tokens is automatically incorporated into the MRU.
  • When you try to do this scan on an traditional RNN, the number of operations scales cubically with the amount of elements in the output state, meaning that limited information is retained compared to the amount of computation. On the other hand, if the states are matrices, the number of operations as a function of elements in the output state is $((d_o)^2)^\frac{3}{2}$, where $(d_o)^2$ is the number of elements in the square $d_o \times d_o$ output matrix state. Some more info here: https://arxiv.org/abs/1709.04057.
  • When processing the tokens sequentially or in parallel with the Brent-Kung parallel scan, the network scales linearly with time in contrast to attention which scales quadratically with time.

Dimensionality and Construction

Dimensions and Computation Complexity

For the rest of this document, let's call the sequence length $s$, the number of heads $h$, the embedding size of the network $d_e$. and the state size of the network $d_s$. The head size, consequently, is $d_h = \frac{d_s}{h}$. The matrix state order, or the width/height of the matrix states is $d_o = \sqrt{d_h} = \sqrt{\frac{d_s}{h}}$. Lastly, the embedding state chunk size is $d_c = \frac{d_e}{h}$.

The number of operations for the MRU itself in is:

  • Using recurrence

$$ s h (d_o)^2 = s d_s $$

  • Using the Brent-Kung scan

$$ 2 s h (d_o)^3 = 2 s h (\frac{d_s}{h})^\frac{3}{2} $$

  • Using the Hillis-Steel scan

$$ log_2(s) s h (d_o)^3 = log_2(s) s h (\frac{d_s}{h})^\frac{3}{2} $$

The parallel scans take more computation, but they have the advantage of using parallel hardware more effeciently. While an RNN would take $s$ steps on a GPU with infinite cores, the Hillis-Steele scan only takes $log_2(s)$, and the Brent-Kung scan takes $2 log_3(s)$. The scans are just the Brent-Kung and Hillis-Steele prefix sum algorithms but repurposed for matrix multiplication.

Restructuring the Vectors into Matrices and Back

The MRU should take in a sequence of vectors and return a sequence of vector, like any other traditional operation in a neural network. For now I'll be ignoring the batch and sequence dimensions and only focus on the last dimension. $X$ and $H$ are matrices, so the network somehow has to convert vectors to matrices and back. In this case we will call the $x$ is the input (not the same as $X$) and $y$ is the output. The way $X$ is generated follows this formula:

$$ X = \text{reshape}(x, h, d_o, d_c) W_{in} $$

$W_{in}$ is a $h \times d_c \times d_h$ tensor, which has the result of essentially matrix-mutliplying each head of $\text{reshape}(x, h, d_o, d_c)$ by a unique weight matrix. The reshaping chunks the embedding into $d_o$ chunks of size $d_c$.

$y$ is simply generated by the reverse:

$$ y = \text{reshape}(H W_{out}, d_e) $$

Therefore, $W_{out}$ is a $h \times d_h \times d_c$ tensor, which also has the effect of matrix-mutliplying each head by a unique matrix.

Comparison with Similar Projects

After finishing this project, I've been informed that this project actually has quite a bit of overlap with DeltaNet (https://arxiv.org/abs/2102.11174) and RWKV7 (https://x.com/BlinkDL_AI/status/1833863117480280528). Note that I may misunderstand these other projects. The recurrence relation of RWKV7 and DeltaNet is almost a subset of the MRU with additional structure on $X$, except they also have a term that is added to (the equivalent of) $H_t$ at each timestep. Despite the overlap, the MRU still has a two key differences.

  • RWKV7 and DeltaNet don't derive an effecient scan like I do in the next section. The paper Parallelizing Linear Transformers with the Delta Rule over Sequence Length (https://arxiv.org/pdf/2406.06484) does derive a less parallel (if I'm not mistaken) chunkwise form, though.
  • The MRU deconstructs the states to extract one output feature per state matrix element by reshaping it. DeltaNet and RWKV, on the other hand, only extract the square root of the number of elements per state matrix by using the matrices as weight for a linear, leading to orders of magitude more computation for an equivalent number of features.

Efficient Scan

For the MRU, I've derived an effecient algorithm using a parallel scan to compute it. Sorry for my most likely incorrect mathematical notation. I am not well versed in the math fields that this computation involves. Note that the $^T$ symbol refers to transposing the last two dimensions and the $I$ symbol refers to the identity matrix. The closed form ($1 \leq j \leq s$) for the MRU is

$$ H_j = \prod_{i=1}^{j} X_i $$

The forward pass can be computed using a parallel scan.

The backwards pass for the MRU way more complicated. $\frac{\partial F(H_j)}{H_j}$ represents output gradient of $H$, or the the partial derivative in respect to the rest of the network and the loss function. The closed form for the partial derivative is

$$ \frac{\partial F(H_j)}{\partial X_i} = \begin{cases} \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i = 1 \\ \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=2}^{j} X_k \right)^T & \text{if } j > i = 1 \\ \left(\prod_{k=1}^{i-1} X_k \right)^T \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i \neq 1 \\ \left(\prod_{k=1}^{i-1} X_k \right)^T \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=i+1}^{j} X_k \right)^T & \text{if } j > i \neq 1 \\ 0 & \text{if } j < i \end{cases} $$

The gradient of $X_i$ is

$$ \nabla X_i = \sum_{j=1}^{s} \frac{\partial F(H_j)}{\partial X_i} $$

The expanded gradient of $X_j$ is

$$ \nabla X_i = \sum_{j=1}^{s} \begin{cases} \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i = 1 \\ \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=2}^{j} X_k \right)^T & \text{if } j > i = 1 \\ \left(\prod_{k=1}^{i-1} X_k \right)^T \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i \neq 1 \\ \left(\prod_{k=1}^{i-1} X_k \right)^T \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=i+1}^{j} X_k \right)^T & \text{if } j > i \neq 1 \\ 0 & \text{if } j < i \end{cases} $$

If we define $A_{i+1} = H_{i}^T$ and $A_1 = I$, by factoring out $A_i$ the expression can be rewritten like:

$$ \nabla X_i = A_i \sum_{j=1}^{s} \begin{cases} \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i \\ \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=i+1}^{j} X_k \right)^T & \text{if } j > i \\ 0 & \text{if } j < i \end{cases} = A_i \sum_{j=i}^{s} \begin{cases} \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i \\ \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=i+1}^{j} X_k \right)^T & \text{if } j > i \end{cases} $$

I'll call the second part of the gradient a new variable, $B_i$:

$$ B_i = \sum_{j=i}^{s} \begin{cases} \frac{\partial F(H_j)}{\partial H_j} & \text{if } j = i \\ \frac{\partial F(H_j)}{\partial H_j} \left(\prod_{k=i+1}^{j} X_k \right)^T & \text{if } j > i \end{cases} $$

You can see $B_s = \frac{\partial F(H_s)}{\partial H_s}$. The recurrent form for $B$ is $B_i = \frac{\partial F(H_i)}{\partial H_i} + B_{i+1} X_{i+1}^T$. $B_i$ can also be found with this expression:

$$ \begin{bmatrix} 0 & 0 \\ B_i & I \end{bmatrix} = \begin{bmatrix} 0 & 0 \\ B_{i+1} & I \end{bmatrix} \begin{bmatrix} X_{i+1}^T & 0 \\ \frac{\partial F(H_i)}{\partial H_i} & I \end{bmatrix} = \begin{bmatrix} 0 & 0 \\ B_{s} & I \end{bmatrix} \begin{bmatrix} X_{s}^T & 0 \\ \frac{\partial F(H_{s-1})}{\partial H_{s-1}} & I \end{bmatrix} \begin{bmatrix} X_{s-1}^T & 0 \\ \frac{\partial F(H_{s-2})}{\partial H_{s-2}} & I \end{bmatrix} \ldots \begin{bmatrix} X_{i+1}^T & 0 \\ \frac{\partial F(H_{i+1})}{\partial H_{i+1}} & I \end{bmatrix} \begin{bmatrix} X_{i+1}^T & 0 \\ \frac{\partial F(H_{i})}{\partial H_{i}} & I \end{bmatrix} $$

If we let

$$ U_i = \begin{cases} X_{i+1}^T & \text{if } i \neq s \\ 0 & \text{if } i = s \end{cases} $$

and

$$ L_i = \frac{\partial F(H_i)}{\partial H_i} $$

then we can express the equation with $B_i$ like:

$$ \begin{bmatrix} 0 & 0 \\ B_i & I \end{bmatrix} = \prod_{k=0}^{s-i} \begin{bmatrix} U_{s-k} & 0 \\ L_{s-k} & 1 \end{bmatrix} $$

Which can be computed with a reverse parallel scan because matrix multiplication is associative.

Combining all of this, we get the final gradient for the input matrices, $X$, which is $\nabla X_i = A_i B_i$, which can be efficiently computed using a parallel scan and the output of the forward pass.