jakeret/unet

Working the code for 6 channel inputs

ShreyaPandita01 opened this issue · 4 comments

As per the code, if the number of input channels is>3 it just uses the first 3 channels.
The input I want to give consists of 6 input channels, first three being rgb and the other 3 being some specific inputs I want to give to the model for my desired use case.
How can we make the model input all 6 channels?

Hi @ShreyaPandita01 are you referring to plotting in the Tensorboard or are you receiving an exception?
As for the implementation, it should accept input tensors with arbitrary number of channels

I am talking about this function in the utils.py notebook
to_rgb(img: np.array)
"""
Converts the given array into a RGB image and normalizes the values to [0, 1).
If the number of channels is less than 3, the array is tiled such that it has 3 channels.
If the number of channels is greater than 3, only the first 3 channels are used
:param img: the array to convert [bs, nx, ny, channels]
:returns img: the rgb image [bs, nx, ny, 3]
"""

This function is used for the visualization in TensorBoard and indeed uses the first 3 channels. I'm unsure how we could visualize your 6 channels. Do you have a suggestion?

ahhh, let me try and work it for my use case and if I can make it happen I will pull a PR. Meanwhile I am closing this issue, Thanks for all the help!