Skip to content

Commit aa0bbdc

Browse files
committed
Pushing the docs to dev/ for branch: master, commit 0a1ee74a14ed8fe94bb0c7c10c9e3d99db9cd2b8
1 parent e9a71de commit aa0bbdc

File tree

1,065 files changed

+3479
-3424
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,065 files changed

+3479
-3424
lines changed
699 Bytes
Binary file not shown.
678 Bytes
Binary file not shown.

dev/_downloads/plot_confusion_matrix.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
},
2727
"outputs": [],
2828
"source": [
29-
"print(__doc__)\n\nimport itertools\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import svm, datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import confusion_matrix\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = svm.SVC(kernel='linear', C=0.01)\ny_pred = classifier.fit(X_train, y_train).predict(X_test)\n\n\ndef plot_confusion_matrix(cm, classes,\n normalize=False,\n title='Confusion matrix',\n cmap=plt.cm.Blues):\n \"\"\"\n This function prints and plots the confusion matrix.\n Normalization can be applied by setting `normalize=True`.\n \"\"\"\n if normalize:\n cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n print(\"Normalized confusion matrix\")\n else:\n print('Confusion matrix, without normalization')\n\n print(cm)\n\n plt.imshow(cm, interpolation='nearest', cmap=cmap)\n plt.title(title)\n plt.colorbar()\n tick_marks = np.arange(len(classes))\n plt.xticks(tick_marks, classes, rotation=45)\n plt.yticks(tick_marks, classes)\n\n fmt = '.2f' if normalize else 'd'\n thresh = cm.max() / 2.\n for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n plt.text(j, i, format(cm[i, j], fmt),\n horizontalalignment=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\n\n plt.ylabel('True label')\n plt.xlabel('Predicted label')\n plt.tight_layout()\n\n\n# Compute confusion matrix\ncnf_matrix = confusion_matrix(y_test, y_pred)\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names,\n title='Confusion matrix, without normalization')\n\n# Plot normalized confusion matrix\nplt.figure()\nplot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,\n title='Normalized confusion matrix')\n\nplt.show()"
29+
"print(__doc__)\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import svm, datasets\nfrom sklearn.model_selection import train_test_split\nfrom sklearn.metrics import confusion_matrix\nfrom sklearn.utils.multiclass import unique_labels\n\n# import some data to play with\niris = datasets.load_iris()\nX = iris.data\ny = iris.target\nclass_names = iris.target_names\n\n# Split the data into a training set and a test set\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)\n\n# Run classifier, using a model that is too regularized (C too low) to see\n# the impact on the results\nclassifier = svm.SVC(kernel='linear', C=0.01)\ny_pred = classifier.fit(X_train, y_train).predict(X_test)\n\n\ndef plot_confusion_matrix(y_true, y_pred, classes,\n normalize=False,\n title=None,\n cmap=plt.cm.Blues):\n \"\"\"\n This function prints and plots the confusion matrix.\n Normalization can be applied by setting `normalize=True`.\n \"\"\"\n if not title:\n if normalize:\n title = 'Normalized confusion matrix'\n else:\n title = 'Confusion matrix, without normalization'\n\n # Compute confusion matrix\n cm = confusion_matrix(y_true, y_pred)\n # Only use the labels that appear in the data\n classes = classes[unique_labels(y_true, y_pred)]\n if normalize:\n cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n print(\"Normalized confusion matrix\")\n else:\n print('Confusion matrix, without normalization')\n\n print(cm)\n\n fig, ax = plt.subplots()\n im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n ax.figure.colorbar(im, ax=ax)\n # We want to show all ticks...\n ax.set(xticks=np.arange(cm.shape[1]),\n yticks=np.arange(cm.shape[0]),\n # ... and label them with the respective list entries\n xticklabels=classes, yticklabels=classes,\n title=title,\n ylabel='True label',\n xlabel='Predicted label')\n\n # Rotate the tick labels and set their alignment.\n plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n rotation_mode=\"anchor\")\n\n # Loop over data dimensions and create text annotations.\n fmt = '.2f' if normalize else 'd'\n thresh = cm.max() / 2.\n for i in range(cm.shape[0]):\n for j in range(cm.shape[1]):\n ax.text(j, i, format(cm[i, j], fmt),\n ha=\"center\", va=\"center\",\n color=\"white\" if cm[i, j] > thresh else \"black\")\n fig.tight_layout()\n return ax\n\n\nnp.set_printoptions(precision=2)\n\n# Plot non-normalized confusion matrix\nplot_confusion_matrix(y_test, y_pred, classes=class_names,\n title='Confusion matrix, without normalization')\n\n# Plot normalized confusion matrix\nplot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,\n title='Normalized confusion matrix')\n\nplt.show()"
3030
]
3131
}
3232
],

