Skip to content

Commit 0e8b5d6

Browse files
xhluxhlulu
authored andcommitted
ML Docs: Update knn and regression based on Emma's reviews
1 parent 7ab73cb commit 0e8b5d6

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

doc/python/ml-knn.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ fig = px.scatter(
105105
X_test, x=0, y=1,
106106
color=y_score, color_continuous_scale='RdBu',
107107
symbol=y_test, symbol_map={'0': 'square-dot', '1': 'circle-dot'},
108-
labels={'symbol': 'Label', 'color': 'Score'}
108+
labels={'symbol': 'label', 'color': 'score of <br>first class'}
109109
)
110110
fig.update_traces(marker_size=12, marker_line_width=1.5)
111111
fig.update_layout(legend_orientation='h')

doc/python/ml-regression.md

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ jupyter:
3939
### Ordinary Least Square (OLS) with `plotly.express`
4040

4141

42-
This example shows how to use `plotly.express`'s `trendline` parameter to train a simply Ordinary Least Square (OLS) for predicting the tips servers will receive based on the value of the total bill.
42+
This example shows how to use `plotly.express`'s `trendline` parameter to train a simply Ordinary Least Square (OLS) for predicting the tips waiters will receive based on the value of the total bill.
4343

4444
```python
4545
import plotly.express as px
@@ -88,7 +88,7 @@ from sklearn.linear_model import LinearRegression
8888
from sklearn.model_selection import train_test_split
8989

9090
df = px.data.tips()
91-
X = df.total_bill.values.reshape(-1, 1)
91+
X = df.total_bill[:, None]
9292
X_train, X_test, y_train, y_test = train_test_split(X, df.tip, random_state=0)
9393

9494
model = LinearRegression()
@@ -162,8 +162,8 @@ X = df.total_bill.values.reshape(-1, 1)
162162
x_range = np.linspace(X.min(), X.max(), 100).reshape(-1, 1)
163163

164164
fig = px.scatter(df, x='total_bill', y='tip', opacity=0.65)
165-
for n_features in [1, 2, 3, 4]:
166-
poly = PolynomialFeatures(n_features)
165+
for degree in [1, 2, 3, 4]:
166+
poly = PolynomialFeatures(degree)
167167
poly.fit(X)
168168
X_poly = poly.transform(X)
169169
x_range_poly = poly.transform(x_range)
@@ -180,13 +180,13 @@ fig.show()
180180

181181
## 3D regression surface with `px.scatter_3d` and `go.Surface`
182182

183-
Visualize the decision plane of your model whenever you have more than one variable in your input data.
183+
Visualize the decision plane of your model whenever you have more than one variable in your input data. Here, we will use [`sklearn.svm.SVR`](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVR.html), which is a Support Vector Machine (SVM) model specifically designed for regression.
184184

185185
```python
186186
import numpy as np
187187
import plotly.express as px
188188
import plotly.graph_objects as go
189-
from sklearn.neighbors import KNeighborsRegressor
189+
from sklearn.svm import SVR
190190

191191
mesh_size = .02
192192
margin = 0
@@ -197,8 +197,8 @@ X = df[['sepal_width', 'sepal_length']]
197197
y = df['petal_width']
198198

199199
# Condition the model on sepal width and length, predict the petal width
200-
knn = KNeighborsRegressor(10, weights='distance')
201-
knn.fit(X, y)
200+
model = SVR(C=1.)
201+
model.fit(X, y)
202202

203203
# Create a mesh grid on which we will run our model
204204
x_min, x_max = X.sepal_width.min() - margin, X.sepal_width.max() + margin
@@ -207,8 +207,8 @@ xrange = np.arange(x_min, x_max, mesh_size)
207207
yrange = np.arange(y_min, y_max, mesh_size)
208208
xx, yy = np.meshgrid(xrange, yrange)
209209

210-
# Run kNN
211-
pred = knn.predict(np.c_[xx.ravel(), yy.ravel()])
210+
# Run model
211+
pred = model.predict(np.c_[xx.ravel(), yy.ravel()])
212212
pred = pred.reshape(xx.shape)
213213

214214
# Generate the plot
@@ -271,7 +271,7 @@ model = LinearRegression()
271271
model.fit(X, y)
272272
y_pred = model.predict(X)
273273

274-
fig = px.scatter(x=y_pred, y=y, labels={'x': 'prediction', 'y': 'actual'})
274+
fig = px.scatter(x=y, y=y_pred, labels={'x': 'ground truth', 'y': 'prediction'})
275275
fig.add_shape(
276276
type="line", line=dict(dash='dash'),
277277
x0=y.min(), y0=y.min(),
@@ -308,10 +308,11 @@ model.fit(X_train, y_train)
308308
df['prediction'] = model.predict(X)
309309

310310
fig = px.scatter(
311-
df, x='prediction', y='petal_width',
311+
df, x='petal_width', y='prediction',
312312
marginal_x='histogram', marginal_y='histogram',
313313
color='split', trendline='ols'
314314
)
315+
fig.update_traces(histnorm='probability', selector={'type':'histogram'})
315316
fig.add_shape(
316317
type="line", line=dict(dash='dash'),
317318
x0=y.min(), y0=y.min(),

0 commit comments

Comments
 (0)