google-research/scenic

simple_cnn baseline doesn't work with jax nightly

sycamoreoak opened this issue · 0 comments

python3 ./simple_cnn.py 
Traceback (most recent call last):
  File "~/.local/lib/python3.10/site-packages/flax/core/axes_scan.py", line 21, in <module>
    from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (~/.local/lib/python3.10/site-packages/jax/__init__.py)

If I modify the linear_util line in to this:

from jax.extend import linear_util as lu

then I get:

  File "~/.local/lib/python3.10/site-packages/flax/struct.py", line 141, in dataclass
    if tuple(map(int, jax.version.__version__.split('.'))) >= (0, 3, 1):
ValueError: invalid literal for int() with base 10: 'dev20240528'