/torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch.

Primary LanguagePythonMIT LicenseMIT

Build Status DOI Anaconda-Server Badge

Torch'em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch. Any feedback is highly appreciated, just open an issue!

Highlights:

  • Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
  • Differentiable augmentations on GPU and CPU thanks to kornia.
  • Off-the-shelf logging with tensorboard or wandb.
  • Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.

Design:

  • All parameters are specified in code, no configuration files.
  • No callback logic; to extend the core functionality inherit from trainer.DefaultTrainer instead.
  • All data-loading is lazy to support training on large data-sets.
# train a 2d U-Net for foreground and boundary segmentation of nuclei
# using data from https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# transform to go from instance segmentation labels
# to foreground/background and boundary channel
label_transform = torch_em.transform.BoundaryTransform(
    add_binary_target=True, ndim=2
)

# training and validation data loader
data_path = "./dsb"  # the training data will be downloaded and saved here
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="train",
    download=True,
    label_transform=label_transform
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform
)

# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
    device=torch.device("cuda")
)
trainer.fit(iterations=5000)

# export bioimage.io model format
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# load one of the images to use as reference image image
# and crop it to a shape that is guaranteed to fit the network
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)

For a more in-depth example, check out one of the example notebooks:

Installation

From conda

You can install torch_em from conda-forge:

conda install -c conda-forge torch_em

Please check out pytorch.org for more information on how to install a pytorch version compatible with your system.

From source

It's recommmended to set up a conda environment for using torch_em. Two conda environment files are provided: environment_cpu.yaml for a pure cpu set-up and environment_gpu.yaml for a gpu set-up. If you want to use the gpu version, make sure to set the correct cuda version for your system in the environment file, by modifiying this-line.

You can set up a conda environment using one of these files like this:

conda env create -f <ENV>.yaml -n <ENV_NAME>
conda activate <ENV_NAME>
pip install -e .

where .yaml is either environment_cpu.yaml or environment_gpu.yaml.

Features

  • Training of 2d U-Nets and 3d U-Nets for various segmentation tasks.
  • Random forest based domain adaptation from Shallow2Deep
  • Training models for embedding prediction with sparse instance labels from SPOCO

Command Line Scripts

A command line interface for training, prediction and conversion to the bioimage.io modelzoo format wll be installed with torch_em:

  • torch_em.train_unet_2d: train a 2D U-Net.
  • torch_em.train_unet_3d: train a 3D U-Net.
  • torch_em.predict: run prediction with a trained model.
  • torch_em.predict_with_tiling: run prediction with tiling.
  • torch_em.export_bioimageio_model: export a model to the modelzoo format.

For more details run <COMMAND> -h for any of these commands. The folder scripts/cli contains some examples for how to use the CLI.

Note: this functionality was recently added and is not fully tested.