escnn is a PyTorch extension for equivariant deep learning. escnn is the successor of the e2cnn library, which only supported planar isometries. Instead, escnn supports steerable CNNs equivariant to both 2D and 3D isometries, as well as equivariant MLPs.
Documentation | Paper ICLR 22 | MSc Thesis Gabriele | e2cnn library |
Paper NeurIPS 19 | PhD Thesis Maurice | e2cnn experiments |
If you prefer using Jax, check our this fork escnn_jax of our library!
Equivariant neural networks guarantee a specified transformation behavior of their feature spaces under transformations of their input.
For instance, classical convolutional neural networks (CNNs) are by design equivariant to translations of their input.
This means that a translation of an image leads to a corresponding translation of the network's feature maps.
This package provides implementations of neural network modules which are equivariant under all isometries
The feature spaces of
Instead of a number of channels, the user has to specify the field types and their multiplicities in order to define a feature space.
Given a specified input- and output feature space, our R2conv
and R3conv
modules instantiate the most general convolutional mapping between them.
Our library provides many other equivariant operations to process feature fields, including nonlinearities, mappings to produce invariant features, batch normalization and dropout.
In theory, feature fields are defined on continuous space GeometricTensor
objects, which wrap a torch.Tensor
with the corresponding transformation law.
All equivariant operations perform a dynamic type-checking in order to guarantee a geometrically sound processing of the feature fields.
To parameterize steerable kernel spaces, equivariant to an arbitrary compact group
- Group Equivariant Convolutional Networks
- Harmonic Networks: Deep Translation and Rotation Equivariance
- Steerable CNNs
- Rotation equivariant vector field networks
- Learning Steerable Filters for Rotation Equivariant CNNs
- HexaConv
- Roto-Translation Covariant Convolutional Networks for Medical Image Analysis
- 3D Steerable CNNs
- Tensor Field Networks
- Cormorant: Covariant Molecular Neural Networks
- 3D GCNNs for Pulmonary Nodule Detection
For more details, we refer to our ICLR 2022 paper A Program to Build E(N)-Equivariant Steerable CNNs and our NeurIPS 2019 paper General E(2)-Equivariant Steerable CNNs.
The library is structured into four subpackages with different high-level features:
Component | Description |
---|---|
escnn.group | implements basic concepts of group and representation theory |
escnn.kernels | solves for spaces of equivariant convolution kernels |
escnn.gspaces | defines the Euclidean spaces and their symmetries |
escnn.nn | contains equivariant modules to build deep neural networks |
WARNING: escnn.kernels received major refactoring in version 1.0.0 and it is not compatible with previous versions of the library. These changes do not affect the interface provided in the rest of the library but, sometimes, the weights of a network trained with a previous version might not load correctly in a newly instantiated model. We recommend using version v0.1.9 for backward compatibility.
Since
The invariance of the features in the comoving frame validates the rotational equivariance of
For comparison, we show a feature map response of a conventional CNN for different image orientations below.
Since conventional CNNs are not equivariant under rotations, the response varies randomly with the image orientation. This prevents CNNs from automatically generalizing learned patterns between different reference frames.
model | Rotated ModelNet10 |
---|---|
CNN baseline | 82.5 ± 1.4 |
SO(2)-CNN | 86.9 ± 1.9 |
Octa-CNN | 89.7 ± 0.6 |
Ico-CNN | 90.0 ± 0.6 |
SO(3)-CNN | 89.5 ± 1.0 |
All models share approximately the same architecture and width. For more details we refer to our paper.
This library supports
model | CIFAR-10 | CIFAR-100 | STL-10 |
---|---|---|---|
CNN baseline | 2.6 ± 0.1 | 17.1 ± 0.3 | 12.74 ± 0.23 |
E(2)-CNN * | 2.39 ± 0.11 | 15.55 ± 0.13 | 10.57 ± 0.70 |
E(2)-CNN | 2.05 ± 0.03 | 14.30 ± 0.09 | 9.80 ± 0.40 |
While using the same training setup (no further hyperparameter tuning) used for the CNN baselines, the equivariant models achieve significantly better results (values are test errors in percent). For a fair comparison, the models without * are designed such that the number of parameters of the baseline is approximately preserved while models with * preserve the number of channels, and hence compute. For more details we refer to our previous e2cnn paper.
escnn is easy to use since it provides a high level user interface which abstracts most intricacies of group and representation theory away. The following code snippet shows how to perform an equivariant convolution from an RGB-image to 10 regular feature fields (corresponding to a group convolution).
from escnn import gspaces # 1
from escnn import nn # 2
import torch # 3
# 4
r2_act = gspaces.rot2dOnR2(N=8) # 5
feat_type_in = nn.FieldType(r2_act, 3*[r2_act.trivial_repr]) # 6
feat_type_out = nn.FieldType(r2_act, 10*[r2_act.regular_repr]) # 7
# 8
conv = nn.R2Conv(feat_type_in, feat_type_out, kernel_size=5) # 9
relu = nn.ReLU(feat_type_out) # 10
# 11
x = torch.randn(16, 3, 32, 32) # 12
x = feat_type_in(x) # 13
# 14
y = relu(conv(x)) # 15
Line 5 specifies the symmetry group action on the image plane
Lines 12 and 13 generate a random minibatch of RGB images and wrap them into a nn.GeometricTensor
to associate them
with their correct field type feat_type_in
.
The equivariant modules process the geometric tensor in line 15.
Each module is thereby checking whether the geometric tensor passed to them satisfies the expected transformation law.
Because the parameters do not need to be updated anymore at test time, after training, any equivariant network can be
converted into a pure PyTorch model with no additional computational overhead in comparison to conventional CNNs.
The code currently supports the automatic conversion of a few commonly used modules through the .export()
method;
check the documentation for more details.
To get started, we provide some examples and tutorials:
- The introductory tutorial introduces the basic functionality of the library.
- A second tutorial goes through building and training an equivariant model on the rotated MNIST dataset.
- Note that escnn also supports equivariant MLPs; see these examples.
- Check also the tutorial on Steerable CNNs using our library in the Deep Learning 2 course at the University of Amsterdam.
More complex 2D equivariant Wide Resnet models are implemented in e2wrn.py. To try a model which is equivariant under reflections call:
cd examples
python e2wrn.py
A version of the same model which is simultaneously equivariant under reflections and rotations of angles multiple of 90 degrees can be run via:
python e2wrn.py --rot90
You can find more examples in the example folder. For instance, se3_3Dcnn.py implements a 3D CNN equivariant to rotations and translations in 3D. You can try it with
cd examples
python se3_3Dcnn.py
If you want to better understand the theory behind equivariant and steerable neural networks, you can check these references:
- Erik Bekkers' lectures on Geometric Deep Learning at in the Deep Learning 2 course at the University of Amsterdam
- The course material also includes a tutorial on group convolution and another about Steerable CNNs, using this library.
- Gabriele's MSc thesis provides a brief overview of the essential mathematical ingredients needed to understand Steerable CNNs.
- Maurice's PhD thesis develops the representation theory of steerable CNNs, deriving the most prominent layers and explaining the gauge theoretic viewpoint.
The library is based on Python3.7
torch>=1.3
numpy
scipy
lie_learn
joblib
py3nj
Optional:
torch-geometric
pymanopt>=1.0.0
autograd
WARNING:
py3nj
enables a fast computation of Clebsh Gordan coefficients. If this package is not installed, our library relies on a numerical method to estimate them. This numerical method is not guaranteed to return the same coefficients computed bypy3nj
(they can differ by a sign). For this reason, models built with and withoutpy3nj
might not be compatible.
To successfully install
py3nj
you may need a Fortran compiler installed in you environment.
You can install the latest release as
pip install escnn
or you can clone this repository and manually install it with
pip install git+https://github.com/QUVA-Lab/escnn
Would you like to contribute to escnn? That's great!
Then, check the instructions in CONTRIBUTING.md and help us to improve the library!
Do you have any doubts? Do you have some idea you would like to discuss? Feel free to open a new thread under in Discussions!
The development of this library was part of the work done for our papers A Program to Build E(N)-Equivariant Steerable CNNs and General E(2)-Equivariant Steerable CNNs. Please cite these works if you use our code:
@inproceedings{cesa2022a,
title={A Program to Build {E(N)}-Equivariant Steerable {CNN}s },
author={Gabriele Cesa and Leon Lang and Maurice Weiler},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=WE4qe9xlnQw}
}
@inproceedings{e2cnn,
title={{General E(2)-Equivariant Steerable CNNs}},
author={Weiler, Maurice and Cesa, Gabriele},
booktitle={Conference on Neural Information Processing Systems (NeurIPS)},
year={2019},
url={https://arxiv.org/abs/1911.08251}
}
Feel free to contact us.
escnn is distributed under BSD Clear license. See LICENSE file.