[Bug] SumBatchLinearOperator fails for high-order tensor
lmao14 opened this issue ยท 0 comments
lmao14 commented
๐ Bug
To reproduce
** Code snippet to reproduce **
import torch
import gpytorch
import linear_operator
kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),),
batch_shape=torch.Size([4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.shape)
print(kxx.to_dense().sum(0).shape)
print(kxx.sum(0).to_dense().shape)
torch.Size([4, 3, 2, 2])
torch.Size([3, 2, 2])
torch.Size([4, 5, 5])
kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),),
batch_shape=torch.Size([5, 4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)
print(kxx.sum(0).to_dense().shape)
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[65], line 5
3 X = torch.randn([2, 5])
4 kxx = kern(X)
----> 5 print(kxx.sum(0).to_dense().shape)
File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:2517, in LinearOperator.sum(self, dim)
2515 # Otherwise: it's a batch dimension
2516 elif dim < self.dim():
-> 2517 return self._sum_batch(dim)
2518 else:
2519 raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape))
File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:861, in LinearOperator._sum_batch(self, dim)
850 """
851 Sum the LinearOperator across a batch dimension (supplied as a positive number).
852
(...)
857 :param dim: The (positive valued) dimension to sum
858 """
859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
--> 861 return SumBatchLinearOperator(self, block_dim=dim)
File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
43 else:
44 new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)
File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
43 else:
44 new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)
File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/block_linear_operator.py:50, in BlockLinearOperator.__init__(self, base_linear_op, block_dim)
48 if block_dim != -3:
49 positive_block_dim = base_linear_op.dim() + block_dim
---> 50 base_linear_op = base_linear_op._permute_batch(
51 *range(positive_block_dim),
52 *range(positive_block_dim + 1, base_linear_op.dim() - 2),
53 positive_block_dim,
54 )
55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
56 self.base_linear_op = base_linear_op
File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:248, in LinearOperator._permute_batch(self, *dims)
246 if torch.is_tensor(component):
247 extra_dims = range(len(dims), component.dim())
--> 248 components.append(component.permute(*dims, *extra_dims))
249 elif isinstance(component, LinearOperator):
250 components.append(component._permute_batch(*dims))
RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3
System information
Please complete the following information:
- LinearOperator Version 0.5.3
- PyTorch Version 2.0.1