How to plot Precision-Recall curves on cross-modal hashing retrieval task?
Shen-Qiu opened this issue · 2 comments
Shen-Qiu commented
I have plotted two Precision-Recall curves on the result of Flickr-25K. Some settings as followed: bit=16, using vgg19 features. The codes and result are provided in SHARING.
These curves look a bit strange because some of their starting points don't show a downward trend as those curves shown in DCMH and SSAH.
Is there a problem with my code?
Shen-Qiu commented
# -*- coding: utf-8 -*-
"""
Plot precision-recall curve on the result of MIR-FLICKR-25K
"""
import numpy as np
import matplotlib.pyplot as plt
def calc_hammingDist(B1, B2):
q = B2.shape[1]
disH = 0.5 * (q - np.dot(B1, B2.transpose()))
return disH
def calc_similarity(label_1, label_2):
return (np.dot(label_1, label_2.transpose()) > 0).astype(np.float32)
def calc_map(qB, rB, query_L, retrieval_L):
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# query_L: {0,1}^{mxl}
# retrieval_L: {0,1}^{nxl}
num_query = query_L.shape[0]
map = 0
for iter in range(num_query):
gnd = (np.dot(query_L[iter, :], retrieval_L.transpose()) > 0).astype(np.float32)
tsum = int(np.sum(gnd))
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
ind = np.argsort(hamm)
gnd = gnd[ind]
count = np.linspace(1, tsum, tsum)
tindex = np.asarray(np.where(gnd == 1)) + 1.0
map = map + np.mean(count / (tindex))
map = map / num_query
return map
def cal_Precision_Recall_Curve(qB, rB, query_L, retrieval_L):
S = calc_similarity(query_L, retrieval_L)
dist = calc_hammingDist(qB, rB)
num = qB.shape[0] # the number of input instances
precision = np.zeros((num, bits + 1))
recall = np.zeros((num, bits + 1))
for i in range(num):
relevant = set(np.where(S[i, :] == 1)[0])
retrieved = set()
for bit in range(bits + 1):
retrieved = set(np.where(dist[i, :] == bit)[0]) | retrieved
ret_rel = len(retrieved & relevant)
#print('bit : {0}, Precision: {1:.4f}, Recall: {2:.4f}'.format(bit,
# ret_rel / len(retrieved), ret_rel / len(relevant)))
recall[i, bit] = ret_rel / len(relevant)
if len(retrieved) == 0:
continue
precision[i, bit] = ret_rel / len(retrieved)
return recall.mean(axis=0), precision.mean(axis=0)
result = np.load('./result_16bits_VGG19.npz')
#qBX = result['qBX'][0:1, :] # image query, just for one instance
qBX = result['qBX'] # image query
qBY = result['qBY'] # text query
rBX = result['rBX'] # image retrieval
rBY = result['rBY'] # text retrieval
#query_L = result['query_L'][0:1, :] # query label, just for one instance
query_L = result['query_L'] # query label
retrieval_L = result['retrieval_L'] # retrieval label
mapi2t = result['mapi2t']
mapt2i = result['mapt2i']
print('mapi2t: {0:.4f}'.format(mapi2t))
print('mapt2i: {0:.4f}'.format(mapt2i))
bits = result['bit']
#calc_map(qBX, rBY, query_L, retrieval_L)
recall, precision = cal_Precision_Recall_Curve(qBX, rBY, query_L, retrieval_L)
fig = plt.figure(1)
ax = fig.add_subplot(121)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Image->Text')
plt.xlabel('Recall')
plt.ylabel('Precision')
#plt.plot()
# Text -> Image
recall, precision = cal_Precision_Recall_Curve(qBY, rBX, query_L, retrieval_L)
ax = fig.add_subplot(122)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Text->Image')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.plot()
anan1030 commented
Hi, you need to exclude the zeros in precision when doing precision.mean(axis=0), you can use np.average() for weighted average.