wilson-labs/cola

[Feature Request] Low Rank Dispatch rules (Woodbury identity, Trace rules, etc)

Opened this issue ยท 0 comments

mfinzi commented

๐Ÿš€ Feature Request

Dispatch rules for Product[Dense,Dense] or Sum[Product[Dense,Dense], Diagonal].

Examples:

Woodbury identity

Let the Woodbury matrix identity be given by:
$(D + UV)^{-1} = D^{-1} - D^{-1}U(I + VD^{-1}U)^{-1}VD^{-1}$

Cyclic trace property:

Given the cyclic trace property:
$\text{Tr}(UV) = \text{Tr}(VU)$
The idea is that for a generic Product[LinearOperator,LinearOperator] where $U$ and $V$ are not square, we can rearrange to reduce the dimensionality. If dense, we can further accelerate by performing the elementwise multiplication of $U$ and $V$ summing only over one axis.

Pitch

Introduce rules such as:

@dispatch
def inv(A: Sum[Product[Dense,Dense], Diagonal], **kwargs):
    ...

@dispatch(cond=product_faster_if_rearranged)
def trace(A: Product):
    ...

Additional context

Plum-dispatch can work a little different than one would expect for parametric types.
Some things need to be spelled out more explicitly (and possibly even changes may need to be made to cola-plum-dispatch)