cornellius-gp/linear_operator

[Bug] solve_triangular(Tensor, LinearOperator) not supported

Turakar opened this issue ยท 2 comments

๐Ÿ› Bug

Passing a linear operator as the right-hand side argument to solve() does not work if it decides to use triangular_solve().

To reproduce

Code snippet to reproduce

import torch
from linear_operator import to_linear_operator

a_root = torch.tensor([[1.0, 0.5, 2.0], [0.0, 0.1, 0.2], [0.0, 0.0, 2.5]])
a = a_root @ a_root.T
b = torch.tensor([[3.0, 1.0], [4.0, 4.0], [1.0, 3.0]])
torch.linalg.solve(a, b)
torch.linalg.solve(to_linear_operator(a), b)
torch.linalg.solve(to_linear_operator(a), to_linear_operator(b))

Stack trace/error message

Traceback (most recent call last):
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 68, in _cholesky_solve
    res = self._tensor._cholesky_solve(rhs=rhs, upper=upper)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/dense_linear_operator.py", line 31, in _cholesky_solve
    return torch.cholesky_solve(rhs, self.to_dense(), upper=upper)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2785, in __torch_function__
    raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.cholesky_solve(DenseLinearOperator, Tensor, upper=bool) is not implemented.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/path/to/virtualenv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3460, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-29-e6878d4d649f>", line 1, in <module>
    torch.linalg.solve(to_linear_operator(a), to_linear_operator(b))
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2789, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2205, in solve
    return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/functions/_solve.py", line 53, in forward
    solves = _solve(linear_op, right_tensor)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/functions/_solve.py", line 17, in _solve
    return linear_op.cholesky()._cholesky_solve(rhs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 76, in _cholesky_solve
    w = self.solve(rhs)
        ^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/triangular_linear_operator.py", line 181, in solve
    res = torch.linalg.solve_triangular(self.to_dense(), right_tensor, upper=self.upper)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/path/to/virtualenv/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2775, in __torch_function__
    raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.linalg.solve_triangular(Tensor, DenseLinearOperator, upper=bool) is not implemented.

Expected Behavior

A solution is found.

System information

Please complete the following information:

  • linear_operator 0.3.0
  • pytorch 1.13.1
  • Fedora Linux 37

Additional context

NNVariationalStrategy in pytorch can happen to make such a call (here).

Or am I misunderstanding the library? It seems to me like this would be an easy fix by just calling to_dense() on the right hand side.

It seems to me like this would be an easy fix by just calling to_dense() on the right hand side.

I think this is the correct fix. You shouldn't be able to call solve with a LinearOperator as the right hand side (at least at the moment). This is a bug in NNVariationalStrategy.

I'm closing the issue here; @Turakar could you open the issue in GPyTorch? (Or better yet, would you be able to post a PR doing exactly what you said; calling .to_dense() on the right hand side?)