Skip to content

Commit bcec62f

Browse files
committed
Pushing the docs for revision for branch: master, commit 97c47d96718084ae3eb828b567753a102c2d8437
1 parent 37355e7 commit bcec62f

File tree

782 files changed

+2801
-2759
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

782 files changed

+2801
-2759
lines changed

dev/_downloads/plot_confusion_matrix.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
print(__doc__)
2828

29+
import itertools
2930
import numpy as np
3031
import matplotlib.pyplot as plt
3132

@@ -37,6 +38,7 @@
3738
iris = datasets.load_iris()
3839
X = iris.data
3940
y = iris.target
41+
class_names = iris.target_names
4042

4143
# Split the data into a training set and a test set
4244
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
@@ -47,32 +49,51 @@
4749
y_pred = classifier.fit(X_train, y_train).predict(X_test)
4850

4951

50-
def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues):
52+
def plot_confusion_matrix(cm, classes,
53+
normalize=False,
54+
title='Confusion matrix',
55+
cmap=plt.cm.Blues):
56+
"""
57+
This function prints and plots the confusion matrix.
58+
Normalization can be applied by setting `normalize=True`.
59+
"""
5160
plt.imshow(cm, interpolation='nearest', cmap=cmap)
5261
plt.title(title)
5362
plt.colorbar()
54-
tick_marks = np.arange(len(iris.target_names))
55-
plt.xticks(tick_marks, iris.target_names, rotation=45)
56-
plt.yticks(tick_marks, iris.target_names)
63+
tick_marks = np.arange(len(classes))
64+
plt.xticks(tick_marks, classes, rotation=45)
65+
plt.yticks(tick_marks, classes)
66+
67+
if normalize:
68+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
69+
print("Normalized confusion matrix")
70+
else:
71+
print('Confusion matrix, without normalization')
72+
73+
print(cm)
74+
75+
thresh = cm.max() / 2.
76+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
77+
plt.text(j, i, cm[i, j],
78+
horizontalalignment="center",
79+
color="white" if cm[i, j] > thresh else "black")
80+
5781
plt.tight_layout()
5882
plt.ylabel('True label')
5983
plt.xlabel('Predicted label')
6084

61-
6285
# Compute confusion matrix
63-
cm = confusion_matrix(y_test, y_pred)
86+
cnf_matrix = confusion_matrix(y_test, y_pred)
6487
np.set_printoptions(precision=2)
65-
print('Confusion matrix, without normalization')
66-
print(cm)
88+
89+
# Plot non-normalized confusion matrix
6790
plt.figure()
68-
plot_confusion_matrix(cm)
91+
plot_confusion_matrix(cnf_matrix, classes=class_names,
92+
title='Confusion matrix, without normalization')
6993

70-
# Normalize the confusion matrix by row (i.e by the number of samples
71-
# in each class)
72-
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
73-
print('Normalized confusion matrix')
74-
print(cm_normalized)
94+
# Plot normalized confusion matrix
7595
plt.figure()
76-
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
96+
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
97+
title='Normalized confusion matrix')
7798

7899
plt.show()
20 Bytes
20 Bytes
-132 Bytes
-132 Bytes
-94 Bytes
-94 Bytes
-247 Bytes
-247 Bytes
-92 Bytes

0 commit comments

Comments
 (0)