bgreenwell/fastshap

pred_wrapper with multinomial classification

gagirob opened this issue · 2 comments

Hi @bgreenwell,

I am using "explain" with categorical dependent variables, predicted with random forests via nestcv.train . I found in one of nestedcv vignettes (https://cran.r-project.org/web/packages/nestedcv/vignettes/nestedcv_shap.html) that with multinomial classification I need to use pred_train_class1, pred_train_class2, etc as pred_wrapper. This works for pred_train_class1, pred_train_class2 and pred_train_class3, but not for pred_train_class4 and the following ones (I have 8 categories in total).

This is the code I am using:

ctrl <- trainControl(method = "cv", number = n_inner_folds, seeds = seeds, classProbs = TRUE, summaryFunction = mnLogLoss, allowParallel = F)
ncv_boruta <- nestedcv::nestcv.train(y = response, x = data, method = "rf", savePredictions = "final", n_outer_folds = n_outer_folds, outer_train_predict = T, n_inner_folds = n_inner_folds, filterFUN = boruta_filter, filter_options = list(select = c("Confirmed", "Tentative"), maxRuns = maxRuns), cv.cores = n_outer_folds, ntree = ntree, maximize = F, tuneGrid = tg, balance = sampling, trControl = ctrl)

nsim<-100
sh<-list()
set.seed(123)
sh[[1]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class1, nsim = nsim)
sh[[2]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class2, nsim = nsim)
sh[[3]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class3, nsim = nsim)
sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)

And this is what I get:

sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)
Error in explain.default(ncv_boruta, X = data, pred_wrapper = pred_train_class4, :
object 'pred_train_class4' not found

Thank you.

Hi @gagirob,

I can answer the question about the missing pred_train_class4 object (I'm the author of the nestedcv package). The source code for pred_train_class3 is as follows:

pred_train_class3 <- function(x, newdata) {
  predict(x, newdata, type="prob")[,3]
}

I provided the first 3 classes as this is a common use case. It's straightforward to make pred_train_class4 as follows:

pred_train_class4 <- function(x, newdata) {
  predict(x, newdata, type="prob")[,4]
}

This way you can make the necessary prediction wrappers for your classes 5-8.

Bw, Myles

Thanks for posting a response @myles-lewis!