/dMap_JAX

Diffusion maps using JAX

Primary LanguagePythonMIT LicenseMIT

Diffusion maps using JAX (dMap_JAX).

Diffusion maps (dMaps) is a nonlinear manifold learning technique. It can be used for learning the intrinsic low-dimensional manifold of the high-dimensional molecular simulation data.

dMaps are obtained by the eigenvalue decomposition of the adjoint matrix of the right stochastic Markovian transition matrix of the high-dimensional data, which is generated by normalizing the similarity matrix wrapped by a Kernel with the row sums in the matrix.

This Python code supports the calculation of dMaps (standard and adaptive) using root mean square deviation (RMSD) as the similarity metric with Gaussian kernel. Both RMSD and dMap calculations can be performed on GPUs.


Comparison of the performance of RMSD matrix calculations obtained with dMap_JAX and MDAnalysis.

ew

Example usage can be found in the notebooks.


Installation

conda env create -f environment.yml

conda activate djax

jax version (tested): 0.4.2

python -m pip install "jax==0.4.2"

pip install --upgrade "jaxlib==0.4.2+cuda11.cudnn86" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


Features

  • Batchwise RMSD matrix calculations on GPUs.
  • Standard and adaptive diffusion Maps on GPUs.
  • Helper functions to plot eigenvalue spectrum and 2D dMaps.

Acknowledgements

  • Max Topel, Andrew L. Ferguson