Skip to content

Commit 5406b87

Browse files
committed
Pushing the docs for revision for branch: master, commit 88b96a762d6ac1cf759cc33bf60ecfa7d5f5e0c5
1 parent b422c0f commit 5406b87

File tree

858 files changed

+3002
-3016
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

858 files changed

+3002
-3016
lines changed

dev/_downloads/plot_gmm_classifier.py renamed to dev/_downloads/plot_gmm_covariances.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
"""
2-
==================
3-
GMM classification
4-
==================
2+
===============
3+
GMM covariances
4+
===============
55
6-
Demonstration of Gaussian mixture models for classification.
6+
Demonstration of several covariances types for Gaussian mixture models.
77
88
See :ref:`gmm` for more information on the estimator.
99
10-
Plots predicted labels on both training and held out test data using a
11-
variety of GMM classifiers on the iris dataset.
10+
Although GMM are often used for clustering, we can compare the obtained
11+
clusters with the actual classes from the dataset. We initialize the means
12+
of the Gaussians with the means of the classes from the training set to make
13+
this comparison valid.
1214
13-
Compares GMMs with spherical, diagonal, full, and tied covariance
14-
matrices in increasing order of performance. Although one would
15+
We plot predicted labels on both training and held out test data using a
16+
variety of GMM covariance types on the iris dataset.
17+
We compare GMMs with spherical, diagonal, full, and tied covariance
18+
matrices in increasing order of performance. Although one would
1519
expect full covariance to perform best in general, it is prone to
1620
overfitting on small datasets and does not generalize well to held out
1721
test data.
@@ -39,6 +43,8 @@
3943

4044

4145
colors = ['navy', 'turquoise', 'darkorange']
46+
47+
4248
def make_ellipses(gmm, ax):
4349
for n, color in enumerate(colors):
4450
v, w = np.linalg.eigh(gmm._get_covars()[n][:2, :2])
@@ -69,28 +75,29 @@ def make_ellipses(gmm, ax):
6975
n_classes = len(np.unique(y_train))
7076

7177
# Try GMMs using different types of covariances.
72-
classifiers = dict((covar_type, GMM(n_components=n_classes,
73-
covariance_type=covar_type, init_params='wc', n_iter=20))
74-
for covar_type in ['spherical', 'diag', 'tied', 'full'])
78+
estimators = dict((covar_type,
79+
GMM(n_components=n_classes, covariance_type=covar_type,
80+
init_params='wc', n_iter=20))
81+
for covar_type in ['spherical', 'diag', 'tied', 'full'])
7582

76-
n_classifiers = len(classifiers)
83+
n_estimators = len(estimators)
7784

78-
plt.figure(figsize=(3 * n_classifiers / 2, 6))
85+
plt.figure(figsize=(3 * n_estimators / 2, 6))
7986
plt.subplots_adjust(bottom=.01, top=0.95, hspace=.15, wspace=.05,
8087
left=.01, right=.99)
8188

8289

83-
for index, (name, classifier) in enumerate(classifiers.items()):
90+
for index, (name, estimator) in enumerate(estimators.items()):
8491
# Since we have class labels for the training data, we can
8592
# initialize the GMM parameters in a supervised manner.
86-
classifier.means_ = np.array([X_train[y_train == i].mean(axis=0)
93+
estimator.means_ = np.array([X_train[y_train == i].mean(axis=0)
8794
for i in xrange(n_classes)])
8895

8996
# Train the other parameters using the EM algorithm.
90-
classifier.fit(X_train)
97+
estimator.fit(X_train)
9198

92-
h = plt.subplot(2, n_classifiers / 2, index + 1)
93-
make_ellipses(classifier, h)
99+
h = plt.subplot(2, n_estimators / 2, index + 1)
100+
make_ellipses(estimator, h)
94101

95102
for n, color in enumerate(colors):
96103
data = iris.data[iris.target == n]
@@ -99,15 +106,14 @@ def make_ellipses(gmm, ax):
99106
# Plot the test data with crosses
100107
for n, color in enumerate(colors):
101108
data = X_test[y_test == n]
102-
print(color)
103109
plt.scatter(data[:, 0], data[:, 1], marker='x', color=color)
104110

105-
y_train_pred = classifier.predict(X_train)
111+
y_train_pred = estimator.predict(X_train)
106112
train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
107113
plt.text(0.05, 0.9, 'Train accuracy: %.1f' % train_accuracy,
108114
transform=h.transAxes)
109115

110-
y_test_pred = classifier.predict(X_test)
116+
y_test_pred = estimator.predict(X_test)
111117
test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
112118
plt.text(0.05, 0.8, 'Test accuracy: %.1f' % test_accuracy,
113119
transform=h.transAxes)
97 Bytes
97 Bytes
188 Bytes
188 Bytes
410 Bytes
410 Bytes
267 Bytes
267 Bytes
-63 Bytes

0 commit comments

Comments
 (0)