|
3 | 3 | Receiver Operating Characteristic (ROC) with cross validation
|
4 | 4 | =============================================================
|
5 | 5 |
|
6 |
| -This example presents how to estimate and visualize the variance of the Receiver |
7 |
| -Operating Characteristic (ROC) metric using cross-validation. |
| 6 | +Example of Receiver Operating Characteristic (ROC) metric to evaluate |
| 7 | +classifier output quality using cross-validation. |
8 | 8 |
|
9 |
| -ROC curves typically feature true positive rate (TPR) on the Y axis, and false |
10 |
| -positive rate (FPR) on the X axis. This means that the top left corner of the |
11 |
| -plot is the "ideal" point - a FPR of zero, and a TPR of one. This is not very |
12 |
| -realistic, but it does mean that a larger Area Under the Curve (AUC) is usually |
13 |
| -better. The "steepness" of ROC curves is also important, since it is ideal to |
14 |
| -maximize the TPR while minimizing the FPR. |
| 9 | +ROC curves typically feature true positive rate on the Y axis, and false |
| 10 | +positive rate on the X axis. This means that the top left corner of the plot is |
| 11 | +the "ideal" point - a false positive rate of zero, and a true positive rate of |
| 12 | +one. This is not very realistic, but it does mean that a larger area under the |
| 13 | +curve (AUC) is usually better. |
| 14 | +
|
| 15 | +The "steepness" of ROC curves is also important, since it is ideal to maximize |
| 16 | +the true positive rate while minimizing the false positive rate. |
15 | 17 |
|
16 | 18 | This example shows the ROC response of different datasets, created from K-fold
|
17 | 19 | cross-validation. Taking all of these curves, it is possible to calculate the
|
18 |
| -mean AUC, and see the variance of the curve when the |
| 20 | +mean area under curve, and see the variance of the curve when the |
19 | 21 | training set is split into different subsets. This roughly shows how the
|
20 |
| -classifier output is affected by changes in the training data, and how different |
21 |
| -the splits generated by K-fold cross-validation are from one another. |
| 22 | +classifier output is affected by changes in the training data, and how |
| 23 | +different the splits generated by K-fold cross-validation are from one another. |
22 | 24 |
|
23 | 25 | .. note::
|
24 | 26 |
|
25 |
| - See :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py` for a |
26 |
| - complement of the present example explaining the averaging strategies to |
27 |
| - generalize the metrics for multiclass classifiers. |
| 27 | + See also :func:`sklearn.metrics.roc_auc_score`, |
| 28 | + :func:`sklearn.model_selection.cross_val_score`, |
| 29 | + :ref:`sphx_glr_auto_examples_model_selection_plot_roc.py`, |
| 30 | +
|
28 | 31 | """
|
29 | 32 |
|
30 | 33 | # %%
|
31 |
| -# Load and prepare data |
32 |
| -# ===================== |
33 |
| -# |
34 |
| -# We import the :ref:`iris_dataset` which contains 3 classes, each one |
35 |
| -# corresponding to a type of iris plant. One class is linearly separable from |
36 |
| -# the other 2; the latter are **not** linearly separable from each other. |
37 |
| -# |
38 |
| -# In the following we binarize the dataset by dropping the "virginica" class |
39 |
| -# (`class_id=2`). This means that the "versicolor" class (`class_id=1`) is |
40 |
| -# regarded as the positive class and "setosa" as the negative class |
41 |
| -# (`class_id=0`). |
42 |
| - |
| 34 | +# Data IO and generation |
| 35 | +# ---------------------- |
43 | 36 | import numpy as np
|
44 |
| -from sklearn.datasets import load_iris |
45 | 37 |
|
46 |
| -iris = load_iris() |
47 |
| -target_names = iris.target_names |
48 |
| -X, y = iris.data, iris.target |
| 38 | +from sklearn import datasets |
| 39 | + |
| 40 | +# Import some data to play with |
| 41 | +iris = datasets.load_iris() |
| 42 | +X = iris.data |
| 43 | +y = iris.target |
49 | 44 | X, y = X[y != 2], y[y != 2]
|
50 | 45 | n_samples, n_features = X.shape
|
51 | 46 |
|
52 |
| -# %% |
53 |
| -# We also add noisy features to make the problem harder. |
| 47 | +# Add noisy features |
54 | 48 | random_state = np.random.RandomState(0)
|
55 |
| -X = np.concatenate([X, random_state.randn(n_samples, 200 * n_features)], axis=1) |
| 49 | +X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] |
56 | 50 |
|
57 | 51 | # %%
|
58 | 52 | # Classification and ROC analysis
|
59 | 53 | # -------------------------------
|
60 |
| -# |
61 |
| -# Here we run a :class:`~sklearn.svm.SVC` classifier with cross-validation and |
62 |
| -# plot the ROC curves fold-wise. Notice that the baseline to define the chance |
63 |
| -# level (dashed ROC curve) is a classifier that would always predict the most |
64 |
| -# frequent class. |
65 |
| - |
66 | 54 | import matplotlib.pyplot as plt
|
67 | 55 |
|
68 | 56 | from sklearn import svm
|
69 | 57 | from sklearn.metrics import auc
|
70 | 58 | from sklearn.metrics import RocCurveDisplay
|
71 | 59 | from sklearn.model_selection import StratifiedKFold
|
72 | 60 |
|
| 61 | +# Run classifier with cross-validation and plot ROC curves |
73 | 62 | cv = StratifiedKFold(n_splits=6)
|
74 | 63 | classifier = svm.SVC(kernel="linear", probability=True, random_state=random_state)
|
75 | 64 |
|
76 | 65 | tprs = []
|
77 | 66 | aucs = []
|
78 | 67 | mean_fpr = np.linspace(0, 1, 100)
|
79 | 68 |
|
80 |
| -fig, ax = plt.subplots(figsize=(6, 6)) |
81 |
| -for fold, (train, test) in enumerate(cv.split(X, y)): |
| 69 | +fig, ax = plt.subplots() |
| 70 | +for i, (train, test) in enumerate(cv.split(X, y)): |
82 | 71 | classifier.fit(X[train], y[train])
|
83 | 72 | viz = RocCurveDisplay.from_estimator(
|
84 | 73 | classifier,
|
85 | 74 | X[test],
|
86 | 75 | y[test],
|
87 |
| - name=f"ROC fold {fold}", |
| 76 | + name="ROC fold {}".format(i), |
88 | 77 | alpha=0.3,
|
89 | 78 | lw=1,
|
90 | 79 | ax=ax,
|
|
93 | 82 | interp_tpr[0] = 0.0
|
94 | 83 | tprs.append(interp_tpr)
|
95 | 84 | aucs.append(viz.roc_auc)
|
96 |
| -ax.plot([0, 1], [0, 1], "k--", label="chance level (AUC = 0.5)") |
| 85 | + |
| 86 | +ax.plot([0, 1], [0, 1], linestyle="--", lw=2, color="r", label="Chance", alpha=0.8) |
97 | 87 |
|
98 | 88 | mean_tpr = np.mean(tprs, axis=0)
|
99 | 89 | mean_tpr[-1] = 1.0
|
|
123 | 113 | ax.set(
|
124 | 114 | xlim=[-0.05, 1.05],
|
125 | 115 | ylim=[-0.05, 1.05],
|
126 |
| - xlabel="False Positive Rate", |
127 |
| - ylabel="True Positive Rate", |
128 |
| - title=f"Mean ROC curve with variability\n(Positive label '{target_names[1]}')", |
| 116 | + title="Receiver operating characteristic example", |
129 | 117 | )
|
130 |
| -ax.axis("square") |
131 | 118 | ax.legend(loc="lower right")
|
132 | 119 | plt.show()
|
0 commit comments