THUwangcy/ReChorus

Question about the evaluate method

Closed this issue · 2 comments

Hi Wang,

Recently, I was reading the source code of Rechorus. There is something coufused me in BaseRunner.py.

It's about evaluate_method in BaseRunner.py on line 48.

def evaluate_method(predictions: np.ndarray, topk: list, metrics: list) -> Dict[str, float]:
    """
    :param predictions: (-1, n_candidates) shape, the first column is the score for ground-truth item
    :param topk: top-K value list
    :param metrics: metric string list
    :return: a result dict, the keys are metric@topk
    """
    evaluations = dict()
    sort_idx = (-predictions).argsort(axis=1)  
    gt_rank = np.argwhere(sort_idx == 0)[:, 1] + 1 
    for k in topk:
        hit = (gt_rank <= k)
        for metric in metrics:
            key = '{}@{}'.format(metric, k)
            if metric == 'HR':
                evaluations[key] = hit.mean()
            elif metric == 'NDCG':
                evaluations[key] = (hit / np.log2(gt_rank + 1)).mean()
            else:
                raise ValueError('Undefined evaluation metric: {}.'.format(metric))
  return evaluations

As the comment says, the first column indicates the score of ground_truth. The sort_idx contains the index values of the array values in descending order. The sort_idx == 0 represents the highest record.

My confusion is that the code hit = (gt_rank <= k). In my understanding, gt_rank means the item with the highest score, not necessarily the ground_truth, but also the random sample item. Can you please explain this for me?

For example, when predictions is [[0.6, 0.7, 0.1, 0.2]], the score of the ground_truth item is 0.6.

sort_idx will return [[1, 0, 3, 2]] that indicates the index of values in descending order.

Considering that the ground_truth item always corresponds to the index 0, sort_idx == 0 actually finds the rank of the ground_truth item, rather than random sampled items.

Thank you very much for the explanation. I didn't know enough about this function before, but now I understand it.