Skip to content

Commit e92d340

Browse files
author
xhlulu
committed
ML Docs: Added annotations after each section of regression notebook
1 parent cf42003 commit e92d340

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

doc/python/ml-regression.md

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ fig.show()
7878

7979
## Model generalization on unseen data
8080

81+
Easily color your plot based on a predefined data split.
82+
8183
```python
8284
import numpy as np
8385
import plotly.express as px
@@ -106,6 +108,8 @@ fig.show()
106108

107109
## Comparing different kNN models parameters
108110

111+
Compare the performance of two different models on the same dataset. This can be easily combined with discrete color legends from `px`.
112+
109113
```python
110114
import numpy as np
111115
import plotly.express as px
@@ -114,14 +118,16 @@ from sklearn.neighbors import KNeighborsRegressor
114118

115119
df = px.data.tips()
116120
X = df.total_bill.values.reshape(-1, 1)
121+
x_range = np.linspace(X.min(), X.max(), 100)
117122

123+
# Model #1
118124
knn_dist = KNeighborsRegressor(10, weights='distance')
119-
knn_uni = KNeighborsRegressor(10, weights='uniform')
120125
knn_dist.fit(X, df.tip)
121-
knn_uni.fit(X, df.tip)
122-
123-
x_range = np.linspace(X.min(), X.max(), 100)
124126
y_dist = knn_dist.predict(x_range.reshape(-1, 1))
127+
128+
# Model #2
129+
knn_uni = KNeighborsRegressor(10, weights='uniform')
130+
knn_uni.fit(X, df.tip)
125131
y_uni = knn_uni.predict(x_range.reshape(-1, 1))
126132

127133
fig = px.scatter(df, x='total_bill', y='tip', color='sex', opacity=0.65)
@@ -132,6 +138,8 @@ fig.show()
132138

133139
## 3D regression surface with `px.scatter_3d` and `go.Surface`
134140

141+
Visualize the decision plane of your model whenever you have more than one variable in your `X`.
142+
135143
```python
136144
import numpy as np
137145
import plotly.express as px
@@ -229,7 +237,7 @@ model = LinearRegression()
229237
model.fit(X, y)
230238
y_pred = model.predict(X)
231239

232-
fig = px.scatter(x=y, y=y_pred, labels={'x': 'y true', 'y': 'y pred'})
240+
fig = px.scatter(x=y_pred, y=y, labels={'x': 'prediction', 'y': 'actual'})
233241
fig.add_shape(
234242
type="line", line=dict(dash='dash'),
235243
x0=y.min(), y0=y.min(),
@@ -238,7 +246,9 @@ fig.add_shape(
238246
fig.show()
239247
```
240248

241-
### Augmented prediction error analysis using `plotly.express`
249+
### Enhanced prediction error analysis using `plotly.express`
250+
251+
Add marginal histograms to quickly diagnoses any prediction bias your model might have. The built-in `OLS` functionality let you visualize how well your model generalizes by comparing it with the theoretical optimal fit (black dotted line).
242252

243253
```python
244254
import plotly.express as px
@@ -254,6 +264,7 @@ df['split'] = 'train'
254264
df.loc[test_idx, 'split'] = 'test'
255265

256266
X = df[['sepal_width', 'sepal_length']]
267+
y = df['petal_width']
257268
X_train = df.loc[train_idx, ['sepal_width', 'sepal_length']]
258269
y_train = df.loc[train_idx, 'petal_width']
259270

@@ -263,7 +274,7 @@ model.fit(X_train, y_train)
263274
df['prediction'] = model.predict(X)
264275

265276
fig = px.scatter(
266-
df, x='petal_width', y='prediction',
277+
df, x='prediction', y='petal_width',
267278
marginal_x='histogram', marginal_y='histogram',
268279
color='split', trendline='ols'
269280
)

0 commit comments

Comments
 (0)