lelan-li/SSAH

How to plot Precision-Recall curves on cross-modal hashing retrieval task?

Shen-Qiu opened this issue · 2 comments

P-R_curve
I have plotted two Precision-Recall curves on the result of Flickr-25K. Some settings as followed: bit=16, using vgg19 features. The codes and result are provided in SHARING.

These curves look a bit strange because some of their starting points don't show a downward trend as those curves shown in DCMH and SSAH.

Is there a problem with my code?

# -*- coding: utf-8 -*-
"""
    Plot precision-recall curve on the result of MIR-FLICKR-25K
"""
import numpy as np
import matplotlib.pyplot as plt 

def calc_hammingDist(B1, B2):
    q = B2.shape[1]
    disH = 0.5 * (q - np.dot(B1, B2.transpose()))
    return disH


def calc_similarity(label_1, label_2):
    return (np.dot(label_1, label_2.transpose()) > 0).astype(np.float32)


def calc_map(qB, rB, query_L, retrieval_L):
	# qB: {-1,+1}^{mxq}
	# rB: {-1,+1}^{nxq}
	# query_L: {0,1}^{mxl}
	# retrieval_L: {0,1}^{nxl}
	num_query = query_L.shape[0]
	map = 0
	for iter in range(num_query):
		gnd = (np.dot(query_L[iter, :], retrieval_L.transpose()) > 0).astype(np.float32)
		tsum = int(np.sum(gnd))
		if tsum == 0:
			continue
		hamm = calc_hammingDist(qB[iter, :], rB)
		ind = np.argsort(hamm)
		gnd = gnd[ind]
		count = np.linspace(1, tsum, tsum)

		tindex = np.asarray(np.where(gnd == 1)) + 1.0
		map = map + np.mean(count / (tindex))
	map = map / num_query
	return map


def cal_Precision_Recall_Curve(qB, rB, query_L, retrieval_L):
    S = calc_similarity(query_L, retrieval_L)
    dist = calc_hammingDist(qB, rB)
    num = qB.shape[0] # the number of input instances
    
    precision = np.zeros((num, bits + 1))
    recall = np.zeros((num, bits + 1))
    for i in range(num):
        relevant = set(np.where(S[i, :] == 1)[0])
        retrieved = set()
        for bit in range(bits + 1):
            retrieved = set(np.where(dist[i, :] == bit)[0]) | retrieved
            ret_rel = len(retrieved & relevant)
            #print('bit : {0}, Precision: {1:.4f}, Recall: {2:.4f}'.format(bit, 
            #      ret_rel / len(retrieved), ret_rel / len(relevant)))
            recall[i, bit] = ret_rel / len(relevant)
            if len(retrieved) == 0:
                continue
            precision[i, bit] = ret_rel / len(retrieved)
    
    return recall.mean(axis=0), precision.mean(axis=0)


result = np.load('./result_16bits_VGG19.npz')
#qBX = result['qBX'][0:1, :] # image query, just for one instance
qBX = result['qBX'] # image query
qBY = result['qBY'] # text query 
rBX = result['rBX'] # image retrieval 
rBY = result['rBY'] # text retrieval 
#query_L = result['query_L'][0:1, :] # query label, just for one instance
query_L = result['query_L'] # query label
retrieval_L = result['retrieval_L'] # retrieval label

mapi2t = result['mapi2t']
mapt2i = result['mapt2i']
print('mapi2t: {0:.4f}'.format(mapi2t))
print('mapt2i: {0:.4f}'.format(mapt2i))

bits = result['bit']
#calc_map(qBX, rBY, query_L, retrieval_L)
recall, precision = cal_Precision_Recall_Curve(qBX, rBY, query_L, retrieval_L)
fig = plt.figure(1)
ax = fig.add_subplot(121)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Image->Text')
plt.xlabel('Recall')
plt.ylabel('Precision')
#plt.plot()

# Text -> Image
recall, precision = cal_Precision_Recall_Curve(qBY, rBX, query_L, retrieval_L)
ax = fig.add_subplot(122)
ax.scatter(recall, precision)
ax.plot(recall, precision)
ax.set(xlim = [0, 1], ylim = [0.5, 1])
plt.title(r'Text->Image')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.plot()

Hi, you need to exclude the zeros in precision when doing precision.mean(axis=0), you can use np.average() for weighted average.