Prototype for a Category Theory-based GNN Library

Primary LanguagePythonMIT LicenseMIT

CatGNN 🐱


Prototype for a category theory-based GNN library. Implementation of Graph Neural Networks are Dynamic Programmers (Dudzik and Veličković, 2022) submitted as coursework for L45: Representation Learning on Graphs and Networks course at Cambridge.

The goal of CatGNN is to provide a generic GNN template using a new set of primitives coming from category theory and abstract algebra. Similarly to PyTorch Geomtric, the user only needs to provide implementations of the new primitives to implement any MPNN.


At some point CatGNN should become a python package installable through pip and conda. Until then, you can follow instructions here.

CatGNN was developed on Python 3.9.12, but should be fine on Python 3.7+.

To test on GPU, simply upload gpu_tests.ipynb notebook to Google Colab, upload source files to Google Drive, mount Google Drive to the notebook and follow the instructions.

To test locally on CPU, you can use environment.yml with conda as follows:

$ git clone https://github.com/KaroliShp/CatGNN.git
$ cd CatGNN
$ conda env create -f environment.yml --name catgnn
$ conda activate catgnn
$ pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+cpu.html
$ export PYTHONPATH="$PWD"

To test locally on GPU, change cpu to your CUDA version for relevant packages.

Implementation details

For an in-depth explanation of the library, refer to the mini-project report report.pdf. For benchmarking, see gpu_tests.ipynb notebook. You can test most of the code (as much as custom user layers are testable) using pytest. We generally use Python Black for code formatting. For development history, see Issues and Projects tabs.

Basic example

We can implement a basic Message-Passing GNN (MPNN) layer that applies a simple linear transformation to sender features and uses standard pullback & pushforward operators as follows:

import torch
import torch_scatter

from catgnn.integral_transform.mpnn_2 import BaseMPNNLayer_2

# Custom layer must extend one of the base classes (BaseMPNNLayer_2)
class BasicMPNNLayer(BaseMPNNLayer_2):
    def __init__(self, in_dim, out_dim):
        self.mlp = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, V, E, X):
        # Perform integral transform by passing V, E and X
        return self.transform_backwards(V, E, X, kernel_factor=False)

    def define_pullback(self, f):
        def pullback(E):
            # s*(e) = f(s(e)) : E -> V
            return f(self.s(E))

        return pullback

    def define_kernel(self, pullback):
        def kernel(E):
            # k(e) : E -> R
            return self.mlp(pullback(E))

        return kernel

    def define_pushforward(self, kernel):
        def pushforward(V):
            # t_*(v) : V -> N[R]
            E, bag_indices = self.t_1(V)
            return kernel(E), bag_indices

        return pushforward

    def define_aggregator(self, pushforward):
        def aggregator(V):
            # \oplus : V -> R
            edge_messages, bag_indices = pushforward(V)
            aggregated = torch_scatter.scatter_add(
                edge_messages.T, bag_indices.repeat(edge_messages.T.shape[0], 1)
            return aggregated[V]

        return aggregator

    def update(self, X, output):
        return output