kingoflolz/mesh-transformer-jax

Can "slim_model.py" work with "d_model" as 768?

leejason opened this issue · 0 comments

I updated "6B_roto_256.json" with the following for trying a smaller model.

"d_model": 768

The pretraining works on one TPU v3-8, but the slimmed model after using "slim_model.py" produces gibberish results.

Why? Does "slim_model.py" work with "d_model: 4096" only? I don't think so but I find no clue after tracing source code for hours.

Thank you for some light.