Diffusion maps using JAX (dMap_JAX).
Diffusion maps (dMaps) is a nonlinear manifold learning technique. It can be used for learning the intrinsic low-dimensional manifold of the high-dimensional molecular simulation data.
dMaps are obtained by the eigenvalue decomposition of the adjoint matrix of the right stochastic Markovian transition matrix of the high-dimensional data, which is generated by normalizing the similarity matrix wrapped by a Kernel with the row sums in the matrix.
This Python code supports the calculation of dMaps (standard and adaptive) using root mean square deviation (RMSD) as the similarity metric with Gaussian kernel. Both RMSD and dMap calculations can be performed on GPUs.
Comparison of the performance of RMSD matrix calculations obtained with dMap_JAX and MDAnalysis.
Example usage can be found in the notebooks.
Installation
conda env create -f environment.yml
conda activate djax
jax version (tested): 0.4.2
python -m pip install "jax==0.4.2"
pip install --upgrade "jaxlib==0.4.2+cuda11.cudnn86" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Features
- Batchwise RMSD matrix calculations on GPUs.
- Standard and adaptive diffusion Maps on GPUs.
- Helper functions to plot eigenvalue spectrum and 2D dMaps.
Acknowledgements
- Max Topel, Andrew L. Ferguson