Beckschen/TransUNet

What is Expected Input and output channel ?

ankit-rl4 opened this issue · 1 comments

Hi , can I train this on breakhis dataset for breast cancer images . It has 3 channel input and grayscale mask.
I trained the model with it but the model returns 3 channel output so i had to convert it to grayscale . I am not sure if thats correct the output i got was like this
image
Which is wrong . I am not sure where i made the mistake. I am a learner.

I have made a change here type cast x to float
`class StdConv2d(nn.Conv2d):

def forward(self, x):
    #change
    x=x.float()
    
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-5)
    return F.conv2d(x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)`

and this is the class i use for dataloading

`class BreakhisDataset(Dataset):
def init(self, root_dir, split, img_size, transform=None):
self.root_dir = root_dir
self.transform = transform
self.split = split
self.img_size = (224, 224)
self.image_paths, self.mask_paths = load_dataset(self.root_dir, self.split)

def __len__(self):
    return len(self.image_paths)

def __getitem__(self, idx):
    img_path = self.image_paths[idx]
    mask_path = self.mask_paths[idx]
    image = np.array(Image.open(img_path).convert('RGB'))
    mask = np.array(Image.open(mask_path).convert('L'), dtype=np.float32)
    mask[mask == 255.0] = 1.0
    image = cv2.resize(image, self.img_size, interpolation=cv2.INTER_LINEAR)
    mask = cv2.resize(mask, self.img_size, interpolation=cv2.INTER_LINEAR)

    if self.transform is not None:
        image = self.transform(image)
        mask = self.transform(mask)

    return image, mask

`

def load_dataset(data_dir, split='train'):
    images_dir = os.path.join(data_dir, split, 'images')
    masks_dir = os.path.join(data_dir, split, 'masks')
    image_paths = sorted([os.path.join(images_dir, f) for f in os.listdir(images_dir) if f.endswith('.png')])
    mask_paths = sorted([os.path.join(masks_dir, f) for f in os.listdir(masks_dir) if f.endswith('.png')])
    return image_paths, mask_paths

and converting it in grayscale like this here

def test_single_volume(image, label, net, classes, patch_size=[224, 224], test_save_path=None, case=None, z_spacing=1):
    image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy()
    if len(image.shape) == 3:
        prediction = np.zeros_like(label)
        print(image.shape)
        for ind in range(image.shape[0]):
            slice = image[ind, :, :]
            x, y = slice.shape[0], slice.shape[1]
            if x != patch_size[0] or y != patch_size[1]:
                slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)  # previous using 0
            input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
            net.eval()
            with torch.no_grad():
                outputs = net(input)
                out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0)
                out = out.cpu().detach().numpy()
                if x != patch_size[0] or y != patch_size[1]:
                    pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0)
                else:
                    pred = out
                #grayscale conversion
                gray_img = np.mean(pred, axis=1)
                gray_img = gray_img / 255.0
                prediction[ind] = gray_img

I am trying to figure out what's wrong , The output i get is in 3 channel. I am not sure if its suppose to be that way since i trained it on 3 channel input and grayscale mask

I have same problem with you, did you result it? It is no sense that in put need [2,2,H,W], I have no idea about the second "2". I try to mask the [2,0,H,W] as 1 if the lable is black and [2,1,H,W] as 0 if the lable is white, but when it passed one_hot function, It becomed [2,2,2,H,W]. I have no ideal about the code