"# Author: Ron Weiss <
[email protected]>, Gael Varoquaux\n# Modified by Thierry Guillemot <
[email protected]>\n# License: BSD 3 clause\n\nimport matplotlib as mpl\nimport matplotlib.pyplot as plt\n\nimport numpy as np\n\nfrom sklearn import datasets\nfrom sklearn.mixture import GaussianMixture\nfrom sklearn.model_selection import StratifiedKFold\n\nprint(__doc__)\n\ncolors = ['navy', 'turquoise', 'darkorange']\n\n\ndef make_ellipses(gmm, ax):\n for n, color in enumerate(colors):\n if gmm.covariance_type == 'full':\n covariances = gmm.covariances_[n][:2, :2]\n elif gmm.covariance_type == 'tied':\n covariances = gmm.covariances_[:2, :2]\n elif gmm.covariance_type == 'diag':\n covariances = np.diag(gmm.covariances_[n][:2])\n elif gmm.covariance_type == 'spherical':\n covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n]\n v, w = np.linalg.eigh(covariances)\n u = w[0] / np.linalg.norm(w[0])\n angle = np.arctan2(u[1], u[0])\n angle = 180 * angle / np.pi # convert to degrees\n v = 2. * np.sqrt(2.) * np.sqrt(v)\n ell = mpl.patches.Ellipse(gmm.means_[n, :2], v[0], v[1],\n 180 + angle, color=color)\n ell.set_clip_box(ax.bbox)\n ell.set_alpha(0.5)\n ax.add_artist(ell)\n ax.set_aspect('equal', 'datalim')\n\niris = datasets.load_iris()\n\n# Break up the dataset into non-overlapping training (75%) and testing\n# (25%) sets.\nskf = StratifiedKFold(n_splits=4)\n# Only take the first fold.\ntrain_index, test_index = next(iter(skf.split(iris.data, iris.target)))\n\n\nX_train = iris.data[train_index]\ny_train = iris.target[train_index]\nX_test = iris.data[test_index]\ny_test = iris.target[test_index]\n\nn_classes = len(np.unique(y_train))\n\n# Try GMMs using different types of covariances.\nestimators = dict((cov_type, GaussianMixture(n_components=n_classes,\n covariance_type=cov_type, max_iter=20, random_state=0))\n for cov_type in ['spherical', 'diag', 'tied', 'full'])\n\nn_estimators = len(estimators)\n\nplt.figure(figsize=(3 * n_estimators // 2, 6))\nplt.subplots_adjust(bottom=.01, top=0.95, hspace=.15, wspace=.05,\n left=.01, right=.99)\n\n\nfor index, (name, estimator) in enumerate(estimators.items()):\n # Since we have class labels for the training data, we can\n # initialize the GMM parameters in a supervised manner.\n estimator.means_init = np.array([X_train[y_train == i].mean(axis=0)\n for i in range(n_classes)])\n\n # Train the other parameters using the EM algorithm.\n estimator.fit(X_train)\n\n h = plt.subplot(2, n_estimators // 2, index + 1)\n make_ellipses(estimator, h)\n\n for n, color in enumerate(colors):\n data = iris.data[iris.target == n]\n plt.scatter(data[:, 0], data[:, 1], s=0.8, color=color,\n label=iris.target_names[n])\n # Plot the test data with crosses\n for n, color in enumerate(colors):\n data = X_test[y_test == n]\n plt.scatter(data[:, 0], data[:, 1], marker='x', color=color)\n\n y_train_pred = estimator.predict(X_train)\n train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100\n plt.text(0.05, 0.9, 'Train accuracy: %.1f' % train_accuracy,\n transform=h.transAxes)\n\n y_test_pred = estimator.predict(X_test)\n test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100\n plt.text(0.05, 0.8, 'Test accuracy: %.1f' % test_accuracy,\n transform=h.transAxes)\n\n plt.xticks(())\n plt.yticks(())\n plt.title(name)\n\nplt.legend(scatterpoints=1, loc='lower right', prop=dict(size=12))\n\n\nplt.show()"
0 commit comments