mazurowski-lab/segmentation-guided-diffusion

Problems in inputting 3 and 5 channel images

Closed this issue · 6 comments

Firstly, thank you for your work on making a clean model for medical image generation

I have problems in trying to insert my 5 channel cell images in the model.

My problem shows:
RuntimeError: Given groups=1, weight of size [128, 5, 3, 3], expected input[1, 3 , 256, 256] to have 5 channels, but got 3 channels instead

where it starts at:
noise_pred = model(sample=noisy_images, timestep=timesteps, return_dict=False)[0]
in training.py file

I have inserted and checked noisy images shape as torch.size(5, 256, 256) before it enters the model code, and timestep as torch.size(5).

same problem arises when I use 3 channel images, but the problem arises are:

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[1, 2 , 256, 256] to have 3 channels, but got 2 channels instead

Whenever it goes into to the model, the size of the channel changes for some reason.

For reference, the model for 5 channel is:

model = diffusers.UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=5,  # the number of input channels, 3 for RGB images
    out_channels=5,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    ),
)

Thank you so much if you can help out on that.

Hi, I'm happy to see that you're trying out our model on cell images! I think that this bug may be due to the code being designed for one-channel images.

I'd like to make the code work for images beyond 1- and 3-channel, but one issue is that the images are loaded with PIL, so could you let me know what filetype your images are? I want to make sure that PIL loads them correctly (or if i should use something besides PIL)

Thanks!

Thank you for the quick response.

For 1 channel images, it works well after I tested, thank you for your framwork again. I am currently working as numpy images.

Hi,

I modified the code to try to load your images from np.arrays, rather than through PIL, if the chosen channel count in the --num_img_channels argument of main.py is not 1 or 3.

It should convert the arrays directly into torch tensors, with the same resizing and normalizing operations as usual. Also, it assumes your np arrays are saved as (N_channels, H, W), and still assumes that your segmentations (if you're using a segmentation-guided model) are saved as image files.

Could you test the most recent commit and see if it works with your setup?

Thank you again for the effort for modifying the code

I have tried the code, with my numpy data, overall it works fairly well after I modified some part:

  1. In line 221, preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0), size=(config.image_size, config.image_size))) for image in examples["image"], F.intepolate did not work for int directly, raised an error:

RuntimeError: "compute_indices_weights_nearest" not implemented for 'Int'

My solution was to add .float() at the back of the code (preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0).float(), size=(config.image_size, config.image_size))) for image in examples["image"])

  1. Same line in main.py the unsqueeze part was required to allow interpolate function to work, but when it is inserted in the code, it arised an extra dimension to the input so it arised an error. (where the input has to be (8,5,256,256) but it had extra dimension (8,5,1,256,256)).

(I did not save an error code for this)

My solution was to squeeze it back right after the preprocess code

I am currently not working on segmentation guided model, but soon to be. So I will update if there is problem in segmentation guided model

Thank you for the code. It is really easy to see as well since they do not have any messy lines.

If my solution seems bad or have a greater idea for solving the solution, it will be really appreciated to hear the ideas as well

Thanks for debugging this! Your solution is good, I hadn't considered those errors; I'll go ahead and add your fixes as a commit 9f532bb (or let me know if you want to add it as a PR instead).

Closing for now since your specific issue seems resolved, but please add a new issue if anything else comes up :)

Firstly, thank you for your work on making a clean model for medical image generation
I have problems in trying to insert my 3 channel cell images in the model.
I used class_conditional: bool = True, because the input is 1 channels semantic segmentation. Is this correct?
My problem shows:
IndexError: The shape of the mask [50, 3, 256, 256] at index 1 does not match the shape of the indexed tensor [50, 1, 256, 256] at index 1
segs = torch.zeros(imgs_shape).to(device)
---segs have 3 channels, but seg only has 1 channel
segs[segs == 0] = seg[segs == 0]
I used class_conditional: bool = False, but I encountered the same issue as mentioned above:
RuntimeError: Given groups=1, weight of size [128, 2, 3, 3], expected input[13, 4, 256, 256] to have 2 channels, but it has 4 channels instead.