bug: AttributeError: 'DenseLinearOperator' object has no attribute 'astype' when applying cholesky
adam-hartshorne opened this issue · 2 comments
Cross post from CoLA, wilson-labs/cola#65.
I have just tried to upgrade from commit wilson-labs/cola@74406c9, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type
<15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>
l_zz = lower_cholesky(K)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/plum/function.py", line 438, in __call__
return _convert(method(*args, **kw_args), return_type)
File "/media/adam/shared_drive/PycharmProjects/Process_Shape_Datasets/lower_cholesky.py", line 38, in lower_cholesky
return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 184, in to_dense
return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
return self._matmat(X)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in _matmat
return sum(M @ v for M in self.Ms)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in <genexpr>
return sum(M @ v for M in self.Ms)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
return self._matmat(X)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 25, in _matmat
return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype)
File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py", line 168, in cast
return array.astype(dtype)
AttributeError: 'DenseLinearOperator' object has no attribute 'astype'. Did you mean: 'dtype'?
We do not support experimental commits from third-party libraries, @adam-hartshorne. Unless there is a problem with a PyPI CoLA release that aligns with our installation requirements, please bring up such matters as discussions in the future.
My apologies. I have now rectified the problem which was a result of still trying to wrap the return from the call to kernel.gram with a cola.ops.Dense, when that function already returns a PSD(Dense(M)).