ModuleNotFoundError: No module named 'jax.extend' related to #209, #210
ZubuNoShoshinsha opened this issue ยท 9 comments
Hello,
My question is related to #209 and #210
My environment is...
Wsl2
OS: Ubuntu 22.04.4
GCC: 11.4.0
CUDA: 12.1
GPU: RTX 4090
LocalColabFold Ver. 1.5.5
As instructed in #209 , I checked if GPU was recognized and it was not.
So, I dongraded jax and jaxlib to
jax 0.4.7
jaxlib0.4.7+cuda11.cudnn82
as instructed in #209 .
And then I checked again using
$ /path/to/your/localcolabfold/colabfold-conda/bin/python3.10
import jax
print(jax.local_devices()[0].platform)
and "gpu" was returned.
Then, I run the localcolabfold. But, this error message popped up and stopped like below
2024-04-01 15:14:35,452 Running colabfold 1.5.5 (61df3b853140ca79dbdf64349824beb14364ebfd)
2024-04-01 15:14:36,006 Running on GPU
Traceback (most recent call last):
File "/mnt/d/Alphafold/localcolabfold/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main())
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 2037, in main run(
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/batch.py", line 1292, in run from colabfold.alphafold.models import load_models_and_params
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/colabfold/alphafold/models.py", line 4, in import haiku
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/init.py", line 20, in from haiku import experimental
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot
File "/mnt/d/AlphaFold/localcolabfold/colabfold-conda/lib/python3.10/site-packages/haiku/_src/dot.py", line 29, in from jax.extend import linear_util as lu
ModuleNotFoundError: No module named 'jax.extend'
It would be helpful if there would be any instruction for solving this issue.
I suspect that the issue lies in the version of the dm-haiku module being 0.0.11 or later. In my environment:
$ localcolabfold/colabfold-conda/bin/python3.10 -m pip list
jax 0.4.23
jaxlib 0.4.23+cuda11.cudnn86
chex 0.1.85
dm-haiku 0.0.10
If CUDA 12.1 is installed, these versions should be fine.
Please set your dm-haiku to version 0.0.10. Otherwise, you may encounter the error ModuleNotFoundError: No module named 'jax.extend'
Thank you for your suggestion. (I just noticed your response)
Actually my dm-haiku was 0.0.12, so I down graded to 0.0.10 as you suggested.
And I ran localcolabfold 1.5.5.
So, my environment is now
jax 0.4.7
jaxlib 0.4.7+cuda11.cudnn82
chex 0.1.82
dm-haiku 0.0.10
No more " ModuleNotFoundError: No module named 'jax.extend' ", but now new message showed up and the program stopped.
" Could not predict ProteinA. Not Enough GPU memory? FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details. "
Is there any more suggestion to solve this??
If you are using WSL2, did you turn on the settings shown in https://github.com/YoshitakaMo/localcolabfold?tab=readme-ov-file#for-wsl2-in-windows ?
Unfortunately, I can't figure out the cause because I don't have a WSL2 environment.
Yes, I did.
I restart wsl2 and tried another shot, but it didn't work well.
Thank you though.
I wonder...
when I downgraded dm-haiku, the message said
" colabfold 1.5.5 requires dm-haiku<0.013, >=0.0.12, but you have dm-haiku 0.0.10 which is incompatible. "
Is it fine to run colabfold appropriately?
Finally,
I might have found the solution.
I downgraded " nvidia-cudnn-cu11 " by doing this command from 9.0.0.312 to 8.5.0.96 .
pip install --upgrade nvidia-cudnn-cu11==8.5.0.96
I ran the localcolabfold and it processed very smoothly on GPU.
I was astonished.
Thank you.
jax 0.4.23 jaxlib 0.4.23+cuda11.cudnn86 chex 0.1.85 dm-haiku 0.0.10
Requirement already satisfied: torch==1.13.1 in /usr/local/lib/python3.10/dist-packages (1.13.1)
Requirement already satisfied: transformers==4.24.0 in /usr/local/lib/python3.10/dist-packages (4.24.0)
Collecting diffusers==0.3.0
Using cached diffusers-0.3.0-py3-none-any.whl (153 kB)
Collecting jax==0.4.23
Downloading jax-0.4.23-py3-none-any.whl (1.7 MB)
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ 1.7/1.7 MB 6.3 MB/s eta 0:00:00
ERROR: Could not find a version that satisfies the requirement jaxlib==0.4.23+cuda11.cudnn86 (from versions: 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23, 0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28)
ERROR: No matching distribution found for jaxlib==0.4.23+cuda11.cudnn86
I updated the installer and updater script for Linux two days ago as Jax 0.4.23 no longer seems suitable for cuda 12 and cudnn 9. Please update your cuda to 12.4, cudnn to 9, and use the latest updater script.