Torch nested tensor, support for mean along the ragged dimension
leeb-m opened this issue · 2 comments
Motivation
I have a large tensor batch x embedding dimension, that consists of sentence embeddings, however there are sub batches that came from different documents. I'm excited by the possibility of using torch.nested to process those batches in a single kernel.
Solution
I'd like to split the tensor along the batch dimension, then transpose the splits, so that the batch dimension of the segments becomes the ragged dimension in a nested tensor.
Then I want to run torch.mean to reduce the ragged dimension back to 1. (Ideally this would produce a normal tensor?)
I get the following error:
NotImplementedError: Could not run 'aten::mean.dim' with arguments from the 'NestedTensorCPU' backend.
I'm not on a cuda system at present, but I note that this is not listed as a supported operation in the documentation.
Alternatives
I'm open to suggestions for workarounds. Sum works, I've tried making a tensor with the ragged lengths and then use div. However I hit these issues:
- Can only divide nested by another nested tensor
- Then the div fails with RuntimeError: div does not support broadcasting when given a NestedTensor
Additional context
import torch
print(torch.__version__)
batch1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
batch2 = torch.tensor([[1, 2, 3]])
batch3 = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
nested = torch.nested.nested_tensor([batch1.permute(1, 0),
batch2.permute(1, 0),
batch3.permute(1, 0)])
# boom
print(nested.mean(2, keepdim=True))
sum = nested.sum(2, keepdim=True) # yay! that works
count = torch.nested.nested_tensor([torch.tensor([2]),
torch.tensor([1]),
torch.tensor([3])])
print(sum.div(count))
Checklist
- I have checked that there is no similar issue in the repo (required)
This does not directly involve tensordict but pytorch, so I'd encourage you to resubmit this issue on https://github.com/pytorch/pytorch
. Do not hesitate to tag me there!
Doh I don't know how I missed that!