Skip to content

Commit f3ac505

Browse files
committed
Pushing the docs to dev/ for branch: master, commit d46663c4812fda90c5cc33478cdba49201df1ae6
1 parent 1f83375 commit f3ac505

File tree

1,211 files changed

+4190
-3791
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,211 files changed

+4190
-3791
lines changed
Binary file not shown.

dev/_downloads/ae34d7cb4f9af673651750533e89c8cc/plot_unveil_tree_structure.py

Lines changed: 118 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,119 +16,165 @@
1616
1717
"""
1818
import numpy as np
19+
from matplotlib import pyplot as plt
1920

2021
from sklearn.model_selection import train_test_split
2122
from sklearn.datasets import load_iris
2223
from sklearn.tree import DecisionTreeClassifier
24+
from sklearn import tree
25+
26+
##############################################################################
27+
# Train tree classifier
28+
# ---------------------
29+
# First, we fit a :class:`~sklearn.tree.DecisionTreeClassifier` using the
30+
# :func:`~sklearn.datasets.load_iris` dataset.
2331

2432
iris = load_iris()
2533
X = iris.data
2634
y = iris.target
2735
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
2836

29-
estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
30-
estimator.fit(X_train, y_train)
37+
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
38+
clf.fit(X_train, y_train)
3139

32-
# The decision estimator has an attribute called tree_ which stores the entire
33-
# tree structure and allows access to low level attributes. The binary tree
34-
# tree_ is represented as a number of parallel arrays. The i-th element of each
35-
# array holds information about the node `i`. Node 0 is the tree's root. NOTE:
36-
# Some of the arrays only apply to either leaves or split nodes, resp. In this
37-
# case the values of nodes of the other type are arbitrary!
40+
##############################################################################
41+
# Tree structure
42+
# --------------
3843
#
39-
# Among those arrays, we have:
40-
# - left_child, id of the left child of the node
41-
# - right_child, id of the right child of the node
42-
# - feature, feature used for splitting the node
43-
# - threshold, threshold value at the node
44+
# The decision classifier has an attribute called ``tree_`` which allows access
45+
# to low level attributes such as ``node_count``, the total number of nodes,
46+
# and ``max_depth``, the maximal depth of the tree. It also stores the
47+
# entire binary tree structure, represented as a number of parallel arrays. The
48+
# i-th element of each array holds information about the node ``i``. Node 0 is
49+
# the tree's root. Some of the arrays only apply to either leaves or split
50+
# nodes. In this case the values of the nodes of the other type is arbitrary.
51+
# For example, the arrays ``feature`` and ``threshold`` only apply to split
52+
# nodes. The values for leaf nodes in these arrays are therefore arbitrary.
4453
#
54+
# Among these arrays, we have:
55+
#
56+
# - ``children_left[i]``: id of the left child of node ``i`` or -1 if leaf
57+
# node
58+
# - ``children_right[i]``: id of the right child of node ``i`` or -1 if leaf
59+
# node
60+
# - ``feature[i]``: feature used for splitting node ``i``
61+
# - ``threshold[i]``: threshold value at node ``i``
62+
# - ``n_node_samples[i]``: the number of of training samples reaching node
63+
# ``i``
64+
# - ``impurity[i]``: the impurity at node ``i``
65+
#
66+
# Using the arrays, we can traverse the tree structure to compute various
67+
# properties. Below, we will compute the depth of each node and whether or not
68+
# it is a leaf.
4569

46-
# Using those arrays, we can parse the tree structure:
47-
48-
n_nodes = estimator.tree_.node_count
49-
children_left = estimator.tree_.children_left
50-
children_right = estimator.tree_.children_right
51-
feature = estimator.tree_.feature
52-
threshold = estimator.tree_.threshold
53-
70+
n_nodes = clf.tree_.node_count
71+
children_left = clf.tree_.children_left
72+
children_right = clf.tree_.children_right
73+
feature = clf.tree_.feature
74+
threshold = clf.tree_.threshold
5475

55-
# The tree structure can be traversed to compute various properties such
56-
# as the depth of each node and whether or not it is a leaf.
5776
node_depth = np.zeros(shape=n_nodes, dtype=np.int64)
5877
is_leaves = np.zeros(shape=n_nodes, dtype=bool)
59-
stack = [(0, -1)] # seed is the root node id and its parent depth
78+
stack = [(0, 0)] # start with the root node id (0) and its depth (0)
6079
while len(stack) > 0:
61-
node_id, parent_depth = stack.pop()
62-
node_depth[node_id] = parent_depth + 1
63-
64-
# If we have a test node
65-
if (children_left[node_id] != children_right[node_id]):
66-
stack.append((children_left[node_id], parent_depth + 1))
67-
stack.append((children_right[node_id], parent_depth + 1))
80+
# `pop` ensures each node is only visited once
81+
node_id, depth = stack.pop()
82+
node_depth[node_id] = depth
83+
84+
# If the left and right child of a node is not the same we have a split
85+
# node
86+
is_split_node = children_left[node_id] != children_right[node_id]
87+
# If a split node, append left and right children and depth to `stack`
88+
# so we can loop through them
89+
if is_split_node:
90+
stack.append((children_left[node_id], depth + 1))
91+
stack.append((children_right[node_id], depth + 1))
6892
else:
6993
is_leaves[node_id] = True
7094

71-
print("The binary tree structure has %s nodes and has "
72-
"the following tree structure:"
73-
% n_nodes)
95+
print("The binary tree structure has {n} nodes and has "
96+
"the following tree structure:\n".format(n=n_nodes))
7497
for i in range(n_nodes):
7598
if is_leaves[i]:
76-
print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
99+
print("{space}node={node} is a leaf node.".format(
100+
space=node_depth[i] * "\t", node=i))
77101
else:
78-
print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
79-
"node %s."
80-
% (node_depth[i] * "\t",
81-
i,
82-
children_left[i],
83-
feature[i],
84-
threshold[i],
85-
children_right[i],
86-
))
87-
print()
88-
89-
# First let's retrieve the decision path of each sample. The decision_path
90-
# method allows to retrieve the node indicator functions. A non zero element of
91-
# indicator matrix at the position (i, j) indicates that the sample i goes
92-
# through the node j.
93-
94-
node_indicator = estimator.decision_path(X_test)
95-
96-
# Similarly, we can also have the leaves ids reached by each sample.
97-
98-
leave_id = estimator.apply(X_test)
102+
print("{space}node={node} is a split node: "
103+
"go to node {left} if X[:, {feature}] <= {threshold} "
104+
"else to node {right}.".format(
105+
space=node_depth[i] * "\t",
106+
node=i,
107+
left=children_left[i],
108+
feature=feature[i],
109+
threshold=threshold[i],
110+
right=children_right[i]))
111+
112+
##############################################################################
113+
# We can compare the above output to the plot of the decision tree.
114+
115+
tree.plot_tree(clf)
116+
plt.show()
117+
118+
##############################################################################
119+
# Decision path
120+
# -------------
121+
#
122+
# We can also retrieve the decision path of samples of interest. The
123+
# ``decision_path`` method outputs an indicator matrix that allows us to
124+
# retrieve the nodes the samples of interest traverse through. A non zero
125+
# element in the indicator matrix at position ``(i, j)`` indicates that
126+
# the sample ``i`` goes through the node ``j``. Or, for one sample ``i``, the
127+
# positions of the non zero elements in row ``i`` of the indicator matrix
128+
# designate the ids of the nodes that sample goes through.
129+
#
130+
# The leaf ids reached by samples of interest can be obtained with the
131+
# ``apply`` method. This returns an array of the node ids of the leaves
132+
# reached by each sample of interest. Using the leaf ids and the
133+
# ``decision_path`` we can obtain the splitting conditions that were used to
134+
# predict a sample or a group of samples. First, let's do it for one sample.
135+
# Note that ``node_index`` is a sparse matrix.
99136

100-
# Now, it's possible to get the tests that were used to predict a sample or
101-
# a group of samples. First, let's make it for the sample.
137+
node_indicator = clf.decision_path(X_test)
138+
leaf_id = clf.apply(X_test)
102139

103140
sample_id = 0
141+
# obtain ids of the nodes `sample_id` goes through, i.e., row `sample_id`
104142
node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
105143
node_indicator.indptr[sample_id + 1]]
106144

107-
print('Rules used to predict sample %s: ' % sample_id)
145+
print('Rules used to predict sample {id}:\n'.format(id=sample_id))
108146
for node_id in node_index:
109-
if leave_id[sample_id] == node_id:
147+
# continue to the next node if it is a leaf node
148+
if leaf_id[sample_id] == node_id:
110149
continue
111150

151+
# check if value of the split feature for sample 0 is below threshold
112152
if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
113153
threshold_sign = "<="
114154
else:
115155
threshold_sign = ">"
116156

117-
print("decision id node %s : (X_test[%s, %s] (= %s) %s %s)"
118-
% (node_id,
119-
sample_id,
120-
feature[node_id],
121-
X_test[sample_id, feature[node_id]],
122-
threshold_sign,
123-
threshold[node_id]))
157+
print("decision node {node} : (X_test[{sample}, {feature}] = {value}) "
158+
"{inequality} {threshold})".format(
159+
node=node_id,
160+
sample=sample_id,
161+
feature=feature[node_id],
162+
value=X_test[sample_id, feature[node_id]],
163+
inequality=threshold_sign,
164+
threshold=threshold[node_id]))
165+
166+
##############################################################################
167+
# For a group of samples, we can determine the common nodes the samples go
168+
# through.
124169

125-
# For a group of samples, we have the following common node.
126170
sample_ids = [0, 1]
171+
# boolean array indicating the nodes both samples go through
127172
common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
128173
len(sample_ids))
129-
174+
# obtain node ids using position in array
130175
common_node_id = np.arange(n_nodes)[common_nodes]
131176

132-
print("\nThe following samples %s share the node %s in the tree"
133-
% (sample_ids, common_node_id))
134-
print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))
177+
print("\nThe following samples {samples} share the node(s) {nodes} in the "
178+
"tree.".format(samples=sample_ids, nodes=common_node_id))
179+
print("This is {prop}% of all nodes.".format(
180+
prop=100 * len(common_node_id) / n_nodes))
Binary file not shown.

0 commit comments

Comments
 (0)