@@ -78,6 +78,8 @@ fig.show()
78
78
79
79
## Model generalization on unseen data
80
80
81
+ Easily color your plot based on a predefined data split.
82
+
81
83
``` python
82
84
import numpy as np
83
85
import plotly.express as px
@@ -106,6 +108,8 @@ fig.show()
106
108
107
109
## Comparing different kNN models parameters
108
110
111
+ Compare the performance of two different models on the same dataset. This can be easily combined with discrete color legends from ` px ` .
112
+
109
113
``` python
110
114
import numpy as np
111
115
import plotly.express as px
@@ -114,14 +118,16 @@ from sklearn.neighbors import KNeighborsRegressor
114
118
115
119
df = px.data.tips()
116
120
X = df.total_bill.values.reshape(- 1 , 1 )
121
+ x_range = np.linspace(X.min(), X.max(), 100 )
117
122
123
+ # Model #1
118
124
knn_dist = KNeighborsRegressor(10 , weights = ' distance' )
119
- knn_uni = KNeighborsRegressor(10 , weights = ' uniform' )
120
125
knn_dist.fit(X, df.tip)
121
- knn_uni.fit(X, df.tip)
122
-
123
- x_range = np.linspace(X.min(), X.max(), 100 )
124
126
y_dist = knn_dist.predict(x_range.reshape(- 1 , 1 ))
127
+
128
+ # Model #2
129
+ knn_uni = KNeighborsRegressor(10 , weights = ' uniform' )
130
+ knn_uni.fit(X, df.tip)
125
131
y_uni = knn_uni.predict(x_range.reshape(- 1 , 1 ))
126
132
127
133
fig = px.scatter(df, x = ' total_bill' , y = ' tip' , color = ' sex' , opacity = 0.65 )
@@ -132,6 +138,8 @@ fig.show()
132
138
133
139
## 3D regression surface with ` px.scatter_3d ` and ` go.Surface `
134
140
141
+ Visualize the decision plane of your model whenever you have more than one variable in your ` X ` .
142
+
135
143
``` python
136
144
import numpy as np
137
145
import plotly.express as px
@@ -229,7 +237,7 @@ model = LinearRegression()
229
237
model.fit(X, y)
230
238
y_pred = model.predict(X)
231
239
232
- fig = px.scatter(x = y , y = y_pred , labels = {' x' : ' y true ' , ' y' : ' y pred ' })
240
+ fig = px.scatter(x = y_pred , y = y , labels = {' x' : ' prediction ' , ' y' : ' actual ' })
233
241
fig.add_shape(
234
242
type = " line" , line = dict (dash = ' dash' ),
235
243
x0 = y.min(), y0 = y.min(),
@@ -238,7 +246,9 @@ fig.add_shape(
238
246
fig.show()
239
247
```
240
248
241
- ### Augmented prediction error analysis using ` plotly.express `
249
+ ### Enhanced prediction error analysis using ` plotly.express `
250
+
251
+ Add marginal histograms to quickly diagnoses any prediction bias your model might have. The built-in ` OLS ` functionality let you visualize how well your model generalizes by comparing it with the theoretical optimal fit (black dotted line).
242
252
243
253
``` python
244
254
import plotly.express as px
@@ -254,6 +264,7 @@ df['split'] = 'train'
254
264
df.loc[test_idx, ' split' ] = ' test'
255
265
256
266
X = df[[' sepal_width' , ' sepal_length' ]]
267
+ y = df[' petal_width' ]
257
268
X_train = df.loc[train_idx, [' sepal_width' , ' sepal_length' ]]
258
269
y_train = df.loc[train_idx, ' petal_width' ]
259
270
@@ -263,7 +274,7 @@ model.fit(X_train, y_train)
263
274
df[' prediction' ] = model.predict(X)
264
275
265
276
fig = px.scatter(
266
- df, x = ' petal_width ' , y = ' prediction ' ,
277
+ df, x = ' prediction ' , y = ' petal_width ' ,
267
278
marginal_x = ' histogram' , marginal_y = ' histogram' ,
268
279
color = ' split' , trendline = ' ols'
269
280
)
0 commit comments