/LLM_RL

Primary LanguagePythonMIT LicenseMIT

LLM_RL

This is the official repository for LLM-RL.

installation

1. pull from github

git clone https://github.com/Sea-Snell/LLM_RL.git
cd LLM_RL

2. install dependencies

Install with conda (cpu, tpu, or gpu).

install with conda (cpu):

conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
python -m pip install -e .

install with conda (gpu):

conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
python -m pip install -e .

install with conda (tpu):

conda env create -f environment.yml
conda activate LLM_RL
python -m pip install --upgrade pip
python -m pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
python -m pip install -e .

3. install JaxSeq2

# navigate to a different directory
cd ~/
git clone https://github.com/Sea-Snell/JaxSeq2.git
cd JaxSeq2
python -m pip install -e .

4. install JaxSeq2