This library implements a generic structure for abstract linear operators and enables a number of standard operations on them:
- Arithmetic:
A + B
,A - B
,-A
,A @ B
all work exactly as expected to combine linear operators. - Indexing:
A[k:ell,m:n]
works as expected. - Solves:
Ax = b
can be solved withCG
for PSD matrices,minres
for symmetric matrices,LSQR
(to be implemented), orLSMR
(to be implemented). - Trace estimation: The trace of square matrices, can be estimated via Hutch++ and Hutchinson's estimator.
- Diamond-Boyd stochastic equilibration
- Randomized Nyström Preconditioning
- Automatic adjoint operator generation.
The public API of the LinearOperator
library is that every LinearOperator
has the
following properties and methods:
class LinearOperator:
# Properties
shape: tuple[int, int]
T: LinearOperator
supports_operator_matrix: bool
device: torch.Device
# Matrix multiply
def __matmul__(self, b: torch.Tensor) -> torch.Tensor: ...
def __rmatmul__(self, b: torch.Tensor) -> torch.Tensor: ...
def __matmul__(self, b: LinearOperator) -> LinearOperator: ...
def __rmatmul__(self, b: LinearOperator) -> LinearOperator: ...
# Linear Solve Methods
def solve_I_p_lambda_AT_A_x_eq_b(self,
lambda_: float,
b: torch.Tensor,
x0: torch.Tensor | None=None,
*, precondition: None | Literal['nsytrom'], hot=False) -> torch.Tensor: ...
def solve_A_x_eq_b(self,
b: torch.Tensor,
x0: torch.Tensor | None=None) -> torch.Tensor: ...
# Transformations on LinearOperator
def __mul__(self, c: float) -> LinearOperator: ...
def __rmul__(self, c: float) -> LinearOperator: ...
def __truediv__(self, c: float) -> LinearOperator: ...
def __pow__(self, k: int) -> LinearOperator: ...
def __add__(self, c: LinearOperator) -> LinearOperator: ...
def __sub__(self, c: LinearOperator) -> LinearOperator: ...
def __neg__(self) -> LinearOperator: ...
def __pos__(self) -> LinearOperator: ...
def __getitem__(self, key) -> LinearOperator: ...
The following functions are available in the root of the library:
def operator_matrix_product(A: LinearOperator, M: torch.Tensor) -> torch.Tensor: ...
def aslinearoperator(A: torch.Tensor | LinearOperator) -> LinearOperator: ...
def vstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...
def hstack(ops: list[LinearOperator] | tuple[LinearOperator, ...]) -> LinearOperator: ...
# To be implemented:
def bmat(ops: list[list[LinearOperator]]) -> LinearOperator: ... # Optimizes out ZeroOperator
The following functions are available in linops.trace
for trace estimation:
def hutchpp(A: lo.LinearOperator, m: int) -> float: ...
def hutchinson(A: lo.LinearOperator, m: int) -> float: ...
linops.equilibration
contains equilibrate
and symmetric_equilibrate
.
Their public API is not finalized, if you wish to use them it is recommend you read the source code.
Linear operators can be constructed in the following way:
- Creating a sub-class of
LinearOperator
- Calling one of the following constructors:
IdentityOperator(n: int)
DiagonalOperator(diag: torch.Tensor)
: wherediag
is a 1D torch tensor.MatrixOperator(M: torch.Tensor)
: whereM
is a 2D torch tensor.SelectionOperator(shape: tuple[int, int], idxs: slice | list[int | slice])
KKTOperator(H: LinearOperator, A: LinearOperator)
: whereH
is a squareLinearOperator
andA
is aLinearOperator
VectorJacobianOperator(f: torch.Tensor, x: torch.Tensor)
: wheref
is the output of the function being differentiated which has a torch autograd value andx
is the vector on whichensures_grad
was called.ZeroOperator(shape: tuple[int, int])
- Combining operators via:
A + B
,A - B
,A @ B
forA
,B
linear operatorshstack
,vstack
A
,c A
,A / c
,v * A
,A / v
for scalarc
and vectorv
.
To implement a LinearOperator
the following are mandatory:
- Set
_shape: tuple[int, int]
to the shape of the operator. - Set
device
appropriately, if the operator requires vectors to be on a particular device. - Implement a method
def _matmul_impl(self, v: torch.Tensor) -> torch.Tensor: ...
that implements your matrix vector product.
The following are recommended to improve performance:
- If your
_matmul_impl
method handles matrix inputs correctly, setsupports_operator_matrix: bool
toTrue
. - If it is possible to describe the adjoint operator, set
_adjoint: LinearOperator
to point to the adjoint of your operator. If you do not compute this, then one will be autogenerated by differentiating through your_matmul_impl
.
It is suggested that, if possible, you replace any other methods with specialized implementations.