pred_wrapper with multinomial classification
gagirob opened this issue · 2 comments
Hi @bgreenwell,
I am using "explain" with categorical dependent variables, predicted with random forests via nestcv.train . I found in one of nestedcv vignettes (https://cran.r-project.org/web/packages/nestedcv/vignettes/nestedcv_shap.html) that with multinomial classification I need to use pred_train_class1, pred_train_class2, etc as pred_wrapper. This works for pred_train_class1, pred_train_class2 and pred_train_class3, but not for pred_train_class4 and the following ones (I have 8 categories in total).
This is the code I am using:
ctrl <- trainControl(method = "cv", number = n_inner_folds, seeds = seeds, classProbs = TRUE, summaryFunction = mnLogLoss, allowParallel = F)
ncv_boruta <- nestedcv::nestcv.train(y = response, x = data, method = "rf", savePredictions = "final", n_outer_folds = n_outer_folds, outer_train_predict = T, n_inner_folds = n_inner_folds, filterFUN = boruta_filter, filter_options = list(select = c("Confirmed", "Tentative"), maxRuns = maxRuns), cv.cores = n_outer_folds, ntree = ntree, maximize = F, tuneGrid = tg, balance = sampling, trControl = ctrl)nsim<-100
sh<-list()
set.seed(123)
sh[[1]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class1, nsim = nsim)
sh[[2]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class2, nsim = nsim)
sh[[3]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class3, nsim = nsim)
sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)
And this is what I get:
sh[[4]] <- explain(ncv_boruta, X=data, pred_wrapper = pred_train_class4, nsim = nsim)
Error in explain.default(ncv_boruta, X = data, pred_wrapper = pred_train_class4, :
object 'pred_train_class4' not found
Thank you.
Hi @gagirob,
I can answer the question about the missing pred_train_class4
object (I'm the author of the nestedcv
package). The source code for pred_train_class3
is as follows:
pred_train_class3 <- function(x, newdata) {
predict(x, newdata, type="prob")[,3]
}
I provided the first 3 classes as this is a common use case. It's straightforward to make pred_train_class4
as follows:
pred_train_class4 <- function(x, newdata) {
predict(x, newdata, type="prob")[,4]
}
This way you can make the necessary prediction wrappers for your classes 5-8.
Bw, Myles
Thanks for posting a response @myles-lewis!