tensorflow/decision-forests

tfdf.model_plotter.plot_model() is broken for GradientBoostedTreesModel and CartModel

fhossfel opened this issue · 5 comments

I am using tfdf 0.2.4 and can successfully train a model and plot it using the plot_model() function.

model = tfdf.keras.RandomForestModel()
model.fit(train_ds)
model.compile(metrics=["accuracy"])
evaluation = model.evaluate(test_ds)
with open("model.html", "w") as html_file:
    html_file.write(tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10))

For my current task I get a decision tree graph consisting of two decision nodes and tree outputs. The key line in the generated HTML file seems to be this one:

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.006622516556291391, 0.695364238410596, 0.2781456953642384, 0.019867549668874173], "num_examples": 151.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 42.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.009174311926605505, 0.963302752293578, 0.0, 0.027522935779816515], "num_examples": 109.0}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "height", "threshold": 42.0}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 1.0, 0.0, 0.0], "num_examples": 103.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.16666666666666666, 0.3333333333333333, 0.0, 0.5], "num_examples": 6.0}}]}]}, "#tree_plot_24de9183c1d54e6b8c963d372b714bc0")

If I use exactly the same code but replace the RandomForestModelwith a GradientBoostedTreesModelI only get one decision and two outputs:

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "REGRESSION", "value": -0.09703703969717026, "num_examples": 135.0, "standard_deviation": 0.08574694002066838}, "condition": {"type": "NUMERICAL_IS_HIGHER_THAN", "attribute": "length", "threshold": 227.0}, "children": [{"value": {"type": "REGRESSION", "value": -0.020000001415610313, "num_examples": 5.0, "standard_deviation": 0.4}}, {"value": {"type": "REGRESSION", "value": -0.10000000149011612, "num_examples": 130.0, "standard_deviation": 0.0}}]}, "#tree_plot_73421ac8ea9a47a88761b7441afab47c")

This can't be right since the inferences of the GradientBoostedTreesModelare perfect (100% correct, thanks!) and that requires to take more features into account that the length od the classified object. Additionally

The model summary is below. (I have replaced some sensitive feature names). I am not really an expert but if I read the summary correctly than the decision tree should have a depth of 5 and 26 to 27 nodes. On the other hand I would have expected more noees to show for the RandomForestModel, too. ¯_(ツ)_/¯

If there is any additional information I can provide please let me know.

Model: "gradient_boosted_trees_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
=================================================================
Total params: 1
Trainable params: 0
Non-trainable params: 1
_________________________________________________________________
Type: "GRADIENT_BOOSTED_TREES"
Task: CLASSIFICATION
Label: "__LABEL"

Input Features (11):
	parcel_count
	ft_ot_text
	girth
	height
	length
	product_group
	tipping_risk
	shipping_mode
	volume
	weight
	width

No weights

Variable Importance: MEAN_MIN_DEPTH:
    1.             "parcel_count"  3.890381 ################
    2.            "__LABEL"  3.890381 ################
    3. "shipping_mode"  3.889688 ###############
    4.           "girth"  3.541476 #############
    5.         "ft_ot_text"  3.516287 #############
    6.              "width"  3.287958 ###########
    7.             "volume"  3.184331 ##########
    8.             "length"  3.039927 #########
    9.             "height"  2.885094 ########
   10.       "tipping_risk"  2.538273 ######
   11.             "weight"  2.267620 ####
   12.           "product_group"  1.719362 

Variable Importance: NUM_AS_ROOT:
    1.     "product_group" 616.000000 ################
    2.       "height" 183.000000 ####
    3.       "weight" 172.000000 ####
    4.       "length" 117.000000 ##
    5.        "width" 47.000000 
    6.       "volume" 41.000000 
    7. "tipping_risk" 24.000000 

Variable Importance: NUM_NODES:
    1.             "weight" 2592.000000 ################
    2.       "tipping_risk" 2367.000000 ##############
    3.             "volume" 1318.000000 ########
    4.             "height" 1271.000000 #######
    5.           "product_group" 1195.000000 #######
    6.              "width" 1062.000000 ######
    7.           "girth" 968.000000 #####
    8.         "ft_ot_text" 730.000000 ####
    9.             "length" 689.000000 ####
   10. "shipping_mode"  5.000000 

