BloodAxe/pytorch-toolbelt

Tiled inference potentially generates wrong multi-class predictions

yurithefury opened this issue ยท 2 comments

๐Ÿ› Bug

I believe the current implementation of the tiled inference could produce erroneous predictions. If I understand it correctly, in your tiled inference approach you accumulate predictions for each pixel and then divide them by a norm_mask (which is the total number of predictions for each pixel). This works well for a binary case, but not for a multi-class classification. For example, if I have 4 classes to predict and I do tiled inference (e.g. tile_size=128, tile_step=64) using your moving window approach I can end up with a mix of predictions for a pixel (e.g. 1,4,4,4), and the final prediction of this pixel (after applying norm_mask) will be 3. Wouldn't it be more appropriate to take mode of all predictions for this pixel to get the final prediction of 4?

To Reproduce

Steps to reproduce the behavior:

for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords):
    self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight
    self.norm_mask[:, y : y + tile_height, x : x + tile_width] += self.weigh
def merge(self) -> torch.Tensor:
    return self.image / self.norm_mask

Environment

  • Pytorch-toolbelt version: 0.6.2
  • Pytorch version: 2.0.0
  • Python version: 3.10
  • OS: Windows 11

Hi. The indented use-case scenario for sliding window inference is that you accumulate probabilities (after softmax/sigmoid) in the accumulator buffer of shape [C,H,W] and not the hard labels.

Of course, in your example the output would be obviously wrong.

So instead of doing argmax and then accumulating predictions, what you want to do is the following:

    model = ...
    num_classes = ...

    image = np.random.random((5000, 6000, 3)).astype(np.uint8)

    tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid")
    tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

    merger = CudaTileMerger(tiler.target_shape, num_classes, tiler.weight) # <-- Note num_classes here
    for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True):
        tiles_batch = tiles_batch.float().cuda()
        predictions = model(tiles_batch).softmax(dim=1) 

        merger.integrate_batch(predictions, coords_batch)

    merged = to_numpy(merger.merge().argmax(dim=0)) # <-- Note argmax over averaged predictions
    predicted_labels = tiler.crop_to_orignal_size(merged)

That makes a lot of sense. Thank you for the explanation. I completely missed channels argument in TileMerger.