|
5 | 5 |
|
6 | 6 | Example of Precision-Recall metric to evaluate classifier output quality.
|
7 | 7 |
|
8 |
| -In information retrieval, precision is a measure of result relevancy, while |
9 |
| -recall is a measure of how many truly relevant results are returned. A high |
10 |
| -area under the curve represents both high recall and high precision, where high |
11 |
| -precision relates to a low false positive rate, and high recall relates to a |
12 |
| -low false negative rate. High scores for both show that the classifier is |
13 |
| -returning accurate results (high precision), as well as returning a majority of |
14 |
| -all positive results (high recall). |
| 8 | +Precision-Recall is a useful measure of success of prediction when the |
| 9 | +classes are very imbalanced. In information retrieval, precision is a |
| 10 | +measure of result relevancy, while recall is a measure of how many truly |
| 11 | +relevant results are returned. |
| 12 | +
|
| 13 | +The precision-recall curve shows the tradeoff between precision and |
| 14 | +recall for different threshold. A high area under the curve represents |
| 15 | +both high recall and high precision, where high precision relates to a |
| 16 | +low false positive rate, and high recall relates to a low false negative |
| 17 | +rate. High scores for both show that the classifier is returning accurate |
| 18 | +results (high precision), as well as returning a majority of all positive |
| 19 | +results (high recall). |
15 | 20 |
|
16 | 21 | A system with high recall but low precision returns many results, but most of
|
17 | 22 | its predicted labels are incorrect when compared to the training labels. A
|
|
37 | 42 |
|
38 | 43 | :math:`F1 = 2\\frac{P \\times R}{P+R}`
|
39 | 44 |
|
40 |
| -It is important to note that the precision may not decrease with recall. The |
| 45 | +Note that the precision may not decrease with recall. The |
41 | 46 | definition of precision (:math:`\\frac{T_p}{T_p + F_p}`) shows that lowering
|
42 | 47 | the threshold of a classifier may increase the denominator, by increasing the
|
43 | 48 | number of results returned. If the threshold was previously set too high, the
|
|
54 | 59 | The relationship between recall and precision can be observed in the
|
55 | 60 | stairstep area of the plot - at the edges of these steps a small change
|
56 | 61 | in the threshold considerably reduces precision, with only a minor gain in
|
57 |
| -recall. See the corner at recall = .59, precision = .8 for an example of this |
58 |
| -phenomenon. |
| 62 | +recall. |
| 63 | +
|
| 64 | +**Average precision** summarizes such a plot as the weighted mean of precisions |
| 65 | +achieved at each threshold, with the increase in recall from the previous |
| 66 | +threshold used as the weight: |
| 67 | +
|
| 68 | +:math:`\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n` |
| 69 | +
|
| 70 | +where :math:`P_n` and :math:`R_n` are the precision and recall at the |
| 71 | +nth threshold. A pair :math:`(R_k, P_k)` is referred to as an |
| 72 | +*operating point*. |
59 | 73 |
|
60 | 74 | Precision-recall curves are typically used in binary classification to study
|
61 |
| -the output of a classifier. In order to extend Precision-recall curve and |
| 75 | +the output of a classifier. In order to extend the precision-recall curve and |
62 | 76 | average precision to multi-class or multi-label classification, it is necessary
|
63 | 77 | to binarize the output. One curve can be drawn per label, but one can also draw
|
64 | 78 | a precision-recall curve by considering each element of the label indicator
|
|
71 | 85 | :func:`sklearn.metrics.precision_score`,
|
72 | 86 | :func:`sklearn.metrics.f1_score`
|
73 | 87 | """
|
74 |
| -print(__doc__) |
75 |
| - |
76 |
| -import matplotlib.pyplot as plt |
77 |
| -import numpy as np |
78 |
| -from itertools import cycle |
| 88 | +from __future__ import print_function |
79 | 89 |
|
| 90 | +############################################################################### |
| 91 | +# In binary classification settings |
| 92 | +# -------------------------------------------------------- |
| 93 | +# |
| 94 | +# Create simple data |
| 95 | +# .................. |
| 96 | +# |
| 97 | +# Try to differentiate the two first classes of the iris data |
80 | 98 | from sklearn import svm, datasets
|
81 |
| -from sklearn.metrics import precision_recall_curve |
82 |
| -from sklearn.metrics import average_precision_score |
83 | 99 | from sklearn.model_selection import train_test_split
|
84 |
| -from sklearn.preprocessing import label_binarize |
85 |
| -from sklearn.multiclass import OneVsRestClassifier |
| 100 | +import numpy as np |
86 | 101 |
|
87 |
| -# import some data to play with |
88 | 102 | iris = datasets.load_iris()
|
89 | 103 | X = iris.data
|
90 | 104 | y = iris.target
|
91 | 105 |
|
92 |
| -# setup plot details |
93 |
| -colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) |
94 |
| -lw = 2 |
95 |
| - |
96 |
| -# Binarize the output |
97 |
| -y = label_binarize(y, classes=[0, 1, 2]) |
98 |
| -n_classes = y.shape[1] |
99 |
| - |
100 | 106 | # Add noisy features
|
101 | 107 | random_state = np.random.RandomState(0)
|
102 | 108 | n_samples, n_features = X.shape
|
103 | 109 | X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
|
104 | 110 |
|
| 111 | +# Limit to the two first classes, and split into training and test |
| 112 | +X_train, X_test, y_train, y_test = train_test_split(X[y < 2], y[y < 2], |
| 113 | + test_size=.5, |
| 114 | + random_state=random_state) |
| 115 | + |
| 116 | +# Create a simple classifier |
| 117 | +classifier = svm.LinearSVC(random_state=random_state) |
| 118 | +classifier.fit(X_train, y_train) |
| 119 | +y_score = classifier.decision_function(X_test) |
| 120 | + |
| 121 | +############################################################################### |
| 122 | +# Compute the average precision score |
| 123 | +# ................................... |
| 124 | +from sklearn.metrics import average_precision_score |
| 125 | +average_precision = average_precision_score(y_test, y_score) |
| 126 | + |
| 127 | +print('Average precision-recall score: {0:0.2f}'.format( |
| 128 | + average_precision)) |
| 129 | + |
| 130 | +############################################################################### |
| 131 | +# Plot the Precision-Recall curve |
| 132 | +# ................................ |
| 133 | +from sklearn.metrics import precision_recall_curve |
| 134 | +import matplotlib.pyplot as plt |
| 135 | + |
| 136 | +precision, recall, _ = precision_recall_curve(y_test, y_score) |
| 137 | + |
| 138 | +plt.step(recall, precision, color='b', alpha=0.2, |
| 139 | + where='post') |
| 140 | +plt.fill_between(recall, precision, step='post', alpha=0.2, |
| 141 | + color='b') |
| 142 | + |
| 143 | +plt.xlabel('Recall') |
| 144 | +plt.ylabel('Precision') |
| 145 | +plt.ylim([0.0, 1.05]) |
| 146 | +plt.xlim([0.0, 1.0]) |
| 147 | +plt.title('2-class Precision-Recall curve: AUC={0:0.2f}'.format( |
| 148 | + average_precision)) |
| 149 | + |
| 150 | +############################################################################### |
| 151 | +# In multi-label settings |
| 152 | +# ------------------------ |
| 153 | +# |
| 154 | +# Create multi-label data, fit, and predict |
| 155 | +# ........................................... |
| 156 | +# |
| 157 | +# We create a multi-label dataset, to illustrate the precision-recall in |
| 158 | +# multi-label settings |
| 159 | + |
| 160 | +from sklearn.preprocessing import label_binarize |
| 161 | + |
| 162 | +# Use label_binarize to be multi-label like settings |
| 163 | +Y = label_binarize(y, classes=[0, 1, 2]) |
| 164 | +n_classes = Y.shape[1] |
| 165 | + |
105 | 166 | # Split into training and test
|
106 |
| -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5, |
| 167 | +X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.5, |
107 | 168 | random_state=random_state)
|
108 | 169 |
|
| 170 | +# We use OneVsRestClassifier for multi-label prediction |
| 171 | +from sklearn.multiclass import OneVsRestClassifier |
| 172 | + |
109 | 173 | # Run classifier
|
110 |
| -classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, |
111 |
| - random_state=random_state)) |
112 |
| -y_score = classifier.fit(X_train, y_train).decision_function(X_test) |
| 174 | +classifier = OneVsRestClassifier(svm.LinearSVC(random_state=random_state)) |
| 175 | +classifier.fit(X_train, Y_train) |
| 176 | +y_score = classifier.decision_function(X_test) |
113 | 177 |
|
114 |
| -# Compute Precision-Recall and plot curve |
| 178 | + |
| 179 | +############################################################################### |
| 180 | +# The average precision score in multi-label settings |
| 181 | +# .................................................... |
| 182 | +from sklearn.metrics import precision_recall_curve |
| 183 | +from sklearn.metrics import average_precision_score |
| 184 | + |
| 185 | +# For each class |
115 | 186 | precision = dict()
|
116 | 187 | recall = dict()
|
117 | 188 | average_precision = dict()
|
118 | 189 | for i in range(n_classes):
|
119 |
| - precision[i], recall[i], _ = precision_recall_curve(y_test[:, i], |
| 190 | + precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i], |
120 | 191 | y_score[:, i])
|
121 |
| - average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i]) |
| 192 | + average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i]) |
122 | 193 |
|
123 |
| -# Compute micro-average ROC curve and ROC area |
124 |
| -precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(), |
| 194 | +# A "micro-average": quantifying score on all classes jointly |
| 195 | +precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(), |
125 | 196 | y_score.ravel())
|
126 |
| -average_precision["micro"] = average_precision_score(y_test, y_score, |
| 197 | +average_precision["micro"] = average_precision_score(Y_test, y_score, |
127 | 198 | average="micro")
|
| 199 | +print('Average precision score, micro-averaged over all classes: {0:0.2f}' |
| 200 | + .format(average_precision["micro"])) |
128 | 201 |
|
| 202 | +############################################################################### |
| 203 | +# Plot the micro-averaged Precision-Recall curve |
| 204 | +# ............................................... |
| 205 | +# |
| 206 | + |
| 207 | +plt.figure() |
| 208 | +plt.step(recall['micro'], precision['micro'], color='b', alpha=0.2, |
| 209 | + where='post') |
| 210 | +plt.fill_between(recall["micro"], precision["micro"], step='post', alpha=0.2, |
| 211 | + color='b') |
129 | 212 |
|
130 |
| -# Plot Precision-Recall curve |
131 |
| -plt.clf() |
132 |
| -plt.plot(recall[0], precision[0], lw=lw, color='navy', |
133 |
| - label='Precision-Recall curve') |
134 | 213 | plt.xlabel('Recall')
|
135 | 214 | plt.ylabel('Precision')
|
136 | 215 | plt.ylim([0.0, 1.05])
|
137 | 216 | plt.xlim([0.0, 1.0])
|
138 |
| -plt.title('Precision-Recall example: AUC={0:0.2f}'.format(average_precision[0])) |
139 |
| -plt.legend(loc="lower left") |
140 |
| -plt.show() |
| 217 | +plt.title( |
| 218 | + 'Average precision score, micro-averaged over all classes: AUC={0:0.2f}' |
| 219 | + .format(average_precision["micro"])) |
141 | 220 |
|
| 221 | +############################################################################### |
142 | 222 | # Plot Precision-Recall curve for each class and iso-f1 curves
|
143 |
| -plt.clf() |
| 223 | +# ............................................................. |
| 224 | +# |
| 225 | +from itertools import cycle |
| 226 | +# setup plot details |
| 227 | +colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal']) |
| 228 | + |
| 229 | +plt.figure(figsize=(7, 8)) |
144 | 230 | f_scores = np.linspace(0.2, 0.8, num=4)
|
145 | 231 | lines = []
|
146 | 232 | labels = []
|
|
152 | 238 |
|
153 | 239 | lines.append(l)
|
154 | 240 | labels.append('iso-f1 curves')
|
155 |
| -l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=lw) |
| 241 | +l, = plt.plot(recall["micro"], precision["micro"], color='gold', lw=2) |
156 | 242 | lines.append(l)
|
157 |
| -labels.append('micro-average Precision-recall curve (area = {0:0.2f})' |
| 243 | +labels.append('micro-average Precision-recall (area = {0:0.2f})' |
158 | 244 | ''.format(average_precision["micro"]))
|
| 245 | + |
159 | 246 | for i, color in zip(range(n_classes), colors):
|
160 |
| - l, = plt.plot(recall[i], precision[i], color=color, lw=lw) |
| 247 | + l, = plt.plot(recall[i], precision[i], color=color, lw=2) |
161 | 248 | lines.append(l)
|
162 |
| - labels.append('Precision-recall curve of class {0} (area = {1:0.2f})' |
| 249 | + labels.append('Precision-recall for class {0} (area = {1:0.2f})' |
163 | 250 | ''.format(i, average_precision[i]))
|
164 | 251 |
|
165 | 252 | fig = plt.gcf()
|
166 |
| -fig.set_size_inches(7, 7) |
167 | 253 | fig.subplots_adjust(bottom=0.25)
|
168 | 254 | plt.xlim([0.0, 1.0])
|
169 | 255 | plt.ylim([0.0, 1.05])
|
170 | 256 | plt.xlabel('Recall')
|
171 | 257 | plt.ylabel('Precision')
|
172 | 258 | plt.title('Extension of Precision-Recall curve to multi-class')
|
173 |
| -plt.figlegend(lines, labels, loc='lower center') |
| 259 | +plt.legend(lines, labels, loc=(0, -.38), prop=dict(size=14)) |
| 260 | + |
| 261 | + |
174 | 262 | plt.show()
|
0 commit comments