Skip to content

Commit 3e4b220

Browse files
xhluluxhlulu
authored andcommitted
Create 2 basic sections, 2 advanced sections
1 parent 38f92ca commit 3e4b220

File tree

1 file changed

+102
-22
lines changed

1 file changed

+102
-22
lines changed

doc/python/ml-knn.md

Lines changed: 102 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
---
22
jupyter:
33
jupytext:
4+
formats: ipynb,md
45
notebook_metadata_filter: all
56
text_representation:
67
extension: .md
@@ -20,27 +21,64 @@ jupyter:
2021
name: python
2122
nbconvert_exporter: python
2223
pygments_lexer: ipython3
23-
version: 3.6.10
24+
version: 3.7.6
2425
plotly:
2526
description: How to visualize k-Nearest Neighbors (kNN) created using scikit-learn
2627
in Python with Plotly.
2728
display_as: basic
2829
language: python
2930
layout: base
30-
name: k-Nearest Neighbors
31+
name: K-Nearest Neighbors (kNN) Classification
3132
order: 1
3233
page_type: example_index
3334
permalink: python/knn/
3435
redirect_from: python/machine-learning-tutorials/
3536
thumbnail: thumbnail/line-and-scatter.jpg
3637
---
3738

38-
## K-Nearest Neighbors (kNN) Classification
39+
## Basic Binary Classification with `plotly.express`
3940

40-
How to visualize K-Nearest Neighbors (kNN) classification using scikit-learn.
41+
```python
42+
import numpy as np
43+
import plotly.express as px
44+
import plotly.graph_objects as go
45+
from sklearn.datasets import make_moons
46+
from sklearn.neighbors import KNeighborsClassifier
47+
48+
X, y = make_moons(noise=0.3, random_state=0)
49+
X_test, _ = make_moons(noise=0.3, random_state=1)
50+
51+
clf = KNeighborsClassifier(15)
52+
clf.fit(X, y.astype(str)) # Fit on training set
53+
y_pred = clf.predict(X_test) # Predict on new data
54+
55+
fig = px.scatter(x=X_test[:, 0], y=X_test[:, 1], color=y_pred, labels={'color': 'predicted'})
56+
fig.update_traces(marker_size=10)
57+
fig.show()
58+
```
4159

60+
## Visualize Binary Prediction Scores
61+
62+
```python
63+
import numpy as np
64+
import plotly.express as px
65+
import plotly.graph_objects as go
66+
from sklearn.datasets import make_classification
67+
from sklearn.neighbors import KNeighborsClassifier
68+
69+
X, y = make_classification(n_features=2, n_redundant=0, random_state=0)
70+
X_test, _ = make_classification(n_features=2, n_redundant=0, random_state=1)
71+
72+
clf = KNeighborsClassifier(15)
73+
clf.fit(X, y) # Fit on training set
74+
y_score = clf.predict_proba(X_test)[:, 1] # Predict on new data
75+
76+
fig = px.scatter(x=X_test[:, 0], y=X_test[:, 1], color=y_score, labels={'color': 'score'})
77+
fig.update_traces(marker_size=10)
78+
fig.show()
79+
```
4280

43-
### Binary Probability Estimates with `go.Contour`
81+
## Probability Estimates with `go.Contour`
4482

4583
```python
4684
import numpy as np
@@ -68,20 +106,22 @@ Z = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:, 1]
68106
Z = Z.reshape(xx.shape)
69107

70108
fig = px.scatter(X, x=0, y=1, color=y.astype(str), labels={'0':'', '1':''})
109+
fig.update_traces(marker_size=10, marker_line_width=1)
71110
fig.add_trace(
72111
go.Contour(
73112
x=xrange,
74113
y=yrange,
75114
z=Z,
76115
showscale=False,
77116
colorscale=['Blue', 'Red'],
78-
opacity=0.4
117+
opacity=0.4,
118+
name='Confidence'
79119
)
80120
)
81121
fig.show()
82122
```
83123

84-
### Multi-class classification with `px.data` and `go.Heatmap`
124+
## Multi-class prediction confidence with `go.Heatmap`
85125

