simple_cnn baseline doesn't work with jax nightly
sycamoreoak opened this issue · 0 comments
sycamoreoak commented
python3 ./simple_cnn.py
Traceback (most recent call last):
File "~/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 21, in <module>
from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (~/.local/lib/python3.10/site-packages/jax/__init__.py)
If I modify the linear_util line in to this:
from jax.extend import linear_util as lu
then I get:
File "~/.local/lib/python3.10/site-packages/flax/struct.py", line 141, in dataclass
if tuple(map(int, jax.version.__version__.split('.'))) >= (0, 3, 1):
ValueError: invalid literal for int() with base 10: 'dev20240528'