"# Author: Matt Terry <
[email protected]>\n#\n# License: BSD 3 clause\nfrom __future__ import print_function\n\nimport numpy as np\n\nfrom sklearn.base import BaseEstimator, TransformerMixin\nfrom sklearn.datasets import fetch_20newsgroups\nfrom sklearn.datasets.twenty_newsgroups import strip_newsgroup_footer\nfrom sklearn.datasets.twenty_newsgroups import strip_newsgroup_quoting\nfrom sklearn.decomposition import TruncatedSVD\nfrom sklearn.feature_extraction import DictVectorizer\nfrom sklearn.feature_extraction.text import TfidfVectorizer\nfrom sklearn.metrics import classification_report\nfrom sklearn.pipeline import FeatureUnion\nfrom sklearn.pipeline import Pipeline\nfrom sklearn.svm import SVC\n\n\nclass ItemSelector(BaseEstimator, TransformerMixin):\n \"\"\"For data grouped by feature, select subset of data at a provided key.\n\n The data is expected to be stored in a 2D data structure, where the first\n index is over features and the second is over samples. i.e.\n\n >> len(data[key]) == n_samples\n\n Please note that this is the opposite convention to scikit-learn feature\n matrixes (where the first index corresponds to sample).\n\n ItemSelector only requires that the collection implement getitem\n (data[key]). Examples include: a dict of lists, 2D numpy array, Pandas\n DataFrame, numpy record array, etc.\n\n >> data = {'a': [1, 5, 2, 5, 2, 8],\n 'b': [9, 4, 1, 4, 1, 3]}\n >> ds = ItemSelector(key='a')\n >> data['a'] == ds.transform(data)\n\n ItemSelector is not designed to handle data grouped by sample. (e.g. a\n list of dicts). If your data is structured this way, consider a\n transformer along the lines of `sklearn.feature_extraction.DictVectorizer`.\n\n Parameters\n ----------\n key : hashable, required\n The key corresponding to the desired value in a mappable.\n \"\"\"\n def __init__(self, key):\n self.key = key\n\n def fit(self, x, y=None):\n return self\n\n def transform(self, data_dict):\n return data_dict[self.key]\n\n\nclass TextStats(BaseEstimator, TransformerMixin):\n \"\"\"Extract features from each document for DictVectorizer\"\"\"\n\n def fit(self, x, y=None):\n return self\n\n def transform(self, posts):\n return [{'length': len(text),\n 'num_sentences': text.count('.')}\n for text in posts]\n\n\nclass SubjectBodyExtractor(BaseEstimator, TransformerMixin):\n \"\"\"Extract the subject & body from a usenet post in a single pass.\n\n Takes a sequence of strings and produces a dict of sequences. Keys are\n `subject` and `body`.\n \"\"\"\n def fit(self, x, y=None):\n return self\n\n def transform(self, posts):\n features = np.recarray(shape=(len(posts),),\n dtype=[('subject', object), ('body', object)])\n for i, text in enumerate(posts):\n headers, _, bod = text.partition('\\n\\n')\n bod = strip_newsgroup_footer(bod)\n bod = strip_newsgroup_quoting(bod)\n features['body'][i] = bod\n\n prefix = 'Subject:'\n sub = ''\n for line in headers.split('\\n'):\n if line.startswith(prefix):\n sub = line[len(prefix):]\n break\n features['subject'][i] = sub\n\n return features\n\n\npipeline = Pipeline([\n # Extract the subject & body\n ('subjectbody', SubjectBodyExtractor()),\n\n # Use FeatureUnion to combine the features from subject and body\n ('union', FeatureUnion(\n transformer_list=[\n\n # Pipeline for pulling features from the post's subject line\n ('subject', Pipeline([\n ('selector', ItemSelector(key='subject')),\n ('tfidf', TfidfVectorizer(min_df=50)),\n ])),\n\n # Pipeline for standard bag-of-words model for body\n ('body_bow', Pipeline([\n ('selector', ItemSelector(key='body')),\n ('tfidf', TfidfVectorizer()),\n ('best', TruncatedSVD(n_components=50)),\n ])),\n\n # Pipeline for pulling ad hoc features from post's body\n ('body_stats', Pipeline([\n ('selector', ItemSelector(key='body')),\n ('stats', TextStats()), # returns a list of dicts\n ('vect', DictVectorizer()), # list of dicts -> feature matrix\n ])),\n\n ],\n\n # weight components in FeatureUnion\n transformer_weights={\n 'subject': 0.8,\n 'body_bow': 0.5,\n 'body_stats': 1.0,\n },\n )),\n\n # Use a SVC classifier on the combined features\n ('svc', SVC(kernel='linear')),\n])\n\n# limit the list of categories to make running this example faster.\ncategories = ['alt.atheism', 'talk.religion.misc']\ntrain = fetch_20newsgroups(random_state=1,\n subset='train',\n categories=categories,\n )\ntest = fetch_20newsgroups(random_state=1,\n subset='test',\n categories=categories,\n )\n\npipeline.fit(train.data, train.target)\ny = pipeline.predict(test.data)\nprint(classification_report(y, test.target))"
0 commit comments