khdlr/HED-UNet

I have some questions about testing.

fatfatfatmouse opened this issue · 6 comments

hello,I am a beginner of studying segmentation.After carefully reading the code and paper you provided,I have some questions as follows:

**1.**I am not sure that which part of the code achieve the edge-detection function.I noticed that the HEDUNET in hed_unet.py seems to achieve only segmentation part. Can you tell me where the code is transferred to the HED edge detection task?

**2.**When I start to train my dataset,the result of“showexample” is very nice,SegAcc and EdgeAcc are very high. But when I use the trained model to test pictures,the result looks terrible. I am not sure that there is a problem with my code. My code is as follows:

cuda = True if torch.cuda.is_available() else False
device = torch.device('cpu' if torch.cuda.device_count() == 0
else 'cuda')
model = HEDUNet(input_channels=3, base_channels=16, stack_height=5, batch_norm=True)
model = model.to(device)
model.load_state_dict(torch.load(path_model))
model.eval()

image = np.expand_dims(image.transpose(2,0,1),0)
image = torch.from_numpy(image).to(device)
image = (image.to(torch.float) / 127.) - 1.

res, _ = model(image)
seg_pred, edge_pred = torch.sigmoid(res.squeeze())

pred = np.array(seg_pred.copy())
pred[pred > 0.5] = 255
pred[pred<=0.5] = 0
pred = pred.astype(np.uint8)
pred = pred.reshape((ImgLength, ImgLength))
predicts.append((pred))

pred = np.array(edge_pred.copy())*255
pred = pred.astype(np.uint8)
pred = pred.reshape((ImgLength, ImgLength))
predicts_edge.append((pred))

**3.**I suspect that the reason is "full_forward". Do I need to use "full_forward" when testing? But if it is called, how to fill in parameters like target and metrics?

khdlr commented

Hi there! Happy to answer your questions 👍

  1. In fact, the model in hed_unet.py performs both segmentation and edge detection. The output is a [BCHW]-tensor with two channels, where the first channel contains the result of the binary segmentation, and the second one the result of the edge detection.
  2. From glancing over your code, it should work. It is hard to debug issues like this from afar. Possible reasons could be different pre-processing between training and inference time, or simply an overfitting model.
  3. full_forward really is just a wrapper around the actual model's forward function that also takes care of metrics calculation, deep supervision, and other things. It is not necessary to use it during inference.

Thank you for your answer!After thinking about your answer, I have some new ideas as follows:
1.After get_pyramid function,the target contains both original masks and edges. So that the loss function in full_forward guides model learning to predict edges and lets the model has the ability to predict the edge. I am not sure whether this conclusion is correct, and I hope to get your confirmation.
2. I notice the code “trn_dataset = Augment(Subset(trnval, trn_indices))” in train.py. Do I also need to use this data enhancement operation during testing?
3.Whether the BN layer should be closed consistently during training and testing?

khdlr commented
  1. Right, in this case the get_pyramid function calculates the edge targets from the segmentation targets. However for other use-cases, the edge targets may already be known.
  2. Augmentation is just a useful trick for making the model generalize better. It's generally not used during evaluation (though there are some studies on test-time augmentation).
  3. What exactly do you mean by "closed"?

I'm sorry for my wrong description.

  1. In [config.yml], the batch_norm is True. But in the hed_unet.py,the batch_norm is False.
    I checked the code and found that the information of BN in config is not used.
  2. So that I train the model with the "batch_norm:False“,but test the model like this:
    model = HEDUNet(input_channels=3, base_channels=16, stack_height=5, batch_norm=True)
  3. I suspect this difference is causing my poor test results. I'm sure my model is not overfitting because losses look correct.
khdlr commented
  1. This just means that if the argument batch_norm is not specified when creating the HED-UNet, it will default to False (no batch norm). The batch_norm argument from config.yml is indeed passed to the model constructor as part of model_args: https://github.com/khdlr/HED-UNet/blob/master/train.py#L192
  2. The load_state_dict method will raise an error when the saved weights were trained without batch norm and the current model has it enabled. I guess (1) explains this.
  3. Again, I believe the error must be somewhere else, but it's hard to say where from my end.

I really appreciate that you took your time helped me, and wish you have a good day!