/JODO

Learning Joint 2D & 3D Diffusion Models for Complete Molecule Generation

Primary LanguagePythonMIT LicenseMIT

JODO


The implementation of Learning Joint 2D & 3D Diffusion Models for Complete Molecule Generation.

Represent molecules as 3D point cloud and 2D bonding graph:

The generative diffusion process:


Visualization of molecules generated by JODO trained on the GEOM-Drugs dataset:

Visualization of molecules generated by JODO trained on the QM9 dataset with explict hydrogen atoms:


Dependencies

Dataset

We recommend using our processed dataset files provided here.

Download datasets:

# 718MB
wget https://zenodo.org/record/7966493/files/data.zip
unzip data.zip

If you want to construct the GEOM-Drugs dataset from scratch:

  • The raw GEOM dataset is available at here.
  • Download rdkit_folder.tar.gz and unpack it.
  • Run python build_geom_dataset.py --data_dir YOUR_DATA_PATH.

Generated Molecules

We provide pickles of 10000 molecules generated by JODO on different datasets in ./rdkit_mols. Molecules are saved as RDKit Mol objects. Just load the list of molecules and make further analysis.

# Example for loading molecules generated from JODO trained on GEOM-Drugs dataset. 
import pickle
mol_list = pickle.load(open('rdkit_mols/geom_jodo_ancestral_ckpt_35.pkl', 'rb'))

Evaluation

We construct a comprehensive evaluation pipeline for molecule generation, including 2D molecular graph metrics, 3D geometry metrics, and substructure geometry alignment metrics.

  • Especially for 3D geometry metrics, we follow https://github.com/ehoogeboom/e3_diffusion_for_molecules to use distance lookup table to predict bonds and report the same stability metrics for 3D geometry comparisons.
  • However, stability metrics for 3D geometry may be tricked in some situation. Some methods get high stability ratio but fail on FCD and alignment MMD, implying poor molecule generation quality. This phenomenon is more pronounced on the GEOM-Drugs dataset because of more atypical interatomic distances.
  • We recommend using the stability metric more cautiously, preferably in combination with other metrics to evaluate molecular quality.

To evaluate your models with our pipeline conveniently, you can save your generated molecules as a list of RDKit Mol objects and run eval_rdkit_pkl.py.

Take QM9 as an example:

# Molecules with 3D positions and atom types, without bonds
python eval_rdkit_pkl.py --dataset_name qm9 --type 3D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom and bond types, without 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type 2D --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

# Molecules with atom types, bond types and 3D positions
python eval_rdkit_pkl.py --dataset_name qm9 --type both --sub_geometry=True --root_path YOUR_DATASET_PATH --pkl_path YOUR_MOL_PATH

Checkpoint

Our checkpoints are provided here.

Download checkpoints:

# Unconditional Generation: QM9, GEOM-Drugs (2.8GB)
wget https://zenodo.org/record/8002902/files/exp_uncond.zip
unzip exp_uncond.zip

# Conditional Generation: single quantum property on QM9 (3.1GB)
wget https://zenodo.org/record/8002902/files/exp_cond.zip 
unzip exp_cond.zip

# Conditional Generation: multi properties (1.6GB)
wget https://zenodo.org/record/8002902/files/exp_cond_multi.zip 
unzip exp_cond_multi.zip

# Molecular Graph Generation: ZINC250k, MOSES (3.9GB)
wget https://zenodo.org/record/8002902/files/exp_2d.zip 
unzip exp_2d.zip

Unconditional Generation

QM9 Training Example:

CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo
  • Set GPU_id with CUDA_VISIBLE_DEVICES, support multi GPUs.

QM9 Sampling Example:

# sample from our pretrained checkpoint
CUDA_VISIBLE_DEVICES=2 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_qm9_jodo --config.eval.ckpts '30' --config.eval.batch_size 2500 --config.sampling.steps 1000
  • Set --config.eval.batch_size to control GPU memory usage.
  • Set iteration steps via --config.sampling.steps. (Great results can be obtained from 1000 steps to 50 steps)

GEOM-Drugs Training Example:

# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.training.n_iters 1500000

GEOM-Drugs Sampling Example:

# Base
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_base --config.model.n_layers 6 --config.model.nf 128 --config.eval.ckpts '30' --config.eval.batch_size 800 --config.sampling.steps 1000

# Medium
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_media --config.eval.ckpts '30' --config.eval.batch_size 1000 --config.sampling.steps 1000

# Large
CUDA_VISIBLE_DEVICES=0,1 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode eval --workdir exp_uncond/vpsde_geom_jodo_large --config.model.nf 384 --config.eval.ckpts '30' --config.eval.batch_size 500 --config.sampling.steps 1000

Using the simplified DGT without extra attention heads can also achieve relatively good performance:

# QM9 Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_qm9_jodo_sim --config.model.name DGT_concat_sim

# GEOM-Drugs Medium Training
CUDA_VISIBLE_DEVICES=2,3 python main.py --config configs/vpsde_geom_uncond_jodo.py --mode train --workdir exp_uncond/vpsde_geom_jodo_media_sim --config.model.name DGT_concat_sim

Conditional Generation

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode train --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_gap --config.cond_property gap --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_homo --config.cond_property homo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_lumo --config.cond_property lumo --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_mu --config.cond_property mu --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_Cv --config.cond_property Cv --config.eval.ckpts '40'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_jodo.py --mode eval --workdir exp_cond/vpsde_qm9_cond_jodo_alpha --config.cond_property alpha --config.eval.ckpts '40'
  • Set conditional property alpha, gap, homo, lumo, mu, Cv by --config.cond_property.
# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode train --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_Cv_mu --config.cond_property1 Cv --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_gap_mu --config.cond_property1 gap --config.cond_property2 mu --config.eval.ckpts '50'
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_cond_multi_jodo.py --mode eval --workdir exp_cond_multi/vpsde_qm9_cond_jodo_alpha_mu --config.cond_property1 alpha --config.cond_property2 mu --config.eval.ckpts '50'
  • Set multi conditional properties via --config.cond_property1 and --config.cond_property2.

Molecular Graph Generation

ZINC250k:

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode train --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_zinc_2d_jodo.py --mode eval --workdir exp_2d/vpsde_zinc_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '5'
  • You can train a smaller model by --config.model.nf 256 --config.model.n_heads 16 --config.model.n_layers 8.

MOSES:

# Training
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode train --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000

# Sampling
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_moses_2d_jodo.py --mode eval --workdir exp_2d/vpsde_moses_2d_jodo --config.model.nf 1024 --config.model.n_heads 64 --config.model.n_layers 6 --config.training.snapshot_freq 300000 --config.eval.ckpts '4'

Training CDGS on QM9 and GEOM-Drugs:

# QM9
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_qm9_2d_cdgs.py --mode train --workdir exp_2d/vpsde_qm9_2d_cdgs

# GEOM-Drugs
CUDA_VISIBLE_DEVICES=0 python main.py --config configs/vpsde_geom_2d_cdgs.py --mode train --workdir exp_2d/vpsde_geom_2d_cdgs

Citation

@article{huang2023learning,
  title={Learning Joint 2D \& 3D Diffusion Models for Complete Molecule Generation},
  author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
  journal={arXiv preprint arXiv:2305.12347},
  year={2023}
}

@article{huang2023conditional,
  title={Conditional Diffusion Based on Discrete Graph Structures for Molecular Graph Generation},
  author={Huang, Han and Sun, Leilei and Du, Bowen and Lv, Weifeng},
  journal={arXiv preprint arXiv:2301.00427},
  year={2023}