werner-duvaud/muzero-general

Why is root.visit_count initialized to 0 and root_predicted_value not included in root node value?

dniku opened this issue · 0 comments

dniku commented

The MCTS implementation here works roughly like this (pseudocode):

def mcts(observation):
    root_predicted_value, stuff = model.initial_inference(observation)
    root = Node()
    root.expand(stuff)
    root.add_exploration_noise()

    for _ in range(num_simulations):
        leaf = find_unexpanded_leaf()  # here UCB formula depends on root.visit_count
        leaf_predicted_value, stuff = model.recurrent_inference(leaf.hidden_state)
        leaf.expand(stuff)

        value = leaf_predicted_value
        for node in reversed([root, ..., leaf]):
            node.value_sum += value
            node.visit_count += 1
            value = node.reward + discount * value

# ... later in store_search_statistics() ...
game_history.root_values.append(root.value_sum / root.visit_count)

Note that each call to expand() updates root value and root visit count — except for the very first one on root itself. There are two consequences to this:

  • When searching for an unexpanded leaf for the first time, prior probabilities are discarded because the UCB formula includes root visit counts in the numerator:

    muzero-general/self_play.py

    Lines 391 to 393 in 23a1f69

    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)
    prior_score = pb_c * child.prior
    This makes the first MCTS simulation less effective than it could be, and may harm performance if the budget for simulations is limited. MuZero paper's first author commented here that this is indeed a problem (Ctrl+F "When selecting among actions of the root, the root's visit count should already be 1.").
  • root_predicted_value does not affect root value. The root value that is eventually stored in game_history.root_values in store_search_statistics(), and it could be made more precise by taking root_predicted_value into account.

A potential fix would be to include

self.backpropagate([root], root_predicted_value, min_max_stats)

right after

muzero-general/self_play.py

Lines 303 to 309 in 23a1f69

root.expand(
legal_actions,
to_play,
reward,
policy_logits,
hidden_state,
)