Tips for training from scratch?
luchris429 opened this issue ยท 10 comments
Hello,
I've been playing with this architecture on nanoGPT. While I can get other architectures to play nicely there (e.g. RMT), I'm really struggling to get GLA to perform well.
Do you have any tips or code for training? For example, do you have a repository you recommend or key hyperparameter differences to normal transformers?
Thanks!
~44M parameters. GLA is getting ~4 bpc (~2.7 nats), whereas RMT and Transformer can get ~1.4 bpc (~1.0 nats).
I think the performance gap is big enough to suggest that something is very clearly not working though.
(Just launched another run that should have minimal differences in training setup -- can see that there's already significant divergence in training despite similar param count).
what is your triton version? can you pass this test
Looks normal! I think after the triton version update your gla training will be good
Also for smaller model, I would recommend the parameter allocation that is used in RetNet. i.e., d_key = d_model, d_value = 2*d_model. FFN expansion=2.
@luchris429 Hi, please refer to this fix commit (sustcsonglin/flash-linear-attention@84a3940).
Sorry for wasting your time.
It's doing better than Transformer now! Thanks so much @yzhangcs @sustcsonglin ! Great work.