kreshuklab/plant-seg

Re-train all PlantSeg models to eliminate hallucination during inference

Opened this issue · 0 comments

Re-train all PlantSeg models to eliminate hallucination during inference

We have historically addressed the tiling artifact by using larger patch sizes. Recently, @Buglakova re-highlighted this issue in Issue #190, prompting further investigation in the PlantSeg and pytorch-3dunet projects. See the related PR #220 for PlantSeg and PR #113 for pytorch-3dunet.

Findings:
After adjusting the halo implementation, it's clear that nuclear models trained with batch normalization (with the default track_running_stats=True) avoid prediction hallucinations on new datasets, unlike those trained with group normalization.

Previous Assumptions:
We assumed group norm performed better under certain conditions based on earlier hyperparameter tuning with datasets from similar distributions. However, this might need reevaluation as batch norm has proven more stable across diverse datasets.

Action Required:
I propose updating versions for models previously trained with group norm. This is critical to ensure consistency and reliability in our semantic segmentation tasks.

Visual Evidence:

  • Left: Prediction from a model trained with group norm using an ovules nuclear dataset.
  • Right: Prediction from a model trained with batch norm using an ovules nuclear dataset.
  • Middle: Raw image of a mouse embryo (Note: Brightness is increased to highlight the absence of signal).

Comparison of model predictions

@wolny, could you assist in identifying the training datasets used for each model to facilitate these updates? I'll do the training after my CBB Seminar talk tomorrow.