Skip to content

Commit c39227e

Browse files
xhluluxhlulu
authored andcommitted
ML Docs: Updated last ML regression section for clarity
1 parent 5c95a2a commit c39227e

File tree

1 file changed

+43
-27
lines changed

1 file changed

+43
-27
lines changed

doc/python/ml-regression.md

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -213,16 +213,16 @@ fig.show()
213213
## Prediction Error Plots
214214

215215

216-
### Simple Prediction Error
216+
### Simple actual vs predicted plot
217217

218218
```python
219219
import plotly.express as px
220220
import plotly.graph_objects as go
221221
from sklearn.linear_model import LinearRegression
222222

223223
df = px.data.iris()
224-
X = df.loc[train_idx, ['sepal_width', 'sepal_length']]
225-
y = df.loc[train_idx, 'petal_width']
224+
X = df[['sepal_width', 'sepal_length']]
225+
y = df['petal_width']
226226

227227
# Condition the model on sepal width and length, predict the petal width
228228
model = LinearRegression()
@@ -238,7 +238,7 @@ fig.add_shape(
238238
fig.show()
239239
```
240240

241-
### Augmented Prediction Error analysis using `plotly.express`
241+
### Augmented prediction error analysis using `plotly.express`
242242

243243
```python
244244
import plotly.express as px
@@ -276,7 +276,7 @@ fig.add_shape(
276276
fig.show()
277277
```
278278

279-
## Residual Plots
279+
## Residual plots
280280

281281
Just like prediction error plots, it's easy to visualize your prediction residuals in just a few lines of codes using `plotly.express` built-in capabilities.
282282

@@ -312,28 +312,34 @@ fig = px.scatter(
312312
fig.show()
313313
```
314314

315-
## Grid Search Visualization using `px` facets
315+
## Grid search visualization using `px.density_heatmap` and `px.box`
316+
317+
In this example, we show how to visualize the results of a grid search on a `DecisionTreeRegressor`. The first plot shows how to visualize the score of each model parameter on individual splits (grouped using facets). The second plot aggregates the results of all splits such that each box represents a single model.
316318

317319
```python
320+
import numpy as np
318321
import pandas as pd
319322
import plotly.express as px
320323
import plotly.graph_objects as go
321324
from sklearn.model_selection import GridSearchCV
322325
from sklearn.tree import DecisionTreeRegressor
323326

324-
N_FOLD = 5
327+
N_FOLD = 6
325328

329+
# Load and shuffle dataframe
326330
df = px.data.iris()
327-
X = df.loc[train_idx, ['sepal_width', 'sepal_length']]
328-
y = df.loc[train_idx, 'petal_width']
331+
df = df.sample(frac=1, random_state=0)
332+
333+
X = df[['sepal_width', 'sepal_length']]
334+
y = df['petal_width']
329335

336+
# Define and fit the grid
330337
model = DecisionTreeRegressor()
331338
param_grid = {
332339
'criterion': ['mse', 'friedman_mse', 'mae'],
333340
'max_depth': range(2, 5)
334341
}
335342
grid = GridSearchCV(model, param_grid, cv=N_FOLD)
336-
337343
grid.fit(X, y)
338344
grid_df = pd.DataFrame(grid.cv_results_)
339345

@@ -344,32 +350,42 @@ melted = (
344350
.rename(columns=lambda col: col.replace('param_', ''))
345351
.melt(
346352
value_vars=[f'split{i}_test_score' for i in range(N_FOLD)],
347-
id_vars=['rank_test_score', 'mean_test_score',
348-
'mean_fit_time', 'criterion', 'max_depth']
353+
id_vars=['mean_test_score', 'mean_fit_time', 'criterion', 'max_depth'],
354+
var_name="cv_split",
355+
value_name="r_squared"
349356
)
350357
)
351358

352-
# Convert R-Squared measure to %
353-
melted[['value', 'mean_test_score']] *= 100
354-
355359
# Format the variable names for simplicity
356-
melted['variable'] = (
357-
melted['variable']
360+
melted['cv_split'] = (
361+
melted['cv_split']
358362
.str.replace('_test_score', '')
359363
.str.replace('split', '')
360364
)
361365

362-
px.bar(
363-
melted, x='variable', y='value',
364-
color='mean_test_score',
365-
facet_row='max_depth',
366-
facet_col='criterion',
367-
title='Test Scores of Grid Search',
368-
hover_data=['mean_fit_time', 'rank_test_score'],
369-
labels={'variable': 'cv_split',
370-
'value': 'r_squared',
371-
'mean_test_score': "mean_r_squared"}
366+
# Single function call to plot each figure
367+
fig_hmap = px.density_heatmap(
368+
melted, x="max_depth", y='criterion',
369+
histfunc="sum", z="r_squared",
370+
title='Grid search results on individual fold',
371+
hover_data=['mean_fit_time'],
372+
facet_col="cv_split", facet_col_wrap=3,
373+
labels={'mean_test_score': "mean_r_squared"}
372374
)
375+
376+
fig_box = px.box(
377+
melted, x='max_depth', y='r_squared',
378+
title='Grid search results ',
379+
hover_data=['mean_fit_time'],
380+
points='all',
381+
color="criterion",
382+
hover_name='cv_split',
383+
labels={'mean_test_score': "mean_r_squared"}
384+
)
385+
386+
# Display
387+
fig_hmap.show()
388+
fig_box.show()
373389
```
374390

375391
### Reference

0 commit comments

Comments
 (0)