cornellius-gp/linear_operator

[Bug] object has no attribute _differentiable_kwargs

jaghili opened this issue ยท 0 comments

๐Ÿ› Bug

  • Install torch==2.0.1
  • Install linear_operator 0.5.3 with pip

To reproduce

I took the snippet from the README

import linear_operator
import torch

class DiagLinearOperator(linear_operator.LinearOperator):
    r"""
    A LinearOperator representing a diagonal matrix.
    """
    def __init__(self, diag):
        # diag: the vector that defines the diagonal of the matrix
        self.diag = diag

    def _matmul(self, v):
        return self.diag.unsqueeze(-1) * v

    def _size(self):
        return torch.Size([*self.diag.shape, self.diag.size(-1)])

    def _transpose_nonbatch(self):
        return self  # Diagonal matrices are symmetric

    # this function is optional, but it will accelerate computation
    def logdet(self):
        return self.diag.log().sum(dim=-1)
# ...

D = DiagLinearOperator(torch.tensor([1., 2., 3.]))
# Represents the matrix
#   [[1., 0., 0.],
#    [0., 2., 0.],
#    [0., 0., 3.]]
torch.matmul(D, torch.tensor([4., 5., 6.]))
# Returns [4., 10., 18.]

** Stack trace/error message **

Traceback (most recent call last):
  File "/home/jagh/codes/ng/src/a.py", line 31, in <module>
    torch.matmul(D, torch.tensor([4., 5., 6.]))
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2970, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 1839, in matmul
    return Matmul.apply(self.representation_tree(), other, *self.representation())
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py", line 2072, in representation_tree
    return LinearOperatorRepresentationTree(self)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jagh/.conda/envs/trot/lib/python3.11/site-packages/linear_operator/operators/linear_operator_representation_tree.py", line 8, in __init__
    self._differentiable_kwarg_names = linear_op._differentiable_kwargs.keys()
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'DiagLinearOperator' object has no attribute '_differentiable_kwargs'

Expected Behavior

Snippet should return [4., 10., 18.]

Additional context

I added self._differentiable_kwargs = { some dict }, which seems by pass the problem, but I get another message with self._nondifferentiable_kwargs I don't know how to setup. Did I miss something?