bigmb/Unet-Segmentation-Pytorch-Nest-of-Unets

question with 'data_transform'

iWeisskohl opened this issue · 8 comments

Hi, I have some doubts with 'data_transform' function. As you suggested , input image should be 3 channel image and input label should be 1 channel image , but I find you use the same data_transform function

data_transform = torchvision.transforms.Compose([
# torchvision.transforms.Resize((128,128)),
# torchvision.transforms.CenterCrop(96),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

for input image and input label during training. And use another function

data_transform = torchvision.transforms.Compose([
# torchvision.transforms.Resize((128,128)),
# torchvision.transforms.CenterCrop(96),
torchvision.transforms.Grayscale(),
#torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ])
for input image and input label for calculating the Dice Score.

and the codes output error with shape error with those functions when i run them . So I am wondering is there a misktake with the defination and using for data_transform function ?
Thanks in advance ! Have a nice day!

bigmb commented

Whata the error?
Its different for input image and label. You can find the data transform in 3d_to_2d.py code

yes, I use the 3d_to_2d.py code to get data and run pytorch_run.py . the error is like :
RuntimeError: output with shape [1, 128, 128] doesn't match the broadcast shape [3, 128, 128]. reported in pytorch_run.py line 311:
line310 s_tb = data_transform(im_tb)
line311 s_label = data_transform(im_label),
becuase i find you use the same transform function for diffeeren channel input, which definited as :
data_transform = torchvision.transforms.Compose([
# torchvision.transforms.Resize((128,128)),
# torchvision.transform, CenterCrop(96),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

so I change the data_trandform definition for input label(because it is single channel, and can't be normalized with three channel setting), and the problem solved.
Actually ,there is no data transform in 2d_from_3d.py, but have transforms.Compose in data_loader.py. So I am wondering if you have done data transform when loading datas, deose it necessary to re-transform data again in pytorch_run.py (such as line 310 and line 311)?

sorry to disturbe you again, I really appreciate your nice work. Waiting for you reply.

bigmb commented

I did data transformation in a different jupyter file because I had to test different configurations.
But yes the data_tranform for input images and labels should be different as 1 is 3 channel and labels are 1 channel.

Are you facing any errors now?
And if there is any change, just send me a pull request.

Hi, I am having the same problem. Namely that the data transform for the mask in line 295 uses the same transform as the 3 channel image data.

s_label = data_transform(im_label)
giving the below error:
RuntimeError: output with shape [1, 1020, 1020] doesn't match the broadcast shape [3, 1020, 1020]

Could you let me know what was changed to get it to run in your case?

Just to answer my own question -

I added a new function for the binary mask transforms:

data_transform_mask = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])

and applied this to line 295:

s_label = data_transform_mask(im_label)
bigmb commented

Done.
Let me know if you need some help.

Hi, I just change the data_transform function for label image with one channel definition (similar with what @c-arthurs did. but I am not sure if the change is reasonable because I am not sure whether do data transform for ground truth image will influent segmentation results .