Feature idea - provide custom validation sets for early stopping
dfsnow opened this issue · 4 comments
Thanks for creating this excellent package. I created a similar fork of treesnip but am planning to replace it with {bonsai}
in all our production models.
One feature that I think would be incredibly useful in {bonsai}
is the ability to provide custom validation sets during early stopping (instead of using a random split of the training data). This would have a few potential benefits:
- More training data. In many cases, you're already going to have a validation set set aside from a classic
train
,validate
,test
split. Currently,{bonsai}
will further split thetrain
data intotrain subset
andvalidation specifically for early stopping
sets. Instead, it would be ideal to be able to pass thevalidate
set directly. This would mean all oftrain
would be used for training. - Ability to do more complex cross-validation. Certain cross-validation techniques (rolling origin, spatial, etc.) don't rely on a random sample of the training data and instead use some sort of partitioning (time or geographic). Allowing custom validation data would let users use the "correct" validation set for early stopping when using these more complex methods.
- Better integration with tidymodels. Tidymodels supports k-fold and other types of cross-validation. Using the validation set created for each fold rather than splitting a separate validation set specifically for early stopping would be much simpler.
Let me know if this is out-of-scope for this project. If not, I'm happy to contribute if needed.
Thanks for the issue! I'm on board. :)
Related to tidymodels/parsnip#760, and tidymodels/parsnip#765.
My response for the analogous parsnip issues reflects where my thinking is at with this in bonsai as well.
This is an interesting idea and one that we ought to consider. xgboost and lightgbm's interfaces for validation sets allow for a lot of user control, but we'd need to think carefully about what a tidymodels-esque interface might feel like here.
This won't be on the top of our to-do list for now, but will leave this open as a possible future extension. :)
Great! Thanks for the quick response. Looks like there's already a PR in {parsnip}
for exactly this @ tidymodels/parsnip#771. I'll await that merge and then happy to assist with any further work needed to integrate it into {bonsai}
.
Whenever you or others here pick this up @simonpcouch , @
me if you need any help with how to do this in {lightgbm}
.
There is a LightGBM-y way to create validation sets that is slightly different from "just subset rows". See https://lightgbm.readthedocs.io/en/latest/R/reference/lgb.Dataset.create.valid.html.
Hi,
I wrote a simple fix to allow an alternative way to specify a custom validation set using "validation" param.
Using this code with bonsai v0.3.0 user can provide:
Example:
validation = 0.3 # default random sample (current solution)
validation = c(0.7, 0.9) # alternative solution to select a continuous subset starting from 70% and ending at 90% of the training set.
Here the code to replace the internal function after bonsai library 0.3.0 has been loaded.
Hope it is useful
Regards
utils::assignInNamespace(
x = "process_data",
ns = "bonsai",
value = function(args, x, y, weights, validation, missing_validation) {
# trn_index | val_index
# ----------------------------------
# needs_validation & missing_validation | 1:n 1:n
# needs_validation & !missing_validation | sample(1:n, m) setdiff(trn_index, 1:n)
# !needs_validation & missing_validation | 1:n NULL
# !needs_validation & !missing_validation | sample(1:n, m) setdiff(trn_index, 1:n)
n <- nrow(x)
needs_validation <- !is.null(args$params$early_stopping_round)
if (!needs_validation) {
# If early_stopping_round isn't set, clear it from arguments actually
# passed to LightGBM.
args$params$early_stopping_round <- NULL
}
if (missing_validation) {
trn_index <- 1:n
if (needs_validation) {
val_index <- trn_index
} else {
val_index <- NULL
}
} else {
if (length(validation)==2) {
# validation range percent bounds c(lower, higher)
l <- floor(n * validation[1]) + 1
h <- floor(n * validation[2])
val_index <- c(l:h)
trn_index <- setdiff(1:n, val_index)
} else {
# validation percent as scalar (default method)
m <- min(floor(n * (1 - validation)) + 1, n - 1)
trn_index <- sample(1:n, size = max(m, 2))
val_index <- setdiff(1:n, trn_index)
}
}
data_args <-
c(
list(
data = bonsai:::prepare_df_lgbm(x[trn_index, , drop = FALSE]),
label = y[trn_index],
categorical_feature = bonsai:::categorical_columns(x[trn_index, , drop = FALSE]),
params = c(list(feature_pre_filter = FALSE), args$params),
weight = weights[trn_index]
),
args$main_args_dataset
)
args$main_args_train$data <-
rlang::eval_bare(
rlang::call2("lgb.Dataset", !!!data_args, .ns = "lightgbm")
)
if (!is.null(val_index)) {
valids_args <-
c(
list(
data = bonsai:::prepare_df_lgbm(x[val_index, , drop = FALSE]),
label = y[val_index],
categorical_feature = bonsai:::categorical_columns(x[val_index, , drop = FALSE]),
params = list(feature_pre_filter = FALSE, args$params),
weight = weights[val_index]
),
args$main_args_dataset
)
args$main_args_train$valids <-
list(
validation =
rlang::eval_bare(
rlang::call2("lgb.Dataset", !!!valids_args, .ns = "lightgbm")
)
)
}
args
})