reiinakano/scikit-plot

ValueError: Found input variables with inconsistent numbers of samples

AntonioAntovski opened this issue · 3 comments

I'm trying to plot the ROC curve, but I get ValueError: Found input variables with inconsistent numbers of samples.
Here's the code I use:

`skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values)

plt.show()`

Both labels_test.values and pred_w2v_cnn.values have the same length and both are of type np.ndarray. I'd be thankful if anyone can help me to solve this problem.

It will be easier to debug if you could post a minimal reproducible sample code that shows the error

Here's the code I use:

`labels_test = test.label
pred_w2v_cnn = pd.read_csv("predicted_word2vec_cnn.csv", sep=',', header=0, names=['index', 0, 1, 2, 3, 4, 5, 6])

#test_labels = labels_test.values.reshape((len(labels_test.values), 1))

skplt.metrics.plot_roc(labels_test.values, pred_w2v_cnn.values)
plt.show()`

Shape of test_labels: (143455, )
Shape of pred_w2v_cnn: (143455, 8)

I tried to reshape the test_labels to (143455, 1), but that didn't work either.

@AntonioAntovski 'plot_roc' function is based on sklearn's 'roc_curve' function, this function will check input data shape. Maybe you should not use the 'index' column, because your label is 7-classes, but you give your prediction probability result is 8D, so raise this error. Drop it, then plot again.

For testing, this is new code:

from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import scikitplot as skplt
from sklearn.linear_model import LogisticRegression
from collections import Counter

iris = load_iris()
x, y = iris.data, iris.target
lr = LogisticRegression()
lr.fit(x, y)
<> this is model prediction prob result.
prob = lr.predict_proba(x)
tmp = np.random.random((len(y), 4))

<> this will work.
skplt.metrics.plot_roc(y, prob)

print('Different Classes Count res: ', Counter(y))
<> Because label is 3-classes, but given object result 'tmp' is 4D, so 
<> this failed, raise error: ValueError: Found input variables with inconsistent numbers of samples
skplt.metrics.plot_roc(y, tmp)
plt.show()