How to get the output leaf indices of every trees in a LightGBM/Xgboost model
Closed this issue · 6 comments
While doing prediction, I want to get the output leaf indices of every trees from my PMML LightGBM/Xgboost model. Any index format is OK, including onehot/labelencoded/tree node idx.
The pmml model is generated by Python sklearn package with sklearn2pmml or jpmml-lightgbm.
Actually, my purpose is the same as LGBMClassifier.predict(data, pred_leaf=True)
in Python.
How can I do that in Java using JPMML-Evaluator?
I want to get the output leaf indices of every trees.. any index format is OK, including onehot/labelencoded/tree node idx.
In PMML representation, tree nodes are identified by the Node@id
attribute:
http://dmg.org/pmml/v4-4-1/TreeModel.html#xsdElement_Node
This is an optional attribute; if missing, the PMML engine shall assign "virtual" 1-based integer identifiers.
How can I do that in Java using JPMML-Evaluator?
The results from tree-based models (decision trees, decision tree ensembles) typically implement the org.jpmml.evaluator.tree.HasDecisionPath
marker interface:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/HasDecisionPath.java
This marker interface, possibly in combination with the model-level org.jpmml.evaluator.tree.HasNodeRegistry
marker interface, should provide all information for achieving custom application goals:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/HasNodeRegistry.java
The situation is a bit more complicated with tree ensemble models (XGBoost, LightGBM, GBT, Random Forest), because the prediction result is "layered", which means that the o.j.e.tree.HasDecisionPath
object is wrapped inside an org.jpmml.evaluator.mining.HasSegmentation
object:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/mining/HasSegmentation.java
The internal structure of the o.j.e.mining.HasSegmentation
object depends on the mining function. For regression-type decision tree ensembles it's simpler (booster only), for classification-type decision tree ensembles it's more complex (booster followed by boosted score normalizer).
TLDR: Use the following approach:
- Start with simple decision trees. For example, any of my
DecisionTreeAudit.pmml
models. - Make a prediction using a simple decision tree, and cast its target value to
org.jpmml.evaluator.tree.HasDecisionPath
object. - Extract
HasDecisionPath#getNode()
and process it. - Move to a more complex example. For example, some regression-type XGBoost or LightGBM model such as my
XGBoostAuto.pmml
orLightGBMAuto.pmml
models. - Make a prediction using it, and cast the target value to
org.jpmml.evaluator.mining.HasSegmentation
object. - Extract individual segment targets, and process them.
- Move to an even more complex example. For example, classification-type XGBoost or LigthGBM models.
- Make a prediction using it, cast the target value to
o.j.e.mining.HasSegmentation
. Extract the partial result corresponding to the booster component, and process according to steps five and six above.
See also th following two sample projects about dealing with decision tree ensemble (RF) models:
Closing this issue, as the provided guidance should be sufficient to continue on your own. Feel free to ask clarifying/follow-up questions if necessary.
Thanks for your detailed reply and guidance. During past days I've tried to implement the 'node-id extraction' function following your approach guidance, and I've got some new problems.
PS1: Currently, I just follow your project https://github.com/vruusmann/rf_feature_impact to manage to get the leaf node ids. The code blocks below are from this project without any edition except my custom data, model and System.out.println()
. Thanks for your good reference.
PS2: All the codes below are working with jpmml 1.4.15
. When I use the latest 1.5.16
, there seems to be many breaking changes in your rf_feature_impact project. I also tried to reproduce the whole process under 1.5.16
, but the targetValue
became a double
(regression model) or ProbabilityDistribution
(classification model) object and couldn't be cast to HasSegmentation
or HasDecisionPath
by:
HasSegmentation hasSegmentation = (HasSegmentation)targetValue;
HasDecisionPath hasDecisionPath = (HasDecisionPath)targetValue;
We may discuss this version problem later.
I've arrived at your step 8, using one sample data to do prediction and get my LGB classification model's org.jpmml.evaluator.mining.HasSegmentation
output, and got the target values inside each SegmentResult
:
results = evaluator.evaluate(arguments);
Object targetValue = results.get(targetField.getName());
HasSegmentation hasSegmentation = (HasSegmentation)targetValue;
Collection<? extends SegmentResult> segmentResults = hasSegmentation.getSegmentResults();
for(SegmentResult segmentResult : segmentResults){
Object segmentTargetValue = segmentResult.getTargetValue();
And I checked the detail of this segmentTargetValue
in your funciton computeFeatureContributions
:
static
private List<Contribution> computeFeatureContributions(String segmentId, Number weight, Object targetValue, String targetClass){
HasDecisionPath hasDecisionPath = (HasDecisionPath)targetValue; // Here targetValue == segmentTargetValue
System.out.println("segmentId: " + segmentId);
System.out.println("targetValue: " + targetValue);
System.out.println("targetValue type: " + targetValue.getClass().toString());
System.out.println("hasDecisionPath node: " + hasDecisionPath.getNode()); // Gets the winning node.
The printed result(an example info from my LGB's 500th tree):
>>> segmentId: 500
>>> targetValue: {result=-3.811698813155931E-4, entityId=1}
>>> targetValue type: class org.jpmml.evaluator.tree.TreeModelEvaluator$1
>>> hasDecisionPath node: org.dmg.pmml.tree.ComplexNode@60e9df3c
I want to get the output leaf indices of every trees from my PMML LightGBM/Xgboost model
targetValue
has an element entityId
, which ranges from 1~7 as my LGB model's num_leaves=7. This entityId
seems like what I want. However, I used the same sample data and did prediction in Python by lightgbm.Booster.predict(data, pred_leaf=True)
. The leaf node ids are totally different from what I got in Java(the prediction probability values are same). And I checked these ids from Python that they are following the rule of this kind of tree index order:
In PMML representation, tree nodes are identified by the
Node@id
attribute: http://dmg.org/pmml/v4-4-1/TreeModel.html#xsdElement_NodeThis is an optional attribute; if missing, the PMML engine shall assign "virtual" 1-based integer identifiers.
My pmml model file doesn't contain the Node@id
attribute. I wonder if the entityId
I've got in Java is exactly the ' "virtual" 1-based integer identifiers' you mentioned? Can I trust this entityId
and use it to correctly represent the prediction output leaves?
Another problem is that, if the entityId
is my goal, how can I get it in targetValue
?
targetValue
is the type of class org.jpmml.evaluator.tree.TreeModelEvaluator
and it doesn't have methods like get()
or getId()
. The getNode()
method of hasDecisionPath
can only return a org.dmg.pmml.tree.ComplexNode
with a 'strange' id 60e9df3c
System.out.println("targetValue: " + targetValue);
System.out.println("targetValue type: " + targetValue.getClass().toString());
System.out.println("hasDecisionPath node: " + hasDecisionPath.getNode()); // Gets the winning node.
>>> targetValue: {result=-3.811698813155931E-4, entityId=1}
>>> targetValue type: class org.jpmml.evaluator.tree.TreeModelEvaluator$1
>>> hasDecisionPath node: org.dmg.pmml.tree.ComplexNode@60e9df3c
(I'm not sure whether this problem is a bit silly as I'm a Java rookie starting Java exactly from this project...)
As a reference, here is a part from my pmml model file that shows some basic info and the structure of my 500th tree segmentation:
<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<PMML xmlns="http://www.dmg.org/PMML-4_3" xmlns:data="http://jpmml.org/jpmml-model/InlineTable" version="4.3">
<Header>
<Application name="JPMML-LightGBM" version="1.3-SNAPSHOT"/>
<Timestamp>2021-11-29T11:29:25Z</Timestamp>
</Header>
<DataDictionary>
<DataField name="_target" optype="categorical" dataType="integer">
<Value value="0"/>
<Value value="1"/>
</DataField>
<DataField name="feature_001" optype="continuous" dataType="double">
<Interval closure="closedClosed" leftMargin="0.0" rightMargin="11.232960820144106"/>
<Value value="NaN" property="missing"/>
</DataField>
......
<Segment id="500">
<True/>
<TreeModel functionName="regression" noTrueChildStrategy="returnLastPrediction">
<MiningSchema>
<MiningField name="feature_002"/>
<MiningField name="feature_003"/>
<MiningField name="feature_004"/>
<MiningField name="feature_005"/>
<MiningField name="feature_006"/>
<MiningField name="feature_007"/>
</MiningSchema>
<Node score="-3.811698813155931E-4">
<True/>
<Node score="-8.988185833947195E-4">
<SimplePredicate field="feature_007" operator="greaterThan" value="0.5638774799243619"/>
<Node score="-0.006818797571779276">
<SimplePredicate field="feature_006" operator="greaterThan" value="48.50000000000001"/>
</Node>
<Node score="0.001541678356026245">
<SimplePredicate field="feature_002" operator="greaterThan" value="0.8183625000000001"/>
</Node>
</Node>
<Node score="-0.006749477204381504">
<SimplePredicate field="feature_005" operator="greaterThan" value="0.025786767554062152"/>
<Node score="0.009428389314363324">
<SimplePredicate field="feature_003" operator="greaterThan" value="894.5000000000001"/>
</Node>
</Node>
<Node score="0.010359079888978103">
<SimplePredicate field="feature_004" operator="greaterThan" value="8.500000000000002"/>
</Node>
</Node>
</TreeModel>
</Segment>
......
Version infomation
[Java side]
(1) Java: 1.8
(2) jpmml-evaluator: 1.4.15 / 1.5.16
(3) pmml model tramsformer: jpmml-lightgbm 1.3
(4) pmml ver: 4.3 (Original tranfromed pmml model is 4.4, manually changed pmml's headline to 4.3 as 4.4 is unsupported in jpmml-evaluator: 1.4.15)
[Python side]
(5) Python: 3.7
(6) LightGBM: 3.2.1
Please feel free to tell me if you need more info about my program. Thanks.
All the codes below are working with jpmml 1.4.15. When I use the latest 1.5.16, there seems to be many breaking changes.
The most important change between 1.4.X and 1.5.X development branches is that 1.5.X contains many decision tree evaluator implementations, and uses the most "lightweight" implementation that does seem to do the job.
The 1.4.X-compatible decision tree evaluator is org.jpmml.evaluator.tree.ComplexTreeModelEvaluator
:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/ComplexTreeModelEvaluator.java
It returns o.j.e.tree.HasDecisionPath
-compatible result values in all cases:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/ComplexTreeModelEvaluator.java#L283-L423
The newer & lightweight tree evaluator is org.jpmml.evaluator.tree.SimpleTreeModelEvaluator
:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/SimpleTreeModelEvaluator.java
As you already observed, it returns java.lang.Number
for regression cases, and java.lang.String
for voting-style classification cases:
https://github.com/jpmml/jpmml-evaluator/blob/1.5.16/pmml-evaluator/src/main/java/org/jpmml/evaluator/tree/SimpleTreeModelEvaluator.java#L93-L100
Most decision tree evaluation tasks are fully served by the o.j.e.tree.SimpleTreeModelEvaluator
. It creates less garbage, and is significantly more performant.
However, you want to access extra information that is not available when using o.j.e.tree.SimpleTreeModelEvaluator
. The solution is therefore to manually force the activation of o.j.e.tree.ComplexTreeModelEvaluator
.
This can be achieved using the org.jpmml.evaluator.ModelEvaluatorBuilder#setExtraResultFeatures(Set<org.dmg.pmml.ResultFeature>)
method. Since you're interested in node identifiers, you'd need to indicate org.dmg.pmml.ResultFeature#ENTITY_ID
there. Something like this:
EvaluatorBuilder evaluatorBuilder = new LoadingModelEvaluatorBuilder()
// THIS!
.setExtraResultFeatures(EnumSet.of(ResultFeature.Entity_ID))
.load(new File());
Evaluator evaluator = evaluatorBuilder.build();
The resulting Evaluator
will now be doing it best to return target values that implement the org.jpmml.evaluator.HasEntityId
marker interface (the org.jpmml.evaluator.tree.HasDecisionPath
is one of its sub-marker interfaces).
My pmml model file doesn't contain the Node@id attribute.
The JPMML-LightGBM library initializes the Node@id
attribute with native LightGBM identifier values:
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/Tree.java#L121
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/Tree.java#L272
Node identifier may get "erased" during decision tree compaction as implemented by the org.jpmml.lightgbm.visitors.TreeModelCompactor
visitor class.
They are required to be present initially:
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L33
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L37-L39
But they get "erased":
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/visitors/TreeModelCompactor.java#L77
Decision tree compaction is active by default. If you are interested in preserving LightGBM decision trees in their native layout, then you should disable it by setting the org.jpmml.lightgbm.HasLightGBMOptions#OPTION_COMPACT
to false
:
https://github.com/jpmml/jpmml-lightgbm/blob/1.3.12/src/main/java/org/jpmml/lightgbm/HasLightGBMOptions.java#L29
For example, if you're converting LightGBM models using the SkLearn2PMML package, then you can toggle this option using the sklearn2pmml.pipeline.PMMLPipeline.configure(**pmml_options)
method:
pipeline = PMMLPipeline([
("classifier", LGBMClassifier())
])
pipeline.fit(X, y)
# THIS!
pipeline.configure(compact = False)
sklearn2pmml(pipeline, "pipeline.pmml")
Exactly the same applies to XGBoost models - you need to turn off decision tree compaction, which is active by default.