Equivariant Mesh Attention Networks
This repository contains the code to reproduce the experiments of Equivariant Mesh Attention Networks published in Transactions on Machine Learning Research (TMLR - 08/2022).
Running experiments
The instructions provided below assume that the python
command is triggered from ./eman
:
FAUST experiments
python experiments/faust_direct.py --model RelTanEMAN --seed 1 --epochs 1 -equiv_bias
TOSCA experiments
python experiments/tosca_direct.py --model RelTanEMAN --seed 1 --epochs 1 -equiv_bias -null_isolated
Installation instructions
Follow the commands below to create a new conda environment and install all dependencies:
conda create --name eman python=3.7
conda activate eman
# GPU installation
# conda install pytorch=1.11 cudatoolkit=11.3 -c pytorch
# CPU installation
# conda install pytorch=1.11 cpuonly -c pytorch
conda install pyg=2.0.3 -c pyg
pip install wandb pytorch-ignite openmesh opt_einsum trimesh
Project structure
eman
│ README.md
│ LICENSE
│
└───data
│ │ FAUST/raw/MPI-FAUST.zip # Download from http://faust.is.tue.mpg.de/
│ │ TOSCA # Automatically downloaded on first experiment
|
└───eman # Implementation of Equivariant Mesh Attention Networks
│ └───nn
│ └───tests
│ └───transform
│ └───utils
|
└───experiments
| | faust_direct.py
| | tosca_direct.py
| | paths.json # Specify dataset locations (default: "./eman/data")
| | ...
|
└───gem_cnn # Implementation of Gauge Equivariant CNNs
│ └───nn
│ └───tests
│ └───transform
│ └───utils
│
└───spiralnet # Implementation of SpiralNet++
| | spiralconv.py
│ └───spiralnet.utils
Citation
Please use the following snippet to cite this work:
@article{basu2022equivariant,
title={{Equivariant Mesh Attention Networks}},
author={Basu, Sourya and Gallego-Posada, Jose and Vigan\`o, Francesco and Rowbottom, James and Cohen, Taco},
year={2022},
month={08},
journal={Transactions on Machine Learning Research}
}
Acknowledgements
- The code for Gauge Equivariant Mesh CNNs is taken from the official GEM-CNN implementation.
- The code for SpiralNet++ comparison is taken from the official SpiralNet++ implementation.