Variable Importance: SUM_SCORE:
    1.           "product_group" 212.827222 ################
    2.             "height" 17.159601 #
    3.             "weight"  3.552953 
    4.       "tipping_risk"  2.266512 
    5.             "length"  1.447021 
    6.             "volume"  0.999544 
    7.           "girth"  0.891605 
    8.              "width"  0.525099 
    9.         "ft_ot_text"  0.106717 
   10. "shipping_mode"  0.000000 



Loss: MULTINOMIAL_LOG_LIKELIHOOD
Validation loss value: 2.87221e-06
Number of trees per iteration: 4
Node format: NOT_SET
Number of trees: 1200
Total number of nodes: 25594

Number of nodes by tree:
Count: 1200 Average: 21.3283 StdDev: 3.18991
Min: 3 Max: 27 Ignored: 0
----------------------------------------------
[  3,  4)   2   0.17%   0.17%
[  4,  5)   0   0.00%   0.17%
[  5,  6)   2   0.17%   0.33%
[  6,  8)   0   0.00%   0.33%
[  8,  9)   0   0.00%   0.33%
[  9, 10)   0   0.00%   0.33%
[ 10, 11)   0   0.00%   0.33%
[ 11, 13)   8   0.67%   1.00%
[ 13, 14)  12   1.00%   2.00%
[ 14, 15)   0   0.00%   2.00%
[ 15, 16)  21   1.75%   3.75% #
[ 16, 18)  73   6.08%   9.83% ##
[ 18, 19)   0   0.00%   9.83%
[ 19, 20) 262  21.83%  31.67% #######
[ 20, 21)   0   0.00%  31.67%
[ 21, 23) 372  31.00%  62.67% ##########
[ 23, 24) 199  16.58%  79.25% #####
[ 24, 25)   0   0.00%  79.25%
[ 25, 26) 156  13.00%  92.25% ####
[ 26, 27]  93   7.75% 100.00% ###

Depth by leafs:
Count: 13397 Average: 3.9155 StdDev: 1.0663
Min: 1 Max: 5 Ignored: 0
----------------------------------------------
[ 1, 2)  178   1.33%   1.33%
[ 2, 3) 1354  10.11%  11.44% ###
[ 3, 4) 3100  23.14%  34.57% ######
[ 4, 5) 3555  26.54%  61.11% #######
[ 5, 5] 5210  38.89% 100.00% ##########

Number of training obs by leaf:
Count: 13397 Average: 12.0923 StdDev: 18.5167
Min: 5 Max: 130 Ignored: 0
----------------------------------------------
[   5,  11) 11675  87.15%  87.15% ##########
[  11,  17)   419   3.13%  90.27%
[  17,  23)    42   0.31%  90.59%
[  23,  30)    40   0.30%  90.89%
[  30,  36)    63   0.47%  91.36%
[  36,  42)     7   0.05%  91.41%
[  42,  49)     1   0.01%  91.42%
[  49,  55)    40   0.30%  91.71%
[  55,  61)   158   1.18%  92.89%
[  61,  68)   320   2.39%  95.28%
[  68,  74)    53   0.40%  95.68%
[  74,  80)   306   2.28%  97.96%
[  80,  86)   226   1.69%  99.65%
[  86,  93)    27   0.20%  99.85%
[  93,  99)    16   0.12%  99.97%
[  99, 105)     2   0.01%  99.99%
[ 105, 112)     1   0.01%  99.99%
[ 112, 118)     0   0.00%  99.99%
[ 118, 124)     0   0.00%  99.99%
[ 124, 130]     1   0.01% 100.00%

Attribute in nodes:
	2592 : weight [NUMERICAL]
	2367 : tipping_risk [NUMERICAL]
	1318 : volume [NUMERICAL]
	1271 : height [NUMERICAL]
	1195 : product_group [CATEGORICAL]
	1062 : width [NUMERICAL]
	968 : girth [NUMERICAL]
	730 : ft_ot_text [CATEGORICAL]
	689 : length [NUMERICAL]
	5 : shipping_mode [CATEGORICAL]

Attribute in nodes with depth <= 0:
	616 : product_group [CATEGORICAL]
	183 : height [NUMERICAL]
	172 : weight [NUMERICAL]
	117 : length [NUMERICAL]
	47 : width [NUMERICAL]
	41 : volume [NUMERICAL]
	24 : tipping_risk [NUMERICAL]

