[BUG] tensordict.TensorDict and tensordict.nn.make_tensordict can't handle dictionaries with non-string keys
Bhartendu-Kumar opened this issue · 4 comments
Describe the bug
The functions: TensorDict
and tensordict.nn.make_tensordict
expects a dictionary to be passed.
a dictionary with non-string keys gives an error: IndexError: tuple index out of range
Same is true about tensordict.TensorDict
function.
To Reproduce
from tensordict import TensorDict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = TensorDict(d, batch_size=2)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
return self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
from tensordict.nn import make_tensordict
d = {1: torch.randn(2), 2: torch.randn(2)}
d = make_tensordict(d)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
return self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
>>> from tensordict.nn import make_tensordict
>>> d = make_tensordict(d)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/functional.py", line 379, in make_tensordict
return TensorDict.from_dict(kwargs, batch_size=batch_size, device=device)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1332, in from_dict
out = cls(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 236, in __init__
self.set(key, value, non_blocking=non_blocking)
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/base.py", line 2315, in set
return self._set_tuple(
File "/home/lisa/anaconda3/envs/pytorch/lib/python3.8/site-packages/tensordict/_td.py", line 1615, in _set_tuple
td = self._get_str(key[0], None)
IndexError: tuple index out of range
Expected behavior
when the dictionary has string keys, a python dictionary is converted to TensorDict ,
eg.
d = {"1": torch.randn(2), "2": torch.randn(2)} d = TensorDict(d, batch_size=2)
This is correct code as expected but, when keys are non-string like
d = {1: torch.randn(2), 2: torch.randn(2)} d = TensorDict(d, batch_size=2)
it gives an error.
Screenshots
If applicable, add screenshots to help explain your problem.
System info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, source, ...):
python -m pip install tensordict==0.3.2
- Python version:
Python 3.8.13
- Versions of any other relevant libraries:
pytorch:2.2.2+cu121
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.3.2 1.22.4 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:04:10)
[GCC 10.3.0] linux 2.2.2+cu121
Additional context
Reason and Possible fixes
I think the code at an abstract level works in 2 steps:
- Step 1: Get the length of keys of the given input dictionary
- Step 2: Get the string keys and construct tensordict object from these keys
Thus, the culprit might
tensordict/_td.py:1615), in TensorDict._set_tuple(self, key, value, inplace, validated, non_blocking)
if len(key) == 1:
return self._set_str(
which calls
td = self._get_str(key[0], None)
So whats happening is search for string keys, where keys might not be string
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Hello
Thanks for posting this!
TensorDict required keys to be strings, tuples of strings or tuples of tuples of strings etc. but no other key type is allowed.
The main reason is that tensordicts can also be indexed along the "shape" dimension, and allowing other key-types (e.g. ints) would lead to undefined behaviours.
Example
data = TensorDict({"a": torch.arange(3)}, batch_size=[3])
data[1] # returns 1
data = TensorDict({1: torch.arange(3)}, batch_size=[3])
data[1] # should this take the second element along shape dimension, or the '1' key?
That being said we should probably capture this error to make things clearer for our users!
Hope that helps
Oh!
Makes sense. Thanks for the reply.
But still the error :
IndexError: tuple index out of range
does not seem verbose enough to know that the conflict is with the dictionary key types.
So I think this check should be there and printing the appropriate error message about expected dictionary than index out of range.
Because earlier the values of the keys were anything different than tensordict, dictionary, scalars and tensors, it explicitly gave the error that data type of value is out of this set.
So I think something similar for keys be beneficial.
Should I go ahead and add the type checking for this, if you confirm that the keys would be just string, tuple of string, so on.
Thanks