[Bug] solve_triangular(Tensor, LinearOperator) not supported
Turakar opened this issue ยท 2 comments
๐ Bug
Passing a linear operator as the right-hand side argument to solve()
does not work if it decides to use triangular_solve()
.
To reproduce
Code snippet to reproduce
import torch
from linear_operator import to_linear_operator
a_root = torch.tensor([[1.0, 0.5, 2.0], [0.0, 0.1, 0.2], [0.0, 0.0, 2.5]])
a = a_root @ a_root.T
b = torch.tensor([[3.0, 1.0], [4.0, 4.0], [1.0, 3.0]])
torch.linalg.solve(a, b)
torch.linalg.solve(to_linear_operator(a), b)
torch.linalg.solve(to_linear_operator(a), to_linear_operator(b))
Stack trace/error message
Traceback (most recent call last):
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 68, in _cholesky_solve
res = self._tensor._cholesky_solve(rhs=rhs, upper=upper)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/dense_linear_operator.py", line 31, in _cholesky_solve
return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2785, in __torch_function__
raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.cholesky_solve(DenseLinearOperator, Tensor, upper=bool) is not implemented.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/path/to/virtualenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-29-e6878d4d649f>", line 1, in <module>
torch.linalg.solve(to_linear_operator(a), to_linear_operator(b))
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2789, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2205, in solve
return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/functions/_solve.py", line 53, in forward
solves = _solve(linear_op, right_tensor)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/functions/_solve.py", line 17, in _solve
return linear_op.cholesky()._cholesky_solve(rhs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 76, in _cholesky_solve
w = self.solve(rhs)
^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 181, in solve
res = torch.linalg.solve_triangular(self.to_dense(), right_tensor, upper=self.upper)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2775, in __torch_function__
raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.linalg.solve_triangular(Tensor, DenseLinearOperator, upper=bool) is not implemented.
Expected Behavior
A solution is found.
System information
Please complete the following information:
- linear_operator 0.3.0
- pytorch 1.13.1
- Fedora Linux 37
Additional context
NNVariationalStrategy in pytorch can happen to make such a call (here).
Or am I misunderstanding the library? It seems to me like this would be an easy fix by just calling to_dense() on the right hand side.
It seems to me like this would be an easy fix by just calling to_dense() on the right hand side.
I think this is the correct fix. You shouldn't be able to call solve
with a LinearOperator as the right hand side (at least at the moment). This is a bug in NNVariationalStrategy.
I'm closing the issue here; @Turakar could you open the issue in GPyTorch? (Or better yet, would you be able to post a PR doing exactly what you said; calling .to_dense()
on the right hand side?)