|
| 1 | +""" |
| 2 | +========================================= |
| 3 | +Understanding the decision tree structure |
| 4 | +========================================= |
| 5 | +
|
| 6 | +The decision tree structure can be analysed to gain further insight on the |
| 7 | +relation between the features and the target to predict. In this example, we |
| 8 | +show how to retrieve: |
| 9 | +
|
| 10 | +- the binary tree structure; |
| 11 | +- the depth of each node and whether or not it's a leaf; |
| 12 | +- the nodes that were reached by a sample using the ``decision_path`` method; |
| 13 | +- the leaf that was reached by a sample using the apply method; |
| 14 | +- the rules that were used to predict a sample; |
| 15 | +- the decision path shared by a group of samples. |
| 16 | +
|
| 17 | +""" |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from sklearn.cross_validation import train_test_split |
| 21 | +from sklearn.datasets import load_iris |
| 22 | +from sklearn.tree import DecisionTreeClassifier |
| 23 | + |
| 24 | +iris = load_iris() |
| 25 | +X = iris.data |
| 26 | +y = iris.target |
| 27 | +X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) |
| 28 | + |
| 29 | +estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0) |
| 30 | +estimator.fit(X_train, y_train) |
| 31 | + |
| 32 | +# The decision estimator has an attribute called tree_ which stores the entire |
| 33 | +# tree structure and allows access to low level attributes. The binary tree |
| 34 | +# tree_ is represented as a number of parallel arrays. The i-th element of each |
| 35 | +# array holds information about the node `i`. Node 0 is the tree's root. NOTE: |
| 36 | +# Some of the arrays only apply to either leaves or split nodes, resp. In this |
| 37 | +# case the values of nodes of the other type are arbitrary! |
| 38 | +# |
| 39 | +# Among those arrays, we have: |
| 40 | +# - left_child, id of the left child of the node |
| 41 | +# - right_child, id of the right child of the node |
| 42 | +# - feature, feature used for splitting the node |
| 43 | +# - threshold, threshold value at the node |
| 44 | +# |
| 45 | + |
| 46 | +# Using those arrays, we can parse the tree structure: |
| 47 | + |
| 48 | +n_nodes = estimator.tree_.node_count |
| 49 | +children_left = estimator.tree_.children_left |
| 50 | +children_right = estimator.tree_.children_right |
| 51 | +feature = estimator.tree_.feature |
| 52 | +threshold = estimator.tree_.threshold |
| 53 | + |
| 54 | + |
| 55 | +# The tree structure can be traversed to compute various properties such |
| 56 | +# as the depth of each node and whether or not it is a leaf. |
| 57 | +node_depth = np.zeros(shape=n_nodes) |
| 58 | +is_leaves = np.zeros(shape=n_nodes, dtype=bool) |
| 59 | +stack = [(0, -1)] # seed is the root node id and its parent depth |
| 60 | +while len(stack) > 0: |
| 61 | + node_id, parent_depth = stack.pop() |
| 62 | + node_depth[node_id] = parent_depth + 1 |
| 63 | + |
| 64 | + # If we have a test node |
| 65 | + if (children_left[node_id] != children_right[node_id]): |
| 66 | + stack.append((children_left[node_id], parent_depth + 1)) |
| 67 | + stack.append((children_right[node_id], parent_depth + 1)) |
| 68 | + else: |
| 69 | + is_leaves[node_id] = True |
| 70 | + |
| 71 | +print("The binary tree structure has %s nodes and has " |
| 72 | + "the following tree structure:" |
| 73 | + % n_nodes) |
| 74 | +for i in range(n_nodes): |
| 75 | + if is_leaves[i]: |
| 76 | + print("%snode=%s leaf node." % (node_depth[i] * "\t", i)) |
| 77 | + else: |
| 78 | + print("%snode=%s test node: go to node %s if X[:, %s] <= %ss else to " |
| 79 | + "node %s." |
| 80 | + % (node_depth[i] * "\t", |
| 81 | + i, |
| 82 | + children_left[i], |
| 83 | + feature[i], |
| 84 | + threshold[i], |
| 85 | + children_right[i], |
| 86 | + )) |
| 87 | +print() |
| 88 | + |
| 89 | +# First let's retrieve the decision path of each sample. The decision_path |
| 90 | +# method allows to retrieve the node indicator functions. A non zero element of |
| 91 | +# indicator matrix at the position (i, j) indicates that the sample i goes |
| 92 | +# through the node j. |
| 93 | + |
| 94 | +node_indicator = estimator.decision_path(X_test) |
| 95 | + |
| 96 | +# Similarly, we can also have the leaves ids reached by each sample. |
| 97 | + |
| 98 | +leave_id = estimator.apply(X_test) |
| 99 | + |
| 100 | +# Now, it's possible to get the tests that were used to predict a sample or |
| 101 | +# a group of samples. First, let's make it for the sample. |
| 102 | + |
| 103 | +sample_id = 0 |
| 104 | +node_index = node_indicator.indices[node_indicator.indptr[sample_id]: |
| 105 | + node_indicator.indptr[sample_id + 1]] |
| 106 | + |
| 107 | +print('Rules used to predict sample %s: ' % sample_id) |
| 108 | +for node_id in node_index: |
| 109 | + if leave_id[sample_id] != node_id: |
| 110 | + continue |
| 111 | + |
| 112 | + if (X_test[sample_id, feature[node_id]] <= threshold[node_id]): |
| 113 | + threshold_sign = "<=" |
| 114 | + else: |
| 115 | + threshold_sign = ">" |
| 116 | + |
| 117 | + print("decision id node %s : (X[%s, %s] (= %s) %s %s)" |
| 118 | + % (node_id, |
| 119 | + sample_id, |
| 120 | + feature[node_id], |
| 121 | + X_test[i, feature[node_id]], |
| 122 | + threshold_sign, |
| 123 | + threshold[node_id])) |
| 124 | + |
| 125 | +# For a group of samples, we have the following common node. |
| 126 | +sample_ids = [0, 1] |
| 127 | +common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) == |
| 128 | + len(sample_ids)) |
| 129 | + |
| 130 | +common_node_id = np.arange(n_nodes)[common_nodes] |
| 131 | + |
| 132 | +print("\nThe following samples %s share the node %s in the tree" |
| 133 | + % (sample_ids, common_node_id)) |
| 134 | +print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,)) |
0 commit comments