Support Weight and Base_Margin in genDMatrix
Closed this issue · 0 comments
SixiangHu commented
It is quite common to set weight and base_margin for each of the observation before training xgboost model.
And xgboost DMatrix does support this feature while reading additional .libsvm files.
https://xgboost.readthedocs.io/en/latest/tutorials/input_format.html#embedding-additional-information-inside-libsvm-file
Have a checked the current genDMatrix code, and found that both parameters can be added really easily.
The concerns here are:
- only support single node version
- make the genDMatrix even slower
Is there any further concern why this was not added as a feature?
genDMatrix = function(df_y, df_X,
file = tempfile(pattern = "DMatrix", fileext = ".libsvm"),
weight=NULL,
base_margin=NULL){
col2len = function(x){
col = df_X[[x]]
if(is.factor(col)){
return (length(levels(col)))
}
return (1)
}
col_len = sapply(names(df_X), FUN = col2len)
col_offset = (cumsum(col_len) - col_len)
factor2pos = function(x){
if(is.na(x)){
return (NA)
}
if(is.factor(x)){
return (as.integer(x) - 1)
}
return (0)
}
format_cell = function(x){
if(is.na(x)){
return (NA)
}
if(is.factor(x)){
return (1)
}
return (x)
}
fp = file(file, "w")
# additional code:
if(!is.null(weight)) {
file_weight = paste0(file,".weight")
fw = file(file_weight, "w")
}
if(!is.null(base_margin)) {
file_base_margin = paste0(file,".base_margin")
fb = file(file_base_margin, "w")
}
for(i in 1:nrow(df_X)){
cell_offset = (col_offset + sapply(df_X[i, ], FUN = factor2pos))
cell_value = sapply(df_X[i, ], FUN = format_cell)
y_value = df_y[i]
X_values = paste(na.omit(cell_offset), na.omit(cell_value), sep = ":", collapse = " ")
cat(paste(y_value, " ", X_values, "\n", sep = ""), file = fp)
# additional code:
if(!is.null(weight)) cat(paste(weight[i], "\n", sep = ""), file = fw)
if(!is.null(base_margin)) cat(paste(base_margin[i], "\n", sep = ""), file = fb)
}
close(fp)
# additional code:
if(!is.null(weight)) close(fw)
if(!is.null(base_margin)) close(fb)
return (xgboost::xgb.DMatrix(file))
}