Better search_value
fidlej opened this issue · 3 comments
Thanks for preparing the whole training script and the JAX environment.
I have minor suggestion to improve the target for the value network.
When using sequential halving, the search does exploration to find the best action.
The root value is then possibly underestimated, because the average value includes the values of the explored actions.
mctx_learning_demo/basic_tree_search.py
Line 189 in e79a655
A better target for the value network would be the Q-value of the selected action:
search_value = policy_output.search_tree.qvalues(policy_output.action)
Hi Ivo,
Thanks for the feedback, and for releasing MCTX!
I tried your suggestion and oddly it significantly hurt performance in the specific example I tested on (plot included for reference). Note that learning does still occur with max q_value to reach a final performance of around 60 compared to 0.5 for random action selection, but it is eclipsed by using the node value. I suspect the issue is that the environment is stochastic, which I don't believe the code was really intended for, probably averaging over multiple actions is providing some heuristically useful smoothing in this case.
Perhaps I will add the option to use either as a configuration choice.
Thanks for the plot.
You are right that MCTS without chance nodes is not suitable for stochastic environments.
If you want to verify that the search works, you can use the same rng_key for the environment inside the search and outside the search.
Another problem can be bootstrapping from the value network, together with the large discount=1.0. For example, the output bias can then keep increasing itself to unrealistic values. An alternative would bootstrap from an older "target" network.
If you want to verify that the search works, you can use the same rng_key for the environment inside the search and outside the search.
It turns out I was doing this, somewhat unintentionally. I meant to change the environment step function to take a passed random key instead of it being part of the env state, but I neglected to remove the env_state key which overwrote it:
mctx_learning_demo/jax_environments.py
Lines 26 to 28 in ec60f9b
If I change it to take a passed random key instead the performance is much worse:
In anycase, probably the right thing to do to make this a more useful demo is to apply it to a deterministic environment instead. Thanks again for your help.