|
3 | 3 | Recognizing hand-written digits
|
4 | 4 | ================================
|
5 | 5 |
|
6 |
| -An example showing how the scikit-learn can be used to recognize images of |
7 |
| -hand-written digits. |
8 |
| -
|
9 |
| -This example is commented in the |
10 |
| -:ref:`tutorial section of the user manual <introduction>`. |
11 |
| -
|
| 6 | +This example shows how scikit-learn can be used to recognize images of |
| 7 | +hand-written digits, from 0-9. |
12 | 8 | """
|
| 9 | + |
13 | 10 | print(__doc__)
|
14 | 11 |
|
15 | 12 | # Author: Gael Varoquaux <gael dot varoquaux at normalesup dot org>
|
|
22 | 19 | from sklearn import datasets, svm, metrics
|
23 | 20 | from sklearn.model_selection import train_test_split
|
24 | 21 |
|
25 |
| -# The digits dataset |
| 22 | +############################################################################### |
| 23 | +# Digits dataset |
| 24 | +# -------------- |
| 25 | +# |
| 26 | +# The digits dataset consists of 8x8 |
| 27 | +# pixel images of digits. The ``images`` attribute of the dataset stores |
| 28 | +# 8x8 arrays of grayscale values for each image. We will use these arrays to |
| 29 | +# visualize the first 4 images. The ``target`` attribute of the dataset stores |
| 30 | +# the digit each image represents and this is included in the title of the 4 |
| 31 | +# plots below. |
| 32 | +# |
| 33 | +# Note: if we were working from image files (e.g., 'png' files), we would load |
| 34 | +# them using :func:`matplotlib.pyplot.imread`. |
| 35 | + |
26 | 36 | digits = datasets.load_digits()
|
27 | 37 |
|
28 |
| -# The data that we are interested in is made of 8x8 images of digits, let's |
29 |
| -# have a look at the first 4 images, stored in the `images` attribute of the |
30 |
| -# dataset. If we were working from image files, we could load them using |
31 |
| -# matplotlib.pyplot.imread. Note that each image must have the same size. For these |
32 |
| -# images, we know which digit they represent: it is given in the 'target' of |
33 |
| -# the dataset. |
34 |
| -_, axes = plt.subplots(2, 4) |
35 |
| -images_and_labels = list(zip(digits.images, digits.target)) |
36 |
| -for ax, (image, label) in zip(axes[0, :], images_and_labels[:4]): |
| 38 | +_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3)) |
| 39 | +for ax, image, label in zip(axes, digits.images, digits.target): |
37 | 40 | ax.set_axis_off()
|
38 | 41 | ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
|
39 | 42 | ax.set_title('Training: %i' % label)
|
40 | 43 |
|
41 |
| -# To apply a classifier on this data, we need to flatten the image, to |
42 |
| -# turn the data in a (samples, feature) matrix: |
| 44 | +############################################################################### |
| 45 | +# Classification |
| 46 | +# -------------- |
| 47 | +# |
| 48 | +# To apply a classifier on this data, we need to flatten the images, turning |
| 49 | +# each 2-D array of grayscale values from shape ``(8, 8)`` into shape |
| 50 | +# ``(64,)``. Subsequently, the entire dataset will be of shape |
| 51 | +# ``(n_samples, n_features)``, where ``n_samples`` is the number of images and |
| 52 | +# ``n_features`` is the total number of pixels in each image. |
| 53 | +# |
| 54 | +# We can then split the data into train and test subsets and fit a support |
| 55 | +# vector classifier on the train samples. The fitted classifier can |
| 56 | +# subsequently be used to predict the value of the digit for the samples |
| 57 | +# in the test subset. |
| 58 | + |
| 59 | +# flatten the images |
43 | 60 | n_samples = len(digits.images)
|
44 | 61 | data = digits.images.reshape((n_samples, -1))
|
45 | 62 |
|
46 | 63 | # Create a classifier: a support vector classifier
|
47 |
| -classifier = svm.SVC(gamma=0.001) |
| 64 | +clf = svm.SVC(gamma=0.001) |
48 | 65 |
|
49 |
| -# Split data into train and test subsets |
| 66 | +# Split data into 50% train and 50% test subsets |
50 | 67 | X_train, X_test, y_train, y_test = train_test_split(
|
51 | 68 | data, digits.target, test_size=0.5, shuffle=False)
|
52 | 69 |
|
53 |
| -# We learn the digits on the first half of the digits |
54 |
| -classifier.fit(X_train, y_train) |
| 70 | +# Learn the digits on the train subset |
| 71 | +clf.fit(X_train, y_train) |
55 | 72 |
|
56 |
| -# Now predict the value of the digit on the second half: |
57 |
| -predicted = classifier.predict(X_test) |
| 73 | +# Predict the value of the digit on the test subset |
| 74 | +predicted = clf.predict(X_test) |
58 | 75 |
|
59 |
| -images_and_predictions = list(zip(digits.images[n_samples // 2:], predicted)) |
60 |
| -for ax, (image, prediction) in zip(axes[1, :], images_and_predictions[:4]): |
| 76 | +############################################################################### |
| 77 | +# Below we visualize the first 4 test samples and show their predicted |
| 78 | +# digit value in the title. |
| 79 | + |
| 80 | +_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3)) |
| 81 | +for ax, image, prediction in zip(axes, digits.images, predicted): |
61 | 82 | ax.set_axis_off()
|
62 | 83 | ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
|
63 |
| - ax.set_title('Prediction: %i' % prediction) |
| 84 | + ax.set_title(f'Prediction: {prediction}') |
| 85 | + |
| 86 | +############################################################################### |
| 87 | +# :func:`~sklearn.metrics.classification_report` builds a text report showing |
| 88 | +# the main classification metrics. |
| 89 | + |
| 90 | +print(f"Classification report for classifier {clf}:\n" |
| 91 | + f"{metrics.classification_report(y_test, predicted)}\n") |
| 92 | + |
| 93 | +############################################################################### |
| 94 | +# We can also plot a :ref:`confusion matrix <confusion_matrix>` of the |
| 95 | +# true digit values and the predicted digit values. |
64 | 96 |
|
65 |
| -print("Classification report for classifier %s:\n%s\n" |
66 |
| - % (classifier, metrics.classification_report(y_test, predicted))) |
67 |
| -disp = metrics.plot_confusion_matrix(classifier, X_test, y_test) |
| 97 | +disp = metrics.plot_confusion_matrix(clf, X_test, y_test) |
68 | 98 | disp.figure_.suptitle("Confusion Matrix")
|
69 |
| -print("Confusion matrix:\n%s" % disp.confusion_matrix) |
| 99 | +print(f"Confusion matrix:\n{disp.confusion_matrix}") |
70 | 100 |
|
71 | 101 | plt.show()
|
0 commit comments