forhaoliu/chain-of-hindsight

GPU requirements results in error

Opened this issue · 2 comments

If I create a new conda environment (or miniconda in my case), and run

micromamba create -f gpu_requirement.yml --prefix ./envs

I was able to install all packages, but when running the training script python3 -m coh.coh_train_llama \, I run into error

cannot import name 'PartitionSpec' from 'jax.sharding'

I am able to get past it by changing gpu_requirements.yml jax/jax lib versions to 0.4.1 just in case anyone else is stuck:

dependencies:
    - python=3.8
    - pip
    - numpy
    - scipy
    - numba
    - h5py
    - matplotlib
    - scikit-learn
    - jupyter
    - tqdm
    - pytorch-cpu=1.13.0
    - jax=0.4.1                           # <- change this and line below from 0.3.25 to 0.4.1
    - jaxlib=0.4.1=*cuda*

Hi! I also got this error. Did you fix this?
Thanks

Unfortunately I have resorted to re-writing my own code in pytorch ..