dev/_downloads/plot_confusion_matrix.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626

2727
print(__doc__)
2828

29-
import itertools
3029
import numpy as np
3130
import matplotlib.pyplot as plt
3231

3332
from sklearn import svm, datasets
3433
from sklearn.model_selection import train_test_split
3534
from sklearn.metrics import confusion_matrix
35+
from sklearn.utils.multiclass import unique_labels
3636

3737
# import some data to play with
3838
iris = datasets.load_iris()
@@ -49,14 +49,24 @@
4949
y_pred = classifier.fit(X_train, y_train).predict(X_test)
5050

5151

52-
def plot_confusion_matrix(cm, classes,
52+
def plot_confusion_matrix(y_true, y_pred, classes,
5353
normalize=False,
54-
title='Confusion matrix',
54+
title=None,
5555
cmap=plt.cm.Blues):
5656
"""
5757
This function prints and plots the confusion matrix.
5858
Normalization can be applied by setting `normalize=True`.
5959
"""
60+
if not title:
61+
if normalize:
62+
title = 'Normalized confusion matrix'
63+
else:
64+
title = 'Confusion matrix, without normalization'
65+
66+
# Compute confusion matrix
67+
cm = confusion_matrix(y_true, y_pred)
68+
# Only use the labels that appear in the data
69+
classes = classes[unique_labels(y_true, y_pred)]
6070
if normalize:
6171
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
6272
print("Normalized confusion matrix")
@@ -65,37 +75,42 @@ def plot_confusion_matrix(cm, classes,
6575

6676
print(cm)
6777

68-
plt.imshow(cm, interpolation='nearest', cmap=cmap)
69-
plt.title(title)
70-
plt.colorbar()
71-
tick_marks = np.arange(len(classes))
72-
plt.xticks(tick_marks, classes, rotation=45)
73-
plt.yticks(tick_marks, classes)
74-
78+
fig, ax = plt.subplots()
79+
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
80+
ax.figure.colorbar(im, ax=ax)
81+
# We want to show all ticks...
82+
ax.set(xticks=np.arange(cm.shape[1]),
83+
yticks=np.arange(cm.shape[0]),
84+
# ... and label them with the respective list entries
85+
xticklabels=classes, yticklabels=classes,
86+
title=title,
87+
ylabel='True label',
88+
xlabel='Predicted label')
89+
90+
# Rotate the tick labels and set their alignment.
91+
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
92+
rotation_mode="anchor")
93+
94+
# Loop over data dimensions and create text annotations.
7595
fmt = '.2f' if normalize else 'd'
7696
thresh = cm.max() / 2.
77-
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
78-
plt.text(j, i, format(cm[i, j], fmt),
79-
horizontalalignment="center",
80-
color="white" if cm[i, j] > thresh else "black")
81-
82-
plt.ylabel('True label')
83-
plt.xlabel('Predicted label')
84-
plt.tight_layout()
97+
for i in range(cm.shape[0]):
98+
for j in range(cm.shape[1]):
99+
ax.text(j, i, format(cm[i, j], fmt),
100+
ha="center", va="center",
101+
color="white" if cm[i, j] > thresh else "black")
102+
fig.tight_layout()
103+
return ax
85104

86105

87-
# Compute confusion matrix
88-
cnf_matrix = confusion_matrix(y_test, y_pred)
89106
np.set_printoptions(precision=2)
90107

91108
# Plot non-normalized confusion matrix
92-
plt.figure()
93-
plot_confusion_matrix(cnf_matrix, classes=class_names,
109+
plot_confusion_matrix(y_test, y_pred, classes=class_names,
94110
title='Confusion matrix, without normalization')
95111

96112
# Plot normalized confusion matrix
97-
plt.figure()
98-
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
113+
plot_confusion_matrix(y_test, y_pred, classes=class_names, normalize=True,
99114
title='Normalized confusion matrix')
100115

101116
plt.show()

dev/_downloads/scikit-learn-docs.pdf

21.3 KB
Binary file not shown.

dev/_images/iris.png

0 Bytes

0 commit comments

Comments
 (0)