JaxGaussianProcesses/GPJax

bug: AttributeError: 'DenseLinearOperator' object has no attribute 'astype' when applying cholesky

adam-hartshorne opened this issue · 2 comments

Cross post from CoLA, wilson-labs/cola#65.

I have just tried to upgrade from commit wilson-labs/cola@74406c9, to the latest in the main branch, and I am now presented with the following error. I am using the lower_cholesky function provided from GPJax, https://github.com/JaxGaussianProcesses/GPJax/blob/main/gpjax/lower_cholesky.py and K is of type
<15x15 Sum[cola.ops.operators.Dense, cola.ops.operators.Product[cola.ops.operators.ScalarMul, cola.ops.operators.Identity]] with dtype=float64>

 l_zz = lower_cholesky(K)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/plum/function.py", line 438, in __call__
    return _convert(method(*args, **kw_args), return_type)
  File "/media/adam/shared_drive/PycharmProjects/Process_Shape_Datasets/lower_cholesky.py", line 38, in lower_cholesky
    return cola.ops.Triangular(jnp.linalg.cholesky(A.to_dense()), lower=True)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 184, in to_dense
    return self @ self.xnp.eye(self.shape[-1], self.shape[-1], dtype=self.dtype, device=self.device)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in _matmat
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 159, in <genexpr>
    return sum(M @ v for M in self.Ms)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operator_base.py", line 211, in __matmul__
    return self._matmat(X)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/ops/operators.py", line 25, in _matmat
    return self.xnp.cast(self.A, dtype) @ self.xnp.cast(X, dtype)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/cola/jax_fns.py", line 168, in cast
    return array.astype(dtype)
AttributeError: 'DenseLinearOperator' object has no attribute 'astype'. Did you mean: 'dtype'?

We do not support experimental commits from third-party libraries, @adam-hartshorne. Unless there is a problem with a PyPI CoLA release that aligns with our installation requirements, please bring up such matters as discussions in the future.

My apologies. I have now rectified the problem which was a result of still trying to wrap the return from the call to kernel.gram with a cola.ops.Dense, when that function already returns a PSD(Dense(M)).