Project-MONAI/GenerativeModels

Question on input transform in `RadImageNetPerceptualSimilarity`

function2-llx opened this issue · 4 comments

Dear developers,

I'm trying to use RadImageNetPerceptualSimilarity from this library. In the code (shows as follows), before being fed into the network, the channels of input will be reordered from "RGB" to "BGR", and subtract the mean.

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
We expect that the input is normalised between [0, 1]. Given the preprocessing performed during the training at
https://github.com/BMEII-AI/RadImageNet, we make sure that the input and target have 3 channels, reorder it from
'RGB' to 'BGR', and then remove the mean components of each input data channel. The outputs are normalised
across the channels, and we obtain the mean from the spatial dimensions (similar approach to the lpips package).
"""
# If input has just 1 channel, repeat channel to have 3 channels
if input.shape[1] == 1 and target.shape[1] == 1:
input = input.repeat(1, 3, 1, 1)
target = target.repeat(1, 3, 1, 1)
# Change order from 'RGB' to 'BGR'
input = input[:, [2, 1, 0], ...]
target = target[:, [2, 1, 0], ...]
# Subtract mean used during training
input = subtract_mean(input)
target = subtract_mean(target)
# Get model outputs
outs_input = self.model.forward(input)
outs_target = self.model.forward(target)
# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)
results = (feats_input - feats_target) ** 2
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
return results

According to a PyTorch demo provided in the official repository (link), I can understand the reason for channel reordering, because they read image with cv2.imread, which uses a "BGR" order according to the OpenCV doc. However, the mean value for subtraction is essentially from torchvision (just a accordingly reordering with channels).

def torchvision_zscore_norm(x: torch.Tensor) -> torch.Tensor:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x[:, 0, :, :] = (x[:, 0, :, :] - mean[0]) / std[0]
x[:, 1, :, :] = (x[:, 1, :, :] - mean[1]) / std[1]
x[:, 2, :, :] = (x[:, 2, :, :] - mean[2]) / std[2]
return x
def subtract_mean(x: torch.Tensor) -> torch.Tensor:
mean = [0.406, 0.456, 0.485]
x[:, 0, :, :] -= mean[0]
x[:, 1, :, :] -= mean[1]
x[:, 2, :, :] -= mean[2]
return x

This is the part that I don't understand. The following code can be found in the official demo, which indicates that it rescales the intensity to [-1, 1].

class createDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transforms.Compose([transforms.ToTensor()])

    def __len__(self):
        return self.dataframe.shape[0]
        
    def __getitem__(self, index):
        image = self.dataframe.iloc[index]["img_dir"]
        image = cv2.imread(image)
        image = (image-127.5)*2 / 255
        image = cv2.resize(image,(224,224))
        #image = np.transpose(image,(2,0,1))   
        if self.transform is not None:
            image = self.transform(image)
        label = self.dataframe.iloc[index]["label"]
        return {"image": image , "label": torch.tensor(label, dtype=torch.long)}

Moreover, according to BMEII-AI/RadImageNet#1 (comment):

The mean and standard deviation derived from the RadImageNet are 0.223 and 0.203 respectively.

Since the input distribution is critical for perceptual loss to work as expected, it will be much appreciated if my question could be clarified. Thanks!

Hi,

I think the model was trained on grayscale data and the expected input iwill s grayscale so the channel ordering should not matter.

The mean and standard deviation derived from the RadImageNet are 0.223 and 0.203 respectively.

Buy yes, it does look like we should be using these values rather than those in the subtract mean function here, which appear to be taken from ImageNet:

 def subtract_mean(x: torch.Tensor) -> torch.Tensor: 
     mean = [0.406, 0.456, 0.485] 
     x[:, 0, :, :] -= mean[0] 
     x[:, 1, :, :] -= mean[1] 
     x[:, 2, :, :] -= mean[2] 
     return x 

I will update this soon - just tagging @Warvito who implemented it in case he has any thoughts too

@marksgraham Thank you for your reply. Yes, the model seems to be trained only on grayscale data as mentioned in the paper (Discussion section):

In our study, the RadImageNet database contained only grayscale medical images, while natural world images use three red-green-blue channels. Pretraining on grayscale images can allow the training of more generalizable low-level filters in the initial layers of the network.

But I'm still wondering if a subtraction of mean value is necessary, since the official demo only rescales the data from [0, 255] to [-1, 1].

Ah I see. Taking the example of the breast_train.py from their repo, they use this image generator:

train_data_generator = ImageDataGenerator(
                                 rescale=1./255,
                                 preprocessing_function=preprocess_input,
                                 rotation_range=10,
                                 width_shift_range=0.1,
                                 height_shift_range=0.1,
                                 shear_range=0.1,
                                 zoom_range=0.1,
                                 horizontal_flip=True,
                                 fill_mode='nearest')

and preprocess_input is defined here. Given it defaults to 'caffe' mode it looks like they scale the inputs between [0,1] and then subtract the ImageNet means. This is also what we are doing - we require users to scale between [0,1] on input, and then we subtract the means. So I think our preprocessing matches that used in RadImageNet :)

@marksgraham Thanks for the information, now I understand that the preprocessing in this library matches the Tensorflow implementation of RadImageNet. I guess maybe it's the official PyTorch demo of RadImageNet that does not match the corresponding Tensorflow implementation.