How to display the graphical decision tree for this scikit-learn decision tree script?

733    Asked by GayatriJaiteley in Data Science , Asked on Nov 4, 2019
Answered by Gayatri Jaiteley

from sklearn import tree

clf = tree.DecisionTreeClassifier()

# [height, weight, shoe_size]

X = [[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40],

     [190, 90, 47], [175, 64, 39],

     [177, 70, 40], [159, 55, 37], [171, 75, 42], [181, 85, 43]]

Y = ['male', 'male', 'female', 'female', 'male', 'male', 'female', 'female',

     'female', 'male', 'male']

clf = clf.fit(X, Y)

prediction = clf.predict([[160, 60, 22]])

print(prediction)

The above script works fine but does not display the graphical tree. How to do that?

To display graphical trees, we have to import graphviz and render an image of the tree we are trying to get. Below is an example code which will show the tree

from sklearn import tree

clf = tree.DecisionTreeClassifier()

# [height, weight, shoe_size]

X = [[181, 80, 44], [177, 70, 43], [160, 60, 38], [154, 54, 37], [166, 65, 40],

     [190, 90, 47], [175, 64, 39],

     [177, 70, 40], [159, 55, 37], [171, 75, 42], [181, 85, 43]]

Y = ['male', 'male', 'female', 'female', 'male', 'male', 'female', 'female',

     'female', 'male', 'male']

clf = clf.fit(X, Y)

prediction = clf.predict([[160, 60, 22]])

print(prediction)

import graphviz

dot_data = tree.export_graphviz(clf, out_file=None)

graph = graphviz.Source(dot_data)

graph.render("gender")

The last line of the code will generate a pdf which will display the decision tree



Your Answer

Interviews

Parent Categories