[BUG] Bad Index Behavior after indexing an expanded tensordict
realquantumcookie opened this issue · 2 comments
realquantumcookie commented
Describe the bug
Hello tensordict team, I'm unsure where this bug comes from, this happened after I called TensorDict.expand() on the tensordict and indexed it with a 2d tensor.
To Reproduce
Steps to reproduce the behavior.
Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
Please use the markdown code blocks for both code and stack traces.
import torch
import tensordict
a = torch.arange(9).reshape(3,3)
d_1 = tensordict.TensorDict({
"a": a,
}, batch_size=(3,3))
d_2 = d_1.unsqueeze(0).expand(1000, 3, 3)
idx = torch.arange(3).unsqueeze(0).expand(3,3) # construct a 2d array to index into d_2
d_2_1 = d_2[idx] # this works
print(d_2_1)
d_2_2 = d_2_1[0:] # this errors
Traceback (most recent call last):
File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/base.py", line 264, in __getitem__
return self._index_tensordict(index)
File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 773, in _index_tensordict
names = self._get_names_idx(index)
File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 1306, in _get_names_idx
names = [names[i] if i is not None else None for i in idx_to_take]
File "/home/quantumcookie/miniconda3/envs/orangerl/lib/python3.8/site-packages/tensordict/_td.py", line 1306, in <listcomp>
names = [names[i] if i is not None else None for i in idx_to_take]
Expected behavior
This should index completely fine
System info
torch==2.2.0
tensordict==0.3.0
Additional context
Add any other context about the problem here.
Reason and Possible fixes
Unsure
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
realquantumcookie commented
Some additional info:
I installed tensordict==0.2.1 instead and it looks like this code would work fine with that version
vmoens commented
Thanks for reporting ! I'll look into it