andosa/treeinterpreter

UnboundLocalError of line_shape variable while using ExtraTreeClassifier

karanranawat opened this issue · 0 comments

ISSUE
Following error trace was encountered while running ti._predict_forest function for ExtraTreeClassifier model. This works perfectly for RandomForestClassifier.

image

POTENTIAL CAUSE
Following code block initializes line_shape variable.

# reshape if squeezed into a single float
if len(values.shape) == 0:
values = np.array([values])
if isinstance(model, DecisionTreeRegressor):
biases = np.full(X.shape[0], values[paths[0][0]])
line_shape = X.shape[1]
elif isinstance(model, DecisionTreeClassifier):
# scikit stores category counts, we turn them into probabilities
normalizer = values.sum(axis=1)[:, np.newaxis]
normalizer[normalizer == 0.0] = 1.0
values /= normalizer

I see there are blocks specific to DecisionTreeRegressor and DecisionTreeClassifer while initializing line_shape variable.

Do we need to add something specific to Extra Tree Classifier?