/MulticlassBrainSegmentation

UNET based multiclass segmentation network written with PyTorch for the FeTA 2021 Challenge.

Primary LanguagePython

Multiclass Brain Segmentation

My code snippets from the UNET based multiclass brain segmentation model made for the Fetal Brain Tissue Annotation and Segmentation Challenge (FeTA), MICCAI 2021.

As the challenge is still ongoing the repository has only less "custom made" elements like data loader or the training script.

Current Results

Model trained on CPU to check if there are any problems with code.

  • Single class segmentation: Epochs: 1, Images: 2560, Batch: 16, Jaccard: 0.42
    195867825_164458648978423_8123026248998937840_n
    Brain image on bottom, red area - label/mask, orange spots - model prediction.

  • Multiclass segmentation: Epochs: 1, Images: 2560, Batch: 16, Jaccard: 0.28 wynikimulti

Training Script (working on it)

As stated before - the repository has only some parts of the code, but if you had them all you could run the training by:

python train5.py LEARNING_RATE BATCH_SIZE EPOCHS TRAIN_AMOUNT VALID_AMOUNT SAVE_MODEL_NAME DATA_PATH

where BATCH_SIZE - amount of pictures to go into one training step, TRAIN_AMOUNT - number of training patients to load, VALID_AMOUNT - number of validation patients to load, SAVE_MODEL - where to save the trained model, DATA_PATH - what goes to the Data Loader object.

Resolved problems:

  • RAM management Made DataLoader object iterable so we are able to load the data patient by patient during training - resulted in significant RAM usage decrease.
  • VRAM management Metrics variables HAVE TO be changed to float() so they do not drag gradient vectors with them - significant VRAM usage decrease.

To do:

  • Cross Validation
  • Data Augmentation

Data Loader

You can create an iterative data object by passing the data folder path to the MRIDataset class as:

data = MRIDataset(path)
for patient in data:
  print(patient.shape)

By each iteration through data object you get next 3D MRI image as a numpy array of 256 x 256 x 256 shape.
Object length is taken from the filtered length of folders in the path directory list.

def __len__(self):
    return len([dir for dir in next(os.walk(self.folderpath))[1] if dir.startswith('sub')])

Utils

We use utils functions to prepare loaded data for the UNET train/valid pass.

def prepare_one(x,y, batch):
    x = np.expand_dims(x, 1)
    x = np.moveaxis(x, 3, 0)
    x = np.moveaxis(x, 1, 2)
    y = np.expand_dims(y, 1)
    y = np.moveaxis(y, 3, 0)
    y = np.moveaxis(y, 1, 2)
    zipped = list(zip(x,y))
    ready = DataLoader(zipped, batch_size=batch)
    return ready

Takes a list of images as x, masks as y and number of images and masks to pack in a training/valid batch. Gives torch.DataLoader objects in return.

def mask_dim(current):
    current = np.concatenate([np.where(current == i, 1, 0) for i in range(1,8)], 1)
    return current

Takes mask array with shape of 1 x 256 x 256 and extracts mask 1-7 values as new dimensions with value of 1, so return shape is 7 x 256 x 256.
Example (1-2 values):

input = [[1 2 1]] -> output = mask_dim(input) -> output = [[1 0 1], [0 1 0]]  
         [0 1 2]                                           [0 1 0]  [0 0 1]  
         [1 2 0]                                           [1 0 0]  [0 1 0]