Different Way to implement Triplet loss
Opened this issue · 1 comments
deep0learning commented
Hi,
Can you please explain the way you implement triplet loss a little bit. I implement triplet loss like that. But I do not understand your implementation.
`class TripletLoss(nn.Module):
"""
Triplet loss
Takes embeddings of an anchor sample, a positive sample and a negative sample
"""
def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative, size_average=True):
distance_positive = (anchor - positive).pow(2).sum(1) # .pow(.5)
distance_negative = (anchor - negative).pow(2).sum(1) # .pow(.5)
losses = F.relu(distance_positive - distance_negative + self.margin)
return losses.mean() if size_average else losses.sum()
`
kerryliu28 commented
@deep0learning
Actually it's calculating the squared distance between anchors and negative or positives and choosing pairs following hard-batch triplet rule.
you can create fake inputs and targets and play with the code, then you'll figure it out.