displaying scikit decision tree figure in jupyter notebook

13,197

Solution 1

There is a simple library called graphviz which you can use to view your decision tree. In this you don't have to export the graphic, it'll directly open the graphic of tree for you and you can later decide if you want to save it or not. You can use it as following -

import graphviz
from sklearn.tree import DecisionTreeClassifier()
from sklearn import tree

clf = DecisionTreeClassifier()
clf.fit(trainX,trainY)
columns=list(trainX.columns)
dot_data = tree.export_graphviz(clf,out_file=None,feature_names=columns,class_names=True)
graph = graphviz.Source(dot_data)
graph.render("image",view=True)
f = open("classifiers/classifier.txt","w+")
f.write(dot_data)
f.close()

because of view = True your graphs will open up as soon as they're rendered but if you don't want that and just want to save graphs, you can use view = False

Hope this helps

Solution 2

You can show the tree directly using IPython.display:

import graphviz
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier,export_graphviz
from sklearn.datasets import make_regression

# Generate a simple dataset
X, y = make_regression(n_features=2, n_informative=2, random_state=0)
clf = DecisionTreeRegressor(random_state=0, max_depth=2)
clf.fit(X, y)
# Visualize the tree
from IPython.display import display
display(graphviz.Source(export_graphviz(clf)))

Solution 3

As of scikit-learn version 21.0 (roughly May 2019), Decision Trees can now be plotted with matplotlib using scikit-learn’s tree.plot_tree without relying on graphviz.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier(max_depth = 5)

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

# You can save your plot if you want
#fig.savefig('imagename.png')

Something similar to what is below will output in your jupyter notebook.
enter image description here

The code was adapted from this post.

Solution 4

There are 4 methods which I'm aware of for plotting the scikit-learn decision tree:

  • print the text representation of the tree with sklearn.tree.export_text method
  • plot with sklearn.tree.plot_tree method (matplotlib needed)
  • plot with sklearn.tree.export_graphviz method (graphviz needed)
  • plot with dtreeviz package (dtreeviz and graphviz needed)

You can find a comparison of different visualization of sklearn decision tree with code snippets in this blog post: link.

When using Jupiter notebook, remember to display the variable with plot. Example for dtreeviz:

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(clf, X, y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz # display the tree
Share:
13,197
Jürgen Erhardt
Author by

Jürgen Erhardt

Updated on June 17, 2022

Comments

  • Jürgen Erhardt
    Jürgen Erhardt almost 2 years

    I am currently creating a machine learning jupyter notebook as a small project and wanted to display my decision trees. However, all options I can find are to export the graphics and then load a picture, which is rather complicated.

    Therefore, I wanted to ask whether there is a way to display my decision trees directly without exporting and loading graphics.

  • Laksh Matai
    Laksh Matai over 5 years
    I'm glad I could help you :)