Iyar Lin 16 July, 2020
The segmenTree
package implements a decision tree based algorithm for
exploration of heterogeneous treatment effects in Randomized Controlled
Trial data. It supports continuous and binary response variable, and
binary treatments. It uses the rpart package as the back-end with a
predefined method on top.
You can install and load the package from github using:
pacman::p_load_gh("IyarLin/segmenTree")
Let’s generate a dataset with a binary response:
set.seed(1) # vary seed, n and effect_size below to get a sense of the model performance sensetivity
effect_size <- 0.25
p_x <- function(Tr, X1, X2, X3){
lp <- 2*X1 + 0.2*X2 + as.numeric(X3)/6
effect_size + (effect_size/0.125)*Tr*X1^3 +
(1 - 2*effect_size)*exp(lp)/(1+exp(lp))
}
n <- 10000
Tr <- rbinom(n, 1, 0.3)
X1 <- runif(n, -0.5, 0.5)
X2 <- rnorm(n)
X3 <- factor(sample(LETTERS[1:3], size = n, replace = T))
p <- p_x(Tr, X1, X2, X3)
y <- sapply(p, function(x) rbinom(1, 1, x))
y_mat <- cbind(y, Tr)
dat <- data.frame(y = I(y_mat), X1, X2, X3)
Next, let’s import the lift method and use rpart to fit a segment tree:
lift_method <- import_lift_method()
segment_tree <- rpart(y ~ ., data = dat,
method = lift_method,
control = rpart.control(cp = 0, minbucket = 1000),
x = T)
One way of exploring the resulting tree is by printing the rpart object:
segment_tree
## n= 10000
##
## node), split, n, deviance, yval
## * denotes terminal node
##
## 1) root 10000 1.000000e+16 -0.001051900
## 2) X1< -0.4005301 1003 1.012054e+12 -0.150025300 *
## 3) X1>=-0.4005301 8997 6.552256e+15 0.013425090
## 6) X1< 0.4016537 7983 4.061295e+15 -0.002719803
## 12) X1< 0.2939269 6939 2.318396e+15 -0.019487100
## 24) X1< -0.2554171 1467 4.631487e+12 -0.078092060 *
## 25) X1>=-0.2554171 5472 8.965703e+14 -0.004552848
## 50) X1< 0.1890516 4430 3.851367e+14 -0.014668450
## 100) X3=B,C 2992 8.013945e+13 0.008535898 *
## 101) X3=A 1438 4.275979e+12 -0.060894750 *
## 51) X1>=0.1890516 1042 1.178883e+12 0.043870860 *
## 13) X1>=0.2939269 1044 1.187960e+12 0.096902740 *
## 7) X1>=0.4016537 1014 1.057187e+12 0.152712600 *
The segmenTree
package also contains a utility function to extract the
resulting segments:
segments <- extract_segments(segment_tree, alpha = 0.15)
print(segments)
## segment n lift lift_lower lift_upper
## 7 X1>=0.4017,X1<0.499855161411688 1014 0.152712621 0.109679301 0.19574594
## 6 X1>=0.2939,X1<0.4017 1044 0.096902744 0.052632159 0.14117333
## 5 X1>=0.1891,X1<0.2939 1042 0.043870858 -0.004383936 0.09212565
## 3 X1>=-0.2554,X1<0.1891, X3={B,C} 2992 0.008535898 -0.019895457 0.03696725
## 4 X1>=-0.2554,X1<0.1891, X3={A} 1438 -0.060894751 -0.101565692 -0.02022381
## 2 X1>=-0.4005,X1<-0.2554 1467 -0.078092056 -0.118100169 -0.03808394
## 1 X1>=-0.49972922494635,X1<-0.4005 1003 -0.150025305 -0.196506492 -0.10354412
Next, let’s compare how the segment tree predictions (black) compare with the true treatment effect (red):
tau <- predict(segment_tree, dat)
p_treat <- p_x(rep(1, n), X1, X2, X3)
p_cont <- p_x(rep(0, n), X1, X2, X3)
cate <- p_treat - p_cont
y_lim <- c(min(tau, cate), max(tau, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "segmenTree",
xlab = "X1", ylab = "true (red) vs predicted (black) lift")
points(dat$X1, cate, col = "red")
points(dat$X1, tau)
We can see that the algorithm predicts well most of the curve, but did some over-fitting in the middle section.
The segmenTree::prune_tree
function uses cross validation to find the
optimal cp
hyper parameter value.
Below we can see the effective lift and error as a function of the cp
parameter:
optimal_cp_cv <- prune_tree(segment_tree)
optimal_cp_cv <- optimal_cp_cv$optimal_cp
pruned_segment_tree <- prune(segment_tree, cp = optimal_cp_cv)
The optimal cp
maximizes the ratio between effective lift and error.
Let’s check the new pruned tree predictions:
tau <- predict(pruned_segment_tree, dat)
y_lim <- c(min(tau, cate), max(tau, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "segmenTree",
xlab = "X1", ylab = "true (red) vs predicted (black) lift")
points(dat$X1, cate, col = "red")
points(dat$X1, tau)
Looks like the tree got pruned a little too much.
One other way we can reduce over-fitting is by adjusting how the leaf size is weighted.
lift_method <- import_lift_method(f_n = function(x) x)
weighted_segment_tree <- rpart(y ~ ., data = dat,
method = lift_method,
control = rpart.control(cp = 0, minbucket = 1000),
x = T)
Let’s see the resulting predictions:
tau <- predict(weighted_segment_tree, dat)
y_lim <- c(min(tau, cate), max(tau, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "segmenTree",
xlab = "X1", ylab = "true (red) vs predicted (black) lift")
points(dat$X1, cate, col = "red")
points(dat$X1, tau)
To demonstrate why we even need a specialized algorithm we’ll compare the segmenTree model with 2 other approaches that use plain ML models:
- Model jointly the treatment and covariates. Predict: tau(x)=f(x,T=1)-f(x,T=0)
- Train a model on the treatment units f_{T=1} and a separate model f_{T=0} on the control units and predict: tau(x) = f_{T=1}(x) - f_{T=0}(x). This is also called “Two model approach”.
We’ll use a decision tree in both of the above.
par(mfrow = c(1, 3))
# segmenTree pruned model
tau <- predict(pruned_segment_tree, dat)
p_treat <- p_x(rep(1, n), X1, X2, X3)
p_cont <- p_x(rep(0, n), X1, X2, X3)
cate <- p_treat - p_cont
y_lim <- c(min(tau, cate), max(tau, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "segmenTree",
xlab = "X1", ylab = "true (red) vs predicted (black) lift")
points(dat$X1, cate, col = "red")
points(dat$X1, tau)
# Model jointly the treatment and covariates
dat2 <- data.frame(y, X1, X2, X3, Tr)
fit2 <- rpart(y ~ ., data = dat2)
dat2_treat <- dat2; dat2_cont <- dat2
dat2_treat$Tr <- 1L; dat2_cont$Tr <- 0L
tau2 <- predict(fit2, dat2_treat) - predict(fit2, dat2_cont)
y_lim <- c(min(tau2, cate), max(tau2, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "regular model",
xlab = "X1", ylab = "")
points(dat$X1, cate, col = "red")
points(dat$X1, tau2)
# Train a model on the treatment units and a seperate model on the control units
dat3_treat <- dat2[dat2$Tr == 1, -5]
dat3_cont <- dat2[dat2$Tr == 0, -5]
fit3_treat <- rpart(y ~ ., data = dat3_treat)
fit3_cont <- rpart(y ~ ., data = dat3_cont)
dat2_treat <- dat2; dat2_cont <- dat2
tau3 <- predict(fit3_treat, dat2) - predict(fit3_cont, dat2)
y_lim <- c(min(tau3, cate), max(tau3, cate))
plot(c(min(dat$X1), max(dat$X1)), y_lim, type = "n", main = "Two models",
xlab = "X1", ylab = "")
points(dat$X1, cate, col = "red")
points(dat$X1, tau3)