berlino/gated_linear_attention

Tips for training from scratch?

luchris429 opened this issue ยท 10 comments

Hello,

I've been playing with this architecture on nanoGPT. While I can get other architectures to play nicely there (e.g. RMT), I'm really struggling to get GLA to perform well.

Do you have any tips or code for training? For example, do you have a repository you recommend or key hyperparameter differences to normal transformers?

Thanks!

~44M parameters. GLA is getting ~4 bpc (~2.7 nats), whereas RMT and Transformer can get ~1.4 bpc (~1.0 nats).

I think the performance gap is big enough to suggest that something is very clearly not working though.

(Just launched another run that should have minimal differences in training setup -- can see that there's already significant divergence in training despite similar param count).

Screen Shot 2024-03-04 at 5 08 30 PM

Are there any potential sharp bits for installation or setup that you've seen?

If it's useful to other people:

Small repo to test: here.

Code diff for including GLA (with exception of recent updates to nanoGPT): here

what is your triton version? can you pass this test

Ahh that might be it! I'm on Triton 2.1.0

Screen Shot 2024-03-04 at 6 35 39 PM

Will upgrade triton and get back to you.

Just upgraded -- Is this diff more reasonable or still too off?

Screen Shot 2024-03-04 at 6 39 38 PM

Looks normal! I think after the triton version update your gla training will be good

Also for smaller model, I would recommend the parameter allocation that is used in RetNet. i.e., d_key = d_model, d_value = 2*d_model. FFN expansion=2.

@luchris429 Hi, please refer to this fix commit (sustcsonglin/flash-linear-attention@84a3940).
Sorry for wasting your time.

It's doing better than Transformer now! Thanks so much @yzhangcs @sustcsonglin ! Great work.

Screen Shot 2024-03-04 at 8 16 27 PM