pytorch/tensordict

[BUG] Bad Index Behavior after indexing an expanded tensordict

realquantumcookie opened this issue · 2 comments

Describe the bug

Hello tensordict team, I'm unsure where this bug comes from, this happened after I called TensorDict.expand() on the tensordict and indexed it with a 2d tensor.

To Reproduce

Steps to reproduce the behavior.

Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.

Please use the markdown code blocks for both code and stack traces.

import torch
import tensordict

a = torch.arange(9).reshape(3,3)
d_1 = tensordict.TensorDict({
    "a": a,
}, batch_size=(3,3))
d_2 = d_1.unsqueeze(0).expand(1000, 3, 3)
idx = torch.arange(3).unsqueeze(0).expand(3,3) # construct a 2d array to index into d_2
d_2_1 = d_2[idx] # this works
print(d_2_1)
d_2_2 = d_2_1[0:] # this errors
Traceback (most recent call last):
    File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/base.py", line 264, in __getitem__
    return self._index_tensordict(index)
  File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 773, in _index_tensordict
    names = self._get_names_idx(index)
  File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 1306, in _get_names_idx
    names = [names[i] if i is not None else None for i in idx_to_take]
  File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 1306, in <listcomp>
    names = [names[i] if i is not None else None for i in idx_to_take]

Expected behavior

This should index completely fine

System info

torch==2.2.0
tensordict==0.3.0

Additional context

Add any other context about the problem here.

Reason and Possible fixes

Unsure

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Some additional info:
I installed tensordict==0.2.1 instead and it looks like this code would work fine with that version

Thanks for reporting ! I'll look into it