|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "\n# Permutation Importance with Multicollinear or Correlated Features\n\nIn this example, we compute the permutation importance on the Wisconsin\nbreast cancer dataset using :func:`~sklearn.inspection.permutation_importance`.\nThe :class:`~sklearn.ensemble.RandomForestClassifier` can easily get about 97%\naccuracy on a test dataset. Because this dataset contains multicollinear\nfeatures, the permutation importance will show that none of the features are\nimportant. One approach to handling multicollinearity is by performing\nhierarchical clustering on the features' Spearman rank-order correlations,\npicking a threshold, and keeping a single feature from each cluster.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>See also\n `sphx_glr_auto_examples_inspection_plot_permutation_importance.py`</p></div>\n" |
| 7 | + "\n# Permutation Importance with Multicollinear or Correlated Features\n\nIn this example, we compute the\n:func:`~sklearn.inspection.permutation_importance` of the features to a trained\n:class:`~sklearn.ensemble.RandomForestClassifier` using the\n`breast_cancer_dataset`. The model can easily get about 97% accuracy on a\ntest dataset. Because this dataset contains multicollinear features, the\npermutation importance shows that none of the features are important, in\ncontradiction with the high test accuracy.\n\nWe demo a possible approach to handling multicollinearity, which consists of\nhierarchical clustering on the features' Spearman rank-order correlations,\npicking a threshold, and keeping a single feature from each cluster.\n\n<div class=\"alert alert-info\"><h4>Note</h4><p>See also\n `sphx_glr_auto_examples_inspection_plot_permutation_importance.py`</p></div>\n" |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | + "cell_type": "markdown", |
| 12 | + "metadata": {}, |
| 13 | + "source": [ |
| 14 | + "## Random Forest Feature Importance on Breast Cancer Data\n\nFirst, we define a function to ease the plotting:\n\n" |
| 15 | + ] |
| 16 | + }, |
| 17 | + { |
| 18 | + "cell_type": "code", |
| 19 | + "execution_count": null, |
| 20 | + "metadata": { |
| 21 | + "collapsed": false |
| 22 | + }, |
| 23 | + "outputs": [], |
| 24 | + "source": [ |
| 25 | + "from sklearn.inspection import permutation_importance\n\n\ndef plot_permutation_importance(clf, X, y, ax):\n result = permutation_importance(clf, X, y, n_repeats=10, random_state=42, n_jobs=2)\n perm_sorted_idx = result.importances_mean.argsort()\n\n ax.boxplot(\n result.importances[perm_sorted_idx].T,\n vert=False,\n labels=X.columns[perm_sorted_idx],\n )\n ax.axvline(x=0, color=\"k\", linestyle=\"--\")\n return ax" |
| 26 | + ] |
| 27 | + }, |
| 28 | + { |
| 29 | + "cell_type": "markdown", |
| 30 | + "metadata": {}, |
| 31 | + "source": [ |
| 32 | + "We then train a :class:`~sklearn.ensemble.RandomForestClassifier` on the\n`breast_cancer_dataset` and evaluate its accuracy on a test set:\n\n" |
| 33 | + ] |
| 34 | + }, |
| 35 | + { |
| 36 | + "cell_type": "code", |
| 37 | + "execution_count": null, |
| 38 | + "metadata": { |
| 39 | + "collapsed": false |
| 40 | + }, |
| 41 | + "outputs": [], |
| 42 | + "source": [ |
| 43 | + "from sklearn.datasets import load_breast_cancer\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.model_selection import train_test_split\n\nX, y = load_breast_cancer(return_X_y=True, as_frame=True)\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n\nclf = RandomForestClassifier(n_estimators=100, random_state=42)\nclf.fit(X_train, y_train)\nprint(f\"Baseline accuracy on test data: {clf.score(X_test, y_test):.2}\")" |
| 44 | + ] |
| 45 | + }, |
| 46 | + { |
| 47 | + "cell_type": "markdown", |
| 48 | + "metadata": {}, |
| 49 | + "source": [ |
| 50 | + "Next, we plot the tree based feature importance and the permutation\nimportance. The permutation importance is calculated on the training set to\nshow how much the model relies on each feature during training.\n\n" |
8 | 51 | ]
|
9 | 52 | },
|
10 | 53 | {
|
|
15 | 58 | },
|
16 | 59 | "outputs": [],
|
17 | 60 | "source": [
|
18 |
| - "from collections import defaultdict\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom scipy.cluster import hierarchy\nfrom scipy.spatial.distance import squareform\nfrom scipy.stats import spearmanr\n\nfrom sklearn.datasets import load_breast_cancer\nfrom sklearn.ensemble import RandomForestClassifier\nfrom sklearn.inspection import permutation_importance\nfrom sklearn.model_selection import train_test_split" |
| 61 | + "import matplotlib.pyplot as plt\nimport numpy as np\nimport pandas as pd\n\nmdi_importances = pd.Series(clf.feature_importances_, index=X_train.columns)\ntree_importance_sorted_idx = np.argsort(clf.feature_importances_)\ntree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\nmdi_importances.sort_values().plot.barh(ax=ax1)\nax1.set_xlabel(\"Gini importance\")\nplot_permutation_importance(clf, X_train, y_train, ax2)\nax2.set_xlabel(\"Decrease in accuracy score\")\nfig.suptitle(\n \"Impurity-based vs. permutation importances on multicollinear features (train set)\"\n)\n_ = fig.tight_layout()" |
19 | 62 | ]
|
20 | 63 | },
|
21 | 64 | {
|
22 | 65 | "cell_type": "markdown",
|
23 | 66 | "metadata": {},
|
24 | 67 | "source": [
|
25 |
| - "## Random Forest Feature Importance on Breast Cancer Data\nFirst, we train a random forest on the breast cancer dataset and evaluate\nits accuracy on a test set:\n\n" |
| 68 | + "The plot on the left shows the Gini importance of the model. As the\nscikit-learn implementation of\n:class:`~sklearn.ensemble.RandomForestClassifier` uses a random subsets of\n$\\sqrt{n_\\text{features}}$ features at each split, it is able to dilute\nthe dominance of any single correlated feature. As a result, the individual\nfeature importance may be distributed more evenly among the correlated\nfeatures. Since the features have large cardinality and the classifier is\nnon-overfitted, we can relatively trust those values.\n\nThe permutation importance on the right plot shows that permuting a feature\ndrops the accuracy by at most `0.012`, which would suggest that none of the\nfeatures are important. This is in contradiction with the high test accuracy\ncomputed as baseline: some feature must be important.\n\nSimilarly, the change in accuracy score computed on the test set appears to be\ndriven by chance:\n\n" |
26 | 69 | ]
|
27 | 70 | },
|
28 | 71 | {
|
|
33 | 76 | },
|
34 | 77 | "outputs": [],
|
35 | 78 | "source": [
|
36 |
| - "data = load_breast_cancer()\nX, y = data.data, data.target\nX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)\n\nclf = RandomForestClassifier(n_estimators=100, random_state=42)\nclf.fit(X_train, y_train)\nprint(\"Accuracy on test data: {:.2f}\".format(clf.score(X_test, y_test)))" |
| 79 | + "fig, ax = plt.subplots(figsize=(7, 6))\nplot_permutation_importance(clf, X_test, y_test, ax)\nax.set_title(\"Permutation Importances on multicollinear features\\n(test set)\")\nax.set_xlabel(\"Decrease in accuracy score\")\n_ = ax.figure.tight_layout()" |
37 | 80 | ]
|
38 | 81 | },
|
39 | 82 | {
|
40 | 83 | "cell_type": "markdown",
|
41 | 84 | "metadata": {},
|
42 | 85 | "source": [
|
43 |
| - "Next, we plot the tree based feature importance and the permutation\nimportance. The permutation importance plot shows that permuting a feature\ndrops the accuracy by at most `0.012`, which would suggest that none of the\nfeatures are important. This is in contradiction with the high test accuracy\ncomputed above: some feature must be important. The permutation importance\nis calculated on the training set to show how much the model relies on each\nfeature during training.\n\n" |
| 86 | + "Nevertheless, one can still compute a meaningful permutation importance in the\npresence of correlated features, as demonstrated in the following section.\n\n## Handling Multicollinear Features\nWhen features are collinear, permuting one feature has little effect on the\nmodels performance because it can get the same information from a correlated\nfeature. Note that this is not the case for all predictive models and depends\non their underlying implementation.\n\nOne way to handle multicollinear features is by performing hierarchical\nclustering on the Spearman rank-order correlations, picking a threshold, and\nkeeping a single feature from each cluster. First, we plot a heatmap of the\ncorrelated features:\n\n" |
44 | 87 | ]
|
45 | 88 | },
|
46 | 89 | {
|
|
51 | 94 | },
|
52 | 95 | "outputs": [],
|
53 | 96 | "source": [
|
54 |
| - "result = permutation_importance(clf, X_train, y_train, n_repeats=10, random_state=42)\nperm_sorted_idx = result.importances_mean.argsort()\n\ntree_importance_sorted_idx = np.argsort(clf.feature_importances_)\ntree_indices = np.arange(0, len(clf.feature_importances_)) + 0.5\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\nax1.barh(tree_indices, clf.feature_importances_[tree_importance_sorted_idx], height=0.7)\nax1.set_yticks(tree_indices)\nax1.set_yticklabels(data.feature_names[tree_importance_sorted_idx])\nax1.set_ylim((0, len(clf.feature_importances_)))\nax2.boxplot(\n result.importances[perm_sorted_idx].T,\n vert=False,\n labels=data.feature_names[perm_sorted_idx],\n)\nfig.tight_layout()\nplt.show()" |
| 97 | + "from scipy.cluster import hierarchy\nfrom scipy.spatial.distance import squareform\nfrom scipy.stats import spearmanr\n\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\ncorr = spearmanr(X).correlation\n\n# Ensure the correlation matrix is symmetric\ncorr = (corr + corr.T) / 2\nnp.fill_diagonal(corr, 1)\n\n# We convert the correlation matrix to a distance matrix before performing\n# hierarchical clustering using Ward's linkage.\ndistance_matrix = 1 - np.abs(corr)\ndist_linkage = hierarchy.ward(squareform(distance_matrix))\ndendro = hierarchy.dendrogram(\n dist_linkage, labels=X.columns.to_list(), ax=ax1, leaf_rotation=90\n)\ndendro_idx = np.arange(0, len(dendro[\"ivl\"]))\n\nax2.imshow(corr[dendro[\"leaves\"], :][:, dendro[\"leaves\"]])\nax2.set_xticks(dendro_idx)\nax2.set_yticks(dendro_idx)\nax2.set_xticklabels(dendro[\"ivl\"], rotation=\"vertical\")\nax2.set_yticklabels(dendro[\"ivl\"])\n_ = fig.tight_layout()" |
55 | 98 | ]
|
56 | 99 | },
|
57 | 100 | {
|
58 | 101 | "cell_type": "markdown",
|
59 | 102 | "metadata": {},
|
60 | 103 | "source": [
|
61 |
| - "## Handling Multicollinear Features\nWhen features are collinear, permutating one feature will have little\neffect on the models performance because it can get the same information\nfrom a correlated feature. One way to handle multicollinear features is by\nperforming hierarchical clustering on the Spearman rank-order correlations,\npicking a threshold, and keeping a single feature from each cluster. First,\nwe plot a heatmap of the correlated features:\n\n" |
| 104 | + "Next, we manually pick a threshold by visual inspection of the dendrogram to\ngroup our features into clusters and choose a feature from each cluster to\nkeep, select those features from our dataset, and train a new random forest.\nThe test accuracy of the new random forest did not change much compared to the\nrandom forest trained on the complete dataset.\n\n" |
62 | 105 | ]
|
63 | 106 | },
|
64 | 107 | {
|
|
69 | 112 | },
|
70 | 113 | "outputs": [],
|
71 | 114 | "source": [
|
72 |
| - "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))\ncorr = spearmanr(X).correlation\n\n# Ensure the correlation matrix is symmetric\ncorr = (corr + corr.T) / 2\nnp.fill_diagonal(corr, 1)\n\n# We convert the correlation matrix to a distance matrix before performing\n# hierarchical clustering using Ward's linkage.\ndistance_matrix = 1 - np.abs(corr)\ndist_linkage = hierarchy.ward(squareform(distance_matrix))\ndendro = hierarchy.dendrogram(\n dist_linkage, labels=data.feature_names.tolist(), ax=ax1, leaf_rotation=90\n)\ndendro_idx = np.arange(0, len(dendro[\"ivl\"]))\n\nax2.imshow(corr[dendro[\"leaves\"], :][:, dendro[\"leaves\"]])\nax2.set_xticks(dendro_idx)\nax2.set_yticks(dendro_idx)\nax2.set_xticklabels(dendro[\"ivl\"], rotation=\"vertical\")\nax2.set_yticklabels(dendro[\"ivl\"])\nfig.tight_layout()\nplt.show()" |
| 115 | + "from collections import defaultdict\n\ncluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion=\"distance\")\ncluster_id_to_feature_ids = defaultdict(list)\nfor idx, cluster_id in enumerate(cluster_ids):\n cluster_id_to_feature_ids[cluster_id].append(idx)\nselected_features = [v[0] for v in cluster_id_to_feature_ids.values()]\nselected_features_names = X.columns[selected_features]\n\nX_train_sel = X_train[selected_features_names]\nX_test_sel = X_test[selected_features_names]\n\nclf_sel = RandomForestClassifier(n_estimators=100, random_state=42)\nclf_sel.fit(X_train_sel, y_train)\nprint(\n \"Baseline accuracy on test data with features removed:\"\n f\" {clf_sel.score(X_test_sel, y_test):.2}\"\n)" |
73 | 116 | ]
|
74 | 117 | },
|
75 | 118 | {
|
76 | 119 | "cell_type": "markdown",
|
77 | 120 | "metadata": {},
|
78 | 121 | "source": [
|
79 |
| - "Next, we manually pick a threshold by visual inspection of the dendrogram\nto group our features into clusters and choose a feature from each cluster to\nkeep, select those features from our dataset, and train a new random forest.\nThe test accuracy of the new random forest did not change much compared to\nthe random forest trained on the complete dataset.\n\n" |
| 122 | + "We can finally explore the permutation importance of the selected subset of\nfeatures:\n\n" |
80 | 123 | ]
|
81 | 124 | },
|
82 | 125 | {
|
|
87 | 130 | },
|
88 | 131 | "outputs": [],
|
89 | 132 | "source": [
|
90 |
| - "cluster_ids = hierarchy.fcluster(dist_linkage, 1, criterion=\"distance\")\ncluster_id_to_feature_ids = defaultdict(list)\nfor idx, cluster_id in enumerate(cluster_ids):\n cluster_id_to_feature_ids[cluster_id].append(idx)\nselected_features = [v[0] for v in cluster_id_to_feature_ids.values()]\n\nX_train_sel = X_train[:, selected_features]\nX_test_sel = X_test[:, selected_features]\n\nclf_sel = RandomForestClassifier(n_estimators=100, random_state=42)\nclf_sel.fit(X_train_sel, y_train)\nprint(\n \"Accuracy on test data with features removed: {:.2f}\".format(\n clf_sel.score(X_test_sel, y_test)\n )\n)" |
| 133 | + "fig, ax = plt.subplots(figsize=(7, 6))\nplot_permutation_importance(clf_sel, X_test_sel, y_test, ax)\nax.set_title(\"Permutation Importances on selected subset of features\\n(test set)\")\nax.set_xlabel(\"Decrease in accuracy score\")\nax.figure.tight_layout()\nplt.show()" |
91 | 134 | ]
|
92 | 135 | }
|
93 | 136 | ],
|
|
0 commit comments