The aim is to train an equivariant neural network to segment the cerebellum in a brain. It has to distinguish between the left and right cerebellum. To do so we output an odd scalar value for each voxel. Zero for the background, one for the left and minus one for the right.
Train prediction:
Test prediction:
We took two brains from the Mindboggle dataset.
The files data/x1.nii.gz
and data/x2.nii.gz
contain the MRI data of two brains.
The files data/y1.nii.gz
and data/y2.nii.gz
contain the labels of the two brains.
We use the data with index 1 for training and index 2 for testing.
This project is based on e3nn-jax.
To install the dependencies:
pip install --upgrade pip
pip install --upgrade nibabel
pip install --upgrade "jax[cpu]" # change this to get cuda support!
pip install --upgrade dm-haiku
pip install --upgrade optax
pip install e3nn-jax==0.4.2 # last version tested
Make sure you execute the code on a computer with a GPU otherwise it will not even compile the code
# wandb login # optional
python unet_odd.py
Prediction of the cerebellum on a test brain (data/x2.nii.gz
) made by an O(3)-equivariant network trained during 2000 steps (8 hours on a Tesla V100 PCIe 32GB) on a single brain (data/x1.nii.gz
).
Using group convolution makes it 3x faster (on V100 gpu)
We can also see that group conv
model spend proportionally more time on non conv op