|
| 1 | +""" |
| 2 | +Visualizing cross-validation behavior in scikit-learn |
| 3 | +===================================================== |
| 4 | +
|
| 5 | +Choosing the right cross-validation object is a crucial part of fitting a |
| 6 | +model properly. There are many ways to split data into training and test |
| 7 | +sets in order to avoid model overfitting, to standardize the number of |
| 8 | +groups in test sets, etc. |
| 9 | +
|
| 10 | +This example visualizes the behavior of several common scikit-learn objects |
| 11 | +for comparison. |
| 12 | +""" |
| 13 | + |
| 14 | +from sklearn.model_selection import (TimeSeriesSplit, KFold, ShuffleSplit, |
| 15 | + StratifiedKFold, GroupShuffleSplit, |
| 16 | + GroupKFold, StratifiedShuffleSplit) |
| 17 | +import numpy as np |
| 18 | +import matplotlib.pyplot as plt |
| 19 | +from matplotlib.patches import Patch |
| 20 | +np.random.seed(1338) |
| 21 | +cmap_data = plt.cm.Paired |
| 22 | +cmap_cv = plt.cm.coolwarm |
| 23 | +n_splits = 4 |
| 24 | + |
| 25 | +############################################################################### |
| 26 | +# Visualize our data |
| 27 | +# ------------------ |
| 28 | +# |
| 29 | +# First, we must understand the structure of our data. It has 100 randomly |
| 30 | +# generated input datapoints, 3 classes split unevenly across datapoints, |
| 31 | +# and 10 "groups" split evenly across datapoints. |
| 32 | +# |
| 33 | +# As we'll see, some cross-validation objects do specific things with |
| 34 | +# labeled data, others behave differently with grouped data, and others |
| 35 | +# do not use this information. |
| 36 | +# |
| 37 | +# To begin, we'll visualize our data. |
| 38 | + |
| 39 | +# Generate the class/group data |
| 40 | +n_points = 100 |
| 41 | +X = np.random.randn(100, 10) |
| 42 | + |
| 43 | +percentiles_classes = [.1, .3, .6] |
| 44 | +y = np.hstack([[ii] * int(100 * perc) |
| 45 | + for ii, perc in enumerate(percentiles_classes)]) |
| 46 | + |
| 47 | +# Evenly spaced groups repeated once |
| 48 | +groups = np.hstack([[ii] * 10 for ii in range(10)]) |
| 49 | + |
| 50 | + |
| 51 | +def visualize_groups(classes, groups, name): |
| 52 | + # Visualize dataset groups |
| 53 | + fig, ax = plt.subplots() |
| 54 | + ax.scatter(range(len(groups)), [.5] * len(groups), c=groups, marker='_', |
| 55 | + lw=50, cmap=cmap_data) |
| 56 | + ax.scatter(range(len(groups)), [3.5] * len(groups), c=classes, marker='_', |
| 57 | + lw=50, cmap=cmap_data) |
| 58 | + ax.set(ylim=[-1, 5], yticks=[.5, 3.5], |
| 59 | + yticklabels=['Data\ngroup', 'Data\nclass'], xlabel="Sample index") |
| 60 | + |
| 61 | + |
| 62 | +visualize_groups(y, groups, 'no groups') |
| 63 | + |
| 64 | +############################################################################### |
| 65 | +# Define a function to visualize cross-validation behavior |
| 66 | +# -------------------------------------------------------- |
| 67 | +# |
| 68 | +# We'll define a function that lets us visualize the behavior of each |
| 69 | +# cross-validation object. We'll perform 4 splits of the data. On each |
| 70 | +# split, we'll visualize the indices chosen for the training set |
| 71 | +# (in blue) and the test set (in red). |
| 72 | + |
| 73 | + |
| 74 | +def plot_cv_indices(cv, X, y, group, ax, n_splits, lw=10): |
| 75 | + """Create a sample plot for indices of a cross-validation object.""" |
| 76 | + |
| 77 | + # Generate the training/testing visualizations for each CV split |
| 78 | + for ii, (tr, tt) in enumerate(cv.split(X=X, y=y, groups=group)): |
| 79 | + # Fill in indices with the training/test groups |
| 80 | + indices = np.array([np.nan] * len(X)) |
| 81 | + indices[tt] = 1 |
| 82 | + indices[tr] = 0 |
| 83 | + |
| 84 | + # Visualize the results |
| 85 | + ax.scatter(range(len(indices)), [ii + .5] * len(indices), |
| 86 | + c=indices, marker='_', lw=lw, cmap=cmap_cv, |
| 87 | + vmin=-.2, vmax=1.2) |
| 88 | + |
| 89 | + # Plot the data classes and groups at the end |
| 90 | + ax.scatter(range(len(X)), [ii + 1.5] * len(X), |
| 91 | + c=y, marker='_', lw=lw, cmap=cmap_data) |
| 92 | + |
| 93 | + ax.scatter(range(len(X)), [ii + 2.5] * len(X), |
| 94 | + c=group, marker='_', lw=lw, cmap=cmap_data) |
| 95 | + |
| 96 | + # Formatting |
| 97 | + yticklabels = list(range(n_splits)) + ['class', 'group'] |
| 98 | + ax.set(yticks=np.arange(n_splits+2) + .5, yticklabels=yticklabels, |
| 99 | + xlabel='Sample index', ylabel="CV iteration", |
| 100 | + ylim=[n_splits+2.2, -.2], xlim=[0, 100]) |
| 101 | + ax.set_title('{}'.format(type(cv).__name__), fontsize=15) |
| 102 | + return ax |
| 103 | + |
| 104 | + |
| 105 | +############################################################################### |
| 106 | +# Let's see how it looks for the `KFold` cross-validation object: |
| 107 | + |
| 108 | +fig, ax = plt.subplots() |
| 109 | +cv = KFold(n_splits) |
| 110 | +plot_cv_indices(cv, X, y, groups, ax, n_splits) |
| 111 | + |
| 112 | +############################################################################### |
| 113 | +# As you can see, by default the KFold cross-validation iterator does not |
| 114 | +# take either datapoint class or group into consideration. We can change this |
| 115 | +# by using the ``StratifiedKFold`` like so. |
| 116 | + |
| 117 | +fig, ax = plt.subplots() |
| 118 | +cv = StratifiedKFold(n_splits) |
| 119 | +plot_cv_indices(cv, X, y, groups, ax, n_splits) |
| 120 | + |
| 121 | +############################################################################### |
| 122 | +# In this case, the cross-validation retained the same ratio of classes across |
| 123 | +# each CV split. Next we'll visualize this behavior for a number of CV |
| 124 | +# iterators. |
| 125 | +# |
| 126 | +# Visualize cross-validation indices for many CV objects |
| 127 | +# ------------------------------------------------------ |
| 128 | +# |
| 129 | +# Let's visually compare the cross validation behavior for many |
| 130 | +# scikit-learn cross-validation objects. Below we will loop through several |
| 131 | +# common cross-validation objects, visualizing the behavior of each. |
| 132 | +# |
| 133 | +# Note how some use the group/class information while others do not. |
| 134 | + |
| 135 | +cvs = [KFold, GroupKFold, ShuffleSplit, StratifiedKFold, |
| 136 | + GroupShuffleSplit, StratifiedShuffleSplit, TimeSeriesSplit] |
| 137 | + |
| 138 | + |
| 139 | +for cv in cvs: |
| 140 | + this_cv = cv(n_splits=n_splits) |
| 141 | + fig, ax = plt.subplots(figsize=(6, 3)) |
| 142 | + plot_cv_indices(this_cv, X, y, groups, ax, n_splits) |
| 143 | + |
| 144 | + ax.legend([Patch(color=cmap_cv(.8)), Patch(color=cmap_cv(.02))], |
| 145 | + ['Testing set', 'Training set'], loc=(1.02, .8)) |
| 146 | + # Make the legend fit |
| 147 | + plt.tight_layout() |
| 148 | + fig.subplots_adjust(right=.7) |
| 149 | +plt.show() |
0 commit comments