What is the purpose of `substeps` hyperparameter?
baofff opened this issue · 1 comments
baofff commented
Thanks for the great work.
I find an interesting usage of jax.lax.scan in your code. Applying it to p_train_step will induce a successive running of p_train_step for substeps times, and it seems that it won't affect the training result. What is the benefit of it compared to the normal training (i.e., without using jax.lax.scan and p_train_step)?
baofff commented
I've found the benefit. This hyperparameter makes jax compile multiple updates together, making the training faster.