bug in createFolds where it fails when y is numeric but all values of y are the same
mikeblazanin opened this issue · 0 comments
mikeblazanin commented
- Start a new R session
- Install the latest version of caret:
update.packages(oldPkgs="caret", ask=FALSE)
- Write a minimal reproducible example
- Do not use parallel processing in the code (unless you are certain that the issue is about parallel processing).
- run
sessionInfo()
Minimal, reproducible example:
custommethod <-
list(library = NULL, type = "Regression", prob = NULL,
fit = function(x, y, wts, param, lev = NULL, last, weights, classprobs, ...) {return(stats::runmed(x = y, k = param$k))},
parameters = data.frame(parameter = "k", class = "numeric", label = "k"),
grid = function(x, y, len, search) {return(data.frame(k = seq(from = 1, by = 2, length.out = len)))},
predict = function(modelFit, newdata, preProc = NULL, submodels = NULL) {return(rep(NA, length(newdata)))})
library(caret)
set.seed(1)
data <- data.frame(x = 1:100, y = rep(0.05, 100))
train(x = data.frame(x = data$x), y = data$y, method = custommethod, trControl = trainControl(method = "cv"))
Error in cut.default(y, breaks, include.lowest = TRUE) :
invalid number of intervals
It appears this issue is because of this section of createFolds
if(is.numeric(y)) {
cuts <- floor(length(y)/k)
if(cuts < 2) cuts <- 2
if(cuts > 5) cuts <- 5
breaks <- unique(quantile(y, probs = seq(0, 1, length = cuts)))
y <- cut(y, breaks, include.lowest = TRUE)
}
When y
is numeric, but has no variation, breaks
will be the single value of y. However, cut
will interpret this single value not as the value where a cut should be made, but as the number of breaks to make. When the y value is an integer, this will likely lead to an unexpected result where the number of breaks made is the y value, rather than the value of cuts. When the y value is not an integer, this will return an error because breaks can only be an integer.
Session Info:
>sessionInfo()
> sessionInfo()
R version 4.3.2 (2023-10-31 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 11 x64 (build 22621)
Matrix products: default
locale:
[1] LC_COLLATE=English_United States.utf8 LC_CTYPE=English_United States.utf8
[3] LC_MONETARY=English_United States.utf8 LC_NUMERIC=C
[5] LC_TIME=English_United States.utf8
time zone: America/New_York
tzcode source: internal
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] caret_6.0-94 lattice_0.21-9 ggplot2_3.4.4
loaded via a namespace (and not attached):
[1] future_1.33.1 utf8_1.2.4 generics_0.1.3 class_7.3-22
[5] stringi_1.8.3 pROC_1.18.5 listenv_0.9.1 digest_0.6.34
[9] magrittr_2.0.3 timechange_0.3.0 evaluate_0.23 grid_4.3.2
[13] iterators_1.0.14 fastmap_1.1.1 foreach_1.5.2 plyr_1.8.9
[17] Matrix_1.6-1.1 ModelMetrics_1.2.2.2 nnet_7.3-19 survival_3.5-7
[21] purrr_1.0.2 fansi_1.0.6 scales_1.3.0 codetools_0.2-19
[25] lava_1.7.3 cli_3.6.2 rlang_1.1.3 hardhat_1.3.1
[29] parallelly_1.37.1 future.apply_1.11.1 munsell_0.5.0 splines_4.3.2
[33] withr_3.0.0 yaml_2.3.8 prodlim_2023.08.28 parallel_4.3.2
[37] tools_4.3.2 reshape2_1.4.4 dplyr_1.1.4 colorspace_2.1-0
[41] recipes_1.0.10 globals_0.16.2 vctrs_0.6.5 R6_2.5.1
[45] rpart_4.1.21 stats4_4.3.2 lubridate_1.9.3 lifecycle_1.0.4
[49] stringr_1.5.1 MASS_7.3-60 pkgconfig_2.0.3 pillar_1.9.0
[53] gtable_0.3.4 glue_1.7.0 data.table_1.15.0 Rcpp_1.0.12
[57] xfun_0.41 tibble_3.2.1 tidyselect_1.2.0 rstudioapi_0.15.0
[61] knitr_1.45 htmltools_0.5.7 nlme_3.1-163 rmarkdown_2.25
[65] ipred_0.9-14 timeDate_4032.109 gower_1.0.1 compiler_4.3.2