choderalab/openmmtools

Openmmtools without Jax?

sef43 opened this issue · 3 comments

sef43 commented

Hello, how can I install the latest Openmmtools without the JAX dependency?

I just want to use the OpenMM test systems but I encounter an error from jaxlib version. I think openmmtools is importing pymbar which imports JAX?
I don't believe jax functionality should be needed for the testsystems.
Combined with PyTorch packaged needed for OpenMM torch this becomes a heavy conda install. Here is a Colab notebook that demonstrates the problem when run on the T4 GPU runtime:
https://colab.research.google.com/drive/1O28OGU3HG2d04xKVLPwGYNl-mpEPs0XM?usp=sharing

Ideally I would not need to install Jax at all.

!mamba install -c conda-forge openmm-torch openmmtools pytorch=*=cuda*

from openmmtools.testsystems import AlanineDipeptideVacuum

# Get the system of alanine dipeptide
ala2 = AlanineDipeptideVacuum(constraints=None)

error:

Warning on use of the timeseries module: If the inherent timescales of the system are long compared to those being analyzed, this statistical inefficiency may be an underestimate.  The estimate presumes the use of many statistically independent samples.  Tests should be performed to assess whether this condition is satisfied.   Be cautious in the interpretation of the data.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-4-d5191668587d>](https://localhost:8080/#) in <cell line: 1>()
----> 1 from openmmtools.testsystems import AlanineDipeptideVacuum
      2 
      3 # Get the system of alanine dipeptide
      4 ala2 = AlanineDipeptideVacuum(constraints=None)

12 frames
[/usr/local/lib/python3.10/site-packages/jax/_src/lib/__init__.py](https://localhost:8080/#) in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version)
     81     msg = (f'jaxlib is version {jaxlib_version}, but this version '
     82            f'of jax requires version >= {minimum_jaxlib_version}.')
---> 83     raise RuntimeError(msg)
     84 
     85   if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.1.75, but this version of jax requires version >= 0.3.7.

@sef43 Openmmtools shouldn't be pulling jax as a dependency, I think that should be coming from pymbar. That's the only dependency I can think of that could be pulling jax.

A way of installing pymbar without jax is discussed in https://github.com/choderalab/pymbar . But I think another way is to just use pymbar=3 when you create your environment. I hope that helps.

sef43 commented

Thanks @ijpulidos, telling mamba to install pymbar=3 works. Also should note that this problem only happens on google Colab so definitely conda to blame for it.