Visualising the decision tree in sklearn

21,958

Solution 1

I assigned the tree to an object and added plt.show(). This works for me.

%matplotlib inline
from sklearn import tree
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
x = cancer.data
y = cancer.target
clf = DecisionTreeClassifier(max_depth = 1000)
x_train,x_test,y_train,y_test = train_test_split(x,y)

fig = clf.fit(x_train,y_train)
tree.plot_tree(fig)
plt.show()

But I recommend using graphviz, it's much more flexible.

Solution 2

upgrade sklearn package:

pip install --upgrade sklearn

Solution 3

This is because plot_tree is new is sklearn version 0.21, as indicated in the documentation. Check if you have a version sufficient by running this:

import sklearn

print(sklearn.__version__)

assert float(sklearn.__version__[2:]) >= 21, 'sklearn version insufficient.'

If you get an error message, you need to update sklearn

pip install --upgrade sklearn
Share:
21,958

Related videos on Youtube

Roshan
Author by

Roshan

Updated on July 17, 2021

Comments

  • Roshan
    Roshan almost 3 years

    When I want to visualise the tree I got this error.

    I have shown the required libraries imported. Is there expected reason with jupiter-notebook ?

    from sklearn import tree
    import matplotlib.pyplot
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import train_test_split
    from sklearn.datasets import load_breast_cancer
    cancer=load_breast_cancer()
    x=cancer.data
    y=cancer.target
    clf=DecisionTreeClassifier(max_depth=1000)
    x_train,x_test,y_train,y_test=train_test_split(x,y)
    clf=clf.fit(x_train,y_train)
    tree.plot_tree(clf.fit(x_train,y_train))
    

    AttributeError: module 'sklearn.tree' has no attribute 'plot_tree'

    • PV8
      PV8 over 4 years
    • Anna Yashina
      Anna Yashina over 4 years
      make sure your matplotlib version is >= 1.5 and try to save your fit to an object before putting it in the plot function
    • Roshan
      Roshan over 4 years
      matplotlib version is 3.0.3
    • Josmoor98
      Josmoor98 over 4 years
      plot_tree is new in version 0.21. Maybe check your scikit-learn version
    • Roshan
      Roshan over 4 years
      Thanks a lot, it was the version of my sklearn
  • Anna Yashina
    Anna Yashina over 4 years
    current version is 0.21.3 you can upgrade running !pip install -U scikit-learn or conda upgrade scikit-learn if you are using anaconda
  • Roshan
    Roshan over 4 years
    Thank you so much, I have got it after upgrading
  • Nicolas Gervais
    Nicolas Gervais about 4 years
    This is not an answer. You are just saying "I don't have this problem".