This is the official repository for LLM-RL.
git clone https://github.com/Sea-Snell/LLM_RL.git
cd LLM_RL
Install with conda (cpu, tpu, or gpu).
install with conda (cpu):
conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
python -m pip install -e .
install with conda (gpu):
conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
python -m pip install -e .
install with conda (tpu):
conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
python -m pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python -m pip install -e .
# navigate to a different directory
cd ~/
git clone https://github.com/Sea-Snell/JaxSeq2.git
cd JaxSeq2
python -m pip install -e .