RNaD - MLP alternatives
frvls opened this issue · 1 comments
frvls commented
I read the research paper, and it seems for the Stratego game a well-thought-out network structure was used to get great results.
The current network is implemented like this in the RNaD code base:
def network(
env_step: EnvStep
) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]:
mlp_torso = hk.nets.MLP(
self.config.policy_network_layers, activate_final=True
)
torso = mlp_torso(env_step.obs)
mlp_policy_head = hk.nets.MLP([self._game.num_distinct_actions()])
logit = mlp_policy_head(torso)
mlp_policy_value = hk.nets.MLP([1])
v = mlp_policy_value(torso)
pi = _legal_policy(logit, env_step.legal)
log_pi = legal_log_policy(logit, env_step.legal)
return pi, v, log_pi, logit
Since simple MLP networks can be unstable with larger models, are there any easy drop-in alternatives for the MLP network structure for the RNaD algorithm (skip connections, etc.) to get better results?
I was looking into this with haiku, but I haven't been able to make any meaningful progress, so would appreciate any help with this.