|
16 | 16 |
|
17 | 17 | """
|
18 | 18 | import numpy as np
|
| 19 | +from matplotlib import pyplot as plt |
19 | 20 |
|
20 | 21 | from sklearn.model_selection import train_test_split
|
21 | 22 | from sklearn.datasets import load_iris
|
22 | 23 | from sklearn.tree import DecisionTreeClassifier
|
| 24 | +from sklearn import tree |
| 25 | + |
| 26 | +############################################################################## |
| 27 | +# Train tree classifier |
| 28 | +# --------------------- |
| 29 | +# First, we fit a :class:`~sklearn.tree.DecisionTreeClassifier` using the |
| 30 | +# :func:`~sklearn.datasets.load_iris` dataset. |
23 | 31 |
|
24 | 32 | iris = load_iris()
|
25 | 33 | X = iris.data
|
26 | 34 | y = iris.target
|
27 | 35 | X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
|
28 | 36 |
|
29 |
| -estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0) |
30 |
| -estimator.fit(X_train, y_train) |
| 37 | +clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0) |
| 38 | +clf.fit(X_train, y_train) |
31 | 39 |
|
32 |
| -# The decision estimator has an attribute called tree_ which stores the entire |
33 |
| -# tree structure and allows access to low level attributes. The binary tree |
34 |
| -# tree_ is represented as a number of parallel arrays. The i-th element of each |
35 |
| -# array holds information about the node `i`. Node 0 is the tree's root. NOTE: |
36 |
| -# Some of the arrays only apply to either leaves or split nodes, resp. In this |
37 |
| -# case the values of nodes of the other type are arbitrary! |
| 40 | +############################################################################## |
| 41 | +# Tree structure |
| 42 | +# -------------- |
38 | 43 | #
|
39 |
| -# Among those arrays, we have: |
40 |
| -# - left_child, id of the left child of the node |
41 |
| -# - right_child, id of the right child of the node |
42 |
| -# - feature, feature used for splitting the node |
43 |
| -# - threshold, threshold value at the node |
| 44 | +# The decision classifier has an attribute called ``tree_`` which allows access |
| 45 | +# to low level attributes such as ``node_count``, the total number of nodes, |
| 46 | +# and ``max_depth``, the maximal depth of the tree. It also stores the |
| 47 | +# entire binary tree structure, represented as a number of parallel arrays. The |
| 48 | +# i-th element of each array holds information about the node ``i``. Node 0 is |
| 49 | +# the tree's root. Some of the arrays only apply to either leaves or split |
| 50 | +# nodes. In this case the values of the nodes of the other type is arbitrary. |
| 51 | +# For example, the arrays ``feature`` and ``threshold`` only apply to split |
| 52 | +# nodes. The values for leaf nodes in these arrays are therefore arbitrary. |
44 | 53 | #
|
| 54 | +# Among these arrays, we have: |
| 55 | +# |
| 56 | +# - ``children_left[i]``: id of the left child of node ``i`` or -1 if leaf |
| 57 | +# node |
| 58 | +# - ``children_right[i]``: id of the right child of node ``i`` or -1 if leaf |
| 59 | +# node |
| 60 | +# - ``feature[i]``: feature used for splitting node ``i`` |
| 61 | +# - ``threshold[i]``: threshold value at node ``i`` |
| 62 | +# - ``n_node_samples[i]``: the number of of training samples reaching node |
| 63 | +# ``i`` |
| 64 | +# - ``impurity[i]``: the impurity at node ``i`` |
| 65 | +# |
| 66 | +# Using the arrays, we can traverse the tree structure to compute various |
| 67 | +# properties. Below, we will compute the depth of each node and whether or not |
| 68 | +# it is a leaf. |
45 | 69 |
|
46 |
| -# Using those arrays, we can parse the tree structure: |
47 |
| - |
48 |
| -n_nodes = estimator.tree_.node_count |
49 |
| -children_left = estimator.tree_.children_left |
50 |
| -children_right = estimator.tree_.children_right |
51 |
| -feature = estimator.tree_.feature |
52 |
| -threshold = estimator.tree_.threshold |
53 |
| - |
| 70 | +n_nodes = clf.tree_.node_count |
| 71 | +children_left = clf.tree_.children_left |
| 72 | +children_right = clf.tree_.children_right |
| 73 | +feature = clf.tree_.feature |
| 74 | +threshold = clf.tree_.threshold |
54 | 75 |
|
55 |
| -# The tree structure can be traversed to compute various properties such |
56 |
| -# as the depth of each node and whether or not it is a leaf. |
57 | 76 | node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
|
58 | 77 | is_leaves = np.zeros(shape=n_nodes, dtype=bool)
|
59 |
| -stack = [(0, -1)] # seed is the root node id and its parent depth |
| 78 | +stack = [(0, 0)] # start with the root node id (0) and its depth (0) |
60 | 79 | while len(stack) > 0:
|
61 |
| - node_id, parent_depth = stack.pop() |
62 |
| - node_depth[node_id] = parent_depth + 1 |
63 |
| - |
64 |
| - # If we have a test node |
65 |
| - if (children_left[node_id] != children_right[node_id]): |
66 |
| - stack.append((children_left[node_id], parent_depth + 1)) |
67 |
| - stack.append((children_right[node_id], parent_depth + 1)) |
| 80 | + # `pop` ensures each node is only visited once |
| 81 | + node_id, depth = stack.pop() |
| 82 | + node_depth[node_id] = depth |
| 83 | + |
| 84 | + # If the left and right child of a node is not the same we have a split |
| 85 | + # node |
| 86 | + is_split_node = children_left[node_id] != children_right[node_id] |
| 87 | + # If a split node, append left and right children and depth to `stack` |
| 88 | + # so we can loop through them |
| 89 | + if is_split_node: |
| 90 | + stack.append((children_left[node_id], depth + 1)) |
| 91 | + stack.append((children_right[node_id], depth + 1)) |
68 | 92 | else:
|
69 | 93 | is_leaves[node_id] = True
|
70 | 94 |
|
71 |
| -print("The binary tree structure has %s nodes and has " |
72 |
| - "the following tree structure:" |
73 |
| - % n_nodes) |
| 95 | +print("The binary tree structure has {n} nodes and has " |
| 96 | + "the following tree structure:\n".format(n=n_nodes)) |
74 | 97 | for i in range(n_nodes):
|
75 | 98 | if is_leaves[i]:
|
76 |
| - print("%snode=%s leaf node." % (node_depth[i] * "\t", i)) |
| 99 | + print("{space}node={node} is a leaf node.".format( |
| 100 | + space=node_depth[i] * "\t", node=i)) |
77 | 101 | else:
|
78 |
| - print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to " |
79 |
| - "node %s." |
80 |
| - % (node_depth[i] * "\t", |
81 |
| - i, |
82 |
| - children_left[i], |
83 |
| - feature[i], |
84 |
| - threshold[i], |
85 |
| - children_right[i], |
86 |
| - )) |
87 |
| -print() |
88 |
| - |
89 |
| -# First let's retrieve the decision path of each sample. The decision_path |
90 |
| -# method allows to retrieve the node indicator functions. A non zero element of |
91 |
| -# indicator matrix at the position (i, j) indicates that the sample i goes |
92 |
| -# through the node j. |
93 |
| - |
94 |
| -node_indicator = estimator.decision_path(X_test) |
95 |
| - |
96 |
| -# Similarly, we can also have the leaves ids reached by each sample. |
97 |
| - |
98 |
| -leave_id = estimator.apply(X_test) |
| 102 | + print("{space}node={node} is a split node: " |
| 103 | + "go to node {left} if X[:, {feature}] <= {threshold} " |
| 104 | + "else to node {right}.".format( |
| 105 | + space=node_depth[i] * "\t", |
| 106 | + node=i, |
| 107 | + left=children_left[i], |
| 108 | + feature=feature[i], |
| 109 | + threshold=threshold[i], |
| 110 | + right=children_right[i])) |
| 111 | + |
| 112 | +############################################################################## |
| 113 | +# We can compare the above output to the plot of the decision tree. |
| 114 | + |
| 115 | +tree.plot_tree(clf) |
| 116 | +plt.show() |
| 117 | + |
| 118 | +############################################################################## |
| 119 | +# Decision path |
| 120 | +# ------------- |
| 121 | +# |
| 122 | +# We can also retrieve the decision path of samples of interest. The |
| 123 | +# ``decision_path`` method outputs an indicator matrix that allows us to |
| 124 | +# retrieve the nodes the samples of interest traverse through. A non zero |
| 125 | +# element in the indicator matrix at position ``(i, j)`` indicates that |
| 126 | +# the sample ``i`` goes through the node ``j``. Or, for one sample ``i``, the |
| 127 | +# positions of the non zero elements in row ``i`` of the indicator matrix |
| 128 | +# designate the ids of the nodes that sample goes through. |
| 129 | +# |
| 130 | +# The leaf ids reached by samples of interest can be obtained with the |
| 131 | +# ``apply`` method. This returns an array of the node ids of the leaves |
| 132 | +# reached by each sample of interest. Using the leaf ids and the |
| 133 | +# ``decision_path`` we can obtain the splitting conditions that were used to |
| 134 | +# predict a sample or a group of samples. First, let's do it for one sample. |
| 135 | +# Note that ``node_index`` is a sparse matrix. |
99 | 136 |
|
100 |
| -# Now, it's possible to get the tests that were used to predict a sample or |
101 |
| -# a group of samples. First, let's make it for the sample. |
| 137 | +node_indicator = clf.decision_path(X_test) |
| 138 | +leaf_id = clf.apply(X_test) |
102 | 139 |
|
103 | 140 | sample_id = 0
|
| 141 | +# obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id` |
104 | 142 | node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
|
105 | 143 | node_indicator.indptr[sample_id + 1]]
|
106 | 144 |
|
107 |
| -print('Rules used to predict sample %s: ' % sample_id) |
| 145 | +print('Rules used to predict sample {id}:\n'.format(id=sample_id)) |
108 | 146 | for node_id in node_index:
|
109 |
| - if leave_id[sample_id] == node_id: |
| 147 | + # continue to the next node if it is a leaf node |
| 148 | + if leaf_id[sample_id] == node_id: |
110 | 149 | continue
|
111 | 150 |
|
| 151 | + # check if value of the split feature for sample 0 is below threshold |
112 | 152 | if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
|
113 | 153 | threshold_sign = "<="
|
114 | 154 | else:
|
115 | 155 | threshold_sign = ">"
|
116 | 156 |
|
117 |
| - print("decision id node %s : (X_test[%s, %s] (= %s) %s %s)" |
118 |
| - % (node_id, |
119 |
| - sample_id, |
120 |
| - feature[node_id], |
121 |
| - X_test[sample_id, feature[node_id]], |
122 |
| - threshold_sign, |
123 |
| - threshold[node_id])) |
| 157 | + print("decision node {node} : (X_test[{sample}, {feature}] = {value}) " |
| 158 | + "{inequality} {threshold})".format( |
| 159 | + node=node_id, |
| 160 | + sample=sample_id, |
| 161 | + feature=feature[node_id], |
| 162 | + value=X_test[sample_id, feature[node_id]], |
| 163 | + inequality=threshold_sign, |
| 164 | + threshold=threshold[node_id])) |
| 165 | + |
| 166 | +############################################################################## |
| 167 | +# For a group of samples, we can determine the common nodes the samples go |
| 168 | +# through. |
124 | 169 |
|
125 |
| -# For a group of samples, we have the following common node. |
126 | 170 | sample_ids = [0, 1]
|
| 171 | +# boolean array indicating the nodes both samples go through |
127 | 172 | common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
|
128 | 173 | len(sample_ids))
|
129 |
| - |
| 174 | +# obtain node ids using position in array |
130 | 175 | common_node_id = np.arange(n_nodes)[common_nodes]
|
131 | 176 |
|
132 |
| -print("\nThe following samples %s share the node %s in the tree" |
133 |
| - % (sample_ids, common_node_id)) |
134 |
| -print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,)) |
| 177 | +print("\nThe following samples {samples} share the node(s) {nodes} in the " |
| 178 | + "tree.".format(samples=sample_ids, nodes=common_node_id)) |
| 179 | +print("This is {prop}% of all nodes.".format( |
| 180 | + prop=100 * len(common_node_id) / n_nodes)) |
0 commit comments