86126
```python
87127
import numpy as np
@@ -92,6 +132,7 @@ from sklearn.neighbors import KNeighborsClassifier
92132
mesh_size = .02
93133
margin = 1
94134

135+
# We will use the iris data, which is included in px
95136
df = px.data.iris()
96137
X = df[['sepal_length', 'sepal_width']]
97138
y = df.species_id
@@ -134,29 +175,66 @@ fig.add_trace(
134175
fig.show()
135176
```
136177

137-
### Visualizing kNN Regression
178+
## 3D Classification with `px.scatter_3d`
179+
180+
```python
181+
import numpy as np
182+
import plotly.express as px
183+
import plotly.graph_objects as go
184+
from sklearn.neighbors import KNeighborsClassifier
185+
from sklearn.model_selection import train_test_split
186+
187+
df = px.data.iris()
188+
features = ["sepal_width", "sepal_length", "petal_width"]
189+
190+
X = df[features]
191+
y = df.species
192+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
193+
194+
# Create classifier, run predictions on grid
195+
clf = KNeighborsClassifier(15, weights='distance')
196+
clf.fit(X_train, y_train)
197+
y_pred = clf.predict(X_test)
198+
y_score = clf.predict_proba(X_test)
199+
y_score = np.around(y_score.max(axis=1), 4)
200+
201+
fig = px.scatter_3d(
202+
X_test,
203+
x='sepal_length',
204+
y='sepal_width',
205+
z='petal_width',
206+
symbol=y_pred,
207+
color=y_score,
208+
labels={'symbol': 'prediction', 'color': 'score'}
209+
)
210+
fig.update_layout(legend=dict(x=0, y=0))
211+
fig.show()
212+
```
213+
214+
## High Dimension Visualization with `px.scatter_matrix`
215+
216+
If you need to visualize classifications that go beyond 3D, you can use the [scatter plot matrix](https://plot.ly/python/splom/).
138217

139218
```python
140219
import numpy as np
141220
import plotly.express as px
142221
import plotly.graph_objects as go
143-
from sklearn.neighbors import KNeighborsRegressor
222+
from sklearn.neighbors import KNeighborsClassifier
223+
from sklearn.model_selection import train_test_split
144224

145-
df = px.data.tips()
146-
X = df.total_bill.values.reshape(-1, 1)
225+
df = px.data.iris()
226+
features = ["sepal_width", "sepal_length", "petal_width", "petal_length"]
147227

148-
knn_dist = KNeighborsRegressor(10, weights='distance')
149-
knn_uni = KNeighborsRegressor(10, weights='uniform')
150-
knn_dist.fit(X, df.tip)
151-
knn_uni.fit(X, df.tip)
228+
X = df[features]
229+
y = df.species
230+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
152231

153-
x_range = np.linspace(X.min(), X.max(), 100)
154-
y_dist = knn_dist.predict(x_range.reshape(-1, 1))
155-
y_uni = knn_uni.predict(x_range.reshape(-1, 1))
232+
# Create classifier, run predictions on grid
233+
clf = KNeighborsClassifier(15, weights='distance')
234+
clf.fit(X_train, y_train)
235+
y_pred = clf.predict(X_test)
156236

157-
fig = px.scatter(df, x='total_bill', y='tip', color='sex', opacity=0.65)
158-
fig.add_traces(go.Scatter(x=x_range, y=y_uni, name='Weights: Uniform'))
159-
fig.add_traces(go.Scatter(x=x_range, y=y_dist, name='Weights: Distance'))
237+
fig = px.scatter_matrix(X_test, dimensions=features, color=y_pred, labels={'color': 'prediction'})
160238
fig.show()
161239
```
162240

@@ -166,8 +244,10 @@ Learn more about `px`, `go.Contour`, and `go.Heatmap` here:
166244
* https://plot.ly/python/plotly-express/
167245
* https://plot.ly/python/heatmaps/
168246
* https://plot.ly/python/contour-plots/
247+
* https://plot.ly/python/3d-scatter-plots/
248+
* https://plot.ly/python/splom/
169249

170250
This tutorial was inspired by amazing examples from the official scikit-learn docs:
171-
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_regression.html
172251
* https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html
173252
* https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html
253+
* https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html

0 commit comments

Comments
 (0)