[Bug Report] Discrepancies between `get_dataset` and `qlearning_dataset`
vmoens opened this issue · 1 comments
I'm having a hard time figuring out how qlearning dataset is being built.
As mentioned by @odelalleau in #182, the "terminals"
key in some env is never True
.
Moreover,
>>> dataset1 = env.get_dataset()
>>> dataset2 = d4rl.qlearning_dataset(env)
>>> dataset3 = d4rl.qlearning_dataset(env, terminate_on_end=True)
>>> # obs 91 matches
>>> dataset1["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset3["observations"][91]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> # obs 92 does not
>>> dataset1["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
>>> dataset2["observations"][92]
array([ 2.9757538 , 2.9996927 , 2.674898 , -0.15739861], dtype=float32)
>>> dataset3["observations"][92]
array([2.9198046 , 3.0007975 , 3.1118677 , 0.05080479], dtype=float32)
I'm a bit puzzled by what this means. It's like qlearning_dataset
gives me a dataset where each consecutive step is different, but qlearning_dataset(..., terminate_on_end=True)
gives me smth similar to get_dataset
where some consecutive steps are identical. What should we do with this?
When does one trajectory stops, when does the other starts?
For what it's worth, here's my takeaway of how this works:
qlearning_dataset()
, by default, gets rid of the timeouts by ignoring the corresponding transitions. The reason for this is that typically you would have a transition of the form(s, a, r, s', timeout=True)
wheres'
is actually the first step of the next episode. Such a transition is thus "invalid" and it is thrown away. This is fine when you're doing 1-step Q-learning based on the resulting dataset, but be careful if you intend to do multi-step Q-learning or anything else looking further down the trajectory, because it means you will be switching between episodes with no way to know (thedone
flag won't be set here).qlearning_dataset(terminate_on_end=True)
will keep this invalid transition, but be aware that it will not set thedone
flag (contrary to what the docstring claims). So in general it's a pretty bad idea, except for datasets with fake timeouts likemaze2d
where the episode doesn't actually end on timeout.- One consequence of the above is that the
done
flag is set only whenterminal == True
. As you noticed, some datasets don't have any terminals (ex: maze2d, which is actually a single trajectory). One thing to be aware of is that next observations'
will be invalid whendone == True
since it will be the first state of the next episode (which in general does not matter in Q-Learning since we don't bootstrap whendone == True
, but if you're doing something else this may matter). Some datasets provide anext_observations
field that can be used to access the last observation (both on timeout and terminal) but theqlearning_dataset()
function doesn't use it. - It is important to realize that neither
timeout == True
norterminal == True
indicates with certainty that an episode has ended! I already gave the example of maze2d for the former, and a typical example of the latter is antmaze, where all states with a reward are marked withterminal == True
even though the episode continues until a timeout is reached! The interpretation ofterminal
is thus "if I were to reach this state during evaluation, the episode would end" (antmaze finishes as soon as you reach a reward at test time), rather than "the episode has ended in the offline dataset" (otherwise you will end up with tons of 1-step episodes).
My recommendation is that unless you're doing standard 1-step Q-Learning, you should write your own function to build the dataset you need instead of relying on qlearning_dataset()
, so you can decide exactly how to handle timeouts, terminals and the last observation that may be missing, and all of this in a dataset-dependent manner.