pytorch/tensordict

[BUG] torch.stack is not implemented for NonTensorStack and NonTensorData

jkrude opened this issue · 1 comments

Describe the bug

This is a follow-up to #831 going into more detail about torch.stack.
NonTensorData and NonTensorStack can't be stacked, even though they hold the data in compatible shapes. This is especially unpredictable due to the behavior of NonTensorData._stack_non_tensor which might return a NonTensorData if all elements are equal.

To Reproduce

Both data and stack hold strings of batch_size torch.Size([2]).
As torch.stack (implemented in _troch_func) calls NonTensorData._stack_non_tensor the data variable becomes a NonTensorData instead of a NonTensorStack as all elements are equal.

data = torch.stack(
        [
            NonTensorData("a"),
            NonTensorData("a"),
        ]
    )
    stack = NonTensorStack(
        *(NonTensorData("b"), NonTensorData("b")),
    )
    assert torch.stack([data, stack], dim=1).batch_size == (2,2)

The last line will raise a NotImplemented error in torch_function as NonTensorData fails the check issubclass(t, (Tensor, cls, TensorDictBase))

E       TypeError: Multiple dispatch failed for 'torch.stack'; all __torch_function__ handlers returned NotImplemented:
E       
E         - tensor subclass <class 'tensordict.tensorclass.NonTensorData'>
E         - tensor subclass <class 'tensordict.tensorclass.NonTensorStack'>
E       
E       For more information, try re-running with TORCH_LOGS=not_implemented

Expected behavior

torch.cat returns a (2,2) NonTensorStack in which both data and stack are stacked together.

Reasons

The behavior of NonTensorData._stack_non_tensor should be transparent to all other functionality, especially torch.stack should work with NonTensorData and NonTensorStack if both are of compatible batch sizes.

System info

import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)

0.4.0 1.26.4 3.10.14 (main, Mar 21 2024, 11:21:31) [Clang 14.0.6 ] darwin 2.3.0
Installed using pip.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

This code runs fine on nightlies

from tensordict import tensorclass, NonTensorData, NonTensorStack

data = torch.stack(
        [
            NonTensorData("a"),
            NonTensorData("a"),
        ]
    )
stack = NonTensorStack(
    *(NonTensorData("b"), NonTensorData("b")),
)
assert torch.stack([data, stack], dim=1).batch_size == (2,2)

Feel free to reopen if needed