/expt-utils

Primary LanguageDockerfile

conda install pytorch torchvision -c pytorch
conda install -c anaconda cudnn=8.2.1 cudatoolkit=11.3
pip install jax flax ml_collections optax tensorflow
pip install -U jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html