how to calculate average precision for a multilabel classification problem.

Step-by-step explanation with a simple example.


Let's consider a multilabel classification problem with 4 classes (A, B, C, D) and 5 samples (S1, S2, S3, S4, S5).
The ground truth labels for each sample are given below:

Targets

Sample A B C D
S1 1.00 0.00 1.00 0.00
S2 0.00 1.00 0.00 0.00
S3 1.00 1.00 1.00 0.00
S4 0.00 0.00 0.00 1.00
S5 1.00 1.00 0.00 0.00

Now, let's assume that we have a classifier that predicts the following probabilities for each class and each sample:

Predictions

Sample A B C D
S1 0.80 0.20 0.65 0.90
S2 0.30 0.20 0.40 0.85
S3 0.20 0.70 0.45 0.85
S4 0.10 0.30 0.70 0.95
S5 0.70 0.60 0.45 0.80

To calculate the mean average precision, we need to compute the precision-recall curve for each class, and then average the area under each curve.

Here are the steps to compute the precision-recall curve for class A:


1. Sort the samples by their predicted probability for class A in descending order:

S1:  0.80  0.20  0.65  0.90  (A, C)
S5:  0.70  0.60  0.45  0.80  (A, B)
S2:  0.30  0.20  0.40  0.85  (B)
S3:  0.20  0.70  0.45  0.85  (A, B, C)
S4:  0.10  0.30  0.70  0.95  (D)

2. Compute precision and recall for each threshold:

$$ P = {TP \over TP + FP} \quad\quad\quad R = {TP \over TP + FN} $$

Class A - Table:

Sample probability target
S1 0.80 1.00
S5 0.70 1.00
S3 0.20 1.00
S2 0.30 0.00
S4 0.10 0.00
Threshold = 0.80 Threshold = 0.70 Threshold = 0.30 Threshold = 0.20 Threshold = 0.10
TP = 1 (S1)
FP = 0
FN = 2 (S3, S5)
TN = 2 (S2, S4)
P = 1.00
R = 0.33
TP = 2 (S1, S5)
FP = 0
FN = 1 (S3)
TN = 2 (S2, S4)
P = 1.00
R = 0.66
TP = 2 (S1, S5)
FP = 1 (S2)
FN = 1 (S3)
TN = 1 (S4)
P = 0.66
R = 0.66
TP = 3 (S1, S3, S5)
FP = 1 (S2)
FN = 0
TN = 1 (S4)
P = 0.75
R = 1.00
TP = 3 (S1, S3, S5)
FP = 2 (S2, S4)
FN = 0
TN = 0
P = 0.60
R = 1.00

3. Compute the area under the PR curve

  • Join the points ( recall, max precision @ recall ) to create the curve.

Possible recalls: 0.33, 0.66, 1.00
Max precisions: 1.00, 1.00, 0.75


Git flow

AP_class_A = (1*0.66666) + (0.75*(1-0.66666)) = 0.9166575

4. Repeat for each class then average the result


Class B

Threshold = 0.70 Threshold = 0.60 Threshold = 0.30 Threshold = 0.20
P = 1.00
R = 0.33
P = 1.00
R = 0.66
P = 0.66
R = 0.66
P = 0.60
R = 1.00
AP_class_B = (1*0.6666) + (0.60)*(1-0.6666) = 0.86664

Class C

Threshold = 0.70 Threshold = 0.65 Threshold = 0.45 Threshold = 0.40
P = undefined
R = 0.33
P = 0.50
R = 0.50
P = 0.50
R = 1.00
P = 0.40
R = 1.00
AP_class_C = 1*0.50 = 0.50

Class D

Threshold = 0.95 Threshold = 0.90 Threshold = 0.85 Threshold = 0.80
P = 1.00
R = 1.00
P = 0.50
R = 1.00
P = 0.25
R = 1.00
P = 0.20
R = 1.00
AP_class_D = 1*1 = 1.00

Finally we have:

     A       B       C       D
AP = 0.9167, 0.8666, 0.5000, 1.0000
mAP = 0.8208

Lets check with TorchMetrics and Scikit learn implementations of average precision.


TorchMetrics

import torch
from torchmetrics.classification import MultilabelAveragePrecision

metric = MultilabelAveragePrecision(num_labels=4, average=None, thresholds=None)

pred = torch.tensor([[0.8, 0.20, 0.65, 0.90],
                     [0.3, 0.20, 0.40, 0.85],
                     [0.2, 0.70, 0.45, 0.85],
                     [0.1, 0.30, 0.70, 0.95],
                     [0.7, 0.60, 0.45, 0.80]])

targ = torch.tensor([[1., 0., 1., 0.],
                     [0., 1., 0., 0.],
                     [1., 1., 1., 0.],
                     [0., 0., 0., 1.],
                     [1., 1., 0., 0.]]).type(torch.int)

r = metric(pred, targ)
print(f'AP: {r}')
print(f'mAP: {torch.mean(r).item():.4f}')
out:
AP: tensor([0.9167, 0.8667, 0.5000, 1.0000])
mAP: 0.8208

Scikit learn

import torch
import numpy as np
from sklearn.metrics import average_precision_score

pred = torch.tensor([[0.8, 0.20, 0.65, 0.90],
                     [0.3, 0.20, 0.40, 0.85],
                     [0.2, 0.70, 0.45, 0.85],
                     [0.1, 0.30, 0.70, 0.95],
                     [0.7, 0.60, 0.45, 0.80]])

targ = torch.tensor([[1., 0., 1., 0.],
                     [0., 1., 0., 0.],
                     [1., 1., 1., 0.],
                     [0., 0., 0., 1.],
                     [1., 1., 0., 0.]]).type(torch.int)

r = average_precision_score(targ, pred, average=None)

print(f'AP: {r}')
print(f'mAP: {np.mean(r).item():.4f}')
out:
AP: [0.91666667 0.86666667 0.5        1.        ]
mAP: 0.8208