/RWKV-LM

The RWKV Language Model

Primary LanguagePythonBSD 2-Clause "Simplified" LicenseBSD-2-Clause

RWKV-LM

We propose the RWKV language model, with alternating time-mix and channel-mix layers:

\begin{align*}
\text{Time-mix :} && \text{TM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_{u} &&\textbf{W}_{t,u,c} &&\cdot&& \text{softmax}_t(\text{K}_{u,c}) &&\cdot&& \text{V}_{u,c}\\
\text{Channel-mix :} && \text{CM}_{t,c} &&=&&\text{sigmoid}(\text{R}_{t,c}) &&\cdot&& &&\textstyle\sum_d &&\textbf{W}_{c,d} &&\cdot&& \text{gelu}(\text{K}_{t,d}) &&\cdot&& \text{V}_{t,d}
\end{align*}

  • The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.

  • The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.

(1) We changed the normalization (denominator). For masked language models, we define:

\text{softmax}_t(\text{K}_{u,c}) = \frac{\exp(\text{K}_{u,c})}{\sum_{v \leq t}\exp(\text{K}_{v,c})}

(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):

W_{t,u,c}=f_h(t-u)\cdot \alpha_h(u) \cdot \beta_h(t)

Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.


the time-shift mixing means explicitly using both (half channel of this token) & (half channel of prev token) to generate all vectors.

i found divide by 2 and shift-1 is the best for chinese LM. you may want to use more shift for english char-level lm. i looked at the weights and found you may want to use less mixing in higher layers.

here is my theory:

when you train a GPT, the hidden representation of a token has to accomplish two different objects:

  1. predict the next token. sometimes this is easy (obvious next token).

  2. collect all prev ctx info so later token can use it. this is always hard.

the time_shifted channels can focus on (2). so we have good propagation of info. it's like some kind of residual connection.

you can use time_shift in usual QKV self-attention too. when i studied the weights, i found V really likes the time_shifted channels. less so for Q. makes sense if you think abt it.


p.s. There is aother MHA_pro model in this repo with strong performance. Give it a try :)


We also propose a new sampling method (as in src/utils.py):

(1) Find the max probability p_max after softmax.

(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2)

(3) Feel free to tune the 0.02 and 2 factor.


Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip

RWKV-vs-MHA

Gray: usual MHA+Rotary+GeGLU - performance not as good.

Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance.

Black: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance.

parameters count: 17.2 vs 18.5 vs 18.5.

@software{peng_bo_2021_5196578,
  author       = {PENG Bo},
  title        = {BlinkDL/RWKV-LM: 0.01},
  month        = aug,
  year         = 2021,
  publisher    = {Zenodo},
  version      = {0.01},
  doi          = {10.5281/zenodo.5196577},
  url          = {https://doi.org/10.5281/zenodo.5196577}
}