DTreeViz crashes, if descision tree was built with objects, that are interpretable as numbers.
lgi1sgm opened this issue · 0 comments
lgi1sgm commented
sklearn.tree.DecisionTreeClassifer
casts the parameter X into dtype=np.float23
(see Documentation), therefore it works with the data provided in the example.
But DTreeViz does not and crashes in a call of np.linspace in the function get_split_node_heights()
.
Example
#!/usr/bin/env python
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text
import dtreeviz
X = pd.DataFrame({'feature_1' : [10, 2, 5], 'feature_2': ['1', '1', '4']})
y = pd.Series([0, 0, 1], name='label')
# X, y = make_blobs(n_samples=10, n_features=2, centers=3)
d = DecisionTreeClassifier()
d.fit(X, y)
print(export_text(d))
dtreeviz_model = dtreeviz.model(d, X_train=X, y_train=y)
dtreeviz_render = dtreeviz_model.view() # this will crash