The group E(3) is the group of 3 dimensional rotations, translations and mirror. This library aims to create E(3) equivariant convolutional neural networks.
The code is separated in two parts:
import torch
from e3nn.image.convolution import SE3Convolution
size = 32 # space size
scalar_field = torch.randn(1, 1, size, size, size) # [batch, _, x, y, z]
Rs_in = [(1, 0)] # 1 scalar field
Rs_out = [(1, 1)] # 1 vector field
conv = SE3Convolution(Rs_in, Rs_out, size=5)
# conv.weight.size() == [2] (2 radial degrees of freedom)
vector_field = conv(scalar_field) # [batch, vector component, x, y, z]
# vector_field.size() == [1, 3, 28, 28, 28]
from functools import partial
import torch
from e3nn.point.radial import CosineBasisModel
from e3nn.point.kernel import Kernel
from e3nn.point.operations import Convolution
from e3nn.util.plot import plot_sh_signal
import matplotlib.pyplot as plt
# Radial model: R -> R^d
# Projection on cos^2 basis functions followed by a fully connected network
RadialModel = partial(CosineBasisModel, max_radius=3.0, number_of_basis=3, h=100, L=1, act=torch.relu)
# kernel: composed on a radial part that contains the learned parameters
# and an angular part given by the spherical hamonics and the Clebsch-Gordan coefficients
K = partial(Kernel, RadialModel=RadialModel)
# Use the kernel to define a convolution operation
C = partial(Convolution, K)
Rs_in = [(1, 0)] # one scalar
Rs_out = [(1, l) for l in range(10)]
conv = C(Rs_in, Rs_out)
n = 3 # number of points
features = torch.ones(1, n, 1)
geometry = torch.randn(1, n, 3)
features = conv(features, geometry)
plt.figure(figsize=(4, 4))
plot_sh_signal(features[:, 0], n=50)
plt.gca().view_init(azim=0, elev=45)
e3nn
contains the librarye3nn/SO3.py
defines all the needed mathematical functionse3nn/image
contains the code specific to voxelse3nn/point
contains the code specific to pointse3nn/non_linearities
non linearities working for both point and voxel code
examples
simple scripts and experiments
- install pytorch
pip install git+https://github.com/AMLab-Amsterdam/lie_learn
pip install git+https://github.com/e3nn/e3nn
Install with
python setup.py install
@misc{mario_geiger_2019_3348277,
author = {Mario Geiger and
Tess Smidt and
Wouter Boomsma and
Maurice Weiler and
Michał Tyszkiewicz and
Jes Frellsen and
Benjamin K. Miller},
title = {mariogeiger/e3nn: Point cloud support},
month = jul,
year = 2019,
doi = {10.5281/zenodo.3348277},
url = {https://doi.org/10.5281/zenodo.3348277}
}