GPU requirements results in error
Opened this issue · 2 comments
PootieT commented
If I create a new conda environment (or miniconda in my case), and run
micromamba create -f gpu_requirement.yml --prefix ./envs
I was able to install all packages, but when running the training script python3 -m coh.coh_train_llama \
, I run into error
cannot import name 'PartitionSpec' from 'jax.sharding'
I am able to get past it by changing gpu_requirements.yml
jax/jax lib versions to 0.4.1 just in case anyone else is stuck:
dependencies:
- python=3.8
- pip
- numpy
- scipy
- numba
- h5py
- matplotlib
- scikit-learn
- jupyter
- tqdm
- pytorch-cpu=1.13.0
- jax=0.4.1 # <- change this and line below from 0.3.25 to 0.4.1
- jaxlib=0.4.1=*cuda*
gzliyu commented
Hi! I also got this error. Did you fix this?
Thanks
PootieT commented
Unfortunately I have resorted to re-writing my own code in pytorch ..