/GDPO

Graph Diffusion Policy Optimization

Primary LanguagePython

Graph Diffusion Policy Optimization

This paper introduces $\textit{graph diffusion policy optimization}$ (GDPO), a novel approach to optimize graph diffusion models for arbitrary (e.g., non-differentiable) objectives using reinforcement learning. GDPO is based on an $\textit{eager policy gradient}$ tailored for graph diffusion models, developed through meticulous analysis and promising improved performance. Experimental results show that GDPO achieves state-of-the-art performance in various graph generation tasks with complex and diverse objectives.

Installing dependence

conda create --name GDPO --file spec-list.txt
conda activate GDPO
pip install -r requrements.txt

If there are still issues, please refer to DiGress and add other dependencies as necessary.

In the following steps, make sure you have activated the GDPO environment.

conda activate GDPO

Prepare Datasets

After downloading, unzip the files to the "dataset" folder, ensuring that the paths for ZINC250k and MOSES are "./dataset/zinc" and "./dataset/moses," respectively.

For Planar and SBM, they will be automatically downloaded during training.

Prepare Pretrained Models

After downloading the pretrained models, place them in the "./pretrained" folder.

If you need to train your own pretrained models, please refer to the following commands and prepare the corresponding dataset as well as the config file (located in "./configs/experiment").

bash run_train.sh

If you are using your own pretrained models, you only need to change the "resume" field in the configuration file located in "./configs/experiment" to the address of your pretrained models (usually located in the "outputs" folder) during the fine-tuning phase.

Run the toy experiments

In the paper, we designed a toy experiment that can be run without preparing any pretrained models.

The following command will run GDPO with 8 nodes.

bash run_ppo_toy.sh

If you want to change the number of nodes and the training method, such as running DDPO with 4 nodes, please modify the corresponding parameters in the "run_ppo_toy.sh" script.

The ".log" files named "evaluation" will display the corresponding evaluation results.

Finetune with GDPO

Here, we mainly divide into two parts. For convenience in reproduction, we also directly provide the corresponding checkpoints. Please note that intermediate models during fine-tuning are saved in the "./multirun" folder.

Planar and SBM

# finetune on the Planar with GDPO
bash run_ppo_planar.sh

# finetune on the SBM with GDPO
bash run_ppo_sbm.sh

Final model checkpoints:

ZINC250k and MOSES

ZINC250k

#finetune on the ZINC250k with GDPO

bash run_ppo_prop.sh

This command will start fine-tuning on Zinc250k by default, targeting the 5ht1b protein. If you need to change the target protein, simply modify "+experiment=zinc_ppo_5ht1b.yaml" in "run_ppo_prop.sh" and replace "5ht1b" with the corresponding protein name. For example, "+experiment=zinc_ppo_parp1.yaml" will start fine-tuning targeting the parp1 protein.

We recommend running four or more different seeds for the same protein, i.e., running "bash run_ppo_prop.sh" four times, to mitigate the influence of random factors.

Final model checkpoints:

MOSES

#finetune on the MOSES with GDPO

bash run_ppo_moses.sh

The fine-tuning process on MOSES is essentially the same as on Zinc250k.

Final model checkpoints:

Evaluation

General Graph Evaluation

For Planar and SBM, modify line 353 in "main_generate.py" to specify "test_method" as "evalgeneral".

Then, modify the "test_only" variable in the "planar_test.yaml" and "sbm_test.yaml" files in "./configs/experiment" to point to the checkpoint path of the fine-tuned model.

Finally, run the following command:

# test model on Planar
bash run_test_graph.sh

If you want to test SBM, modify the "dataset" and "experiment" variables in "run_test_graph.sh".

Molecular Graph Evaluation

For ZINC250k and MOSES, modify line 353 in "main_generate.py" to specify "test_method" as "evalproperty".

Modify the files in "./configs/experiment" to specify the "test_only" path.

Finally, run the following command:

# test model on Planar
bash run_test.sh

Modify the configuration in "run_test.sh" to evaluate different models with the target protein.

Note that we provide multiple final model checkpoints, but they are not in a format directly usable by PyTorch Lightning. To test these models, follow the code from lines 329-344 in "main_generate.py" to load these checkpoints into the model before conducting testing.