
One issue in

souxun2015 opened this issue · 2 comments

Nice code for the memseg, concise and clear.

I found one issue in the focal loss part.
In the training script, the input of focal_loss is already through softmax function. But inside the focal loss, the softmax function is applied again.

# training part
            # predict
            outputs = model(inputs)
            outputs = F.softmax(outputs, dim=1)
            l1_loss = l1_criterion(outputs[:,1,:], masks)
            focal_loss = focal_criterion(outputs, masks)
            loss = (l1_weight * l1_loss) + (focal_weight * focal_loss)

# inside the focal loss
        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

BTW, I forked your repo, and I made some changes based on the paper and some code snippets released by the author.
You can have a look at my branch.

The main changes are located at:

  1., add label_smoothing
  2., change the selection of samples from the memory bank
  4. and

If you agree to these changes, I can make a pull request.

Hi, @souxun2015

I apologize for the delay in responding.
Thank you for sharing your modifications.

I ran new experiments with your advice about only focal loss.
I removed log softmax and added label smoothing into focal loss.
The results as follows:

target AUROC-image AUROC-pixel AUPRO-pixel
0 leather 100 99.2 99.7
1 pill 98.83 98.26 98.49
2 carpet 99.36 97.47 97.46
3 hazelnut 100 98.58 99.18
4 tile 100 99.53 99.47
5 cable 74.78 69.56 72.56
6 toothbrush 100 99.55 99.11
7 transistor 94.5 84.15 87.85
8 zipper 99.87 98.33 97.8
9 metal_nut 99.17 95.07 97.09
10 grid 100 99.05 98.97
11 bottle 100 98.94 98.77
12 capsule 91.74 97.26 97.1
13 screw 93.63 95.66 95.16
14 wood 99.82 96.75 98.03
15 Average 96.78 95.16 95.78

I think that other modifications does not need.
If you don't mind, I'd like to upload it with just the part about focal loss fixed.

Best regards,

Hi, @TooTouch
It's great to hear that you've conducted new experiments using the focal loss with label smoothing. Others are the dataset-specific hyperparameters, no need to spend time to modifty them.
Happy holiday!