How to understand the OhemCEloss?
feiyuhuahuo opened this issue · 0 comments
feiyuhuahuo commented
def forward(self, logits, labels):
N, C, H, W = logits.size()
loss = self.criteria(logits, labels).view(-1)
loss, _ = torch.sort(loss, descending=True)
if loss[self.n_min] > self.thresh:
loss = loss[loss>self.thresh]
else:
loss = loss[:self.n_min]
return torch.mean(loss)
What's the meaning of these codes? And how to decide thresh
and n_min
?