|
| 1 | +""" |
| 2 | +====================================================== |
| 3 | +Classification of text documents using sparse features |
| 4 | +====================================================== |
| 5 | +
|
| 6 | +This is an example showing how scikit-learn can be used to classify documents |
| 7 | +by topics using a bag-of-words approach. This example uses a scipy.sparse |
| 8 | +matrix to store the features and demonstrates various classifiers that can |
| 9 | +efficiently handle sparse matrices. |
| 10 | +
|
| 11 | +The dataset used in this example is the 20 newsgroups dataset. It will be |
| 12 | +automatically downloaded, then cached. |
| 13 | +
|
| 14 | +The bar plot indicates the accuracy, training time (normalized) and test time |
| 15 | +(normalized) of each classifier. |
| 16 | +
|
| 17 | +""" |
| 18 | + |
| 19 | +# Author: Peter Prettenhofer <[email protected]> |
| 20 | +# Olivier Grisel <[email protected]> |
| 21 | +# Mathieu Blondel <[email protected]> |
| 22 | +# Lars Buitinck <[email protected]> |
| 23 | +# License: BSD 3 clause |
| 24 | + |
| 25 | +from __future__ import print_function |
| 26 | + |
| 27 | +import logging |
| 28 | +import numpy as np |
| 29 | +from optparse import OptionParser |
| 30 | +import sys |
| 31 | +from time import time |
| 32 | +import matplotlib.pyplot as plt |
| 33 | + |
| 34 | +from sklearn.datasets import fetch_20newsgroups |
| 35 | +from sklearn.feature_extraction.text import TfidfVectorizer |
| 36 | +from sklearn.feature_extraction.text import HashingVectorizer |
| 37 | +from sklearn.feature_selection import SelectKBest, chi2 |
| 38 | +from sklearn.linear_model import RidgeClassifier |
| 39 | +from sklearn.svm import LinearSVC |
| 40 | +from sklearn.linear_model import SGDClassifier |
| 41 | +from sklearn.linear_model import Perceptron |
| 42 | +from sklearn.linear_model import PassiveAggressiveClassifier |
| 43 | +from sklearn.naive_bayes import BernoulliNB, MultinomialNB |
| 44 | +from sklearn.neighbors import KNeighborsClassifier |
| 45 | +from sklearn.neighbors import NearestCentroid |
| 46 | +from sklearn.utils.extmath import density |
| 47 | +from sklearn import metrics |
| 48 | + |
| 49 | + |
| 50 | +# Display progress logs on stdout |
| 51 | +logging.basicConfig(level=logging.INFO, |
| 52 | + format='%(asctime)s %(levelname)s %(message)s') |
| 53 | + |
| 54 | + |
| 55 | +# parse commandline arguments |
| 56 | +op = OptionParser() |
| 57 | +op.add_option("--report", |
| 58 | + action="store_true", dest="print_report", |
| 59 | + help="Print a detailed classification report.") |
| 60 | +op.add_option("--chi2_select", |
| 61 | + action="store", type="int", dest="select_chi2", |
| 62 | + help="Select some number of features using a chi-squared test") |
| 63 | +op.add_option("--confusion_matrix", |
| 64 | + action="store_true", dest="print_cm", |
| 65 | + help="Print the confusion matrix.") |
| 66 | +op.add_option("--top10", |
| 67 | + action="store_true", dest="print_top10", |
| 68 | + help="Print ten most discriminative terms per class" |
| 69 | + " for every classifier.") |
| 70 | +op.add_option("--all_categories", |
| 71 | + action="store_true", dest="all_categories", |
| 72 | + help="Whether to use all categories or not.") |
| 73 | +op.add_option("--use_hashing", |
| 74 | + action="store_true", |
| 75 | + help="Use a hashing vectorizer.") |
| 76 | +op.add_option("--n_features", |
| 77 | + action="store", type=int, default=2 ** 16, |
| 78 | + help="n_features when using the hashing vectorizer.") |
| 79 | +op.add_option("--filtered", |
| 80 | + action="store_true", |
| 81 | + help="Remove newsgroup information that is easily overfit: " |
| 82 | + "headers, signatures, and quoting.") |
| 83 | + |
| 84 | +(opts, args) = op.parse_args() |
| 85 | +if len(args) > 0: |
| 86 | + op.error("this script takes no arguments.") |
| 87 | + sys.exit(1) |
| 88 | + |
| 89 | +print(__doc__) |
| 90 | +op.print_help() |
| 91 | +print() |
| 92 | + |
| 93 | + |
| 94 | +############################################################################### |
| 95 | +# Load some categories from the training set |
| 96 | +if opts.all_categories: |
| 97 | + categories = None |
| 98 | +else: |
| 99 | + categories = [ |
| 100 | + 'alt.atheism', |
| 101 | + 'talk.religion.misc', |
| 102 | + 'comp.graphics', |
| 103 | + 'sci.space', |
| 104 | + ] |
| 105 | + |
| 106 | +if opts.filtered: |
| 107 | + remove = ('headers', 'footers', 'quotes') |
| 108 | +else: |
| 109 | + remove = () |
| 110 | + |
| 111 | +print("Loading 20 newsgroups dataset for categories:") |
| 112 | +print(categories if categories else "all") |
| 113 | + |
| 114 | +data_train = fetch_20newsgroups(subset='train', categories=categories, |
| 115 | + shuffle=True, random_state=42, |
| 116 | + remove=remove) |
| 117 | + |
| 118 | +data_test = fetch_20newsgroups(subset='test', categories=categories, |
| 119 | + shuffle=True, random_state=42, |
| 120 | + remove=remove) |
| 121 | +print('data loaded') |
| 122 | + |
| 123 | +categories = data_train.target_names # for case categories == None |
| 124 | + |
| 125 | + |
| 126 | +def size_mb(docs): |
| 127 | + return sum(len(s.encode('utf-8')) for s in docs) / 1e6 |
| 128 | + |
| 129 | +data_train_size_mb = size_mb(data_train.data) |
| 130 | +data_test_size_mb = size_mb(data_test.data) |
| 131 | + |
| 132 | +print("%d documents - %0.3fMB (training set)" % ( |
| 133 | + len(data_train.data), data_train_size_mb)) |
| 134 | +print("%d documents - %0.3fMB (test set)" % ( |
| 135 | + len(data_test.data), data_test_size_mb)) |
| 136 | +print("%d categories" % len(categories)) |
| 137 | +print() |
| 138 | + |
| 139 | +# split a training set and a test set |
| 140 | +y_train, y_test = data_train.target, data_test.target |
| 141 | + |
| 142 | +print("Extracting features from the training dataset using a sparse vectorizer") |
| 143 | +t0 = time() |
| 144 | +if opts.use_hashing: |
| 145 | + vectorizer = HashingVectorizer(stop_words='english', non_negative=True, |
| 146 | + n_features=opts.n_features) |
| 147 | + X_train = vectorizer.transform(data_train.data) |
| 148 | +else: |
| 149 | + vectorizer = TfidfVectorizer(sublinear_tf=True, max_df=0.5, |
| 150 | + stop_words='english') |
| 151 | + X_train = vectorizer.fit_transform(data_train.data) |
| 152 | +duration = time() - t0 |
| 153 | +print("done in %fs at %0.3fMB/s" % (duration, data_train_size_mb / duration)) |
| 154 | +print("n_samples: %d, n_features: %d" % X_train.shape) |
| 155 | +print() |
| 156 | + |
| 157 | +print("Extracting features from the test dataset using the same vectorizer") |
| 158 | +t0 = time() |
| 159 | +X_test = vectorizer.transform(data_test.data) |
| 160 | +duration = time() - t0 |
| 161 | +print("done in %fs at %0.3fMB/s" % (duration, data_test_size_mb / duration)) |
| 162 | +print("n_samples: %d, n_features: %d" % X_test.shape) |
| 163 | +print() |
| 164 | + |
| 165 | +if opts.select_chi2: |
| 166 | + print("Extracting %d best features by a chi-squared test" % |
| 167 | + opts.select_chi2) |
| 168 | + t0 = time() |
| 169 | + ch2 = SelectKBest(chi2, k=opts.select_chi2) |
| 170 | + X_train = ch2.fit_transform(X_train, y_train) |
| 171 | + X_test = ch2.transform(X_test) |
| 172 | + print("done in %fs" % (time() - t0)) |
| 173 | + print() |
| 174 | + |
| 175 | + |
| 176 | +def trim(s): |
| 177 | + """Trim string to fit on terminal (assuming 80-column display)""" |
| 178 | + return s if len(s) <= 80 else s[:77] + "..." |
| 179 | + |
| 180 | + |
| 181 | +# mapping from integer feature name to original token string |
| 182 | +if opts.use_hashing: |
| 183 | + feature_names = None |
| 184 | +else: |
| 185 | + feature_names = np.asarray(vectorizer.get_feature_names()) |
| 186 | + |
| 187 | + |
| 188 | +############################################################################### |
| 189 | +# Benchmark classifiers |
| 190 | +def benchmark(clf): |
| 191 | + print('_' * 80) |
| 192 | + print("Training: ") |
| 193 | + print(clf) |
| 194 | + t0 = time() |
| 195 | + clf.fit(X_train, y_train) |
| 196 | + train_time = time() - t0 |
| 197 | + print("train time: %0.3fs" % train_time) |
| 198 | + |
| 199 | + t0 = time() |
| 200 | + pred = clf.predict(X_test) |
| 201 | + test_time = time() - t0 |
| 202 | + print("test time: %0.3fs" % test_time) |
| 203 | + |
| 204 | + score = metrics.f1_score(y_test, pred) |
| 205 | + print("f1-score: %0.3f" % score) |
| 206 | + |
| 207 | + if hasattr(clf, 'coef_'): |
| 208 | + print("dimensionality: %d" % clf.coef_.shape[1]) |
| 209 | + print("density: %f" % density(clf.coef_)) |
| 210 | + |
| 211 | + if opts.print_top10 and feature_names is not None: |
| 212 | + print("top 10 keywords per class:") |
| 213 | + for i, category in enumerate(categories): |
| 214 | + top10 = np.argsort(clf.coef_[i])[-10:] |
| 215 | + print(trim("%s: %s" |
| 216 | + % (category, " ".join(feature_names[top10])))) |
| 217 | + print() |
| 218 | + |
| 219 | + if opts.print_report: |
| 220 | + print("classification report:") |
| 221 | + print(metrics.classification_report(y_test, pred, |
| 222 | + target_names=categories)) |
| 223 | + |
| 224 | + if opts.print_cm: |
| 225 | + print("confusion matrix:") |
| 226 | + print(metrics.confusion_matrix(y_test, pred)) |
| 227 | + |
| 228 | + print() |
| 229 | + clf_descr = str(clf).split('(')[0] |
| 230 | + return clf_descr, score, train_time, test_time |
| 231 | + |
| 232 | + |
| 233 | +results = [] |
| 234 | +for clf, name in ( |
| 235 | + (RidgeClassifier(tol=1e-2, solver="lsqr"), "Ridge Classifier"), |
| 236 | + (Perceptron(n_iter=50), "Perceptron"), |
| 237 | + (PassiveAggressiveClassifier(n_iter=50), "Passive-Aggressive"), |
| 238 | + (KNeighborsClassifier(n_neighbors=10), "kNN")): |
| 239 | + print('=' * 80) |
| 240 | + print(name) |
| 241 | + results.append(benchmark(clf)) |
| 242 | + |
| 243 | +for penalty in ["l2", "l1"]: |
| 244 | + print('=' * 80) |
| 245 | + print("%s penalty" % penalty.upper()) |
| 246 | + # Train Liblinear model |
| 247 | + results.append(benchmark(LinearSVC(loss='l2', penalty=penalty, |
| 248 | + dual=False, tol=1e-3))) |
| 249 | + |
| 250 | + # Train SGD model |
| 251 | + results.append(benchmark(SGDClassifier(alpha=.0001, n_iter=50, |
| 252 | + penalty=penalty))) |
| 253 | + |
| 254 | +# Train SGD with Elastic Net penalty |
| 255 | +print('=' * 80) |
| 256 | +print("Elastic-Net penalty") |
| 257 | +results.append(benchmark(SGDClassifier(alpha=.0001, n_iter=50, |
| 258 | + penalty="elasticnet"))) |
| 259 | + |
| 260 | +# Train NearestCentroid without threshold |
| 261 | +print('=' * 80) |
| 262 | +print("NearestCentroid (aka Rocchio classifier)") |
| 263 | +results.append(benchmark(NearestCentroid())) |
| 264 | + |
| 265 | +# Train sparse Naive Bayes classifiers |
| 266 | +print('=' * 80) |
| 267 | +print("Naive Bayes") |
| 268 | +results.append(benchmark(MultinomialNB(alpha=.01))) |
| 269 | +results.append(benchmark(BernoulliNB(alpha=.01))) |
| 270 | + |
| 271 | + |
| 272 | +class L1LinearSVC(LinearSVC): |
| 273 | + |
| 274 | + def fit(self, X, y): |
| 275 | + # The smaller C, the stronger the regularization. |
| 276 | + # The more regularization, the more sparsity. |
| 277 | + self.transformer_ = LinearSVC(penalty="l1", |
| 278 | + dual=False, tol=1e-3) |
| 279 | + X = self.transformer_.fit_transform(X, y) |
| 280 | + return LinearSVC.fit(self, X, y) |
| 281 | + |
| 282 | + def predict(self, X): |
| 283 | + X = self.transformer_.transform(X) |
| 284 | + return LinearSVC.predict(self, X) |
| 285 | + |
| 286 | +print('=' * 80) |
| 287 | +print("LinearSVC with L1-based feature selection") |
| 288 | +results.append(benchmark(L1LinearSVC())) |
| 289 | + |
| 290 | + |
| 291 | +# make some plots |
| 292 | + |
| 293 | +indices = np.arange(len(results)) |
| 294 | + |
| 295 | +results = [[x[i] for x in results] for i in range(4)] |
| 296 | + |
| 297 | +clf_names, score, training_time, test_time = results |
| 298 | +training_time = np.array(training_time) / np.max(training_time) |
| 299 | +test_time = np.array(test_time) / np.max(test_time) |
| 300 | + |
| 301 | +plt.figure(figsize=(12, 8)) |
| 302 | +plt.title("Score") |
| 303 | +plt.barh(indices, score, .2, label="score", color='r') |
| 304 | +plt.barh(indices + .3, training_time, .2, label="training time", color='g') |
| 305 | +plt.barh(indices + .6, test_time, .2, label="test time", color='b') |
| 306 | +plt.yticks(()) |
| 307 | +plt.legend(loc='best') |
| 308 | +plt.subplots_adjust(left=.25) |
| 309 | +plt.subplots_adjust(top=.95) |
| 310 | +plt.subplots_adjust(bottom=.05) |
| 311 | + |
| 312 | +for i, c in zip(indices, clf_names): |
| 313 | + plt.text(-.3, i, c) |
| 314 | + |
| 315 | +plt.show() |
0 commit comments