Skip to content

Commit 756ce7f

Browse files
xhluxhlulu
authored andcommitted
ML Docs: Add t-SNE/UMAP notebook (read todo)
TODO: Add thumbnail, references, description of sections
1 parent 2b3b9e3 commit 756ce7f

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

doc/python/tsne-umap-projections.md

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
---
2+
jupyter:
3+
jupytext:
4+
notebook_metadata_filter: all
5+
text_representation:
6+
extension: .md
7+
format_name: markdown
8+
format_version: '1.1'
9+
jupytext_version: 1.1.1
10+
kernelspec:
11+
display_name: Python 3
12+
language: python
13+
name: python3
14+
language_info:
15+
codemirror_mode:
16+
name: ipython
17+
version: 3
18+
file_extension: .py
19+
mimetype: text/x-python
20+
name: python
21+
nbconvert_exporter: python
22+
pygments_lexer: ipython3
23+
version: 3.7.6
24+
plotly:
25+
description: Visualize scikit-learn's k-Nearest Neighbors (kNN) classification
26+
in Python with Plotly.
27+
display_as: ai_ml
28+
language: python
29+
layout: base
30+
name: t-SNE and UMAP projections
31+
order: 1
32+
page_type: example_index
33+
permalink: python/t-sne-and-umap-projections/
34+
thumbnail: thumbnail/tsne-umap-projections.png
35+
---
36+
37+
## Basic t-SNE projections
38+
39+
40+
### Visualizing high-dimensional data with `px.scatter_matrix`
41+
42+
```python
43+
import plotly.express as px
44+
45+
df = px.data.iris()
46+
features = ["sepal_width", "sepal_length", "petal_width", "petal_length"]
47+
fig = px.scatter_matrix(df, dimensions=features, color="species")
48+
fig.show()
49+
```
50+
51+
### Project data into 2D with t-SNE and `px.scatter`
52+
53+
```python
54+
from sklearn.manifold import TSNE
55+
import plotly.express as px
56+
57+
df = px.data.iris()
58+
59+
features = df.loc[:, :'petal_width']
60+
61+
tsne = TSNE(n_components=2, random_state=0)
62+
projections = tsne.fit_transform(features)
63+
64+
fig = px.scatter(
65+
projections, x=0, y=1,
66+
color=df.species, labels={'color': 'species'}
67+
)
68+
fig.show()
69+
```
70+
71+
### Project data into 3D with t-SNE and `px.scatter_3d`
72+
73+
```python
74+
from sklearn.manifold import TSNE
75+
import plotly.express as px
76+
77+
df = px.data.iris()
78+
79+
features = df.loc[:, :'petal_width']
80+
81+
tsne = TSNE(n_components=3, random_state=0)
82+
projections = tsne.fit_transform(features, )
83+
84+
fig = px.scatter_3d(
85+
projections, x=0, y=1, z=2,
86+
color=df.species, labels={'color': 'species'}
87+
)
88+
fig.update_traces(marker_size=8)
89+
fig.show()
90+
```
91+
92+
## Projections with UMAP
93+
94+
Just like t-SNE, [UMAP](https://umap-learn.readthedocs.io/en/latest/index.html) is a dimensionality reduction specifically designed for visualizing complex data in low dimensions (2D or 3D). As the number of data points increase, [UMAP becomes more time efficient](https://umap-learn.readthedocs.io/en/latest/benchmarking.html) compared to TSNE.
95+
96+
In the example below, we see how easy it is to use UMAP as a drop-in replacement for scikit-learn's `manifold.TSNE`.
97+
98+
```python
99+
from umap import UMAP
100+
import plotly.express as px
101+
102+
df = px.data.iris()
103+
104+
features = df.loc[:, :'petal_width']
105+
106+
umap_2d = UMAP(n_components=2, init='random', random_state=0)
107+
umap_3d = UMAP(n_components=3, init='random', random_state=0)
108+
109+
proj_2d = umap_2d.fit_transform(features)
110+
proj_3d = umap_3d.fit_transform(features)
111+
112+
fig_2d = px.scatter(
113+
proj_2d, x=0, y=1,
114+
color=df.species, labels={'color': 'species'}
115+
)
116+
fig_3d = px.scatter_3d(
117+
proj_3d, x=0, y=1, z=2,
118+
color=df.species, labels={'color': 'species'}
119+
)
120+
fig_3d.update_traces(marker_size=5)
121+
122+
fig_2d.show()
123+
fig_3d.show()
124+
```
125+
126+
## Visualizing image datasets
127+
128+
In the following example, we show how to visualize large image datasets using UMAP. Here, we use [`load_digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html), a subset of the famous MNIST dataset that was downsized to 8x8 and flattened to 64 dimensions.
129+
130+
```python
131+
import plotly.express as px
132+
from sklearn.datasets import load_digits
133+
from umap import UMAP
134+
135+
digits = load_digits()
136+
137+
umap_2d = UMAP(random_state=0)
138+
umap_2d.fit(digits.data)
139+
140+
projections = umap_2d.transform(digits.data)
141+
142+
fig = px.scatter(
143+
projections, x=0, y=1,
144+
color=digits.target.astype(str), labels={'color': 'digit'}
145+
)
146+
fig.show()
147+
```
148+
149+
### Reference

0 commit comments

Comments
 (0)