FS-GNNTR: Few-shot Learning with Transformers via Graph Embeddings for Molecular Property Prediction

In this work, we propose a few-shot GNN-Transformer architecture, FS-GNNTR to face the problem of low-data in molecular property prediction. It is demonstrated that this model provides strong boosts when predicting molecular properties on few-shot data over simple graph-based methods.

The GNN-Transformer network learns deep representations from graph-level embeddings. First, a GNN module encodes the structural information of molecular graphs as a set of node and edge features. Node and edge embeddings are then converted into graph embedding representations by neighborhood aggregation. Then, a vision Transformer encoder exploits the contextual information of these vectorial embeddings to propagate deep representations across attention layers.


A two-module meta-learning framework was explored to optimize model parameters across tasks and quickly adapt to new molecular properties on few-shot data.


Extensive experiments on real multiproperty prediction data, Tox21 and SIDER, demonstrate the predictive power and stable performances of the proposed model when inferring task-specific molecular properties.

This repository provides the source code and datasets for the proposed work.

Article Link: https://doi.org/10.1016/j.eswa.2023.120005

Contact Information: (uc2015241578@student.uc.pt, luistorres@dei.uc.pt), if you have any questions about this work.

Data Availability and Pre-Processing

The Tox21 and SIDER datasets are downloaded from the repository Data (chem_dataset.zip) from Hu et al. (2020).

Raw data is pre-processed and SMILES strings are converted into molecular graphs using RDKit.Chem.

The implementation is based on Strategies for Pre-training Graph Neural Networks (Hu et al.) (2020).

Python Packages

We used the following Python packages for core development. We tested on Python 3.7.

- torch = 1.10.1
- torch-cluster = 1.5.9
- torch-geometric = 2.0.4
- torch-scatter = 2.0.9
- torch-sparse = 0.6.12
- torch-spline-conv = 1.2.1
- torchvision = 0.10.0
- vit-pytorch = 0.35.8
- scikit-learn = 1.0.2
- seaborn = 0.11.2
- scipy = 1.8.0
- numpy = 1.21.5
- tqdm = 4.50.0
- tensorflow = 2.8.0
- keras = 2.8.0
- tsnecuda = 3.0.1
- tqdm = 4.62.3
- matplotlib = 3.5.1
- pandas = 1.4.1
- networkx = 2.7.1
- rdkit


