lucidrains/x-transformers

Simplifying Transformer Blocks (https://arxiv.org/abs/2311.01906)

Closed this issue · 9 comments

Would be nice to have this one here (https://arxiv.org/abs/2311.01906).

hmm, could probably do a separate repository for that, with some makeover (relative positions etc)

have you tried it? does it work?

@Froskekongen so there is research out there that suggests the parallel block architecture leads to instability at scale (paper out of salesforce). however, i'm game for the serial version if you let me know how it fares, share some successful experiments on your end, etc

I think the serial version was the most interesting (Figure 1, top right). And I think you are right - probably easier with a separate repo for this since a lot of the content is about dealing with initialization and whatnot.

Will report if I find some time to experiment with it. Closing the issue for now.

or maybe you can get Bobby to make his repo pip installable?

ah, figure 1 top right is still a parallel block. i don't think they ever did experiments on a serial version. i don't know if i believe in parallel blocks anymore; i can link that salesforce paper later once i find it

Didn't PALM do parallel blocks to great effect? ". This approach is also used by PaLM (Chowdhery et al., 2022), where this technique sped up the largest model’s training by 15% without performance degradation." (ViT 22B paper)

@zaptrem yea.. so first you need to know some behind-the-scenes. the parallel block originated from the open source community. it was devised by a precocious college student, Ben Wang, for the training of GPT-J. It was then adopted by Brain for the training of PaLM, a lot of the code probably taken verbatim from GPT-J (as it is in jax). However, what you need to know is that Ben confided in me that during the tail end of training for GPT-J, he actually faced insurmountable instability. Luckily, it was near the end, the model was good enough, so he stopped training a bit early, and just open sourced it. The rest was history. That bit never made it into the paper afaict.

If you read PaLM paper, they actually documented this instability. In fact, they had a really hacky way of getting around it, by rewinding to just a bit before each divergence and trying different batches. In other words, I have no doubt it has some performance benefits, but I don't think this instability is worth the cost.

@zaptrem who knows, maybe there is a solution if enough researchers work on the problem, but why do so when serial architecture already works so well? (llama)

anyways, just to show i have put some thought into this.

here's the salesforce paper, where the author addresses the instability issue of parallel blocks directly https://blog.salesforceairesearch.com/xgen/