AttributeError: module 'jax' has no attribute 'linear_uti
ohjeyy93 opened this issue · 4 comments
Hello Colabfold team,
I started getting this error,
"AttributeError: module 'jax' has no attribute 'linear_util'"
I saw a previous solution which was installing different Jax and jaxlib versions. I tried older versions 4.14~4.18, 4.23, 4.25, 4.26, 4.33, 4.35, and etc. However, changing to those versions didn't fix the problem.
By the way, colabfold seems to be fine for a while but then start getting the error by itself sometimes. It sometimes changes colabfold version alone. Is the colabfold updating by itself?
jax.linear_util
has been moved to jax.extend.linear_util
since jax 0.4.24. Please check this thread to fix your issue.
I see the problem was
colabfold/localcolabfold/colabfold-conda/bin/pip install --upgrade "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
didn't actually update the colabfold conda
I had to manually activate colabfold conda and install the jax/jaxlib 0.4.23 version. thank you.
Now localcolabfold uses jax[cuda12]==0.4.28 in the current installer script. Did you install your localcolabfold with it or update with the latest update script?
Using the latest installation script,I encountered the same problem. my method was to uninstall jax, jaxlib and chex. Jax\jaxlib used 0.4.3 and chex used 0.1.7. Currently, it is running normally.