Pytorch implementation of ms-loss
iGuaZi opened this issue · 1 comments
iGuaZi commented
class MultiSimilarityLoss(nn.Module):
def __init__(self, configer=None):
super(MultiSimilarityLoss, self).__init__()
self.is_norm = True
self.eps = 0.1
self.lamb = 1
self.alpha = 2
self.beta = 50
def forward(self, inputs, targets):
n = inputs.size(0)
if self.is_norm:
inputs = inputs / torch.norm(inputs, dim=1, keepdim=True)
similari_matrix = inputs.matmul(inputs.t())
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
loss = None
for i in range(n):
temp_sim, temp_mask = similari_matrix[i], mask[i]
min_ap, max_an = temp_sim[temp_mask].min(), temp_sim[temp_mask==0].max()
temp_AP = temp_sim[(temp_mask==1) & (temp_sim < max_an + self.eps)] # may be tensor([])
temp_AN = temp_sim[(temp_mask==0) & (temp_sim > min_ap - self.eps)] # torch.sum(tensor([])) = tensor(0.)
L1 = torch.log(1 + torch.sum(torch.exp(-self.alpha * (temp_AP - self.lamb)))) / self.alpha
L2 = torch.log(1 + torch.sum(torch.exp(self.beta * (temp_AN - self.lamb)))) / self.beta
L = L1 + L2
if loss is None:
loss = L
else:
loss += L
loss /= n
return loss