catboost and predict threshold
Opened this issue · 1 comments
pecto2020 commented
Hi,
predict(catboost) in tidymodels doesn't use the default threshold of 0.5 but something else. Does catboost use a class_weight during the training process? In that case how do I change it in tidymodels/treesnip? I attach a comparison between catboost and random forest.
Thanks
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
library(mlbench)
library(catboost)
library(treesnip)
set_dependency("boost_tree", eng = "catboost", "catboost")
set_dependency("boost_tree", eng = "catboost", "treesnip")
#load data
data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes
#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)
#train Random Forest
# model specification
trees_spec<-rand_forest()%>%
set_mode("classification") %>%
set_engine("ranger")
# fit on training data
trees_fit<-trees_spec %>% fit(diabetes~., data=diabetes_train)
# predict
trees_pred<-predict(trees_fit, diabetes_test)%>%
bind_cols(predict(trees_fit,diabetes_test, type="prob"))%>%
bind_cols(diabetes_test%>% select(diabetes))
# get metrics
trees_perf<- trees_pred %>%
roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
bind_rows(trees_pred %>% sens(trut = diabetes, .pred_class, event_levels="second"))
# change threshold
trees_05<-trees_pred %>%
mutate(
.pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
mutate_if(is.character, as.factor)
# get metrics
trees_perf_05<-trees_05%>%
roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
bind_rows( trees_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
trees_perf
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.823
#> 2 sens binary 0.856
trees_perf_05
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.823
#> 2 sens binary 0.856
#train Catboost
# model specification
catboost_spec<-(boost_tree(tree_depth=10) %>%
set_mode("classification") %>%
set_engine("catboost", nthread=4))
# fit on training data
catboost_fit<-catboost_spec %>% fit(diabetes~., data=diabetes_train)
# predict
catboost_pred<-predict(catboost_fit, diabetes_test) %>%
bind_cols(predict(catboost_fit,diabetes_test, type="prob"))%>%
bind_cols(diabetes_test%>% select(diabetes))
# get metrics
catboost_perf<- catboost_pred %>%
roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
bind_rows(catboost_pred %>% sens(truth = diabetes, .pred_class, event_levels="second"))
# change threshold
catboost_05<-catboost_pred %>%
mutate(
.pred_class = ifelse(.pred_pos>0.5,"pos","neg"))%>%
mutate_if(is.character, as.factor)
# get metrics
catboost_perf_05<-catboost_05%>%
roc_auc(truth = diabetes, .pred_pos, event_level="second") %>%
bind_rows(catboost_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
catboost_perf
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.801
#> 2 sens binary 1
catboost_perf_05
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.801
#> 2 sens binary 0.992
Created on 2022-02-02 by the reprex package (v2.0.1)
pecto2020 commented
Notably, using catboost with caret seems to work
library(mlbench)
library(catboost)
library(caret)
#> Loading required package: ggplot2
#> Loading required package: lattice
library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#> method from
#> required_pkgs.model_spec parsnip
data(PimaIndiansDiabetes)
diabetes_orig<-PimaIndiansDiabetes
#set random seed
set.seed(123)
#create initial split
diabetes_split <- initial_split(diabetes_orig, prop = 3/4)
diabetes_split
#> <Analysis/Assess/Total>
#> <576/192/768>
#create training set
diabetes_train <- training(diabetes_split)
#create test set
diabetes_test <- testing(diabetes_split)
fitControl <- trainControl(method = "cv",
number = 3,
savePredictions = TRUE,
summaryFunction = twoClassSummary,
classProbs = TRUE)
model <- train(x = diabetes_train %>% select(-diabetes),
y = diabetes_train$diabetes,
method = catboost.caret,
trControl = fitControl,
tuneLength = 3,
metric = "ROC")
preds1<-predict(model, diabetes_test) %>% as_tibble() %>% mutate(.pred_class = value, .keep="unused") %>%
bind_cols(predict(model,diabetes_test, type="prob")) %>%
bind_cols(diabetes_test %>% select(diabetes))
preds1%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
bind_rows( preds1 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.821
#> 2 sens binary 0.848
preds1_05<-preds1 %>% mutate(
.pred_class = ifelse(pos>0.5,"pos","neg"))%>%
mutate_if(is.character, as.factor)
preds1_05%>% roc_auc(truth = diabetes,pos, event_level="second") %>%
bind_rows( preds1_05 %>% sens(truth = diabetes, .pred_class, event_levels="second"))
#> # A tibble: 2 x 3
#> .metric .estimator .estimate
#> <chr> <chr> <dbl>
#> 1 roc_auc binary 0.821
#> 2 sens binary 0.848
Created on 2022-02-02 by the reprex package (v2.0.1)