/cell-tracker-gnn

[ECCV 2022] Official PyTorch implementation of the paper - Graph Neural Network for Cell Tracking in Microscopy Videos

Primary LanguagePythonOtherNOASSERTION

Graph Neural Network for Cell Tracking in Microscopy Videos (ECCV 2022)

Arxiv

Official Implementation: Graph Neural Network for Cell Tracking in Microscopy Videos

model

bmp2_control_concat1



Preliminaries

Our implementation integrates PyTorch Lightning PyG Config: Hydra libraries:

PyTorch Lightning is a lightweight PyTorch wrapper for high-performance AI research.

PyG (PyTorch Geometric) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

Hydra is an open-source Python framework that simplifies the development of research and other complex applications.

If you are not familiar with PyTorch, PyTorch Lightning, PyG and Hydra. We highly recommend to read about them before starting.

We use older version of the publicly available deep learning template provided in Template

Set up conda virtual environment

Install dependencies on linux enviroment (click to expand): we provide conda envrioment setup dependencies - if you are not familiar with conda, please read about before starting
# Enter to the code folder
cd cell-tracker-gnn

# create conda environment python=3.8 pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 faiss-gpu pytorch-lightning==1.4.9
conda create --name cell-tracking-challenge --file requirements-conda.txt
conda activate cell-tracking-challenge

# install other requirements
pip install -r requirements.txt

Structure

The directory structure of our implementation looks like (click to expand):
├── configs                 <- Hydra configuration files
│   ├── callbacks               <- Callbacks configs
│   ├── datamodule              <- Datamodule configs
│   ├── feat_extract            <- Feature extraction configs
│   ├── logger                  <- Logger configs
│   ├── metric_learning         <- Metric learning configs
│   ├── model                   <- Model configs
│   ├── trainer                 <- Trainer configs
│   │
│   ├── config.yaml             <- Main project configuration file
│
├── data                    <- Project data
│
├── logs                    <- Logs generated by Hydra and PyTorch Lightning loggers
│
├── outputs                 <- Outputs generated by Hydra and tensorboard loggers when training deep metric learning model
│
│
├── src
│   ├── callbacks                   <- Lightning callbacks
│   ├── datamodules                 <- Lightning datamodules and dataset files used
│   │   ├── datasets                             <- Graph Dataset implementation
│   │   │   └── graph_dataset.py                          <- Graph Dataset implementation
│   │   ├── extract_features                     <- Extract  features used for graph
│   │   │   ├── preprocess_seq2graph_2d.py                <- Extract  features for 2d dataset with full segmentation
│   │   │   ├── preprocess_seq2graph_3D.py                <- Extract  features for 3d dataset
│   │   │   └── preprocess_seq2graph_patch_based.py       <- Extract  features for 2d dataset with markers annotations
│   │   ├── celltrack_datamodule.py              <- Lightning datamodules implementing split for train, valid and test using separate sequences for each
│   │   └── celltrack_datamodule_mulSeq.py       <- Lightning datamodules implementing split for train, valid and test using combine sequences for each
│   │
│   ├── metrics                     <- Lightning metrics use to track performances
│   ├── models                      <- Lightning model + PyTorch models +  PyTorch Geometric model
│   │   ├── modules                             <- models implementation
│   │   │   ├── celltrack_model.py                        <- complete model implementation
│   │   │   ├── edge_mpnn.py                              <- Edge-oriented message passing implementation
│   │   │   ├── mlp.py                                    <- multilayer perceptron implementation
│   │   │   └── pdn_conv.py                               <- PDN-Conv implementation
│   │   └── celltrack_plmodel.py                <- Lightning model implementing training routine
│   ├── utils                   <- Utility scripts
│   │   └── utils.py                            <- Utils features
│   │
│   └── train.py                                <- Training pipeline
│
├── src_metric_learning
│   ├── Data               <- Data modules - datasets and sampler
│   │   ├── dataset_2D.py       <- Implemetation of 2D dataset
│   │   ├── dataset_3D.py       <- Implemetation of 3D dataset
│   │   └── sampler.py          <- Implemetation of sampler used for batch construction
│   ├── modules                 <- Pytorch models
│   │   ├── resnet_2d           <- Implemetation of ResNet for 2D dataset
│   │   │   ├── resnet.py             <- Final models
│   │   │   └── utils_resnet.py       <- Multiple ResNet blocks and models Implemetation
│   │   ├── resnet_3d           <- Implemetation of ResNet for 3D dataset
│   │   │   ├── resnet.py             <- Final models
│   └── └── └── utils_resnet.py       <- Multiple ResNet blocks and models Implemetation
│
├── LICENSE                 <- Attribution-NonCommercial 4.0 International
├── README.md               <- All information
│
├── requirements.txt              <- File for installing python dependencies (specification of dependencies)
├── requirements-conda.txt        <- File for conda environment creation (specification of dependencies)
│
├── run.py                             <- Run training of the complete model with any pipeline configuration of 'configs/config.yaml'
├── run_feat_extract.py                <- Run feature extraction pipeline 'configs/feat_extract/feat_extract.yaml' configuration file
└── run_train_metric_learning.py       <- Run training of any settings using 'configs/metric_learning/...' configuration files

