I have some questions about testing.
fatfatfatmouse opened this issue · 6 comments
hello,I am a beginner of studying segmentation.After carefully reading the code and paper you provided,I have some questions as follows:
**1.**I am not sure that which part of the code achieve the edge-detection function.I noticed that the HEDUNET in hed_unet.py seems to achieve only segmentation part. Can you tell me where the code is transferred to the HED edge detection task?
**2.**When I start to train my dataset,the result of“showexample” is very nice,SegAcc and EdgeAcc are very high. But when I use the trained model to test pictures,the result looks terrible. I am not sure that there is a problem with my code. My code is as follows:
cuda = True if torch.cuda.is_available() else False
device = torch.device('cpu' if torch.cuda.device_count() == 0
else 'cuda')
model = HEDUNet(input_channels=3, base_channels=16, stack_height=5, batch_norm=True)
model = model.to(device)
model.load_state_dict(torch.load(path_model))
model.eval()
image = np.expand_dims(image.transpose(2,0,1),0)
image = torch.from_numpy(image).to(device)
image = (image.to(torch.float) / 127.) - 1.
res, _ = model(image)
seg_pred, edge_pred = torch.sigmoid(res.squeeze())
pred = np.array(seg_pred.copy())
pred[pred > 0.5] = 255
pred[pred<=0.5] = 0
pred = pred.astype(np.uint8)
pred = pred.reshape((ImgLength, ImgLength))
predicts.append((pred))
pred = np.array(edge_pred.copy())*255
pred = pred.astype(np.uint8)
pred = pred.reshape((ImgLength, ImgLength))
predicts_edge.append((pred))
**3.**I suspect that the reason is "full_forward". Do I need to use "full_forward" when testing? But if it is called, how to fill in parameters like target and metrics?
Hi there! Happy to answer your questions 👍
- In fact, the model in
hed_unet.py
performs both segmentation and edge detection. The output is a [BCHW]-tensor with two channels, where the first channel contains the result of the binary segmentation, and the second one the result of the edge detection. - From glancing over your code, it should work. It is hard to debug issues like this from afar. Possible reasons could be different pre-processing between training and inference time, or simply an overfitting model.
full_forward
really is just a wrapper around the actual model's forward function that also takes care of metrics calculation, deep supervision, and other things. It is not necessary to use it during inference.
Thank you for your answer!After thinking about your answer, I have some new ideas as follows:
1.After get_pyramid function,the target contains both original masks and edges. So that the loss function in full_forward guides model learning to predict edges and lets the model has the ability to predict the edge. I am not sure whether this conclusion is correct, and I hope to get your confirmation.
2. I notice the code “trn_dataset = Augment(Subset(trnval, trn_indices))” in train.py. Do I also need to use this data enhancement operation during testing?
3.Whether the BN layer should be closed consistently during training and testing?
- Right, in this case the
get_pyramid
function calculates the edge targets from the segmentation targets. However for other use-cases, the edge targets may already be known. - Augmentation is just a useful trick for making the model generalize better. It's generally not used during evaluation (though there are some studies on test-time augmentation).
- What exactly do you mean by "closed"?
I'm sorry for my wrong description.
- In [config.yml], the batch_norm is True. But in the hed_unet.py,the batch_norm is False.
I checked the code and found that the information of BN in config is not used. - So that I train the model with the "batch_norm:False“,but test the model like this:
model = HEDUNet(input_channels=3, base_channels=16, stack_height=5, batch_norm=True) - I suspect this difference is causing my poor test results. I'm sure my model is not overfitting because losses look correct.
- This just means that if the argument
batch_norm
is not specified when creating the HED-UNet, it will default toFalse
(no batch norm). Thebatch_norm
argument fromconfig.yml
is indeed passed to the model constructor as part ofmodel_args
: https://github.com/khdlr/HED-UNet/blob/master/train.py#L192 - The
load_state_dict
method will raise an error when the saved weights were trained without batch norm and the current model has it enabled. I guess (1) explains this. - Again, I believe the error must be somewhere else, but it's hard to say where from my end.
I really appreciate that you took your time helped me, and wish you have a good day!