grf-labs/policytree

Hybrid policy tree prints different leaf labels than predicted leaf labels

Closed this issue · 4 comments

Description of the bug
Hybrid policy tree grows redundant leaf-nodes where no training data falls into. Also, node prediction returns non-leaf nodes.

Steps to reproduce

library(policytree)
set.seed(2)
n <- 25
p <- 5
d <- 2
X <- matrix(runif(n * p), n, p)
Y <- matrix(rnorm(n * d), n, d)
tree <- hybrid_policy_tree(X, Y)
tree
# policy_tree object 
# Tree depth:  3 
# Actions:  1 2 
# Variable splits: 
# (1) split_variable: X1  split_value: 0.405282 
#   (2) split_variable: X2  split_value: 0.164642 
#     (4) split_variable: X1  split_value: 0.225825 
#       (6) * action: 1 
#       (7) * action: 2 
#     (5) split_variable: X2  split_value: 0.667226 
#       (8) * action: 2 
#       (9) * action: 1 
#   (3) split_variable: X3  split_value: 0.613953 
#     (10) split_variable: X3  split_value: 0.275701 
#       (12) * action: 1 
#       (13) * action: 2 
#     (11) split_variable: X5  split_value: 0.453377 
#       (14) * action: 1 
#       (15) * action: 2
preds <- predict(tree, X, type = "node.id")
table(preds)
# 8  9 10 11 12 13 14 15 
# 2  1  3  5  5  3  4  2 

When predicting,

  • no training data falls into leaves 6 and 7
  • non-leaf nodes 10 and 11 are predicted

policytree version

> packageVersion("policytree")
[1] ‘1.2.0

Thanks for reporting this @jarkki. Ideally there's a representation indexing bug somewhere which should only affect leaf id, point predictions should be correct per this test, I won't have time to look closer until a bit later. Thanks again.

So yes this was only hybrid policytree using different terminal node labels than what was printed. If you only care about subgroups, preds <- predict(tree, X, type = "node.id") will work correctly.

The printed label / predicted label mismatch is fixed in #153.

Thank you so much for fixing this! My use case was to represent the tree in a table format like

leaf.id rules CATE N obs

Steps to build the table were

  1. Predict leaves and calculate local CATE for each leaf
  2. Travel the tree and gather all splitting rules that lead to a leaf
  3. Join rules and CATE by leaf index

The leaf indices now match and this works as expected.

The library is really amazing, thanks for developing it!

Just a heads up @jarkki, I changed the fix to this in #156 by instead just renaming the hybrid tree's leaf id's s.t. they are level-first, which is the same as rest of policytree. Now policy_tree and hybrid_policy_tree will give you the same node "labels".