[Bug Report] Error when attempting to create a dataset for MuJoco Hopper-v4 and v5
Opened this issue · 8 comments
Describe the bug
Error when attempting to create a dataset for MuJoco Hopper-v4 and v5
If I do not use the minari.DataCollector wrapper the code works all ok and trains successfully
Code example
env = gym.make("Hopper-v5", disable_env_checker=True )
env = gym.wrappers.RecordEpisodeStatistics(env)
env = minari.DataCollector(env, record_infos=True, observation_space=env.observation_space, action_space=env.action_space)
obs, _ = env.reset()
action = env.action_space.sample()
next_obs, rewards, terminations, truncations, infos = env.step(action)
File /data1/deploy/Minari/minari/data_collector/data_collector.py:155, in DataCollector.step(self, action)
153 if not self._record_infos:
154 step_data["info"] = {}
--> 155 self._buffer = self._buffer.add_step_data(step_data)
157 if step_data["termination"] or step_data["truncation"]:
158 self._storage.update_episodes([self._buffer])
File /data1/deploy/Minari/minari/data_collector/episode_buffer.py:60, in EpisodeBuffer.add_step_data(self, step_data)
58 infos = jtu.tree_map(lambda x: [x], step_data["info"])
59 else:
---> 60 infos = jtu.tree_map(_append, step_data["info"], self.infos)
62 self.rewards.append(step_data["reward"])
63 self.terminations.append(step_data["termination"])
File /data1/conda/envs/python3.12/lib/python3.12/site-packages/jax/_src/tree_util.py:342, in tree_map(f, tree, is_leaf, *rest)
340 """Alias of :func:jax.tree.map
."""
341 leaves, treedef = tree_flatten(tree, is_leaf)
--> 342 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
343 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
ValueError: Dict key mismatch; expected keys: ['reward_ctrl', 'reward_forward', 'reward_survive', 'x_position', 'x_velocity', 'z_distance_from_origin']; dict: {'x_position': np.float64(0.0012007819486893048), 'z_distance_from_origin': np.float64(-0.00042011979112821507)}.
System Info
Describe the characteristic of your environment:
Gymnasium installed from latest main branch of repo "'version 1.0.0'"
-
Describe how Minari was installed (pip, docker, source, ...)
lastest github commit and pip install -e .
version = '0.5.0' -
Operating system:
RHEL8 -
Python version:
3.12
Additional context
JypyterLab
Checklist
- [X ] I have checked that there is no similar issue in the repo (required)
it seems that the data_collector is not prepared for "infos" with different keys for env.reset() and env.step()
it seems that the data_collector is not prepared for "infos" with different keys for env.reset() and env.step()
Yes, this is the case, we don't support it atm (see #191 (comment)).
Either you disable info recording, or you define a StepData class that always return the same keys.
Ok, you are right, after verifying again the env.reset and env.step return different keys
I vote for taking an action on this, since there could be arbitrary wrappers adding infos arbitrarily.
For instance, the RecordEpisodeStatistics that adds info entries at the final step , and the "final_observation" entry in info.
Perhaps just saving it as an arbitrary pickled binary string into the storage and then minari unplickles infos ?
At the end is a list of dicts , so it seems as the typical schemaless unstructured data of Documents (e.g. json) databases
Perhaps just saving it as an arbitrary pickled binary string into the storage and then minari unplickles infos ?
At the end is a list of dicts , so it seems as the typical schemaless unstructured data of Documents (e.g. json) databases
We are using tabular structure as well with PyArrow.
The way to add this feature, as I see it, is to pad the data. We already do something like that inside PyArrow storage.
However, I am working on higher priority stuffs atm, so I don't have the bandwidth to work on this in the near future. A PR is appreciated, but it is not a straightforward one. Otherwise I suggest to use StepData callback.
Ok, I will think deep a propose a solution and possibly a PR. I consider usability is a priority as well
I will need help in identifying the pieces of code and docs that needs ti be updated.
The new definition is that infos is a list of heterogeneous dicts, one list per episode.
Now it is currently working for both arrow and hdf5.
I have now to made it backward compatible with the current version accepting an homogeneous "infos" dict instead of list of dict.
Infos can be saved as a dictionary of np.arrays or as a list of arbitrary dictionaries by
setting the optional infos_format
parameter. Default is dictionary format infos_format = None or "dict"
:
from minari import DataCollector
import gymnasium as gym
env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, infos_format="list")