google-deepmind/mctx

`muzero_policy` search vs `gumbel_muzero_policy` search performance

Closed this issue · 1 comments

I've seen others reporting that the muzero_policy is slow and I've run into this problem myself so I wanted to add a bit more information. I'm not expecting a solution to this problem, but it might be interesting for the authors to better understand this problem.

muzero_policy is drastically slower than gumbel_muzero_policy, and the problem stems from the body_fun in search.py. I'm not able to identify the exact LOC, but the execution halts to almost a complete stop in this loop. When executing on CPU I see very strange behavior where the first N calls to muzero_policy execute lightning fast, only to come to a complete stop at iteration N+1, taking minutes to complete a search with only 100 simulations. The batch size seems to affect when this halt occurs.

I wish my solution to this problem was as easy as switching to gumbel_muzero_policy but I'm working with stochastic nodes, and we've observed significantly worse performance from gumbel in this setting

I've identified the root selection function as the culprit. Switching out the muzero's root selection function to gumbel's root selection policy incurs a ~100x speedup in my case.