google/brax

Trained model inference

Closed this issue · 3 comments

I need to use the trained network in a separate test code without using the training-side code. I have a problem with the shape of inference_fn as below:

# training code
inference_fn
Out[10]: <function brax.training.agents.ppo.networks.make_inference_fn.<locals>.make_policy.<locals>.policy(observations: jax.Array, key_sample: jax.Array) -> Tuple[jax.Array, Mapping[str, Any]]>
# test code
inference_fn
Out[181]: <function brax.training.agents.ppo.networks.make_inference_fn.<locals>.make_policy(params: Tuple[Any, Any], deterministic: bool = False) -> brax.training.types.Policy>

Test code:

# Load the saved parameters
model_path = '/tmp/mjx_brax_quadruped_policy'
params = model.load_params(model_path)

# Create the environment (ensure this matches the original environment from Code1)
observation_size = env.observation_size
action_size = env.action_size

# Re-create the policy and value networks with the exact same architecture
make_networks_factory = functools.partial(
   ppo_networks.make_ppo_networks,
   observation_size=observation_size,
   action_size=action_size,
   policy_hidden_layer_sizes=(128, 128, 128, 128)  # Match the architecture from Code1
)

# Initialize the PPO networks (policy and value networks)
ppo_modified = make_networks_factory()

# Now use `make_inference_fn` to create the inference function
inference_fn = ppo_networks.make_inference_fn(ppo_modified)

# JIT-compile the inference function
jit_inference_fn = jax.jit(inference_fn(params)) 

Hi @OmurAydogmus, what's the issue exactly? Does the "test" code not work?

Thank you, btaba. I think I found the mistake. How can we normalize it. Because normalize_observations=True was selected during traning

# Load Policies and Test
paramsTEST = model.load_params('/tmp/params')
 

ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size)
make_inference = ppo.ppo_networks.make_inference_fn(ppoTEST)
inference_fnTEST = make_inference(paramsTEST)
 
env = envs.create(env_name=env_name, backend=backend)

jit_env_reset = jax.jit(env.reset)
jit_env_step = jax.jit(env.step)
jit_inference_fn = jax.jit(inference_fnTEST)

rollout = []
rng = jax.random.PRNGKey(seed=1)
state = jit_env_reset(rng=rng)
for _ in range(100):
  rollout.append(state.pipeline_state)
  act_rng, rng = jax.random.split(rng)
  act, _ = jit_inference_fn(state.obs, act_rng)
  state = jit_env_step(state, act)

media.show_video(env.render(rollout, camera='track'), fps=1.0 / env.dt)

I think it is okay. We nee to define preprocess_observations_fn using normalization.

ppoTEST = ppo.ppo_networks.make_ppo_networks(action_size=env.action_size, observation_size=env.observation_size, preprocess_observations_fn=running_statistics.normalize)