LIVIAETS/SizeLoss_WSS

About 3d Size

MENG1996 opened this issue · 2 comments

Thanks for your work.I would like to ask a detailed question. Does the 3d size segmentation in the article also use a 2d network? and each batch contains all slices of 3d volume?

Yes exactly, I kept the same 2D network and fed all slices for a single patient at once (on ACDC this was around 10/15 images at a time).

The only difference, then is when computing the size: one does not compute 15 independent sizes (one per images), but one for the whole volume. This difference can be seen here, with only the c axis remains in the einsum:

# fns
def soft_size(a: Tensor) -> Tensor:
    return torch.einsum("bcwh->bc", a)[..., None]


def batch_soft_size(a: Tensor) -> Tensor:
    return torch.einsum("bcwh->c", a)[..., None]

Thank you for your reply