[BUG] torch.stack is not implemented for NonTensorStack and NonTensorData
jkrude opened this issue · 1 comments
Describe the bug
This is a follow-up to #831 going into more detail about torch.stack
.
NonTensorData
and NonTensorStack
can't be stacked, even though they hold the data in compatible shapes. This is especially unpredictable due to the behavior of NonTensorData._stack_non_tensor
which might return a NonTensorData
if all elements are equal.
To Reproduce
Both data
and stack
hold strings of batch_size torch.Size([2])
.
As torch.stack
(implemented in _troch_func) calls NonTensorData._stack_non_tensor
the data
variable becomes a NonTensorData
instead of a NonTensorStack
as all elements are equal.
data = torch.stack(
[
NonTensorData("a"),
NonTensorData("a"),
]
)
stack = NonTensorStack(
*(NonTensorData("b"), NonTensorData("b")),
)
assert torch.stack([data, stack], dim=1).batch_size == (2,2)
The last line will raise a NotImplemented
error in torch_function as NonTensorData
fails the check issubclass(t, (Tensor, cls, TensorDictBase))
E TypeError: Multiple dispatch failed for 'torch.stack'; all __torch_function__ handlers returned NotImplemented:
E
E - tensor subclass <class 'tensordict.tensorclass.NonTensorData'>
E - tensor subclass <class 'tensordict.tensorclass.NonTensorStack'>
E
E For more information, try re-running with TORCH_LOGS=not_implemented
Expected behavior
torch.cat returns a (2,2) NonTensorStack
in which both data and stack are stacked together.
Reasons
The behavior of NonTensorData._stack_non_tensor
should be transparent to all other functionality, especially torch.stack
should work with NonTensorData
and NonTensorStack
if both are of compatible batch sizes.
System info
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Installed using pip.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
This code runs fine on nightlies
from tensordict import tensorclass, NonTensorData, NonTensorStack
data = torch.stack(
[
NonTensorData("a"),
NonTensorData("a"),
]
)
stack = NonTensorStack(
*(NonTensorData("b"), NonTensorData("b")),
)
assert torch.stack([data, stack], dim=1).batch_size == (2,2)
Feel free to reopen if needed