Practical AI project- ConfGf-SEGNN_Tox21

The Tox21 project is aimed to investigate the potential for new technologies to accelerate testing for toxicity. Tox21 dataset [1] contains real in vivo data on over 12,000 compounds. The main goal of this Practical AI project is to evaluate the SEGNN model for predicting toxicity with 3d tox21 molecule data generated by the ConfGF model. Graph neural networks have become the state-of-the-art method in several different types of biological tasks, both ConfGF and SEGNN are based on GNN. Graph Neural Network's key idea is to generate node embeddings based on local network neighborhoods, and nodes aggregate information from their neighbors using neural networks.

2 Methods

2.1 ConfGF The fundamental problem in computational chemistry is molecular conformation generation, finding a set of geometrical positions for all the atoms, the bond lengths, and the bond angles that define the molecular structure. A particular molecule is associated with many such sets that are not equivalent in chemical and physical properties. The properties of a molecule depend on the geometrical structure of the molecule. The ConfGF model leverages Langevin dynamics to generate stable conformations through estimated gradient fields directly, it can generate conformations with a single stage and preserves the roto-translation equivariance of conformations.

2.2 3D Conformation Dataset Generation Workflow The ConfGF model receives a 2D graph as input and output 3d conformation data. In the beginning, Tox21 smiles in the SDF data were converted into randomized Smiles, and then they were transformed into 2D molecular graphs in the ConfGF solver, which applies Langevin dynamics to compute 3D positions, and generated a 28GB 3D tox21 dataset with the help of 4 GPUs consuming several days. The generated 3D dataset includes atom type, bond edge index, edge index, edge length, edge name, edge order, edge type if the edge is a bond, the number of positions generated, the number of positions trajectory, the position, the positions generated, the trajectory of the positions, number of edges generated, number of edges recovered, rdmol and Smiles.

2.3 TOX21 Dataset The TOX21 3D Dataset class was built under the Pytorch Geometric InMemoryDataset. There are around 53 types of atoms found in the dataset, each atom type is assigned to a corresponding type index, for example, H as 0, C as 1, N as 2, etc. type index, atomic number, aromatic info, and sp, sp2, sp3 hybridization info are incorporated into features x. the outputs includes features x, position, edge index, edge attribute, node attribute, edge distance, target or label, etc. Edge & node attribute and edge distance are computed from edge index, position, and attribute irreps with spherical harmonics from e3nn. Inside the TOX21 Dataset class, the dataset is partitioned into the train, test, validation, and cross-valid train and cross-valid test according to the annotation in the tox21 CSV file.

2.4 SEGNN SEGNN is a steerable equivariance graph neural network, which is based on non-linear group convolutions. The basic idea is to use the steerable vectors and their equivariant transformations to represent and process node features. The Clebsch-Gordan tensor product is used to steer the update and message functions by geometric information such as 3D position in the steerable message passing process [4]. Spherical Harmonics and Steerable Vectors Spherical harmonics are constrained by two numbers, one is - azimuthal quantum number and the other is - magnetic quantum number, and they should satisfy 𝓁 𝓂 . The combination of these numbers can produce a variety of steerable vector spaces. represents a specific set of spherical harmonics. denotes Wigner-D matrices, the group can act on any dimensional vector space.

Message Passing Framework The workhorse of SEGNN message passing is the Clebsch-Gordan tensor product, used to steer the update and message functions by geometric information such as position[4]. It is a bilinear operator that combines two steerable vectors of type. Let with steerable input vectors of type and and returns another denote a steerable vector type and and its components. The CG tensor product is given by There are two message passing layers and two update layers in the SEGNN message passing layer, the first message passing layer computes the Clebsch-Gordan tensor product of input irreps–input dimension, hidden irreps–output dimension, and node attribute irreps–as the second input irreps, and normalizes them with Swish gate–SiLU and/or sigmoid activation. The second message passing layer does the same thing to hidden irreps–as input&output irreps and edge attribute irreps– as the second input irreps. The first update layer update, which computes CG tensor product with Swish gate, the corresponding message passing input irreps, hidden irreps, node attribute irreps, and edge attribute irreps. The second update layer computes the CG tensor product of those irreps without activation.

3 Experiments Models are from the SEGNN library, it contains two kinds of SEGNN models for choice, one is with a fixed embedding layer(O3 tensor product), and another one is with two fixed embedding layers(O3 tensor product and swish gate). They have customizable message-passing layers, with the prediction on the whole graph or node, in this experiment, we use graph prediction. Graph prediction includes two pre-pooling layers and two post-pooling layers. Input irreps 59, output irreps 12 corresponding 12 toxicity classes. Training The training was divided into 5 folds cross-validation and final training, validation, and test. The criterion used during training is Binary Cross Entropy with Logit Loss and accompanied by Adam optimizer and MultistepLR scheduler. Graph data containing x features, position, edge index, etc. is fed into the SEGNN model, and the output of the model and processed target are masked according to the tox21 dataset missing values that are masked out during this process. The loss was computed from those masked outputs and targets and the model was trained from 20 to 60 epochs. Hyperparameter Search The hyperparameters include the degree of spherical harmonics, the of attributes which determines edge attribute irreps, and node attribute irreps. The of hidden irreps together with node attribute irreps and hidden features irreps determine hidden irreps. the best result comes from a degree of 3 of attributes, a degree of 2 of the of hidden irreps, and hidden features irreps together with 5 message passing layers, weight decay of 1e-8, and learning rate 1e-4. 4 Results The metric is AUC ROC, The average ROC on the test set is 0.71 with a new model, and the best ROC of 0.74 was achieved in cross-validation for both old and new models. – New Model: The model with two embedding layers, the first embedding layer is the O3 Tensor Product–Clebsh-Gorden Tensor product with a Swish Gate– SiLU or sigmoid activation, and the second embedding layer is the O3 Tensor Product layer. – Old Model: The model with one Clebsh-Gorden embedding layer.