1
1
---
2
2
jupyter :
3
3
jupytext :
4
+ formats : ipynb,md
4
5
notebook_metadata_filter : all
5
6
text_representation :
6
7
extension : .md
@@ -20,27 +21,64 @@ jupyter:
20
21
name : python
21
22
nbconvert_exporter : python
22
23
pygments_lexer : ipython3
23
- version : 3.6.10
24
+ version : 3.7.6
24
25
plotly :
25
26
description : How to visualize k-Nearest Neighbors (kNN) created using scikit-learn
26
27
in Python with Plotly.
27
28
display_as : basic
28
29
language : python
29
30
layout : base
30
- name : k -Nearest Neighbors
31
+ name : K -Nearest Neighbors (kNN) Classification
31
32
order : 1
32
33
page_type : example_index
33
34
permalink : python/knn/
34
35
redirect_from : python/machine-learning-tutorials/
35
36
thumbnail : thumbnail/line-and-scatter.jpg
36
37
---
37
38
38
- ## K-Nearest Neighbors (kNN) Classification
39
+ ## Basic Binary Classification with ` plotly.express `
39
40
40
- How to visualize K-Nearest Neighbors (kNN) classification using scikit-learn.
41
+ ``` python
42
+ import numpy as np
43
+ import plotly.express as px
44
+ import plotly.graph_objects as go
45
+ from sklearn.datasets import make_moons
46
+ from sklearn.neighbors import KNeighborsClassifier
47
+
48
+ X, y = make_moons(noise = 0.3 , random_state = 0 )
49
+ X_test, _ = make_moons(noise = 0.3 , random_state = 1 )
50
+
51
+ clf = KNeighborsClassifier(15 )
52
+ clf.fit(X, y.astype(str )) # Fit on training set
53
+ y_pred = clf.predict(X_test) # Predict on new data
54
+
55
+ fig = px.scatter(x = X_test[:, 0 ], y = X_test[:, 1 ], color = y_pred, labels = {' color' : ' predicted' })
56
+ fig.update_traces(marker_size = 10 )
57
+ fig.show()
58
+ ```
41
59
60
+ ## Visualize Binary Prediction Scores
61
+
62
+ ``` python
63
+ import numpy as np
64
+ import plotly.express as px
65
+ import plotly.graph_objects as go
66
+ from sklearn.datasets import make_classification
67
+ from sklearn.neighbors import KNeighborsClassifier
68
+
69
+ X, y = make_classification(n_features = 2 , n_redundant = 0 , random_state = 0 )
70
+ X_test, _ = make_classification(n_features = 2 , n_redundant = 0 , random_state = 1 )
71
+
72
+ clf = KNeighborsClassifier(15 )
73
+ clf.fit(X, y) # Fit on training set
74
+ y_score = clf.predict_proba(X_test)[:, 1 ] # Predict on new data
75
+
76
+ fig = px.scatter(x = X_test[:, 0 ], y = X_test[:, 1 ], color = y_score, labels = {' color' : ' score' })
77
+ fig.update_traces(marker_size = 10 )
78
+ fig.show()
79
+ ```
42
80
43
- ### Binary Probability Estimates with ` go.Contour `
81
+ ## Probability Estimates with ` go.Contour `
44
82
45
83
``` python
46
84
import numpy as np
@@ -68,20 +106,22 @@ Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
68
106
Z = Z.reshape(xx.shape)
69
107
70
108
fig = px.scatter(X, x = 0 , y = 1 , color = y.astype(str ), labels = {' 0' :' ' , ' 1' :' ' })
109
+ fig.update_traces(marker_size = 10 , marker_line_width = 1 )
71
110
fig.add_trace(
72
111
go.Contour(
73
112
x = xrange ,
74
113
y = yrange,
75
114
z = Z,
76
115
showscale = False ,
77
116
colorscale = [' Blue' , ' Red' ],
78
- opacity = 0.4
117
+ opacity = 0.4 ,
118
+ name = ' Confidence'
79
119
)
80
120
)
81
121
fig.show()
82
122
```
83
123
84
- ### Multi-class classification with ` px.data ` and ` go.Heatmap `
124
+ ## Multi-class prediction confidence with ` go.Heatmap `
85
125
86
126
``` python
87
127
import numpy as np
@@ -92,6 +132,7 @@ from sklearn.neighbors import KNeighborsClassifier
92
132
mesh_size = .02
93
133
margin = 1
94
134
135
+ # We will use the iris data, which is included in px
95
136
df = px.data.iris()
96
137
X = df[[' sepal_length' , ' sepal_width' ]]
97
138
y = df.species_id
@@ -134,29 +175,66 @@ fig.add_trace(
134
175
fig.show()
135
176
```
136
177
137
- ### Visualizing kNN Regression
178
+ ## 3D Classification with ` px.scatter_3d `
179
+
180
+ ``` python
181
+ import numpy as np
182
+ import plotly.express as px
183
+ import plotly.graph_objects as go
184
+ from sklearn.neighbors import KNeighborsClassifier
185
+ from sklearn.model_selection import train_test_split
186
+
187
+ df = px.data.iris()
188
+ features = [" sepal_width" , " sepal_length" , " petal_width" ]
189
+
190
+ X = df[features]
191
+ y = df.species
192
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
193
+
194
+ # Create classifier, run predictions on grid
195
+ clf = KNeighborsClassifier(15 , weights = ' distance' )
196
+ clf.fit(X_train, y_train)
197
+ y_pred = clf.predict(X_test)
198
+ y_score = clf.predict_proba(X_test)
199
+ y_score = np.around(y_score.max(axis = 1 ), 4 )
200
+
201
+ fig = px.scatter_3d(
202
+ X_test,
203
+ x = ' sepal_length' ,
204
+ y = ' sepal_width' ,
205
+ z = ' petal_width' ,
206
+ symbol = y_pred,
207
+ color = y_score,
208
+ labels = {' symbol' : ' prediction' , ' color' : ' score' }
209
+ )
210
+ fig.update_layout(legend = dict (x = 0 , y = 0 ))
211
+ fig.show()
212
+ ```
213
+
214
+ ## High Dimension Visualization with ` px.scatter_matrix `
215
+
216
+ If you need to visualize classifications that go beyond 3D, you can use the [ scatter plot matrix] ( https://plot.ly/python/splom/ ) .
138
217
139
218
``` python
140
219
import numpy as np
141
220
import plotly.express as px
142
221
import plotly.graph_objects as go
143
- from sklearn.neighbors import KNeighborsRegressor
222
+ from sklearn.neighbors import KNeighborsClassifier
223
+ from sklearn.model_selection import train_test_split
144
224
145
- df = px.data.tips ()
146
- X = df.total_bill.values.reshape( - 1 , 1 )
225
+ df = px.data.iris ()
226
+ features = [ " sepal_width " , " sepal_length " , " petal_width " , " petal_length " ]
147
227
148
- knn_dist = KNeighborsRegressor(10 , weights = ' distance' )
149
- knn_uni = KNeighborsRegressor(10 , weights = ' uniform' )
150
- knn_dist.fit(X, df.tip)
151
- knn_uni.fit(X, df.tip)
228
+ X = df[features]
229
+ y = df.species
230
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3 , random_state = 0 )
152
231
153
- x_range = np.linspace(X.min(), X.max(), 100 )
154
- y_dist = knn_dist.predict(x_range.reshape(- 1 , 1 ))
155
- y_uni = knn_uni.predict(x_range.reshape(- 1 , 1 ))
232
+ # Create classifier, run predictions on grid
233
+ clf = KNeighborsClassifier(15 , weights = ' distance' )
234
+ clf.fit(X_train, y_train)
235
+ y_pred = clf.predict(X_test)
156
236
157
- fig = px.scatter(df, x = ' total_bill' , y = ' tip' , color = ' sex' , opacity = 0.65 )
158
- fig.add_traces(go.Scatter(x = x_range, y = y_uni, name = ' Weights: Uniform' ))
159
- fig.add_traces(go.Scatter(x = x_range, y = y_dist, name = ' Weights: Distance' ))
237
+ fig = px.scatter_matrix(X_test, dimensions = features, color = y_pred, labels = {' color' : ' prediction' })
160
238
fig.show()
161
239
```
162
240
@@ -166,8 +244,10 @@ Learn more about `px`, `go.Contour`, and `go.Heatmap` here:
166
244
* https://plot.ly/python/plotly-express/
167
245
* https://plot.ly/python/heatmaps/
168
246
* https://plot.ly/python/contour-plots/
247
+ * https://plot.ly/python/3d-scatter-plots/
248
+ * https://plot.ly/python/splom/
169
249
170
250
This tutorial was inspired by amazing examples from the official scikit-learn docs:
171
- * https://scikit-learn.org/stable/auto_examples/neighbors/plot_regression.html
172
251
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
173
252
* https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html
253
+ * https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html
0 commit comments