The RWKV Language Model
RWKV v2: RNN with Transformer Performance
RWKV v2 is a RNN which can also be directly trained like a GPT transformer. You only need x_t, a_t, b_t of position t to compute the vectors for position t+1. Hence it can be 100x faster than GPT, and 100x more VRAM friendly.
See the release for a 27M params model on enwik8 with 0.72 BPC(dev).
How it works
The a b c d factors work together to build a time-decay curve: X, 1, W, W^2, W^3, ...
Write out the formulas for "token at pos 2" and "token at pos 3" and you will get the idea:
- a and b: EMAs of kv and k.
- c and d: a and b combined with self-attention.
kv / k is the memory mechanism. The token with high k can be remembered for a long duration, if W is close to 1 in the channel.
It's also using my SmallInitEmb trick https://github.com/BlinkDL/SmallInitEmb (applicable to all transformers), and a custom CUDA kernel https://github.com/BlinkDL/RWKV-CUDA .
I find it might be nice to make the model stay on a mid-lr for a long period, because in theory that's where most learning shall happen. For example: 6e-4 to 1e-4 in 15% of steps, stays on 1e-4 for 60% of steps (actually I monitor the loss and decay the lr when it plateaus), then 1e-4 to 1e-5 in 25% of steps.
The pseudocode (execution from top to bottom):
v1
We propose the RWKV language model, with alternating time-mix and channel-mix layers:
-
The R, K, V are generated by linear transforms of input, and W is parameter. The idea of RWKV is to decompose attention into R(target) * W(src, target) * K(src). So we can call R "receptance", and sigmoid means it's in 0~1 range.
-
The Time-mix is similar to AFT (https://arxiv.org/abs/2105.14103). There are two differences.
(1) We changed the normalization (denominator). For masked language models, we define:
(UPDATE: We are using the original AFT normalization in v2)
Initialize K and R matrices (and the output projection matrix) to ZERO for fast & stable convergence.
(2) We decompose W_{t,u,c} and introduce multi-head W (here h is the corresponding head of c):
Moreover we multiply the final output of Time-mix layer by γ(t). The reason for the α β γ factors, is because the context size is smaller when t is small, and this can be compensated using the α β γ factors.
(UPDATE: We remove α β γ factors in v2-RNN and restrict W to be of a simple form and hence able to rewrite it as RNN)
-
The Channel-mix is similar to GeGLU (https://arxiv.org/abs/2002.05202) with an extra R factor. Initialize R and W matrices to ZERO for fast & stable convergence.
-
Finally, we add extra token-shift (time-shift mixing) as in (https://github.com/BlinkDL/minGPT-tuned).
Token-shift (time-shift mixing)
The token-shift explicitly uses (half the channels of this token) & (half the channels of prev token) to generate all vectors (QKV, RWKV, ...).
self.time_shift = nn.ZeroPad2d((0,0,1,-1))
x = torch.cat([self.time_shift(x[:, :, :C//2]), x[:, :, C//2:]], dim = -1)
Dividing channels by 2 and shift-1 works great for char-level English and char-level Chinese LM.
However for BPE-level English LM, it's only effective if your embedding is large enough (at least 1024 - so the usual small L12-D768 model is not enough).
My theory on the effectiveness of token-shift:
When we train a GPT, the hidden representation of a token has to accomplish two different objects:
-
Predict the next token. Sometimes this is easy (obvious next token).
-
Collect all previous context info, so later tokens can use it. This is always hard.
The shifted channels can focus on (2), so we have good propagation of info. It's like some kind of residual connection, or a small RNN inside the transformer.
You can use token-shift in usual QKV self-attention too. I looked at the weights, and found V really likes the shifted channels, less so for Q. Makes sense if you think about it. I also found you may want to use less mixing in higher layers.
p.s. There is a MHA_pro model in this repo with strong performance. Give it a try :)
The Head-QK Trick: learning to copy and avoid tokens
In usual transformer, a small model has difficulty copying tokens (such as person names) in the context. We add extra Q & K to the final output such that the model can directly copy (or avoid) tokens in the context. Afterwards the model will teach itself NER (named entity recognition) if you look at the learned weights.
q = self.head_q(x)[:,:T,:]
k = self.head_k(x)[:,:T,:]
c = (q @ k.transpose(-2, -1)) * (1.0 / 256)
c = c.masked_fill(self.copy_mask[:T,:T] == 0, 0)
c = c @ F.one_hot(idx, num_classes = self.config.vocab_size).float()
x = self.head(x) + c
The top-a Sampling method
We also propose a new sampling method called top-a (as in src/utils.py):
(1) Find the max probability p_max after softmax.
(2) Remove all entries whose probability is lower than 0.02 * pow(p_max, 2). So it's adaptive, hence "top-a".
(3) Feel free to tune the 0.02 and 2 factor. Tune 0.02 first.
The idea of top-a:
- If max_prob=0.9, then remove all tokens with prob < 0.0162 (so, removing most alternatives)
- If max_prob=0.5, then remove all tokens with prob < 0.0050 (so, allowing more choices)
- If max_prob=0.1, then remove all tokens with prob < 0.0002 (so, allowing lots of possibilities)
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), 2) * 0.02
logits[probs < limit] = -float('Inf')
Performance
Character-level loss on simplebooks-92 dataset https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip
Gray: usual MHA+Rotary+GeGLU - performance not as good. 17.2M params.
Red: RWKV ("linear" attention) - VRAM friendly - quite faster when ctx window is long - good performance. 16.6M params.
Green: MHA+Rotary+GeGLU+Token_shift. 17.2M params.
Blue: MHA_pro (MHA with various tweaks & RWKV-type-FFN) - slow - needs more VRAM - good performance. 16.6M params.
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = aug,
year = 2021,
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196577},
url = {https://doi.org/10.5281/zenodo.5196577}
}
Initialization
We use careful initialization for RWKV to get fast convergence - orthogonal matrices with proper scaling, and special time_w curves. Check model.py for details.
Some learned time_w examples: