google-research/vdm

What is the purpose of `substeps` hyperparameter?

baofff opened this issue · 1 comments

Thanks for the great work.
I find an interesting usage of jax.lax.scan in your code. Applying it to p_train_step will induce a successive running of p_train_step for substeps times, and it seems that it won't affect the training result. What is the benefit of it compared to the normal training (i.e., without using jax.lax.scan and p_train_step)?

I've found the benefit. This hyperparameter makes jax compile multiple updates together, making the training faster.