/GSVAE

Git repo for the paper "Physics-constrained predictive molecular latent space discovery with graph scattering variational autoencoder."

Primary LanguagePythonMIT LicenseMIT

Predictive Molecular Graph Latent Space Discovery

This repository contains a PyTorch implementation of Graph Scattering Variational AutoEncoder (GSVAE), a molecular generative model developed based on variational inference and graph theory. In this model, the encoding network is based on the scattering transform with adaptive spectral filters, which allows for better generalization of the model in the presence of limited training data. The decoding network is a one-shot graph generative model that conditions atom types on molecular topology.

Physical constraints are implemented to encourage energetically stable and valid molecules.

To account for the training with a limited-sized dataset, a Bayesian formalism is considered that allows capturing the uncertainties in the predictive estimates of the molecular properties.

Getting Started

Dependencies

This implementation requires:

  • Python (>= 3.5)
  • SciPy (>= 1.4.1)
  • PyParsing (>= 1.1)
  • PyTorch (>= 1.5.0)
  • RDKit (>= 2019.09.3)
  • NumPy (>= 1.18.1)
  • Seaborn (>= 0.9.0)
  • scikit-learn (>= 0.22.1)
  • Matplotlib (>= 3.1.1)
  • chainer-chemistry (>=0.6.0)

Installation

After downloading the code, you may install it by running

pip install -r requirements.txt

Note that to use functions in utils.py, you need to have RDKit package installed. You can find more information at https://www.rdkit.org/docs/Install.html .

Data

Data samples are generated through data_gen.py, which also performs classic bootstrapping. The script accepts the following arguments:

optional arguments:
  --data_size           Total size of the training + test dataset (default: 100000)
  --N                   Size of the training set. This only affects the bootstrapping (default: 600)
  --n_samples           Number of the bootstrap samples (default: 1, no bootstrap)

Note that the Bayesian bootstrapping is done in the main code. To generate data, run:

cd data
python3 data_gen.py

Run

Training

The model is trained using main.py. This code accepts the following arguments:

optional arguments:
  --epochs              number of epochs to train (default: 1900)
  --batch_number        number of batches per epoch (default: 25)
  --gpu_mode            accelerate the script using GPU (default: 1)
  --z_dim               latent space dimensionality (default: 30)
  --seed                random seed (default: 1400)
  --loadtrainedmodel    path to trained model
  --mu_reg_1            regularization parameter for ghost nodes and valence constraint (default: 0)
  --mu_reg_2            regularization parameter for connectivity constraint (default: 0)
  --mu_reg_3            regularization parameter for 3-member cycle constraint (default: 0)
  --mu_reg_4            regularization parameter for cycle with triple bond constraint (default: 0)
  --N_vis               number of test data for visualization (default: 3000)
  --log_interval        number of epochs between visualizations (default: 200)
  --mol_vis             visualize samples molecules (default: 0)
  --n_samples           number of generated samples from molecular space (default: 10000)
  --wlt_scales          number of wavelet scales (default: 12)
  --scat_layers         number of scattering layers (default: 4)
  --database            name of the training database (default: 'QM9')
  --datafile            name and location of the training file in data folder (default: 'QM9_0.data')
  --BB_samples          index for Bayesian bootstrap sample (default: 0)
  --N                   number of training data (default: 600)
  --res                 path for storing the results (default: 'results/')
  --y_id                index for target property in the conditional design (default: None, unconditional design)
  --y_target            target property value in the conditional design (default: None, unconditional design)

After generating the data, run

python3 main.py

to train the base model. To run the constrained model, set the regularization parameters mu_reg_1, mu_reg_2, mu_reg_3, and mu_reg_4 to a positive value and tune them based on the output statistics.

Conditional design

This code performs conditional design by setting a target property value for the sampled molecules. Set the property ID with argument y_id (0: PSA, 1: MolWt, 2: LogP) and the target value with y_target.

Quantifying uncertainties

To perform UQ analysis, use utils.py. The utils.py script accepts the following arguments:

optional arguments:
  --BB_samples          number of samples for uncertainty quantification (default: 0)
  --N                   number of training data (default: 600)
  --database            name of the training database (default: 'QM9')
  --sample_file         predictive samples directory (default: 'BB_600')
  --gpu_mode            accelerate the script using GPU (default: 0)

To compute the confidence interval, use the following example script

ITR=25
DIR=B_200
N=200

for i in `seq 1 ${ITR}`;
do
    python3 main.py --N "$N" --BB_samples "$i" --res results/"${DIR}"
done

mkdir data/samples
mkdir data/samples/${DIR}
mv results/"${DIR}"/*/samples_*.data data/samples/${DIR}

python3 utils.py --BB_samples "$ITR" --N "$N" --sample_file "${DIR}"

Filters

You can run filter.py independently in order to perform scattering transform and visualize graph filters. The filter.py script accepts the following arguments:

optional arguments:
  --gpu_mode            accelerate the script using GPU (default: 0)
  --wlt_scales          number of wavelet scales (default: 12)
  --scat_layers         number of scattering layers (default: 4)
  --N                   number of training data (default: 600)
  --database            name of the training database (default: 'QM9')

Citation

You can use this code, as whole or in part, by citing:

@article{shervani2020physics,
  title={Physics-Constrained Predictive Molecular Latent Space Discovery with Graph Scattering Variational Autoencoder},
  author={Navid, Shervani-Tabar and Zabaras, Nicholas},
  journal={arXiv preprint arXiv:2009.13878},
  year={2020}
}

Questions

For any questions or comments regarding this work, feel free to submit an issue here or contact Navid Shervani-Tabar (nshervan@nd.edu). In the email title, please use "Regarding GSVAE paper".