RaRe-Technologies/movie-plots-by-genre

Some mistake or misundesrstand in plot_confusion_matrix

Closed this issue · 0 comments

Hey there!
I think there is a mistake in notebook:
First of all, i try to undesrtand, why this line was hardcoded:
my_tags = ['sci-fi' , 'action', 'comedy', 'fantasy', 'animation', 'romance']

It used in function:

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(my_tags))
    target_names = my_tags
    plt.xticks(tick_marks, target_names, rotation=45)
    plt.yticks(tick_marks, target_names)
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

cm gets from this function:

def evaluate_prediction(predictions, target, title="Confusion matrix"):
    print('accuracy %s' % accuracy_score(target, predictions))
    cm = confusion_matrix(target, predictions)
    print('confusion matrix\n %s' % cm)
    print('(row=expected, col=predicted)')
    
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plot_confusion_matrix(cm_normalized, title + ' Normalized')

In documentation of scikit confusion_matrix said:

labels : array, shape = [n_classes], optional
List of labels to index the matrix. This may be used to reorder or select a subset of labels. If none is given, those that appear at least once in y_true or y_pred are used in sorted order.

I made some changes in the code, first of all i get tag list from dataframe:
my_tags = df.tag.unique()

Then i change evaluate_prediction function (just add label parameter to confusion matrix):

def evaluate_prediction(predictions, target, title="Confusion matrix"):
    print('accuracy %s' % accuracy_score(target, predictions))
    cm = confusion_matrix(target, predictions, labels=my_tags)
    print('confusion matrix\n %s' % cm)
    print('(row=expected, col=predicted)')
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    plot_confusion_matrix(cm_normalized, title + ' Normalized')

And i have results other than yours in matrix, i think hardcoded line were incorrect.

P.S. Sorry for my poor english.