Data

Data Structure

Recommended Data directory should look like (click to expand):

├── data                    <- Project data
│   ├── CTC                 <- Cell tracking challenge data
│   │   ├── Training                             <- Training Split
│   │   │   ├── Fluo-N2DH-SIM+                        <- Fluo-N2DH-SIM+ Dataset
│   │   │   │   ├── 01                                    <- Seuqence 01
│   │   │   │   ├── 01_GT                                 <- Seuqence 01 GT
│   │   │   │   │   ├── TRA                                   <- Tracking GT
│   │   │   │   │   └── SEG                                   <- Tracking SEG (Not used)
│   │   │   │   ├── 02                                    <- Seuqence 02
│   │   │   │   ├── 02_GT                                 <- Seuqence 02 GT
│   │   │   │   │   ├── TRA                                   <- Tracking GT
│   │   │   │   │   └── SEG                                   <- Tracking SEG (Not used)
│   │   │   ├── PhC-C2DH-U373                             <- PhC-C2DH-U373 Dataset
│   │   │   │   ├── 01                                    <- Seuqence 01
│   │   │   │   ├── 01_GT                                 <- Seuqence 01 GT
│   │   │   │   │   ├── TRA                                   <- Tracking GT
│   │   │   │   │   └── SEG                                   <- Tracking SEG (Not used)
│   │   │   │   ├── 01_ST                                 <- Seuqence 01 Silver GT
│   │   │   │   └── └── SEG                                   <- Tracking SEG
│   │   │   .
│   │   │   .
│   │   │   .
│   │   ├── Test                             <- Graph Dataset implementation
│   │   │   ├── Fluo-N2DH-SIM+                        <- Fluo-N2DH-SIM+ Dataset
│   │   │   │   ├── 01                                    <- Seuqence 01
│   │   │   │   ├── 02                                    <- Seuqence 02
│   │   │   ├── PhC-C2DH-U373                             <- PhC-C2DH-U373 Dataset
│   │   │   │   ├── 01                                    <- Seuqence 01
│   │   │   │   ├── 02                                    <- Seuqence 02
│   │   │   .
│   │   │   .
│   │   │   .

Download Datasets

Training code

Overview

