vip package: Variable interactions
HanLum opened this issue · 2 comments
HanLum commented
I'm trying to use vip to plot variable interaction heatmaps using the vivid package (EIX doesn't like my model). Vivid relies on the vint function which no longer exists in vip - is there an alternative way of calculating variable interactions for this purpose? Thank you @bgreenwell
bgreenwell commented
Hi @HanLum, thanks for posting the issue here! We took vint()
out of vip in the last version since it was a bit hacky in its implementation and also seemed a bit out of scope. Nonetheless, here's the original function, which should still work (perhaps with a minor tweak or two). Let me know if you run into issuess!
#' Interaction effects
#'
#' Quantify the strength of two-way interaction effects using a simple
#' \emph{feature importance ranking measure} (FIRM) approach. For details, see
#' \href{https://arxiv.org/abs/1805.04755}{Greenwell et al. (2018)}.
#'
#' @param object A fitted model object (e.g., a \code{"randomForest"} object).
#'
#' @param feature_names Character string giving the names of the two features of
#' interest.
#'
#' @param parallel Logical indicating whether or not to run \code{partial} in
#' parallel using a backend provided by the \code{foreach} package. Default is
#' \code{FALSE}.
#'
#' @param paropts List containing additional options to be passed on to
#' \code{\link[foreach]{foreach}} when \code{parallel = TRUE}.
#'
#' @param ... Additional optional arguments to be passed on to
#' \code{\link[pdp]{partial}}.
#'
#' @details This function quantifies the strength of interaction between
#' features $X_1$ and $X_2$ by measuring the change in variance along slices of
#' the partial dependence of $X_1$ and $X_2$ on the target $Y$. See
#' \href{https://arxiv.org/abs/1805.04755}{Greenwell et al. (2018)} for
#' details and examples.
#'
#' @references
#' Greenwell, B. M., Boehmke, B. C., and McCarthy, A. J.: A Simple
#' and Effective Model-Based Variable Importance Measure. arXiv preprint
#' arXiv:1805.04755 (2018).
#'
#' @examples
#' \dontrun{
#' #
#' # The Friedman 1 benchmark problem
#' #
#'
#' # Load required packages
#' library(gbm)
#' library(ggplot2)
#' library(mlbench)
#'
#' # Simulate training data
#' trn <- vip::gen_friedman(500, seed = 101) # ?vip::gen_friedman
#'
#' #
#' # NOTE: The only interaction that actually occurs in the model from which
#' # these data are generated is between x.1 and x.2!
#' #
#'
#' # Fit a GBM to the training data
#' set.seed(102) # for reproducibility
#' fit <- gbm(y ~ ., data = trn, distribution = "gaussian", n.trees = 1000,
#' interaction.depth = 2, shrinkage = 0.01, bag.fraction = 0.8,
#' cv.folds = 5)
#' best_iter <- gbm.perf(fit, plot.it = FALSE, method = "cv")
#'
#' # Quantify relative interaction strength
#' all_pairs <- combn(paste0("x", 1:10), m = 2)
#' res <- NULL
#' for (i in seq_along(all_pairs[1, ])) {
#' interact <- vint(fit, feature_names = all_pairs[, i], n.trees = best_iter)
#' res <- rbind(res, interact)
#' }
#'
#' # Plot top 20 results
#' top_20 <- res[1L:20L, ]
#' ggplot(top_20, aes(x = reorder(Variables, Interaction), y = Interaction)) +
#' geom_col() +
#' coord_flip() +
#' xlab("") +
#' ylab("Interaction strength")
#' }
vint <- function(object, feature_names, progress = "none", parallel = FALSE,
paropts = NULL, ...) {
# warning("This function is experimental, use at your own risk!", call. = FALSE)
# FIXME: Should we force `chull = FALSE` in the call to `pdp::partial()`?
all.pairs <- utils::combn(feature_names, m = 2)
ints <- plyr::aaply(
all.pairs, .margins = 2, .progress = progress, .parallel = parallel,
.paropts = paropts,
.fun = function(x) {
pd <- pdp::partial(object, pred.var = x, ...)
mean(c(
stats::sd(tapply(pd$yhat, INDEX = pd[[x[1L]]], FUN = stats::sd)),
stats::sd(tapply(pd$yhat, INDEX = pd[[x[2L]]], FUN = stats::sd))
))
})
ints <- data.frame(
"Variables" = paste0(all.pairs[1L, ], "*", all.pairs[2L, ]),
"Interaction" = ints
)
ints <- ints[order(ints["Interaction"], decreasing = TRUE), ]
tibble::as_tibble(ints)
}
Quick example:
# Load required packages
library(gbm)
library(ggplot2)
library(mlbench)
# Simulate training data
trn <- vip::gen_friedman(500, seed = 101) # ?vip::gen_friedman
#
# NOTE: The only interaction that actually occurs in the model from which
# these data are generated is between x.1 and x.2!
#
# Fit a GBM to the training data
set.seed(102) # for reproducibility
fit <- gbm(y ~ ., data = trn, distribution = "gaussian", n.trees = 1000,
interaction.depth = 2, shrinkage = 0.01, bag.fraction = 0.8,
cv.folds = 5)
best_iter <- gbm.perf(fit, plot.it = FALSE, method = "cv")
# Quantify relative interaction strength
all_pairs <- combn(paste0("x", 1:10), m = 2)
res <- NULL
for (i in seq_along(all_pairs[1, ])) {
interact <- vint(fit, feature_names = all_pairs[, i], n.trees = best_iter)
res <- rbind(res, interact)
}
print(res)
# # A tibble: 45 × 2
# Variables Interaction
# <chr> <dbl>
# 1 x1*x2 4.98e- 1
# 2 x1*x3 3.60e- 3
# 3 x1*x4 8.61e- 2
# 4 x1*x5 7.69e- 2
# 5 x1*x6 0
# 6 x1*x7 0
# 7 x1*x8 0
# 8 x1*x9 0
# 9 x1*x10 1.06e-16
# 10 x2*x3 6.94e- 3
# # ℹ 35 more rows
# # ℹ Use `print(n = ...)` to see more rows
HanLum commented
Perfect, thank you so much @bgreenwell! 😊