waymo-research/waymax

Help required regarding Brax wrapper

vcharraut opened this issue · 0 comments

Hello, I'm curious if there any Brax examples available. I'm trying to run a pipeline using a code that usually runs on Brax environment with SAC, but porting it to Waymax is a bit complicated

I tried to get out the maximum I could from the documentation, but there is one point I don't understand specifically. In Brax you are supposed to jit the environment's function, such as reset and step, but it seems not possible with Waymax because these functions contains custom objects like State

I have this kind of error:

Traceback (most recent call last):
  File "/home/o-vcharrau/Workspace/Frankenstein/frankenstein/waymax/main.py", line 581, in <module>
    train(dynamics_model=dynamics_model, env_config=env_config, scenarios=scenarios, args=args_, progress_fn=progress)
  File "/home/o-vcharrau/Workspace/Frankenstein/frankenstein/waymax/main.py", line 467, in train
    metrics = evaluator.run_evaluation(
              ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/o-vcharrau/Workspace/Frankenstein/frankenstein/brax/evaluate.py", line 64, in run_evaluation
    eval_state = self._generate_eval_unroll(policy_params, unroll_key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/o-vcharrau/Workspace/Frankenstein/frankenstein/brax/evaluate.py", line 42, in generate_eval_unroll
    eval_first_state = eval_env.reset(reset_keys)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/o-vcharrau/Workspace/Frankenstein/venv/lib/python3.11/site-packages/brax/envs/wrappers/training.py", line 157, in reset
    reset_state = self.env.reset(rng)
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/o-vcharrau/Workspace/Frankenstein/venv/lib/python3.11/site-packages/waymax/env/wrappers/brax_wrapper.py", line 101, in reset
    initial_state = self._wrapped_env.reset(state)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/o-vcharrau/Workspace/Frankenstein/venv/lib/python3.11/site-packages/waymax/env/base_environment.py", line 77, in reset
    self.config.max_num_objects, state.log_trajectory.num_objects
                                 ^^^^^^^^^^^^^^^^^^^^
AttributeError: DynamicJaxprTracer has no attribute log_trajectory
--------------------

Therefore, I'm curious if I'm doing something wrong from the start, and it would be really helpful to have an example/tutorial for this, thank you