MichaelFan01/STDC-Seg

How to understand the OhemCEloss?

feiyuhuahuo opened this issue · 0 comments

def forward(self, logits, labels):
    N, C, H, W = logits.size()
    loss = self.criteria(logits, labels).view(-1)
    loss, _ = torch.sort(loss, descending=True)
    if loss[self.n_min] > self.thresh:
        loss = loss[loss>self.thresh]
    else:
        loss = loss[:self.n_min]
    return torch.mean(loss)

What's the meaning of these codes? And how to decide thresh and n_min?