jiyanggao/Video-Person-ReID

Different Way to implement Triplet loss

Opened this issue · 1 comments

Hi,
Can you please explain the way you implement triplet loss a little bit. I implement triplet loss like that. But I do not understand your implementation.
`class TripletLoss(nn.Module):
"""
Triplet loss
Takes embeddings of an anchor sample, a positive sample and a negative sample
"""

def __init__(self, margin):
    super(TripletLoss, self).__init__()
    self.margin = margin

def forward(self, anchor, positive, negative, size_average=True):
    distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
    distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
    losses = F.relu(distance_positive - distance_negative + self.margin)
    return losses.mean() if size_average else losses.sum()

`

@deep0learning
Actually it's calculating the squared distance between anchors and negative or positives and choosing pairs following hard-batch triplet rule.

you can create fake inputs and targets and play with the code, then you'll figure it out.