Attribute in nodes with depth <= 1:
	709 : weight [NUMERICAL]
	627 : product_group [CATEGORICAL]
	468 : height [NUMERICAL]
	457 : tipping_risk [NUMERICAL]
	378 : length [NUMERICAL]
	314 : volume [NUMERICAL]
	218 : width [NUMERICAL]
	156 : girth [NUMERICAL]
	95 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 2:
	1550 : weight [NUMERICAL]
	1225 : tipping_risk [NUMERICAL]
	767 : volume [NUMERICAL]
	741 : product_group [CATEGORICAL]
	675 : height [NUMERICAL]
	479 : length [NUMERICAL]
	437 : width [NUMERICAL]
	361 : girth [NUMERICAL]
	277 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 3:
	2342 : weight [NUMERICAL]
	1860 : tipping_risk [NUMERICAL]
	1077 : volume [NUMERICAL]
	937 : product_group [CATEGORICAL]
	927 : height [NUMERICAL]
	778 : girth [NUMERICAL]
	734 : width [NUMERICAL]
	601 : length [NUMERICAL]
	336 : ft_ot_text [CATEGORICAL]

Attribute in nodes with depth <= 5:
	2592 : weight [NUMERICAL]
	2367 : tipping_risk [NUMERICAL]
	1318 : volume [NUMERICAL]
	1271 : height [NUMERICAL]
	1195 : product_group [CATEGORICAL]
	1062 : width [NUMERICAL]
	968 : girth [NUMERICAL]
	730 : ft_ot_text [CATEGORICAL]
	689 : length [NUMERICAL]
	5 : shipping_mode [CATEGORICAL]

Condition type in nodes:
	10267 : HigherCondition
	1930 : ContainsBitmapCondition
Condition type in nodes with depth <= 0:
	616 : ContainsBitmapCondition
	584 : HigherCondition
Condition type in nodes with depth <= 1:
	2700 : HigherCondition
	722 : ContainsBitmapCondition
Condition type in nodes with depth <= 2:
	5494 : HigherCondition
	1018 : ContainsBitmapCondition
Condition type in nodes with depth <= 3:
	8319 : HigherCondition
	1273 : ContainsBitmapCondition
Condition type in nodes with depth <= 5:
	10267 : HigherCondition
	1930 : ContainsBitmapCondition

None

CartModel has a similar problem of showing only one decision but at least the mouseover is working.

display_tree({"margin": 10, "node_x_size": 160, "node_y_size": 28, "node_x_offset": 180, "node_y_offset": 33, "font_size": 10, "edge_rounding": 20, "node_padding": 2, "show_plot_bounding_box": false}, {"value": {"type": "PROBABILITY", "distribution": [0.007407407407407408, 0.7333333333333333, 0.24444444444444444, 0.014814814814814815], "num_examples": 135.0}, "condition": {"type": "CATEGORICAL_IS_IN", "attribute": "product_Group", "mask": ["DIY"]}, "children": [{"value": {"type": "PROBABILITY", "distribution": [0.0, 0.0, 1.0, 0.0], "num_examples": 33.0}}, {"value": {"type": "PROBABILITY", "distribution": [0.00980392156862745, 0.9705882352941176, 0.0, 0.0196078431372549], "num_examples": 102.0}}]}, "#tree_plot_e7010c332612435caae222c9a1230050")
rstz commented

Hi,
I'm not sure I correctly understand the problem just yet, but let me summarize what I think is going on.

The GradientBoostedTrees model you're building has Number of trees: 1200 i.e. it consists of 1200 trees. You inspect the first tree of this collection using tfdf.model_plotter.plot_model(model, tree_idx=0, max_depth=10) (this is what tree_idx does). This tree alone might not be great, but this is expected - all 1200 trees together give great performance, not a single tree.

For CART, there is indeed just a single tree - but for most problems, CART models do not perform as well as Random Forests or Gradient Boosted Trees.

Ahh, okay. Did not read the manual properly and misinterpreted the tree_idx parameter.

I had noticed that the missing class distribution bars are for the gradient boosted trees. Is that intentional?

rstz commented

Can you please clarify what you mean with "missing class distribution bars"?

rstz commented

Closing this as stale