karpathy/llama2.c

why not use key and value caches in model.py?

mvuthegoat opened this issue · 2 comments

I was just wondering why you didn't use caches to store the key and value tensors in the Transformer like Meta did
Also, Meta uses a different generate function that take advantage of these caches. Their model doesn't have to recalculate K and V with increasing seqlen dimension, like what you did here:

llama2.c/model.py

Lines 340 to 341 in 7ac65cb

# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

Using caches are clearly more efficient imo, but not sure why they weren't implemented. Am I missing something?

It justs costs lines of code and some complexity that I wanted to avoid. The solution I am more inclined towards is to separate out training and inference, which is done in most serious projects. These two regimes are sufficiently different. So what's missing in this repo is a class usually called Engine, which efficiently serves a trained model (in PyTorch). It may be that our sample.py can go on to hold the Engine, and it would maintain the KV cache.

theoretically KV cache could be usefull also in training, if X increments over a token sequence (Y is always the next token).
So from sequence of length n, you get n-1 (X, Y) samples.
I think nanoGPT was like that, but I see llama2.c is not like that.