1
1
"""
2
- ==================
3
- GMM classification
4
- ==================
2
+ ===============
3
+ GMM covariances
4
+ ===============
5
5
6
- Demonstration of Gaussian mixture models for classification .
6
+ Demonstration of several covariances types for Gaussian mixture models .
7
7
8
8
See :ref:`gmm` for more information on the estimator.
9
9
10
- Plots predicted labels on both training and held out test data using a
11
- variety of GMM classifiers on the iris dataset.
10
+ Although GMM are often used for clustering, we can compare the obtained
11
+ clusters with the actual classes from the dataset. We initialize the means
12
+ of the Gaussians with the means of the classes from the training set to make
13
+ this comparison valid.
12
14
13
- Compares GMMs with spherical, diagonal, full, and tied covariance
14
- matrices in increasing order of performance. Although one would
15
+ We plot predicted labels on both training and held out test data using a
16
+ variety of GMM covariance types on the iris dataset.
17
+ We compare GMMs with spherical, diagonal, full, and tied covariance
18
+ matrices in increasing order of performance. Although one would
15
19
expect full covariance to perform best in general, it is prone to
16
20
overfitting on small datasets and does not generalize well to held out
17
21
test data.
39
43
40
44
41
45
colors = ['navy' , 'turquoise' , 'darkorange' ]
46
+
47
+
42
48
def make_ellipses (gmm , ax ):
43
49
for n , color in enumerate (colors ):
44
50
v , w = np .linalg .eigh (gmm ._get_covars ()[n ][:2 , :2 ])
@@ -69,28 +75,29 @@ def make_ellipses(gmm, ax):
69
75
n_classes = len (np .unique (y_train ))
70
76
71
77
# Try GMMs using different types of covariances.
72
- classifiers = dict ((covar_type , GMM (n_components = n_classes ,
73
- covariance_type = covar_type , init_params = 'wc' , n_iter = 20 ))
74
- for covar_type in ['spherical' , 'diag' , 'tied' , 'full' ])
78
+ estimators = dict ((covar_type ,
79
+ GMM (n_components = n_classes , covariance_type = covar_type ,
80
+ init_params = 'wc' , n_iter = 20 ))
81
+ for covar_type in ['spherical' , 'diag' , 'tied' , 'full' ])
75
82
76
- n_classifiers = len (classifiers )
83
+ n_estimators = len (estimators )
77
84
78
- plt .figure (figsize = (3 * n_classifiers / 2 , 6 ))
85
+ plt .figure (figsize = (3 * n_estimators / 2 , 6 ))
79
86
plt .subplots_adjust (bottom = .01 , top = 0.95 , hspace = .15 , wspace = .05 ,
80
87
left = .01 , right = .99 )
81
88
82
89
83
- for index , (name , classifier ) in enumerate (classifiers .items ()):
90
+ for index , (name , estimator ) in enumerate (estimators .items ()):
84
91
# Since we have class labels for the training data, we can
85
92
# initialize the GMM parameters in a supervised manner.
86
- classifier .means_ = np .array ([X_train [y_train == i ].mean (axis = 0 )
93
+ estimator .means_ = np .array ([X_train [y_train == i ].mean (axis = 0 )
87
94
for i in xrange (n_classes )])
88
95
89
96
# Train the other parameters using the EM algorithm.
90
- classifier .fit (X_train )
97
+ estimator .fit (X_train )
91
98
92
- h = plt .subplot (2 , n_classifiers / 2 , index + 1 )
93
- make_ellipses (classifier , h )
99
+ h = plt .subplot (2 , n_estimators / 2 , index + 1 )
100
+ make_ellipses (estimator , h )
94
101
95
102
for n , color in enumerate (colors ):
96
103
data = iris .data [iris .target == n ]
@@ -99,15 +106,14 @@ def make_ellipses(gmm, ax):
99
106
# Plot the test data with crosses
100
107
for n , color in enumerate (colors ):
101
108
data = X_test [y_test == n ]
102
- print (color )
103
109
plt .scatter (data [:, 0 ], data [:, 1 ], marker = 'x' , color = color )
104
110
105
- y_train_pred = classifier .predict (X_train )
111
+ y_train_pred = estimator .predict (X_train )
106
112
train_accuracy = np .mean (y_train_pred .ravel () == y_train .ravel ()) * 100
107
113
plt .text (0.05 , 0.9 , 'Train accuracy: %.1f' % train_accuracy ,
108
114
transform = h .transAxes )
109
115
110
- y_test_pred = classifier .predict (X_test )
116
+ y_test_pred = estimator .predict (X_test )
111
117
test_accuracy = np .mean (y_test_pred .ravel () == y_test .ravel ()) * 100
112
118
plt .text (0.05 , 0.8 , 'Test accuracy: %.1f' % test_accuracy ,
113
119
transform = h .transAxes )
0 commit comments