Visualising the decision tree in sklearn
21,958
Solution 1
I assigned the tree to an object and added plt.show()
. This works for me.
%matplotlib inline
from sklearn import tree
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
cancer = load_breast_cancer()
x = cancer.data
y = cancer.target
clf = DecisionTreeClassifier(max_depth = 1000)
x_train,x_test,y_train,y_test = train_test_split(x,y)
fig = clf.fit(x_train,y_train)
tree.plot_tree(fig)
plt.show()
But I recommend using graphviz
, it's much more flexible.
Solution 2
upgrade sklearn package:
pip install --upgrade sklearn
Solution 3
This is because plot_tree
is new is sklearn
version 0.21, as indicated in the documentation. Check if you have a version sufficient by running this:
import sklearn
print(sklearn.__version__)
assert float(sklearn.__version__[2:]) >= 21, 'sklearn version insufficient.'
If you get an error message, you need to update sklearn
pip install --upgrade sklearn
Related videos on Youtube
Author by
Roshan
Updated on July 17, 2021Comments
-
Roshan almost 3 years
When I want to visualise the tree I got this error.
I have shown the required libraries imported. Is there expected reason with jupiter-notebook ?
from sklearn import tree import matplotlib.pyplot from sklearn.tree import DecisionTreeClassifier from sklearn.model_selection import train_test_split from sklearn.datasets import load_breast_cancer cancer=load_breast_cancer() x=cancer.data y=cancer.target clf=DecisionTreeClassifier(max_depth=1000) x_train,x_test,y_train,y_test=train_test_split(x,y) clf=clf.fit(x_train,y_train) tree.plot_tree(clf.fit(x_train,y_train))
AttributeError: module 'sklearn.tree' has no attribute 'plot_tree'
-
PV8 over 4 yearscheck it out here: scikit-learn.org/stable/modules/tree.html
-
Anna Yashina over 4 yearsmake sure your
matplotlib
version is >= 1.5 and try to save your fit to an object before putting it in the plot function -
Roshan over 4 yearsmatplotlib version is 3.0.3
-
Josmoor98 over 4 years
plot_tree
is new in version 0.21. Maybe check your scikit-learn version -
Roshan over 4 yearsThanks a lot, it was the version of my sklearn
-
-
Anna Yashina over 4 yearscurrent version is 0.21.3 you can upgrade running
!pip install -U scikit-learn
orconda upgrade scikit-learn
if you are using anaconda -
Roshan over 4 yearsThank you so much, I have got it after upgrading
-
Nicolas Gervais about 4 yearsThis is not an answer. You are just saying "I don't have this problem".