thu-ml/tianshou

Buffer: fix discrepancy in slicing order

MischaPanch opened this issue · 7 comments

It doesn't make sense for buffer[:].obs to be so wildly different from buffer.obs[:]. One is retrieving the full buffer, filled up with zeros, the other just the filled entries.

We should probably never retrieve the full buffer through slicing. There could be a new method like get_full_entry_sequence that is documented to retrieve the non-filled entries as well if it's actually needed.

While addressing this issue, buffer and collector related tests should be adjusted and improved.

I can look into this.

Currently this is the signature of ReplayBuffer's __getitem__(index: slice | int | list[int] | np.ndarray).

So the idea here is to not support full-buffer (filled and non-filled entries) retrieval, when slice is provided, for buffer[:].obs because it can be computationally expensive?

What about the cases when user gives a list[int] or np.ndarray with all the indices? Practically it is the same as slice. Should we allow retrieving the full buffer in this case?

I'm not sure when we should ever be retrieving the full buffer. Tbh I haven't given too much thought about the best way of resolving this, it just seems very confusing and arbitrary that buffer.obs[:] and buffer[:].obs would be semantically wildly different entities, right?

I have reviewed this issue again. The different way of retrieving with slicing seems arbitrary to me. The user can check the maximum size via the buffer.max_size attribute, no need to return the empty values as well. A special method to retrieve the full buffer (if anyone ever needs it!?) as you mentioned is more appropriate.

We should probably update buffer.obs[:] (and other slicing methods like start:stop) to retrieve only non-empty values like buffer[:].obs does.

Here are some examples of how indexing is currently different between the two:

In [52]: dummy_buf = ReplayBuffer(size=10)
    ...: for i in range(6):
    ...:     dummy_buf.add(
    ...:         Batch(obs=i, act=i, rew=i, terminated=0, truncated=0, done=0, obs_next=i + 1, info={}),
    ...:     )
    ...: 

In [53]: dummy_buf.obs[2:8]
Out[53]: array([2, 3, 4, 5, 0, 0])

In [54]: dummy_buf[2:8].obs
Out[54]: array([2, 3, 4, 5])

In [55]: dummy_buf.obs[:]
Out[55]: array([0, 1, 2, 3, 4, 5, 0, 0, 0, 0])

In [56]: dummy_buf[:].obs
Out[56]: array([0, 1, 2, 3, 4, 5])

In [57]: dummy_buf.obs[np.arange(10)]
Out[57]: array([0, 1, 2, 3, 4, 5, 0, 0, 0, 0])

In [58]: dummy_buf[np.arange(10)].obs
Out[58]: array([0, 1, 2, 3, 4, 5, 0, 0, 0, 0])

Glad you agree with me on this ^^. I'm not sure whether anywhere in the code the retrieval of the slice with empty values is used. For me it's fine to completely remove it, however, many tests will need to be adjusted, as now many of them rely on this somehow weird retrieval mechanism.

We could live without the full retrieval method until someone actually needs it. It's a good practice to keep the public interface as small as possible

Created the issue accidentally, sry for that

image

Something is off here. As seen above, index is of type str so Batch.__getitem__ immediately returns. How can we get access on the last .[] ?

EDIT: dummy_buffer.obs is of type np.ndarray so we can't have direct access on that slicing.

Sorry, I'm afraid I don't understand what you are asking. We can have a call on Friday if you want to :)