|
6 | 6 | Effect of transforming the targets in regression model
|
7 | 7 | ======================================================
|
8 | 8 |
|
9 |
| -In this example, we give an overview of the |
10 |
| -:class:`sklearn.compose.TransformedTargetRegressor`. Two examples |
11 |
| -illustrate the benefit of transforming the targets before learning a linear |
| 9 | +In this example, we give an overview of |
| 10 | +:class:`~sklearn.compose.TransformedTargetRegressor`. We use two examples |
| 11 | +to illustrate the benefit of transforming the targets before learning a linear |
12 | 12 | regression model. The first example uses synthetic data while the second
|
13 |
| -example is based on the Boston housing data set. |
14 |
| -
|
| 13 | +example is based on the Ames housing data set. |
15 | 14 | """
|
16 | 15 |
|
17 | 16 | # Author: Guillaume Lemaitre <[email protected]>
|
18 | 17 | # License: BSD 3 clause
|
19 | 18 |
|
20 |
| - |
21 | 19 | import numpy as np
|
22 | 20 | import matplotlib
|
23 | 21 | import matplotlib.pyplot as plt
|
24 | 22 | from distutils.version import LooseVersion
|
25 | 23 |
|
26 |
| -print(__doc__) |
27 |
| - |
28 |
| -############################################################################### |
29 |
| -# Synthetic example |
30 |
| -############################################################################### |
31 |
| - |
32 | 24 | from sklearn.datasets import make_regression
|
33 | 25 | from sklearn.model_selection import train_test_split
|
34 | 26 | from sklearn.linear_model import RidgeCV
|
35 | 27 | from sklearn.compose import TransformedTargetRegressor
|
36 | 28 | from sklearn.metrics import median_absolute_error, r2_score
|
37 | 29 |
|
| 30 | +############################################################################### |
| 31 | +# Synthetic example |
| 32 | +############################################################################## |
38 | 33 |
|
39 | 34 | # `normed` is being deprecated in favor of `density` in histograms
|
40 | 35 | if LooseVersion(matplotlib.__version__) >= '2.1':
|
|
43 | 38 | density_param = {'normed': True}
|
44 | 39 |
|
45 | 40 | ###############################################################################
|
46 |
| -# A synthetic random regression problem is generated. The targets ``y`` are |
47 |
| -# modified by: (i) translating all targets such that all entries are |
48 |
| -# non-negative and (ii) applying an exponential function to obtain non-linear |
49 |
| -# targets which cannot be fitted using a simple linear model. |
| 41 | +# A synthetic random regression dataset is generated. The targets ``y`` are |
| 42 | +# modified by: |
| 43 | +# |
| 44 | +# 1. translating all targets such that all entries are |
| 45 | +# non-negative (by adding the absolute value of the lowest ``y``) and |
| 46 | +# 2. applying an exponential function to obtain non-linear |
| 47 | +# targets which cannot be fitted using a simple linear model. |
50 | 48 | #
|
51 | 49 | # Therefore, a logarithmic (`np.log1p`) and an exponential function
|
52 | 50 | # (`np.expm1`) will be used to transform the targets before training a linear
|
53 | 51 | # regression model and using it for prediction.
|
54 | 52 |
|
55 | 53 | X, y = make_regression(n_samples=10000, noise=100, random_state=0)
|
56 |
| -y = np.exp((y + abs(y.min())) / 200) |
| 54 | +y = np.expm1((y + abs(y.min())) / 200) |
57 | 55 | y_trans = np.log1p(y)
|
58 | 56 |
|
59 | 57 | ###############################################################################
|
60 |
| -# The following illustrate the probability density functions of the target |
| 58 | +# Below we plot the probability density functions of the target |
61 | 59 | # before and after applying the logarithmic functions.
|
62 | 60 |
|
63 | 61 | f, (ax0, ax1) = plt.subplots(1, 2)
|
|
73 | 71 | ax1.set_xlabel('Target')
|
74 | 72 | ax1.set_title('Transformed target distribution')
|
75 | 73 |
|
76 |
| -f.suptitle("Synthetic data", y=0.035) |
| 74 | +f.suptitle("Synthetic data", y=0.06, x=0.53) |
77 | 75 | f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
|
78 | 76 |
|
79 | 77 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
80 | 78 |
|
81 | 79 | ###############################################################################
|
82 | 80 | # At first, a linear model will be applied on the original targets. Due to the
|
83 |
| -# non-linearity, the model trained will not be precise during the |
| 81 | +# non-linearity, the model trained will not be precise during |
84 | 82 | # prediction. Subsequently, a logarithmic function is used to linearize the
|
85 | 83 | # targets, allowing better prediction even with a similar linear model as
|
86 | 84 | # reported by the median absolute error (MAE).
|
87 | 85 |
|
88 | 86 | f, (ax0, ax1) = plt.subplots(1, 2, sharey=True)
|
89 |
| - |
| 87 | +# Use linear model |
90 | 88 | regr = RidgeCV()
|
91 | 89 | regr.fit(X_train, y_train)
|
92 | 90 | y_pred = regr.predict(X_test)
|
93 |
| - |
| 91 | +# Plot results |
94 | 92 | ax0.scatter(y_test, y_pred)
|
95 | 93 | ax0.plot([0, 2000], [0, 2000], '--k')
|
96 | 94 | ax0.set_ylabel('Target predicted')
|
|
100 | 98 | r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)))
|
101 | 99 | ax0.set_xlim([0, 2000])
|
102 | 100 | ax0.set_ylim([0, 2000])
|
103 |
| - |
| 101 | +# Transform targets and use same linear model |
104 | 102 | regr_trans = TransformedTargetRegressor(regressor=RidgeCV(),
|
105 | 103 | func=np.log1p,
|
106 | 104 | inverse_func=np.expm1)
|
|
125 | 123 | ###############################################################################
|
126 | 124 |
|
127 | 125 | ###############################################################################
|
128 |
| -# In a similar manner, the boston housing data set is used to show the impact |
| 126 | +# In a similar manner, the Ames housing data set is used to show the impact |
129 | 127 | # of transforming the targets before learning a model. In this example, the
|
130 |
| -# targets to be predicted corresponds to the weighted distances to the five |
131 |
| -# Boston employment centers. |
| 128 | +# target to be predicted is the selling price of each house. |
132 | 129 |
|
133 |
| -from sklearn.datasets import load_boston |
| 130 | +from sklearn.datasets import fetch_openml |
134 | 131 | from sklearn.preprocessing import QuantileTransformer, quantile_transform
|
135 | 132 |
|
136 |
| -dataset = load_boston() |
137 |
| -target = np.array(dataset.feature_names) == "DIS" |
138 |
| -X = dataset.data[:, np.logical_not(target)] |
139 |
| -y = dataset.data[:, target].squeeze() |
140 |
| -y_trans = quantile_transform(dataset.data[:, target], |
141 |
| - n_quantiles=300, |
| 133 | +ames = fetch_openml(name="house_prices", as_frame=True) |
| 134 | +# Keep only numeric columns |
| 135 | +X = ames.data.select_dtypes(np.number) |
| 136 | +# Remove columns with NaN or Inf values |
| 137 | +X = X.drop(columns=['LotFrontage', 'GarageYrBlt', 'MasVnrArea']) |
| 138 | +y = ames.target |
| 139 | +y_trans = quantile_transform(y.to_frame(), |
| 140 | + n_quantiles=900, |
142 | 141 | output_distribution='normal',
|
143 | 142 | copy=True).squeeze()
|
144 | 143 |
|
145 | 144 | ###############################################################################
|
146 |
| -# A :class:`sklearn.preprocessing.QuantileTransformer` is used such that the |
147 |
| -# targets follows a normal distribution before applying a |
148 |
| -# :class:`sklearn.linear_model.RidgeCV` model. |
| 145 | +# A :class:`~sklearn.preprocessing.QuantileTransformer` is used to normalize |
| 146 | +# the target distribution before applying a |
| 147 | +# :class:`~sklearn.linear_model.RidgeCV` model. |
149 | 148 |
|
150 | 149 | f, (ax0, ax1) = plt.subplots(1, 2)
|
151 | 150 |
|
152 | 151 | ax0.hist(y, bins=100, **density_param)
|
153 | 152 | ax0.set_ylabel('Probability')
|
154 | 153 | ax0.set_xlabel('Target')
|
155 |
| -ax0.set_title('Target distribution') |
| 154 | +ax0.text(s='Target distribution', x=1.2e5, y=9.8e-6, fontsize=12) |
| 155 | +ax0.ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) |
156 | 156 |
|
157 | 157 | ax1.hist(y_trans, bins=100, **density_param)
|
158 | 158 | ax1.set_ylabel('Probability')
|
159 | 159 | ax1.set_xlabel('Target')
|
160 |
| -ax1.set_title('Transformed target distribution') |
| 160 | +ax1.text(s='Transformed target distribution', x=-6.8, y=0.479, fontsize=12) |
161 | 161 |
|
162 |
| -f.suptitle("Boston housing data: distance to employment centers", y=0.035) |
| 162 | +f.suptitle("Ames housing data: selling price", y=0.04) |
163 | 163 | f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95])
|
164 | 164 |
|
165 | 165 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=1)
|
166 | 166 |
|
167 | 167 | ###############################################################################
|
168 | 168 | # The effect of the transformer is weaker than on the synthetic data. However,
|
169 |
| -# the transform induces a decrease of the MAE. |
| 169 | +# the transformation results in an increase in :math:`R^2` and large decrease |
| 170 | +# of the MAE. The residual plot (predicted target - true target vs predicted |
| 171 | +# target) without target transformation takes on a curved, 'reverse smile' |
| 172 | +# shape due to residual values that vary depending on the value of predicted |
| 173 | +# target. With target transformation, the shape is more linear indicating |
| 174 | +# better model fit. |
170 | 175 |
|
171 |
| -f, (ax0, ax1) = plt.subplots(1, 2, sharey=True) |
| 176 | +f, (ax0, ax1) = plt.subplots(2, 2, sharey='row', figsize=(6.5, 8)) |
172 | 177 |
|
173 | 178 | regr = RidgeCV()
|
174 | 179 | regr.fit(X_train, y_train)
|
175 | 180 | y_pred = regr.predict(X_test)
|
176 | 181 |
|
177 |
| -ax0.scatter(y_test, y_pred) |
178 |
| -ax0.plot([0, 10], [0, 10], '--k') |
179 |
| -ax0.set_ylabel('Target predicted') |
180 |
| -ax0.set_xlabel('True Target') |
181 |
| -ax0.set_title('Ridge regression \n without target transformation') |
182 |
| -ax0.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( |
| 182 | +ax0[0].scatter(y_pred, y_test, s=8) |
| 183 | +ax0[0].plot([0, 7e5], [0, 7e5], '--k') |
| 184 | +ax0[0].set_ylabel('True target') |
| 185 | +ax0[0].set_xlabel('Predicted target') |
| 186 | +ax0[0].text(s='Ridge regression \n without target transformation', x=-5e4, |
| 187 | + y=8e5, fontsize=12, multialignment='center') |
| 188 | +ax0[0].text(3e4, 64e4, r'$R^2$=%.2f, MAE=%.2f' % ( |
183 | 189 | r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)))
|
184 |
| -ax0.set_xlim([0, 10]) |
185 |
| -ax0.set_ylim([0, 10]) |
| 190 | +ax0[0].set_xlim([0, 7e5]) |
| 191 | +ax0[0].set_ylim([0, 7e5]) |
| 192 | +ax0[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) |
| 193 | + |
| 194 | +ax1[0].scatter(y_pred, (y_pred - y_test), s=8) |
| 195 | +ax1[0].set_ylabel('Residual') |
| 196 | +ax1[0].set_xlabel('Predicted target') |
| 197 | +ax1[0].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) |
186 | 198 |
|
187 | 199 | regr_trans = TransformedTargetRegressor(
|
188 | 200 | regressor=RidgeCV(),
|
189 |
| - transformer=QuantileTransformer(n_quantiles=300, |
| 201 | + transformer=QuantileTransformer(n_quantiles=900, |
190 | 202 | output_distribution='normal'))
|
191 | 203 | regr_trans.fit(X_train, y_train)
|
192 | 204 | y_pred = regr_trans.predict(X_test)
|
193 | 205 |
|
194 |
| -ax1.scatter(y_test, y_pred) |
195 |
| -ax1.plot([0, 10], [0, 10], '--k') |
196 |
| -ax1.set_ylabel('Target predicted') |
197 |
| -ax1.set_xlabel('True Target') |
198 |
| -ax1.set_title('Ridge regression \n with target transformation') |
199 |
| -ax1.text(1, 9, r'$R^2$=%.2f, MAE=%.2f' % ( |
| 206 | +ax0[1].scatter(y_pred, y_test, s=8) |
| 207 | +ax0[1].plot([0, 7e5], [0, 7e5], '--k') |
| 208 | +ax0[1].set_ylabel('True target') |
| 209 | +ax0[1].set_xlabel('Predicted target') |
| 210 | +ax0[1].text(s='Ridge regression \n with target transformation', x=-5e4, |
| 211 | + y=8e5, fontsize=12, multialignment='center') |
| 212 | +ax0[1].text(3e4, 64e4, r'$R^2$=%.2f, MAE=%.2f' % ( |
200 | 213 | r2_score(y_test, y_pred), median_absolute_error(y_test, y_pred)))
|
201 |
| -ax1.set_xlim([0, 10]) |
202 |
| -ax1.set_ylim([0, 10]) |
| 214 | +ax0[1].set_xlim([0, 7e5]) |
| 215 | +ax0[1].set_ylim([0, 7e5]) |
| 216 | +ax0[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) |
203 | 217 |
|
204 |
| -f.suptitle("Boston housing data: distance to employment centers", y=0.035) |
205 |
| -f.tight_layout(rect=[0.05, 0.05, 0.95, 0.95]) |
| 218 | +ax1[1].scatter(y_pred, (y_pred - y_test), s=8) |
| 219 | +ax1[1].set_ylabel('Residual') |
| 220 | +ax1[1].set_xlabel('Predicted target') |
| 221 | +ax1[1].ticklabel_format(axis="both", style="sci", scilimits=(0, 0)) |
| 222 | + |
| 223 | +f.suptitle("Ames housing data: selling price", y=0.035) |
206 | 224 |
|
207 | 225 | plt.show()
|
0 commit comments