- "print(__doc__)\n\nfrom time import time\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport scipy as sp\n\nfrom sklearn.decomposition import MiniBatchDictionaryLearning\nfrom sklearn.feature_extraction.image import extract_patches_2d\nfrom sklearn.feature_extraction.image import reconstruct_from_patches_2d\n\n\ntry: # SciPy >= 0.16 have face in misc\n from scipy.misc import face\n face = face(gray=True)\nexcept ImportError:\n face = sp.face(gray=True)\n\n# Convert from uint8 representation with values between 0 and 255 to\n# a floating point representation with values between 0 and 1.\nface = face / 255.\n\n# downsample for higher speed\nface = face[::2, ::2] + face[1::2, ::2] + face[::2, 1::2] + face[1::2, 1::2]\nface /= 4.0\nheight, width = face.shape\n\n# Distort the right half of the image\nprint('Distorting image...')\ndistorted = face.copy()\ndistorted[:, width // 2:] += 0.075 * np.random.randn(height, width // 2)\n\n# Extract all reference patches from the left half of the image\nprint('Extracting reference patches...')\nt0 = time()\npatch_size = (7, 7)\ndata = extract_patches_2d(distorted[:, :width // 2], patch_size)\ndata = data.reshape(data.shape[0], -1)\ndata -= np.mean(data, axis=0)\ndata /= np.std(data, axis=0)\nprint('done in %.2fs.' % (time() - t0))\n\n# #############################################################################\n# Learn the dictionary from reference patches\n\nprint('Learning the dictionary...')\nt0 = time()\ndico = MiniBatchDictionaryLearning(n_components=100, alpha=1, n_iter=500)\nV = dico.fit(data).components_\ndt = time() - t0\nprint('done in %.2fs.' % dt)\n\nplt.figure(figsize=(4.2, 4))\nfor i, comp in enumerate(V[:100]):\n plt.subplot(10, 10, i + 1)\n plt.imshow(comp.reshape(patch_size), cmap=plt.cm.gray_r,\n interpolation='nearest')\n plt.xticks(())\n plt.yticks(())\nplt.suptitle('Dictionary learned from face patches\\n' +\n 'Train time %.1fs on %d patches' % (dt, len(data)),\n fontsize=16)\nplt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)\n\n\n# #############################################################################\n# Display the distorted image\n\ndef show_with_diff(image, reference, title):\n \"\"\"Helper function to display denoising\"\"\"\n plt.figure(figsize=(5, 3.3))\n plt.subplot(1, 2, 1)\n plt.title('Image')\n plt.imshow(image, vmin=0, vmax=1, cmap=plt.cm.gray,\n interpolation='nearest')\n plt.xticks(())\n plt.yticks(())\n plt.subplot(1, 2, 2)\n difference = image - reference\n\n plt.title('Difference (norm: %.2f)' % np.sqrt(np.sum(difference ** 2)))\n plt.imshow(difference, vmin=-0.5, vmax=0.5, cmap=plt.cm.PuOr,\n interpolation='nearest')\n plt.xticks(())\n plt.yticks(())\n plt.suptitle(title, size=16)\n plt.subplots_adjust(0.02, 0.02, 0.98, 0.79, 0.02, 0.2)\n\nshow_with_diff(distorted, face, 'Distorted image')\n\n# #############################################################################\n# Extract noisy patches and reconstruct them using the dictionary\n\nprint('Extracting noisy patches... ')\nt0 = time()\ndata = extract_patches_2d(distorted[:, width // 2:], patch_size)\ndata = data.reshape(data.shape[0], -1)\nintercept = np.mean(data, axis=0)\ndata -= intercept\nprint('done in %.2fs.' % (time() - t0))\n\ntransform_algorithms = [\n ('Orthogonal Matching Pursuit\\n1 atom', 'omp',\n {'transform_n_nonzero_coefs': 1}),\n ('Orthogonal Matching Pursuit\\n2 atoms', 'omp',\n {'transform_n_nonzero_coefs': 2}),\n ('Least-angle regression\\n5 atoms', 'lars',\n {'transform_n_nonzero_coefs': 5}),\n ('Thresholding\\n alpha=0.1', 'threshold', {'transform_alpha': .1})]\n\nreconstructions = {}\nfor title, transform_algorithm, kwargs in transform_algorithms:\n print(title + '...')\n reconstructions[title] = face.copy()\n t0 = time()\n dico.set_params(transform_algorithm=transform_algorithm, **kwargs)\n code = dico.transform(data)\n patches = np.dot(code, V)\n\n patches += intercept\n patches = patches.reshape(len(data), *patch_size)\n if transform_algorithm == 'threshold':\n patches -= patches.min()\n patches /= patches.max()\n reconstructions[title][:, width // 2:] = reconstruct_from_patches_2d(\n patches, (height, width // 2))\n dt = time() - t0\n print('done in %.2fs.' % dt)\n show_with_diff(reconstructions[title], face,\n title + ' (time: %.1fs)' % dt)\n\nplt.show()"
0 commit comments