Loading Trained AlphaZero Model
sr5434 opened this issue · 2 comments
sr5434 commented
Hey all! I trained AlphaZero on Kuhn Poker with the provided example. I am now trying to adapt the baseline loader to load my model, but it returns this error:
Traceback (most recent call last):
File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 130, in <module>
print(model(
^^^^^^
File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 19, in apply
(logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/multi_transform.py", line 296, in apply_fn
return f.apply(params, state, None, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/transform.py", line 456, in apply_fn
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 13, in forward_fn
policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/samir/Documents/Apps/alphaZero/test_model.py", line 82, in __call__
logits = hk.Linear(self.num_actions)(logits)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 458, in wrapped
out = f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/module.py", line 299, in run_interceptors
return bound_method(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/basic.py", line 179, in __call__
w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/haiku/_src/base.py", line 685, in get_parameter
raise ValueError(
ValueError: 'az_net/linear/w' with retrieved shape (14, 4) does not match shape=[2, 4] dtype=dtype('float32')
This is my code:
import pickle
from typing import NamedTuple
from pydantic import BaseModel
import jax
import jax.numpy as jnp
import haiku as hk
import pgx
def _make_az_baseline_model(model_args, model_params, model_state):
def forward_fn(x, is_eval=False):
net = _create_az_model_v0(num_actions=4, num_channels=model_args.num_channels, num_layers=model_args.num_layers)
policy_out, value_out = net(x, is_training=not is_eval, test_local_stats=False)
return policy_out, value_out
forward = hk.without_apply_rng(hk.transform_with_state(forward_fn))
def apply(obs):
(logits, value), _ = forward.apply(model_params, model_state, obs, is_eval=True)
return logits, value
return apply
def _create_az_model_v0(
num_actions,
num_channels: int = 128,
num_layers: int = 6,
resnet_v2: bool = True,
):
class BlockV2(hk.Module):
def __init__(self, num_channels, name="BlockV2"):
super(BlockV2, self).__init__(name=name)
self.num_channels = num_channels
def __call__(self, x, is_training, test_local_stats):
i = x
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
x = hk.Conv2D(self.num_channels, kernel_shape=3)(x)
return x + i
class AZNet(hk.Module):
"""AlphaZero NN architecture."""
def __init__(
self,
num_actions,
num_channels: int,
num_layers: int,
resnet_v2=True,
name="az_net",
):
super().__init__(name=name)
self.num_actions = num_actions
self.num_channels = num_channels
self.num_layers = num_layers
self.resnet_v2 = True
self.resnet_cls = BlockV2
def __call__(self, x, is_training=False, test_local_stats=False):
x = x.reshape((x.shape[0], x.shape[1], 1))
x = x.astype(jnp.float32)
x = hk.Conv2D(self.num_channels, kernel_shape=2)(x)
for i in range(self.num_layers):
x = self.resnet_cls(self.num_channels, name=f"block_{i}")(
x, is_training, test_local_stats
)
x = hk.BatchNorm(True, True, 0.9)(x, is_training, test_local_stats)
x = jax.nn.relu(x)
# policy head
logits = hk.Conv2D(output_channels=2, kernel_shape=1)(x)
logits = hk.BatchNorm(True, True, 0.9)(logits, is_training, test_local_stats)
logits = jax.nn.relu(logits)
logits = hk.Flatten()(logits)
logits = hk.Linear(self.num_actions)(logits)
# value head
v = hk.Conv2D(output_channels=1, kernel_shape=1)(x)
v = hk.BatchNorm(True, True, 0.9)(v, is_training, test_local_stats)
v = jax.nn.relu(v)
v = hk.Flatten()(v)
v = hk.Linear(self.num_channels)(v)
v = jax.nn.relu(v)
v = hk.Linear(1)(v)
v = jnp.tanh(v)
v = v.reshape((-1,))
return logits, v
return AZNet(num_actions, num_channels, num_layers, resnet_v2)
class Config(BaseModel):
env_id: pgx.EnvId = "kuhn_poker"
seed: int = 0
max_num_iters: int = 50000
# network params
num_channels: int = 128
num_layers: int = 6
resnet_v2: bool = True
# selfplay params
selfplay_batch_size: int = 1
num_simulations: int = 32
max_num_steps: int = 256
# training params
training_batch_size: int = 4096
learning_rate: float = 0.001
# eval params
eval_interval: int = 5
class Sample(NamedTuple):
obs: jnp.ndarray
policy_tgt: jnp.ndarray
value_tgt: jnp.ndarray
mask: jnp.ndarray
config: Config = Config()
env = pgx.make(config.env_id)
if __name__ == "__main__":
with open("/Users/samir/Documents/Apps/alphaZero/model.ckpt", "rb") as f:
d = pickle.load(f)
model = _make_az_baseline_model(d["config"], d["model"][0], d["model"][1])
print(model(
jnp.array([[0.],
[0.],
[1.],
[1.],
[0.],
[1.],
[0.]])
))
How do I fix this?
sotetsuk commented
Hi, thank you for your comment and very sorry for the late response 🙏
Sorry but we do not provide any support for training issues by users 🙏 😭
It looks just a simple shape problems, you may check the layer output step by step.