+ "import time\nimport warnings\n\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nfrom sklearn import cluster, datasets, mixture\nfrom sklearn.neighbors import kneighbors_graph\nfrom sklearn.preprocessing import StandardScaler\nfrom itertools import cycle, islice\n\nnp.random.seed(0)\n\n# ============\n# Generate datasets. We choose the size big enough to see the scalability\n# of the algorithms, but not too big to avoid too long running times\n# ============\nn_samples = 500\nnoisy_circles = datasets.make_circles(n_samples=n_samples, factor=0.5, noise=0.05)\nnoisy_moons = datasets.make_moons(n_samples=n_samples, noise=0.05)\nblobs = datasets.make_blobs(n_samples=n_samples, random_state=8)\nno_structure = np.random.rand(n_samples, 2), None\n\n# Anisotropicly distributed data\nrandom_state = 170\nX, y = datasets.make_blobs(n_samples=n_samples, random_state=random_state)\ntransformation = [[0.6, -0.6], [-0.4, 0.8]]\nX_aniso = np.dot(X, transformation)\naniso = (X_aniso, y)\n\n# blobs with varied variances\nvaried = datasets.make_blobs(\n n_samples=n_samples, cluster_std=[1.0, 2.5, 0.5], random_state=random_state\n)\n\n# ============\n# Set up cluster parameters\n# ============\nplt.figure(figsize=(9 * 2 + 3, 13))\nplt.subplots_adjust(\n left=0.02, right=0.98, bottom=0.001, top=0.95, wspace=0.05, hspace=0.01\n)\n\nplot_num = 1\n\ndefault_base = {\n \"quantile\": 0.3,\n \"eps\": 0.3,\n \"damping\": 0.9,\n \"preference\": -200,\n \"n_neighbors\": 3,\n \"n_clusters\": 3,\n \"min_samples\": 7,\n \"xi\": 0.05,\n \"min_cluster_size\": 0.1,\n \"allow_single_cluster\": True,\n \"hdbscan_min_cluster_size\": 15,\n \"hdbscan_min_samples\": 3,\n}\n\ndatasets = [\n (\n noisy_circles,\n {\n \"damping\": 0.77,\n \"preference\": -240,\n \"quantile\": 0.2,\n \"n_clusters\": 2,\n \"min_samples\": 7,\n \"xi\": 0.08,\n },\n ),\n (\n noisy_moons,\n {\n \"damping\": 0.75,\n \"preference\": -220,\n \"n_clusters\": 2,\n \"min_samples\": 7,\n \"xi\": 0.1,\n },\n ),\n (\n varied,\n {\n \"eps\": 0.18,\n \"n_neighbors\": 2,\n \"min_samples\": 7,\n \"xi\": 0.01,\n \"min_cluster_size\": 0.2,\n },\n ),\n (\n aniso,\n {\n \"eps\": 0.15,\n \"n_neighbors\": 2,\n \"min_samples\": 7,\n \"xi\": 0.1,\n \"min_cluster_size\": 0.2,\n },\n ),\n (blobs, {\"min_samples\": 7, \"xi\": 0.1, \"min_cluster_size\": 0.2}),\n (no_structure, {}),\n]\n\nfor i_dataset, (dataset, algo_params) in enumerate(datasets):\n # update parameters with dataset-specific values\n params = default_base.copy()\n params.update(algo_params)\n\n X, y = dataset\n\n # normalize dataset for easier parameter selection\n X = StandardScaler().fit_transform(X)\n\n # estimate bandwidth for mean shift\n bandwidth = cluster.estimate_bandwidth(X, quantile=params[\"quantile\"])\n\n # connectivity matrix for structured Ward\n connectivity = kneighbors_graph(\n X, n_neighbors=params[\"n_neighbors\"], include_self=False\n )\n # make connectivity symmetric\n connectivity = 0.5 * (connectivity + connectivity.T)\n\n # ============\n # Create cluster objects\n # ============\n ms = cluster.MeanShift(bandwidth=bandwidth, bin_seeding=True)\n two_means = cluster.MiniBatchKMeans(n_clusters=params[\"n_clusters\"], n_init=\"auto\")\n ward = cluster.AgglomerativeClustering(\n n_clusters=params[\"n_clusters\"], linkage=\"ward\", connectivity=connectivity\n )\n spectral = cluster.SpectralClustering(\n n_clusters=params[\"n_clusters\"],\n eigen_solver=\"arpack\",\n affinity=\"nearest_neighbors\",\n )\n dbscan = cluster.DBSCAN(eps=params[\"eps\"])\n hdbscan = cluster.HDBSCAN(\n min_samples=params[\"hdbscan_min_samples\"],\n min_cluster_size=params[\"hdbscan_min_cluster_size\"],\n allow_single_cluster=params[\"allow_single_cluster\"],\n )\n optics = cluster.OPTICS(\n min_samples=params[\"min_samples\"],\n xi=params[\"xi\"],\n min_cluster_size=params[\"min_cluster_size\"],\n )\n affinity_propagation = cluster.AffinityPropagation(\n damping=params[\"damping\"], preference=params[\"preference\"], random_state=0\n )\n average_linkage = cluster.AgglomerativeClustering(\n linkage=\"average\",\n metric=\"cityblock\",\n n_clusters=params[\"n_clusters\"],\n connectivity=connectivity,\n )\n birch = cluster.Birch(n_clusters=params[\"n_clusters\"])\n gmm = mixture.GaussianMixture(\n n_components=params[\"n_clusters\"], covariance_type=\"full\"\n )\n\n clustering_algorithms = (\n (\"MiniBatch\\nKMeans\", two_means),\n (\"Affinity\\nPropagation\", affinity_propagation),\n (\"MeanShift\", ms),\n (\"Spectral\\nClustering\", spectral),\n (\"Ward\", ward),\n (\"Agglomerative\\nClustering\", average_linkage),\n (\"DBSCAN\", dbscan),\n (\"HDBSCAN\", hdbscan),\n (\"OPTICS\", optics),\n (\"BIRCH\", birch),\n (\"Gaussian\\nMixture\", gmm),\n )\n\n for name, algorithm in clustering_algorithms:\n t0 = time.time()\n\n # catch warnings related to kneighbors_graph\n with warnings.catch_warnings():\n warnings.filterwarnings(\n \"ignore\",\n message=\"the number of connected components of the \"\n + \"connectivity matrix is [0-9]{1,2}\"\n + \" > 1. Completing it to avoid stopping the tree early.\",\n category=UserWarning,\n )\n warnings.filterwarnings(\n \"ignore\",\n message=\"Graph is not fully connected, spectral embedding\"\n + \" may not work as expected.\",\n category=UserWarning,\n )\n algorithm.fit(X)\n\n t1 = time.time()\n if hasattr(algorithm, \"labels_\"):\n y_pred = algorithm.labels_.astype(int)\n else:\n y_pred = algorithm.predict(X)\n\n plt.subplot(len(datasets), len(clustering_algorithms), plot_num)\n if i_dataset == 0:\n plt.title(name, size=18)\n\n colors = np.array(\n list(\n islice(\n cycle(\n [\n \"#377eb8\",\n \"#ff7f00\",\n \"#4daf4a\",\n \"#f781bf\",\n \"#a65628\",\n \"#984ea3\",\n \"#999999\",\n \"#e41a1c\",\n \"#dede00\",\n ]\n ),\n int(max(y_pred) + 1),\n )\n )\n )\n # add black color for outliers (if any)\n colors = np.append(colors, [\"#000000\"])\n plt.scatter(X[:, 0], X[:, 1], s=10, color=colors[y_pred])\n\n plt.xlim(-2.5, 2.5)\n plt.ylim(-2.5, 2.5)\n plt.xticks(())\n plt.yticks(())\n plt.text(\n 0.99,\n 0.01,\n (\"%.2fs\" % (t1 - t0)).lstrip(\"0\"),\n transform=plt.gca().transAxes,\n size=15,\n horizontalalignment=\"right\",\n )\n plot_num += 1\n\nplt.show()"
0 commit comments