Our code consists of 3 run files located on the 'home' directory of the project -run.py, run_feat_extract.py, and run_train_metric_learning.py- dividing our project into 3 parts namely 'complete model', 'feature extraction', and 'metric learning', respectively. An overview of each is provided in the next few sentences:

  • Metric Learning: is responsible for training a model for extracting features using the Pytorch Metric Learning library and building using a separate source code.(see src_metric_learning in #Project Structure). Before running this part, we should generate CSV files consisting of relevant information about the cells, used by the datasets in metric learning training.
  • Feature Extraction: After training a discriminative model to extract features, we are extracting features used later to build our graphs.
  • Complete Model: When all the required data is ready, we can use it to train a graph neural network model as presented in the main paper.

Command lines Summary

We summarize all the relevant command lines to produce a run, an explanation for each variable is provided in Training code Section below.

 export CUDA_VISIBLE_DEVICES=0 # select GPU number

# run feat_extract for metric learning -
# please ensure that your target is correct in the 'configs/feat_extract/feat_extract.yaml' file.
python run_feat_extract.py params.input_images=<image_dir> params.input_masks=<masks_dir> params.input_seg=<masks_dir> params.output_csv=<save_output> params.basic=True params.sequences=[<str_sequences>, <str_sequences>, ...] params.seg_dir=<seg_dir_template>

# run metric learning training -
python run_train_metric_learning.py dataset.kwargs.data_dir_img=<image_directory> dataset.kwargs.data_dir_mask=<data_dir_mask> dataset.kwargs.dir_csv=<dir_csv>
# output 'all_params.pth' is generated at end, it is the input_model for the next comand line

# run feat_extract for cell tracking training -
python run_feat_extract.py params.input_images=<image_dir> params.input_masks=<masks_dir> params.input_seg=<masks_dir> params.input_model=<path_to_all_params_produced_in_metric_learning> params.output_csv=<save_output> params.basic=False params.sequences=[<str_sequences>,<str_sequences>,...] params.seg_dir=<seg_dir_template>

# cell tracking training run
python run.py datamodule.dataset_params.main_path=<csv_home_directory> datamodule.dataset_params.exp_name="<name>_<2D/3D>"

For example, if your data structure is organized as recommended, you can run training for Fluo-N2DH-SIM+ dataset with the following command lines:

 export CUDA_VISIBLE_DEVICES=0 # select GPU number

# run feat_extract for metric learning -
python run_feat_extract.py params.input_images="./data/CTC/Training/Fluo-N2DH-SIM+" params.input_masks="./data/CTC/Training/Fluo-N2DH-SIM+" params.input_seg="./data/CTC/Training/Fluo-N2DH-SIM+" params.output_csv="./data/basic_features/" params.sequences=['01','02']  params.seg_dir='_GT/TRA' params.basic=True

# run metric learning training -
python run_train_metric_learning.py dataset.kwargs.data_dir_img="./data/CTC/Training/Fluo-N2DH-SIM+" dataset.kwargs.data_dir_mask="./data/CTC/Training/Fluo-N2DH-SIM+" dataset.kwargs.dir_csv="./data/basic_features/Fluo-N2DH-SIM+" dataset.kwargs.subdir_mask='GT/TRA'
# output 'all_params.pth' is generated at end, it is the input_model for the next comand line

# run feat_extract for cell tracking training -
python run_feat_extract.py params.input_images="./data/CTC/Training/Fluo-N2DH-SIM+" params.input_masks="./data/CTC/Training/Fluo-N2DH-SIM+" params.input_seg="./data/CTC/Training/Fluo-N2DH-SIM+" params.output_csv="./data/ct_features/" params.sequences=['01','02']  params.seg_dir='_GT/TRA' params.basic=False params.input_model=<path_to_all_params_produced_in_metric_learning>

# cell tracking training run
python run.py datamodule.dataset_params.main_path="./data/ct_features/Fluo-N2DH-SIM+" datamodule.dataset_params.exp_name="2D_SIM" datamodule.dataset_params.drop_feat=[]

Dive Into Details

We provide details on how to run any aspect of our code, from metric learning to our full model performing cell tracking, and extracting features in between.

Run Metric Learning

  1. Before running training, we should generate CSV files consisting of relevant information about the cells, we do so using run_feat_extract.py file and the corresponding configuration file located in configs/feat_extract/feat_extract.yaml:
defaults:
    - params: params.yaml # do not change
_target_: src.datamodules.extract_features.<choose_seq2graph_file> # options - preprocess_seq2graph_2d/preprocess_seq2graph_3D/preprocess_seq2graph_patch_based

Where the params configs/feat_extract/params/params.yaml configuration:

input_images: #Please/insert/path/to/image_frames
input_masks: #Please/insert/path/to/image_masks/corresponds/image_frames
input_seg: #Please/insert/path/to/segmentation_mask/corresponds/image_frames
input_model: #Please/insert/path/of/metric_learning/feature_extractor_model
output_csv: #Please/insert/path/to/save/features
basic:  # !! Most important now- should be True !! -options True/False
sequences: # example: ['01', '02']
seg_dir: <choose_seg_dir_template> # options '_GT/SEG'/'_ST/SEG'

An explanation of each variable is detailed in the comments.

In this stage, the 'basic' parameter is the most important one- should be set to True, indicating for basic features used for metric learning. The 'seg_dir' variable is used since in the cell tracking challenge (CTC) the partitions to folders are made in a fixed template. for example for sequence 1 - '01' folder is for images, '01_GT/TRA' folder is for markers annotation, '01_GT/SEG' folder is for segmentation annotation. In the case of silver ground truth segmentation, the folder is '01_ST/SEG'. We are following this assumption to all datasets, even datasets that are not in the CTC.

After setting all paths, we can run run_feat_extract.py to extract features and the CSV files will be saved to the folder of 'output_csv' (Please pay attention to the log provided which indicates the place that the files saved).

Now, you are familiar with all the relevant variables, we are providing a command line to produce the corresponding run with an override of the discussed variables.

# run feat_extract for metric learning -
# please ensure that your target is correct in the 'configs/feat_extract/feat_extract.yaml' file.
python run_feat_extract.py params.input_images=<image_dir> params.input_masks=<masks_dir> params.input_seg=<masks_dir> params.output_csv=<save_output> params.basic=True params.sequences=[<str_sequences>, <str_sequences>, ...] params.seg_dir=<seg_dir_template>
  1. After generating the required CSVs, the next step is to train a discriminative model to extract features using run_train_metric_learning.py and the corresponding configuration files in configs/metric_learning:
  • config_2D.yaml: hyperparameters for training. Here we also set the 'exp_name' indicates for the folder name to save the outputs, and we can also choose between two optional settings for 2D datasets- those with a marker (dataset/dataset_2D_patch_based.yaml) and those with segmentations(dataset/dataset_2D.yaml).

  • config_3D.yaml: hyperparameters for training. Here we also set the 'exp_name' indicates for the folder name to save the outputs, it works with the segmentation setting for 3D datasets(dataset/dataset_3D.yaml).

  • The default configuration is 2D datasets in run_train_metric_learning.py. To change it, you should change the 'config_name' to config_3D.yaml in the following line included as part of run_train_metric_learning.py: @hydra.main(config_path="configs/metric_learning/", config_name="config_2D.yaml").

  • dataset_**.yaml- the important variables to set here are the paths to the images, masks, and CSV produced in step 1 above.

    data_dir_img: #Please/insert/path/to/images_directory
    data_dir_mask: #Please/insert/path/to/segmentation_mask/corresponds/images
    subdir_mask:  # options '_GT/SEG'/'_ST/SEG'/'GT/TRA'
    dir_csv: #Please/insert/path/to/saved_basic_CSV

In case you work with marker, you should set 'subdir_mask' to 'GT/TRA', in case you have full segmentation by GT(set '_GT/SEG') or silver GT (set '_ST/SEG').

A command line to produce the corresponding run with the override of the discussed variables is provided:

python run_train_metric_learning.py dataset.kwargs.data_dir_img=<image_directory> dataset.kwargs.data_dir_mask=<data_dir_mask> dataset.kwargs.dir_csv=<dir_csv>
  ```
3. Now, you've set everything up and you're ready to run ```run_train_metric_learning.py```.
At the end of the run, our code prepares wraps the best checkpoints and saves them with metadata in a dictionary file called "/outputs/<date_time>/<time>/all_params.pth" in the project directory. This dictionary is required for learned features extraction back in ```configs/feat_extract/params/params.yaml```:

input_images: #Please/insert/path/to/image_frames input_masks: #Please/insert/path/to/image_masks/corresponds/image_frames input_seg: #Please/insert/path/to/segmentation_mask/corresponds/image_frames input_model: #Please/insert/path/of/metric_learning/feature_extractor_model output_csv: #Please/insert/path/to/save/features basic: # !! Most important now- should be True !! -options True/False sequences: # example: ['01', '02'] seg_dir: <choose_seg_dir_template> # options '_GT/SEG'/'_ST/SEG'


The 'basic' variable should be set to 'False' and the input_model is the save dictionary ('all_params') file logged at the end of the training of the metric learning. You should now run again ```run_feat_extract.py``` to extract features - both spatio-temporal and deep metric learning features.


A command line to produce the corresponding run with an override of the discussed variables is provided:
```yaml
# run feat_extract for cell tracking  -
# please ensure that your target is correct in the 'configs/feat_extract/feat_extract.yaml' file.
python run_feat_extract.py params.input_images=<image_dir> params.input_masks=<masks_dir> params.input_seg=<masks_dir> params.input_model=<path_to_all_params_produced_in_metric_learning> params.output_csv=<save_output> params.basic=False params.sequences=[<str_sequences>, <str_sequences>, ...] params.seg_dir=<seg_dir_template>

Run Training of the full model Cell Tracking by GNN

Our main file is run.py with the configuration configs/config.yaml. Main project config contains default training configuration:

defaults:
  - trainer: default_trainer.yaml # do not change
  - model: celltrack_model_patch_based.yaml # can be changed
  - datamodule: datamodule_multiSequence.yaml # can be changed
  - callbacks: default_callbacks.yaml  # do not change
  - logger: many_loggers.yaml  # do not change

as mentioned in the comments, you can change the model and the datamodule configurations only, which are located in configs/model and configs/datamodule folders, respectively. model is provided with 3 main option - celltrack_model_2d.yaml, celltrack_model_3d.yaml, and celltrack_model_patch_based.yaml indicate for 2D dataset with segmentation, 3D dataset with segmentation, and 2d dataset with markers, respectively. The only differnece between each is the input features dimension.

datamodule is provided with 2 main option -

  1. Run with separted sequences for train/validation/test using datamodule_sepSequences.yaml
  2. Run with combination of sequences for train/validation/test using datamodule_multiSequence.yaml- this configuration is used to train the final model to CTC (with 2 combination of the provided sequences to train and validation).

In each setting, you should change the directory in the variable "main_path" and 'dirs_path' sub-dirs of the main path. In case that you don't want to run with patch base settings (marker annotation settings) and you do want to run with segmentation annotation settings, please comment the strings in "drop_feat" argument, or simply override them with adding datamodule.dataset_params.drop_feat=[] to the run command line.

model is provided with 3 main options -

  1. Run with 2d+segmentation celltrack_model_2d.yaml
  2. Run with 2d+markers celltrack_model_patch_based.yaml
  3. Run with 3d+segmentation celltrack_model_3d.yaml

In these configurations, no changes are requested. Just setting the preference settings in the main config file configs/config.yaml.

Now, when you are familiar with all the relevant variables, we are providing a command line to produce the corresponding run with an override of the discussed variables.

export CUDA_VISIBLE_DEVICES=0 # select GPU number
# training run
python run.py datamodule.dataset_params.main_path=<csv_home_directory> datamodule.dataset_params.exp_name="<name>_<2D/3D>"

At the end of the training, a run is made to extract the validation set scores on the edges of the graph for the best checkpoint. Messages with the performance and the best checkpoint path are logged. The achieved precision, recall, and accuracy scores by our method on the edge classification are approximate ~99% (and even higher), and the scores are logged at this stage along with other information.

Summary of all required command lines is provided in Section Command lines Summary above.

Logs Formats

Hydra creates a new working directory for every executed training run (metric_learning/cell_tracking).
By default, logs have the following structure separated for two main directories logs/outputs correspond to cell_tracking/metric_learning, respectively:
│
├── logs                  # Logs generated by Hydra and PyTorch Lightning loggers in the cell tracking model training
│   ├── runs                    # Folder for logs generated from single runs of the full model
│   │   ├── 2021-02-15              # Date of executing run
│   │   │   ├── 16-50-49                # Hour of executing run
│   │   │   │   ├── .hydra                  # Hydra logs
│   │   │   │   ├── wandb                   # Weights&Biases logs
│   │   │   │   ├── checkpoints             # Training checkpoints
│   │   │   │   └── ...                     # Any other thing saved during training
│   │   │   ├── ...
│   │   │   └── ...
│   │   ├── ...
│   └── └── ...
│   │
├── outputs                     # Outputs generated by Hydra and tensorboard loggers when training deep metric learning model
│   ├── runs                    # Folder for logs generated from single runs
│   │   ├── 2021-02-15              # Date of executing run
│   │   │   ├── 16-50-49                # Hour of executing run
│   │   │   │   ├── .hydra                  # Hydra logs
│   │   │   │   ├── logs_<exp_name>         # Any other thing saved during training - included checkpoints and logs
│   │   │   │   └── all_params.pth         # A dictionary consisting of all the relevant information (model state dicts and other parameters that are used for feature extraction)
│   │   │   ├── ...
│   │   │   └── ...
│   │   ├── ...
│   └── └── ...
│



Evaluation code

To run evaluation, we provide an example script (submitted to CTC) and all the relevant files to run our code in src/inference folder, details below:

SEQUENCE="01"
FOV="0"
DATASET="${PWD}/../Fluo-N2DH-SIM+" # path/to/dataset/dir
CODE_TRA="${PWD}" # path/to/inference_of_tracking/algorithm(ours)
MODEL_METRIC_LEARNING="${PWD}/parameters/Features_Models/Fluo-N2DH-SIM+/all_params.pth" # path/to/MODEL_METRIC_LEARNING/parameters
MODEL_PYTORCH_LIGHTNING="${PWD}/parameters/Tracking_Models/Fluo-N2DH-SIM+/checkpoints/epoch=132.ckpt" # path/to/tracking_model/parameters
CODE_SEG="${PWD}/seg_code/" # path/to/seg/algorithm
SEG_MODEL="${PWD}/parameters/Seg_Models/Fluo-N2DH-SIM+/" # path/to/seg_model/parameters
MODALITY="2D"  # dataset modality

# seg prediction
python ${CODE_SEG}/Inference2D.py --gpu_id 0 --model_path ${SEG_MODEL} --sequence_path "${DATASET}/${SEQUENCE}" --output_path "${DATASET}/${SEQUENCE}_SEG_RES" --edge_dist 3 --edge_thresh=0.3 --min_cell_size 100 --max_cell_size 1000000 --fov 0 --centers_sigmoid_threshold 0.8 --min_center_size 10 --pre_sequence_frames 4 --data_format NCHW --save_intermediate --save_intermediate_path ${DATASET}/${SEQUENCE}_SEG_intermediate

# cleanup
rm -r "${DATASET}/${SEQUENCE}_SEG_intermediate"

# Finish segmentation - start tracking

# our model needs CSVs, so let's create from image and segmentation.
python ${CODE_TRA}/preprocess_seq2graph_clean.py -cs 20 -ii "${DATASET}/${SEQUENCE}" -iseg "${DATASET}/${SEQUENCE}_SEG_RES" -im "${MODEL_METRIC_LEARNING}" -oc "${DATASET}/${SEQUENCE}_CSV"

# run the prediction
python ${CODE_TRA}/inference_clean.py -mp "${MODEL_PYTORCH_LIGHTNING}" -ns "${SEQUENCE}" -oc "${DATASET}"

# postprocess
python ${CODE_TRA}/postprocess_clean.py -modality "${MODALITY}" -iseg "${DATASET}/${SEQUENCE}_SEG_RES" -oi "${DATASET}/${SEQUENCE}_RES_inference"

rm -r "${DATASET}/${SEQUENCE}_CSV" "${DATASET}/${SEQUENCE}_RES_inference" "${DATASET}/${SEQUENCE}_SEG_RES"

You should create the same script as above with the relevant parameters to trained models (which are elaborated above how to produce), In comments, we explain each variable. Please refer to the main paper and read about the segmentation algorithms used. Please refer to read about evaluation-methodology of the challenge here http://celltrackingchallenge.net/evaluation-methodology/ - it is also provided with the Command-line software packages that implement the TRA measure (publicly available in the link)

Google Colab notebook example

Open In Colab

Pretrained Models

The submitted softwate and pretrained models to the cell tracking challenge are available at the Releases

Citation

If you find either the code or the paper useful for your research, cite our paper:

@inproceedings{ben2022graph,
title={Graph Neural Network for Cell Tracking in Microscopy Videos},
author={Ben-Haim, Tal and Riklin-Raviv, Tammy},
booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
year={2022},
}