Attribute error when importing cr.sparse
Closed this issue · 4 comments
I installed cr-sparse according to the installation instructions and tried following one of the tutorials, but ran into an Attribute error when trying to import cr-sparse. See below for the stack trace. I suspect there has been some update in Jax that isn't reflected yet in cr-sparse.
Package versions:
Jax 0.4.18
cr-sparse 0.3.2
System:
miniconda + VS Code
macOS Ventura 13.4, Apple M2
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
[/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/cs_testing.ipynb](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/cs_testing.ipynb) Cell 1 line 9
import matplotlib.pyplot as plt
# import cr.sparse as crs
--> import cr.sparse.dict as crdict
import cr.sparse.pursuit as pursuit
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/__init__.py:37](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/__init__.py:37)
31 # Evaluation Tools
33 from cr.sparse._src.tools.performance import (
34 RecoveryPerformance
35 )
---> 37 from cr.sparse._src.tools.trials_at_fixed_m_n import (
38 RecoveryTrialsAtFixed_M_N
39 )
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/tools/trials_at_fixed_m_n.py:25](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/tools/trials_at_fixed_m_n.py:25)
22 import jax.numpy as jnp
24 import cr.sparse as crs
---> 25 from cr.sparse import pursuit
26 import cr.sparse.data as crdata
27 import cr.sparse.dict as crdict
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/pursuit/__init__.py:26](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/pursuit/__init__.py:26)
18 # pylint: disable=W0611
20 from cr.sparse._src.pursuit.util import (
21 abs_max_idx,
22 gram_chol_update,
23 largest_indices
24 )
---> 26 from cr.sparse._src.pursuit.defs import (
27 RecoverySolution
28 )
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:24](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:24)
20 from cr.nimble.dsp import build_signal_from_indices_and_values
21 norm = jnp.linalg.norm
23 @dataclass
---> 24 class SingleRecoverySolution:
25 signals: jnp.DeviceArray = None
26 representations : jnp.DeviceArray = None
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:25](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:25), in SingleRecoverySolution()
23 @dataclass
24 class SingleRecoverySolution:
---> 25 signals: jnp.DeviceArray = None
26 representations : jnp.DeviceArray = None
27 residuals : jnp.DeviceArray = None
File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/jax/_src/deprecations.py:53](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/jax/_src/deprecations.py:53), in deprecation_getattr.<locals>.getattr(name)
51 warnings.warn(message, DeprecationWarning, stacklevel=2)
52 return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'
@joeybarreto let me check it out. There may have been recent changes in JAX. I will fix it and revert.
I can see that Jax now has a unified jax.Array
type which subsumes DeviceArray
starting version 0.4.1. Details at https://jax.readthedocs.io/en/latest/jax_array_migration.html
Will need to update the code accordingly.
@joeybarreto I have updated the code (and dependencies cr-nimble and cr-wavelets) to work with JAX 0.4.x. If you work with the latest code from main branch [clone and then pip install -e .
], it should work well. I am still doing more tests before I will release CR-Sparse 0.4.0. Let me know if it works for you. I don't have MacBook + M2 to verify.
thanks for the quick reply @shailesh1729. Indeed, I can now import cr
,cr.wavelets
, and cr.nimble
without any errors using the latest code in main
Have now tested everything on an Ubuntu 22.04.3 LTS, JAX 0.4.18, CUDA 12.2 Quadro RTX 5000. All tests passing. Released CR-Sparse 0.4.0. You should be good now. Thanks.