tidymodels/tidypredict

number of trees and case when statements

Closed this issue · 6 comments

Thanks for developing this package. Maybe I'm missing a nuance, but shouldn't the tidypredict_sql on ranger generate a sql statement with a case_when statement for every tree followed by an aggregation function that averages the predictions of individual trees?

In the example below, I'm training a model with 100 trees.

library(ranger)
model <- ranger::ranger(Species ~ .,data = iris ,num.trees = 100)
library(tidypredict)
tidypredict_sql(model, dbplyr::simulate_mssql())

The output does not represent the structure of those 100 trees though.

<SQL> CASE
WHEN ((`Petal.Width` < 0.75)) THEN ('setosa')
WHEN ((`Petal.Length` >= 5.05 AND `Petal.Width` >= 0.75 AND `Petal.Width` < 1.75)) THEN ('virginica')
WHEN ((`Petal.Width` >= 1.75 AND `Petal.Width` >= 0.75 AND `Petal.Length` < 4.9)) THEN ('virginica')
WHEN ((`Petal.Length` >= 4.9 AND `Petal.Width` >= 1.75 AND `Petal.Width` >= 0.75)) THEN ('virginica')
WHEN ((`Petal.Width` >= 0.75 AND `Sepal.Length` < 4.95 AND `Petal.Length` < 5.05 AND `Petal.Width` < 1.75)) THEN ('virginica')
WHEN ((`Sepal.Length` >= 4.95 AND `Petal.Width` >= 0.75 AND `Petal.Length` < 4.95 AND `Petal.Length` < 5.05 AND `Petal.Width` < 1.75)) THEN ('versicolor')
WHEN ((`Petal.Length` >= 4.95 AND `Sepal.Length` >= 4.95 AND `Petal.Width` >= 0.75 AND `Sepal.Length` < 6.35 AND `Petal.Length` < 5.05 AND `Petal.Width` < 1.75)) THEN ('virginica')
WHEN ((`Sepal.Length` >= 6.35 AND `Petal.Length` >= 4.95 AND `Sepal.Length` >= 4.95 AND `Petal.Width` >= 0.75 AND `Petal.Length` < 5.05 AND `Petal.Width` < 1.75)) THEN ('versicolor')
END

Hi, I'm using ranger::treeInfo() to obtain all of the paths. Not sure if there's anywhere else to look for more paths within the ranger object.

I have a helper function called tidypredict_test() and it seems that only one prediction is different than in tidypredict than from the native ranger predict() method:

library(ranger)
library(tidypredict)
model <- ranger::ranger(Species ~ .,data = iris ,num.trees = 100)
treeInfo(model)
#>    nodeID leftChild rightChild splitvarID splitvarName splitval terminal
#> 1       0         1          2          3 Petal.Length     2.60    FALSE
#> 2       1        NA         NA         NA         <NA>       NA     TRUE
#> 3       2         3          4          4  Petal.Width     1.65    FALSE
#> 4       3         5          6          2  Sepal.Width     2.25    FALSE
#> 5       4         7          8          1 Sepal.Length     5.95    FALSE
#> 6       5         9         10          3 Petal.Length     4.75    FALSE
#> 7       6        11         12          3 Petal.Length     5.35    FALSE
#> 8       7        13         14          2  Sepal.Width     3.00    FALSE
#> 9       8        NA         NA         NA         <NA>       NA     TRUE
#> 10      9        NA         NA         NA         <NA>       NA     TRUE
#> 11     10        NA         NA         NA         <NA>       NA     TRUE
#> 12     11        NA         NA         NA         <NA>       NA     TRUE
#> 13     12        NA         NA         NA         <NA>       NA     TRUE
#> 14     13        NA         NA         NA         <NA>       NA     TRUE
#> 15     14        NA         NA         NA         <NA>       NA     TRUE
#>    prediction
#> 1        <NA>
#> 2      setosa
#> 3        <NA>
#> 4        <NA>
#> 5        <NA>
#> 6        <NA>
#> 7        <NA>
#> 8        <NA>
#> 9   virginica
#> 10 versicolor
#> 11  virginica
#> 12 versicolor
#> 13  virginica
#> 14  virginica
#> 15 versicolor

Created on 2018-09-18 by the reprex package (v0.2.0).

Our split values don't match cause I guess the seeds don't match, it looks like ranger uses it.

I see what is happening. The ranger::treeInfo() has a default of using the first tree e.g.(ranger::treeInfo(model, tree=1). In other words, your generated sql is only for the first tree in the forest.

Ah! Great catch, I'll work on a fix.

Also, @stasSajin , thanks for looking into this!

Hi @stasSajin , I have a good first draft of the solution in this branch: https://github.com/edgararuiz/tidypredict/tree/earth-fixes

It should work for categorical and continuous outcomes. The tidypredict_sql() and tidypredict_fit() will return a list with as many items as there are trees. The key is using tidypredict_to_column(), that function will now create a new query for each tree. It then groups it and either uses the average (continuous), or the one answer that was the most prevalent (categorical)

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.