support cross_entropy as objective for lightgbm
panlanfeng opened this issue · 7 comments
Hello,
Thanks for the great work on this project.
I was wondering if supporting cross entropy objective in your supporting roadmap or not.
I have a use case that I need to use numeric probability labels in [0, 1]. I got the following error message. Could you help to take a look? thanks!
Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jun 30, 2021 3:56:41 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: cross_entropy
at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
at org.jpmml.lightgbm.Main.run(Main.java:137)
at org.jpmml.lightgbm.Main.main(Main.java:127)
Exception in thread "main" java.lang.IllegalArgumentException: cross_entropy
at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:529)
at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
at org.jpmml.lightgbm.Main.run(Main.java:137)
at org.jpmml.lightgbm.Main.main(Main.java:127)
Does the cross_entropy
objective function (aka xentropy
) also use the sigmoid function for calculating probabilities?
If it does, then try simply inserting cross_entropy
here:
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523
Something like this:
switch(objective){
// BinaryLogloss
case "binary":
case "cross_entropy":
return new BinomialLogisticRegression(average_output, config.getDouble("sigmoid"));
}
If you rebuild the project, and re-do the conversion, then does the PMML model make correct predictions or not?
Looks like cross_entropy is not using sigmoid. I made the change as you suggested and get the following error when converting
Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
INFO: Loading GBDT..
Jul 01, 2021 12:41:01 AM org.jpmml.lightgbm.Main run
SEVERE: Failed to load GBDT
java.lang.IllegalArgumentException: sigmoid
at org.jpmml.lightgbm.Section.get(Section.java:106)
at org.jpmml.lightgbm.Section.get(Section.java:100)
at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
at org.jpmml.lightgbm.Main.run(Main.java:137)
at org.jpmml.lightgbm.Main.main(Main.java:127)
Exception in thread "main" java.lang.IllegalArgumentException: sigmoid
at org.jpmml.lightgbm.Section.get(Section.java:106)
at org.jpmml.lightgbm.Section.get(Section.java:100)
at org.jpmml.lightgbm.Section.getDouble(Section.java:74)
at org.jpmml.lightgbm.GBDT.loadObjectiveFunction(GBDT.java:525)
at org.jpmml.lightgbm.GBDT.load(GBDT.java:103)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:51)
at org.jpmml.lightgbm.LightGBMUtil.loadGBDT(LightGBMUtil.java:43)
at org.jpmml.lightgbm.Main.run(Main.java:137)
at org.jpmml.lightgbm.Main.main(Main.java:127)
According to this line, cross entropy is directly doing the calculation instead of calling sigmoid function and it does not take sigmoid parameter as in binary classification.
I was able to make it generate the correct score after making the following change to
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.9/src/main/java/org/jpmml/lightgbm/GBDT.java#L523
case "cross_entropy":
return new BinomialLogisticRegression(average_output, 1.0 );
I can make a CR for this change if it looks OK to you.
return new BinomialLogisticRegression(average_output, 1.0 );
Yes, that appears to be the solution. There is no need for an explicit sigmoid
parameter, because the coefficient is hard-coded as 1
.
I can make a CR for this change if it looks OK to you.
Not needed - I'll do a proper cross_entropy
support with test cases for the next release myself.
In the meantime, you can keep using your patched codebase.
Thanks!
I was also wondering if it is possible to also add this cross entropy support to history version 1.2.* as well?
Ask because our team are still using version 1.2.*.
It is OK if there is no such plan.
I was also wondering if it is possible to also add this cross entropy support to history version 1.2.* as well?
I'll see if the 1.2.X development branch has the same API available that is being "touched" here. If it is, I'll implement the change in 1.2.X, and then merge forward to 1.3.X.
The fix is available both in JPMML-LightGBM 1.2.15 and 1.3.10.