batch sampling: only last tokens?
Howuhh opened this issue · 1 comments
Howuhh commented
Hi! Thank you for the good paper.
According to paper:
We feed the last K timesteps into Decision Transformer ...
Does this mean only for the inference, or also during the sampling of batches only the last tokens K of episode are sampled (this seems strange to me, but the sampling code is not completely clear)?
kzl commented
During inference, DT operates only on the last K timesteps of the episode, see L112-115 here: decision_transformer.py.
During training, DT operates on randomly sampled sequences of up to length K from anywhere in the episode, you can see here in experiment.py that the start index si
is sampled randomly.