[Bug] AttributeError: 'DenseLinearOperator' object has no attribute 'astype' when applying cholesky
adam-hartshorne opened this issue · 3 comments
I have just tried to upgrade from commit 74406c9, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type
<15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>
l_zz = lower_cholesky(K)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/plum/function.py", line 438, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/media/adam/shared_drive/PycharmProjects/Process_Shape_Datasets/lower_cholesky.py", line 38, in lower_cholesky
return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 184, in to_dense
return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
return self._matmat(X)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in _matmat
return sum(M @ v for M in self.Ms)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in <genexpr>
return sum(M @ v for M in self.Ms)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
return self._matmat(X)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 25, in _matmat
return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py", line 168, in cast
return array.astype(dtype)
AttributeError: 'DenseLinearOperator' object has no attribute 'astype'. Did you mean: 'dtype'?
Normally I would provide a MVE but I am not sure how best to test for this. Let me know how best to assist if required.
Hmm. It looks to me like in this case the .A
attribute of a Dense
LinearOperator is itself a LinearOperator whereas it is expected to be an array (but maybe it would not explicitly fail before this commit though operating in a way that's unintended).
Can you link the code used to product the linear operator that enters into this example?
(Later we can add some better input validation for the basic linear operators to help raise these errors earlier)
After some further research, I have found out what the issue was. GPJax function kernel.gram(...) now returns a Linear Operator of type PSD(Dense(M)), which I was then "wrapping" with Dense (as previously the return from this function was not a cola Linear Operator), and this caused the resulting error i.e.
k_zz = kernel.gram(Z)
k_zz_cola = cola.ops.Dense(k_zz) <---- This was root cause of the error
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz_cola + noise * cola.ops.I_like(k_zz_cola))
l_zz = lower_cholesky(K)
removing the unnecessary extra cola.ops.Dense removes the error i.e.
k_zz = kernel.gram(Z)
noise = jnp.exp(self.log_beta) + self.jitter
K = cola.PSD(k_zz + noise * cola.ops.I_like(k_zz))
l_zz = lower_cholesky(K)
yeah in general cola.lazify
will be safer if you're not sure if the object is a linear operator