parrt/dtreeviz

DTreeViz crashes, if descision tree was built with objects, that are interpretable as numbers.

Opened this issue · 0 comments

sklearn.tree.DecisionTreeClassifer casts the parameter X into dtype=np.float23 (see Documentation), therefore it works with the data provided in the example.

But DTreeViz does not and crashes in a call of np.linspace in the function get_split_node_heights().

Example

#!/usr/bin/env python

import pandas as pd
from sklearn.tree import DecisionTreeClassifier, export_text
import dtreeviz


X = pd.DataFrame({'feature_1' : [10, 2, 5], 'feature_2': ['1', '1', '4']})
y = pd.Series([0, 0, 1], name='label')

# X, y = make_blobs(n_samples=10, n_features=2, centers=3)

d = DecisionTreeClassifier()
d.fit(X, y)

print(export_text(d))

dtreeviz_model = dtreeviz.model(d, X_train=X, y_train=y)

dtreeviz_render = dtreeviz_model.view()  # this will crash