carnotresearch/cr-sparse

Attribute error when importing cr.sparse

Closed this issue · 4 comments

I installed cr-sparse according to the installation instructions and tried following one of the tutorials, but ran into an Attribute error when trying to import cr-sparse. See below for the stack trace. I suspect there has been some update in Jax that isn't reflected yet in cr-sparse.

Package versions:
Jax 0.4.18
cr-sparse 0.3.2

System:
miniconda + VS Code
macOS Ventura 13.4, Apple M2

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/cs_testing.ipynb](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/cs_testing.ipynb) Cell 1 line 9

import matplotlib.pyplot as plt
# import cr.sparse as crs
--> import cr.sparse.dict as crdict
import cr.sparse.pursuit as pursuit

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/__init__.py:37](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/__init__.py:37)
     31 # Evaluation Tools
     33 from cr.sparse._src.tools.performance import (
     34     RecoveryPerformance
     35 )
---> 37 from cr.sparse._src.tools.trials_at_fixed_m_n import (
     38     RecoveryTrialsAtFixed_M_N
     39 )

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/tools/trials_at_fixed_m_n.py:25](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/tools/trials_at_fixed_m_n.py:25)
     22 import jax.numpy as jnp
     24 import cr.sparse as crs
---> 25 from cr.sparse import pursuit
     26 import cr.sparse.data as crdata
     27 import cr.sparse.dict as crdict

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/pursuit/__init__.py:26](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/pursuit/__init__.py:26)
     18 # pylint: disable=W0611
     20 from cr.sparse._src.pursuit.util import (
     21     abs_max_idx,
     22     gram_chol_update,
     23     largest_indices
     24 )
---> 26 from cr.sparse._src.pursuit.defs import (
     27     RecoverySolution
     28 )

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:24](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:24)
     20 from cr.nimble.dsp import build_signal_from_indices_and_values
     21 norm = jnp.linalg.norm
     23 @dataclass
---> 24 class SingleRecoverySolution:
     25     signals: jnp.DeviceArray = None
     26     representations : jnp.DeviceArray = None

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:25](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/cr/sparse/_src/pursuit/defs.py:25), in SingleRecoverySolution()
     23 @dataclass
     24 class SingleRecoverySolution:
---> 25     signals: jnp.DeviceArray = None
     26     representations : jnp.DeviceArray = None
     27     residuals : jnp.DeviceArray =  None

File [~/miniconda3/envs/sparse/lib/python3.11/site-packages/jax/_src/deprecations.py:53](https://file+.vscode-resource.vscode-cdn.net/Users/joey/USC/lidar_group/shadowOQS/dynamical_shadow_tomography/~/miniconda3/envs/sparse/lib/python3.11/site-packages/jax/_src/deprecations.py:53), in deprecation_getattr.<locals>.getattr(name)
     51   warnings.warn(message, DeprecationWarning, stacklevel=2)
     52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.numpy' has no attribute 'DeviceArray'

@joeybarreto let me check it out. There may have been recent changes in JAX. I will fix it and revert.

I can see that Jax now has a unified jax.Array type which subsumes DeviceArray starting version 0.4.1. Details at https://jax.readthedocs.io/en/latest/jax_array_migration.html

Will need to update the code accordingly.

@joeybarreto I have updated the code (and dependencies cr-nimble and cr-wavelets) to work with JAX 0.4.x. If you work with the latest code from main branch [clone and then pip install -e .], it should work well. I am still doing more tests before I will release CR-Sparse 0.4.0. Let me know if it works for you. I don't have MacBook + M2 to verify.

thanks for the quick reply @shailesh1729. Indeed, I can now import cr,cr.wavelets, and cr.nimble without any errors using the latest code in main

Have now tested everything on an Ubuntu 22.04.3 LTS, JAX 0.4.18, CUDA 12.2 Quadro RTX 5000. All tests passing. Released CR-Sparse 0.4.0. You should be good now. Thanks.