google/paxml

What does USE_REPEATED_LAYER do?

abhinavgoel95 opened this issue · 1 comments

I wondering if anyone knew the purpose of the USE_REPEATED_LAYER flag in c4.py. Thanks. :)

It basically makes use of https://github.com/google/praxis/blob/2e46886e5582e39a65a871439ccab29b40dffe93/praxis/layers/repeats.py#L64 Repeat layers, which have very nice features such as nn.scan, which reduces your overall XLA graph size and thus compilation time, and nn.remat, improving performance by trading device memory for compute time.