google-deepmind/chex

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

maming109 opened this issue · 2 comments

/tmp/ipykernel_34/2874194604.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
from IPython.core.display import display, HTML

ImportError Traceback (most recent call last)
Cell In[27], line 20
18 # Import model definition from big_vision
19 from big_vision.models.proj.paligemma import paligemma
---> 20 from big_vision.trainers.proj.paligemma import predict_fns
22 # Import big vision utilities
23 import big_vision.datasets.jsonl

File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20
17 import functools
19 from big_vision.pp import registry
---> 20 import big_vision.utils as u
21 import einops
22 import jax

File /kaggle/working/big_vision_repo/big_vision/utils.py:38
36 import flax.jax_utils as flax_utils
37 import jax
---> 38 from jax.experimental.array_serialization import serialization as array_serial
39 import jax.numpy as jnp
40 import ml_collections as mlc

File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36
34 from jax._src import sharding
35 from jax._src import sharding_impls
---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL
37 from jax._src import typing
38 from jax._src import util

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

I have same error. what can we do?

Hi there. I don't really understand this issue, it doesn't look like you're importing chex here. Could you provide a step-by-step reproduction?