"# Author: Jaques Grobler <
[email protected]>\n# License: BSD 3 clause\n\nfrom time import time\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.ticker import NullFormatter\nfrom sklearn import manifold\nfrom sklearn.utils import check_random_state\n\n# Unused but required import for doing 3d projections with matplotlib < 3.2\nimport mpl_toolkits.mplot3d # noqa: F401\nimport warnings\n\n# Variables for manifold learning.\nn_neighbors = 10\nn_samples = 1000\n\n# Create our sphere.\nrandom_state = check_random_state(0)\np = random_state.rand(n_samples) * (2 * np.pi - 0.55)\nt = random_state.rand(n_samples) * np.pi\n\n# Sever the poles from the sphere.\nindices = (t < (np.pi - (np.pi / 8))) & (t > ((np.pi / 8)))\ncolors = p[indices]\nx, y, z = (\n np.sin(t[indices]) * np.cos(p[indices]),\n np.sin(t[indices]) * np.sin(p[indices]),\n np.cos(t[indices]),\n)\n\n# Plot our dataset.\nfig = plt.figure(figsize=(15, 8))\nplt.suptitle(\n \"Manifold Learning with %i points, %i neighbors\" % (1000, n_neighbors), fontsize=14\n)\n\nax = fig.add_subplot(251, projection=\"3d\")\nax.scatter(x, y, z, c=p[indices], cmap=plt.cm.rainbow)\nax.view_init(40, -10)\n\nsphere_data = np.array([x, y, z]).T\n\n# Perform Locally Linear Embedding Manifold learning\nmethods = [\"standard\", \"ltsa\", \"hessian\", \"modified\"]\nlabels = [\"LLE\", \"LTSA\", \"Hessian LLE\", \"Modified LLE\"]\n\nfor i, method in enumerate(methods):\n t0 = time()\n trans_data = (\n manifold.LocallyLinearEmbedding(\n n_neighbors=n_neighbors, n_components=2, method=method\n )\n .fit_transform(sphere_data)\n .T\n )\n t1 = time()\n print(\"%s: %.2g sec\" % (methods[i], t1 - t0))\n\n ax = fig.add_subplot(252 + i)\n plt.scatter(trans_data[0], trans_data[1], c=colors, cmap=plt.cm.rainbow)\n plt.title(\"%s (%.2g sec)\" % (labels[i], t1 - t0))\n ax.xaxis.set_major_formatter(NullFormatter())\n ax.yaxis.set_major_formatter(NullFormatter())\n plt.axis(\"tight\")\n\n# Perform Isomap Manifold learning.\nt0 = time()\ntrans_data = (\n manifold.Isomap(n_neighbors=n_neighbors, n_components=2)\n .fit_transform(sphere_data)\n .T\n)\nt1 = time()\nprint(\"%s: %.2g sec\" % (\"ISO\", t1 - t0))\n\nax = fig.add_subplot(257)\nplt.scatter(trans_data[0], trans_data[1], c=colors, cmap=plt.cm.rainbow)\nplt.title(\"%s (%.2g sec)\" % (\"Isomap\", t1 - t0))\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\nplt.axis(\"tight\")\n\n# Perform Multi-dimensional scaling.\nt0 = time()\nmds = manifold.MDS(2, max_iter=100, n_init=1)\ntrans_data = mds.fit_transform(sphere_data).T\nt1 = time()\nprint(\"MDS: %.2g sec\" % (t1 - t0))\n\nax = fig.add_subplot(258)\nplt.scatter(trans_data[0], trans_data[1], c=colors, cmap=plt.cm.rainbow)\nplt.title(\"MDS (%.2g sec)\" % (t1 - t0))\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\nplt.axis(\"tight\")\n\n# Perform Spectral Embedding.\nt0 = time()\nse = manifold.SpectralEmbedding(n_components=2, n_neighbors=n_neighbors)\ntrans_data = se.fit_transform(sphere_data).T\nt1 = time()\nprint(\"Spectral Embedding: %.2g sec\" % (t1 - t0))\n\nax = fig.add_subplot(259)\nplt.scatter(trans_data[0], trans_data[1], c=colors, cmap=plt.cm.rainbow)\nplt.title(\"Spectral Embedding (%.2g sec)\" % (t1 - t0))\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\nplt.axis(\"tight\")\n\n# Perform t-distributed stochastic neighbor embedding.\n# TODO(1.2) Remove warning handling.\nwith warnings.catch_warnings():\n warnings.filterwarnings(\n \"ignore\", message=\"The PCA initialization\", category=FutureWarning\n )\n t0 = time()\n tsne = manifold.TSNE(\n n_components=2, init=\"pca\", random_state=0, learning_rate=\"auto\"\n )\n trans_data = tsne.fit_transform(sphere_data).T\n t1 = time()\nprint(\"t-SNE: %.2g sec\" % (t1 - t0))\n\nax = fig.add_subplot(2, 5, 10)\nplt.scatter(trans_data[0], trans_data[1], c=colors, cmap=plt.cm.rainbow)\nplt.title(\"t-SNE (%.2g sec)\" % (t1 - t0))\nax.xaxis.set_major_formatter(NullFormatter())\nax.yaxis.set_major_formatter(NullFormatter())\nplt.axis(\"tight\")\n\nplt.show()"
0 commit comments