这段挖掘hard 负样本的方法不太懂,还请大佬赐教
Opened this issue · 0 comments
SapereAudo commented
class HardNegativeMining(tf.keras.layers.Layer):
"""Hard Negative"""
def __init__(self, num_hard_negatives: int, **kwargs):
super(HardNegativeMining, self).__init__(**kwargs)
self._num_hard_negatives = num_hard_negatives
def call(self, logits: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
num_sampled = tf.minimum(self._num_hard_negatives + 1, tf.shape(logits)[1])
_, indices = tf.nn.top_k(logits + labels * MAX_FLOAT, k=num_sampled, sorted=False)
logits = _gather_elements_along_row(logits, indices)
labels = _gather_elements_along_row(labels, indices)
return logits, labels