/torch-gauge

A light-weight PyTorch extension for equivariant deep learning

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

Torch-Gauge

A light-weight PyTorch extension for efficient equivariant deep learning.

About

Torch-Gauge is a library to boost geometric learning on physical data structures with Lie-group symmetry and beyond two-body interactions. The library is designed to be specifically optimized for training and inference on GPUs, with mini-batch training utilities natively supported.

Usage

Torch-Gauge uses padded Verlet list as the core for manipulating relational data, which enables compact and highly customizable implementation of geometric learning models.

As an illustration, SchNet 's interaction module can be concisely implemented:

import torch
from torch.nn import Linear, Parameter
from torch_gauge.verlet_list import VerletList
from torch_gauge.nn import SSP

class SchNetLayer(torch.nn.Module):
  def __init__(self, num_features):
    super().__init__()
    _nf = num_features
    self.gamma = 10.0
    self.rbf_centers = Parameter(torch.linspace(0.1, 30.1, 300), requires_grad=False)
    self.cfconv = torch.nn.Sequential(
        Linear(300, _nf, bias=False), SSP(), Linear(_nf, _nf, bias=False), SSP()
    )
    self.pre_conv = Linear(_nf, _nf)
    self.post_conv = torch.nn.Sequential(Linear(_nf, _nf), SSP(), Linear(_nf, _nf))

  def forward(self, vl: VerletList, l: int):
    xyz = vl.ndata["xyz"]
    pre_conv = self.pre_conv(vl.ndata[f"atomic_{l}"])
    d_ij = (vl.query_src(xyz) - xyz.unsqueeze(1)).norm(dim=2, keepdim=True)
    filters = self.cfconv(
      torch.exp(-self.gamma * (d_ij - self.rbf_centers.view(1, 1, -1)).pow(2))
    )
    conv_out = (filters * vl.query_src(pre_conv) * vl.edge_mask.unsqueeze(2)).sum(1)
    vl.ndata[f"atomic_{l+1}"] = vl.ndata[f"atomic_{l}"] + self.post_conv(conv_out)
    return vl.ndata[f"atomic_{l+1}"]

With the SphericalTensor support we can feasibly extend it to SE(3)-equivariant convolutions:

import torch
from torch_gauge.nn import SSP, IELin
from torch.nn import Linear, Parameter
from torch_gauge.verlet_list import VerletList
from torch_gauge.o3.spherical import SphericalTensor
from torch_gauge.o3.rsh import RSHxyz
from torch_gauge.geometric import poly_env
from torch_gauge.o3.clebsch_gordan import LeviCivitaCoupler

class SE3Layer(torch.nn.Module):
  def __init__(self, num_features):
    super().__init__()
    _nf = num_features
    self.rbf_freqs = Parameter(torch.arange(16), requires_grad=False)
    self.rsh_mod = RSHxyz(max_l=1)
    self.coupler = LeviCivitaCoupler(torch.LongTensor([_nf, _nf]))
    self.filter_gen = torch.nn.Sequential(
        Linear(16, _nf, bias=False), SSP(), Linear(_nf, _nf, bias=False)
    )
    self.pre_conv = IELin([_nf, _nf], [_nf, _nf])
    self.post_conv = torch.nn.ModuleList(
      [IELin([2*_nf, 3*_nf], [_nf, _nf]), SSP(), IELin([_nf, _nf], [_nf, _nf])]
    )

  def forward(self, vl: VerletList, l: int):
    r_ij = vl.query_src(vl.ndata["xyz"]) - vl.ndata["xyz"].unsqueeze(1)
    d_ij = r_ij.norm(dim=2, keepdim=True)
    r_ij = r_ij / d_ij
    feat_in: SphericalTensor = vl.ndata[f"atomic_{l}"]
    pre_conv = vl.query_src(self.pre_conv(feat_in))
    filters_radial = self.filter_gen(
      torch.sin(d_ij / 5 * self.rbf_freqs.view(1, 1, -1)) / d_ij * poly_env(d_ij/5)
    )
    filters = pre_conv.self_like(self.rsh_mod(r_ij).ten.unsqueeze(-1).mul(filters_radial).flatten(2, 3))
    coupling_out = self.post_conv[0](self.coupler(pre_conv, filters, overlap_out=False))
    conv_out = feat_in.self_like(coupling_out.ten.mul(vl.edge_mask.unsqueeze(2)).sum(1))
    conv_out = conv_out.scalar_mul(self.post_conv[1](conv_out.invariant()))
    conv_out.ten = self.post_conv[2](conv_out).ten
    vl.ndata[f"atomic_{l+1}"] = feat_in + conv_out
    return vl.ndata[f"atomic_{l+1}"]

Torch-Gauge's Verlet list also supports generating higher-order views (triplets, quartets, etc.) to streamline the implemention of models with specialized non-2body message passing algorithms.

Setups

From pip

pip install torch-gauge

From Source

Once PyTorch is installed, running

pip install versioneer && versioneer install
pip install -e .

Project Structure

  • torch_gauge High-level python functions
    • torch_gauge/o3 O(3) group algebra functionals and data structures
    • torch_gauge/nn.py Tensorial neural network building blocks
    • torch_gauge/verlet_list.py Verlet neighbor-list operations for representing relational data
    • torch_gauge/geometric.py contains geometric and algebra operations with autograd
    • torch_gauge/models contains exemplary implementations of GNN variants and descriptors
    • torch_gauge/tests contains pytests

Test

make test

Questions?

Please submit a question/issue or email me.