LongmaoTeamTf/deep_recommenders

这段挖掘hard 负样本的方法不太懂,还请大佬赐教

Opened this issue · 0 comments

class HardNegativeMining(tf.keras.layers.Layer):
"""Hard Negative"""

def __init__(self, num_hard_negatives: int, **kwargs):
    super(HardNegativeMining, self).__init__(**kwargs)

    self._num_hard_negatives = num_hard_negatives

def call(self, logits: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    num_sampled = tf.minimum(self._num_hard_negatives + 1, tf.shape(logits)[1])

    _, indices = tf.nn.top_k(logits + labels * MAX_FLOAT, k=num_sampled, sorted=False)

    logits = _gather_elements_along_row(logits, indices)
    labels = _gather_elements_along_row(labels, indices)

    return logits, labels