/tglite

Temporal GNN Lightweight Framework

Primary LanguagePythonApache License 2.0Apache-2.0

TGLite - A Framework for Temporal GNNs

TGLite is a lightweight framework that provides core abstractions and building blocks for practitioners and researchers to implement efficient TGNN models. TGNNs, or Temporal Graph Neural Networks, learn node embeddings for graphs that dynamically change over time by jointly aggregating structural and temporal information from neighboring nodes. TGLite employs an abstraction called a TBlock to represent the temporal graph dependencies when aggregating from neighbors, with explicit support for capturing temporal details like edge timestamps, as well as composable operators and optimizations. Compared to prior art, TGLite can outperform the TGL framework by up to 3x in terms of training time.

End-to-end training epoch time comparison on an Nvidia A100 GPU.

Installation

See our documentation for instructions on how to install the TGLite binaries, as well as examples and references for supported functionality. To install from source or for local development, go to the Building from source session, it also explains how to run examples.

Getting Started

TGLite is currently designed to be used with PyTorch as a training backend, typically with GPU devices. A TGNN model can be defined and trained in the usual way using PyTorch, with the computations constructed using a mix of PyTorch functions and operators/optimizations from TGLite. Below is a simple example (not a real network architecture, just for demonstration purposes):

import torch
import tglite as tg

class TGNN(torch.nn.Module):
    def __init__(self, ctx: tg.TContext, dim_node=100, dim_time=100):
        super().__init__()
        self.ctx = ctx
        self.linear = torch.nn.Linear(dim_node + dim_time, dim_node)
        self.sampler = tg.TSampler(num_nbrs=10, strategy='recent')
        self.encoder = tg.nn.TimeEncode(dim_time)

    def forward(self, batch: tg.TBatch):
        blk = batch.block(self.ctx)
        blk = tg.op.dedup(blk)
        blk = self.sampler.sample(blk)
        blk.srcdata['h'] = blk.srcfeat()
        return tg.op.aggregate(blk, self.compute, key='h')

    def compute(self, blk: tg.TBlock):
        feats = self.encoder(blk.time_deltas())
        feats = torch.cat([blk.srcdata['h'], feats], dim=1)
        embeds = self.linear(feats)
        embeds = tg.op.edge_reduce(blk, embeds, op='sum')
        return torch.relu(embeds)

graph = tg.from_csv(...)
ctx = tg.TContext(graph)
model = TGNN(ctx)
train(model)

The example model is defined to first construct the graph dependencies for nodes in the current batch of edges. The dedup() optimization is applied before sampling for 10 recent neighbors. Node embeddings are computed by simply combining node and time features, applying a linear layer and summing across neighbors. More complex computations and aggregations, such as temporal self-attention often used with TGNNs, can be defined using the provided building blocks.

Publication

If you find TGLite useful, please consider attributing to the following citation:

@inproceedings{wang2024tglite,
  author = {Wang, Yufeng and Mendis, Charith},
  title = {TGLite: A Lightweight Programming Framework for Continuous-Time Temporal Graph Neural Networks},
  year = {2024},
  booktitle = {Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2},
  doi = {10.1145/3620665.3640414}
}