sqrt() of the margin in Triplet loss
dinarkino opened this issue · 2 comments
dinarkino commented
Thank you for the work! Could you please clarify the moment with sqrt() of margin in Triplet loss? Why you do that? Do we need sqrt there?
# original paper/code doesn't sqrt() the distances, we do, so sqrt() the margin, I think :D
criterion = nn.TripletMarginLoss(margin=opt.margin**0.5,
p=2, reduction='sum').to(device)
Nanne commented
I guess the comment is a bit vague, I think this has to do with the original code using squared L2 distance, and in this code base L2 distance is used. So instead of using the same margin I use sqrt of that margin, so that it lines up with the difference in distance function.
It probably shouldn't be hardcoded to do sqrt, but don't think it's a major factor - worthwhile to experiment with nonetheless.
dinarkino commented
Ok, I see, thank you for the answer!