Farama-Foundation/D4RL

[Bug Report] Discrepancies between `get_dataset` and `qlearning_dataset`

vmoens opened this issue · 1 comments

I'm having a hard time figuring out how qlearning dataset is being built.
As mentioned by @odelalleau in #182, the "terminals" key in some env is never True.

Moreover,

>>> dataset1 = env.get_dataset()
>>> dataset2 = d4rl.qlearning_dataset(env)
>>> dataset3 = d4rl.qlearning_dataset(env, terminate_on_end=True)
>>> # obs 91 matches
>>> dataset1["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset3["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> # obs 92 does not
>>> dataset1["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][92]
array([ 2.9757538 ,  2.9996927 ,  2.674898  , -0.15739861], dtype=float32)
>>> dataset3["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)

I'm a bit puzzled by what this means. It's like qlearning_dataset gives me a dataset where each consecutive step is different, but qlearning_dataset(..., terminate_on_end=True) gives me smth similar to get_dataset where some consecutive steps are identical. What should we do with this?

When does one trajectory stops, when does the other starts?

For what it's worth, here's my takeaway of how this works:

  • qlearning_dataset(), by default, gets rid of the timeouts by ignoring the corresponding transitions. The reason for this is that typically you would have a transition of the form (s, a, r, s', timeout=True) where s' is actually the first step of the next episode. Such a transition is thus "invalid" and it is thrown away. This is fine when you're doing 1-step Q-learning based on the resulting dataset, but be careful if you intend to do multi-step Q-learning or anything else looking further down the trajectory, because it means you will be switching between episodes with no way to know (the done flag won't be set here).
  • qlearning_dataset(terminate_on_end=True) will keep this invalid transition, but be aware that it will not set the done flag (contrary to what the docstring claims). So in general it's a pretty bad idea, except for datasets with fake timeouts like maze2d where the episode doesn't actually end on timeout.
  • One consequence of the above is that the done flag is set only when terminal == True. As you noticed, some datasets don't have any terminals (ex: maze2d, which is actually a single trajectory). One thing to be aware of is that next observation s' will be invalid when done == True since it will be the first state of the next episode (which in general does not matter in Q-Learning since we don't bootstrap when done == True, but if you're doing something else this may matter). Some datasets provide a next_observations field that can be used to access the last observation (both on timeout and terminal) but the qlearning_dataset() function doesn't use it.
  • It is important to realize that neither timeout == True nor terminal == True indicates with certainty that an episode has ended! I already gave the example of maze2d for the former, and a typical example of the latter is antmaze, where all states with a reward are marked with terminal == True even though the episode continues until a timeout is reached! The interpretation of terminal is thus "if I were to reach this state during evaluation, the episode would end" (antmaze finishes as soon as you reach a reward at test time), rather than "the episode has ended in the offline dataset" (otherwise you will end up with tons of 1-step episodes).

My recommendation is that unless you're doing standard 1-step Q-Learning, you should write your own function to build the dataset you need instead of relying on qlearning_dataset(), so you can decide exactly how to handle timeouts, terminals and the last observation that may be missing, and all of this in a dataset-dependent manner.