pytorch/tensordict

Torch nested tensor, support for mean along the ragged dimension

leeb-m opened this issue · 2 comments

Motivation

I have a large tensor batch x embedding dimension, that consists of sentence embeddings, however there are sub batches that came from different documents. I'm excited by the possibility of using torch.nested to process those batches in a single kernel.

Solution

I'd like to split the tensor along the batch dimension, then transpose the splits, so that the batch dimension of the segments becomes the ragged dimension in a nested tensor.

Then I want to run torch.mean to reduce the ragged dimension back to 1. (Ideally this would produce a normal tensor?)

I get the following error:
NotImplementedError: Could not run 'aten::mean.dim' with arguments from the 'NestedTensorCPU' backend.

I'm not on a cuda system at present, but I note that this is not listed as a supported operation in the documentation.

Alternatives

I'm open to suggestions for workarounds. Sum works, I've tried making a tensor with the ragged lengths and then use div. However I hit these issues:

  1. Can only divide nested by another nested tensor
  2. Then the div fails with RuntimeError: div does not support broadcasting when given a NestedTensor

Additional context

import torch

print(torch.__version__)

batch1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
batch2 = torch.tensor([[1, 2, 3]])
batch3 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

nested = torch.nested.nested_tensor([batch1.permute(1, 0),
                                     batch2.permute(1, 0),
                                     batch3.permute(1, 0)])

# boom
print(nested.mean(2, keepdim=True))

sum = nested.sum(2, keepdim=True) # yay! that works

count = torch.nested.nested_tensor([torch.tensor([2]),
                                    torch.tensor([1]),
                                    torch.tensor([3])])

print(sum.div(count))

Checklist

  • I have checked that there is no similar issue in the repo (required)

This does not directly involve tensordict but pytorch, so I'd encourage you to resubmit this issue on https://github.com/pytorch/pytorch. Do not hesitate to tag me there!

Doh I don't know how I missed that!