Why is root.visit_count initialized to 0 and root_predicted_value not included in root node value?
dniku opened this issue · 0 comments
dniku commented
The MCTS implementation here works roughly like this (pseudocode):
def mcts(observation):
root_predicted_value, stuff = model.initial_inference(observation)
root = Node()
root.expand(stuff)
root.add_exploration_noise()
for _ in range(num_simulations):
leaf = find_unexpanded_leaf() # here UCB formula depends on root.visit_count
leaf_predicted_value, stuff = model.recurrent_inference(leaf.hidden_state)
leaf.expand(stuff)
value = leaf_predicted_value
for node in reversed([root, ..., leaf]):
node.value_sum += value
node.visit_count += 1
value = node.reward + discount * value
# ... later in store_search_statistics() ...
game_history.root_values.append(root.value_sum / root.visit_count)
Note that each call to expand()
updates root value and root visit count — except for the very first one on root
itself. There are two consequences to this:
- When searching for an unexpanded leaf for the first time, prior probabilities are discarded because the UCB formula includes root visit counts in the numerator:
Lines 391 to 393 in 23a1f69
root_predicted_value
does not affect root value. The root value that is eventually stored ingame_history.root_values
instore_search_statistics()
, and it could be made more precise by takingroot_predicted_value
into account.
A potential fix would be to include
self.backpropagate([root], root_predicted_value, min_max_stats)
right after
Lines 303 to 309 in 23a1f69