gee-community/geemap

Can not use sklearn locally trained random forest regression ee.classifier object

alekusi opened this issue · 3 comments

Imports

import ee
import geemap
import pandas as pd
import numpy as np

from geemap import ml
from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV
from sklearn import ensemble
from sklearn import metrics

Description

I am trying to convert locally trained random forest regression model to ee.classifier with ml.strings_to_classifier -function. Code creates to the object but when extracting the information with getInfo(), EEExecption error occurs. Error states that the expected values for parsing is 5 but it only got 3. This also happens when trying to predict data with the classifier object.

I can use ee-data and functions so there should be no problem with the connection or authentication.

I am using Python 3.8.9 with geemap 0.22.1, earthengine_api 0.1.355 and scikit_learn 1.2.2. Currently my windows system is using following delimitters:

Decimal symbol: "."
Digit grouping symbol: " " (white space)
List seperator: ","

Can this cause the issue as the problems seems to be related to the parsing and default digit grouping delimiter (in Windows) is "," ?

What I Did

variables = ["a", "b", "c"]

xdata = df[variables]
ydata = df["y"]

X_train, X_test,\
    y_train, y_test = train_test_split(xdata, ydata,
                                       test_size=0.3,
                                       random_state=42)

rf_fit = ensemble.RandomForestRegressor(n_estimators=500, 
                                        max_depth=80, 
                                        max_samples=0.632,
                                        min_samples_leaf=5, 
                                        max_features=5,
                                        max_leaf_nodes=3,
                                        random_state=0).fit(X_train,y_train)

trees = ml.rf_to_strings(rf_fit, variables, output_mode="REGRESSION")
ee_classifier = ml.strings_to_classifier(trees)
ee_classifier.getInfo()

HttpError                                 Traceback (most recent call last)
File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\ee\data.py:346, in _execute_cloud_call(call, num_retries)
    345 try:
--> 346   return call.execute(num_retries=num_retries)
    347 except googleapiclient.errors.HttpError as e:

File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\googleapiclient\_helpers.py:130, in positional..positional_decorator..positional_wrapper(*args, **kwargs)
    129         logger.warning(message)
--> 130 return wrapped(*args, **kwargs)

File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\googleapiclient\http.py:938, in HttpRequest.execute(self, http, num_retries)
    937 if resp.status >= 300:
--> 938     raise HttpError(resp, content, uri=self.uri)
    939 return self.postproc(resp, content)

HttpError: 

During handling of the above exception, another exception occurred:

EEException                               Traceback (most recent call last)
Cell In[30], line 3
      1 ee_classifier = ml.strings_to_classifier(trees)
      2 #ee_classifier = ml.csv_to_classifier("[D:/2023/test/trees.csv](file:///D:/2023/test/trees.csv)")
----> 3 ee_classifier.getInfo()

File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\ee\computedobject.py:96, in ComputedObject.getInfo(self)
     90 def getInfo(self):
     91   """Fetch and return information about this object.
     92 
     93   Returns:
     94     The object can evaluate to anything.
     95   """
---> 96   return data.computeValue(self)

File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\ee\data.py:955, in computeValue(obj)
    952 body = {'expression': serializer.encode(obj, for_cloud_api=True)}
    953 _maybe_populate_workload_tag(body)
--> 955 return _execute_cloud_call(
    956     _get_cloud_projects()
    957     .value()
    958     .compute(body=body, project=_get_projects_path(), prettyPrint=False)
    959 )['result']

File [c:\Users\Python38\Portable](file:///C:/Users/Python38/Portable) Python-3.8.9 x64\App\Python\lib\site-packages\ee\data.py:348, in _execute_cloud_call(call, num_retries)
    346   return call.execute(num_retries=num_retries)
    347 except googleapiclient.errors.HttpError as e:
--> 348   raise _translate_cloud_exception(e)

EEException: Classifier.decisionTreeEnsemble: Error parsing line 9: expected 5, got 3.

Did you try out the example? https://geemap.org/notebooks/46_local_rf_training/

Yes, I followed the notebook documentation with the exception that I am using regression instead of classification. I did some printing about the classifier objects:

Scikit-learn trained classifier:

ee_classifier = ml.strings_to_classifier(trees)
print(ee_classifier)
ee.Classifier({
  "functionInvocationValue": {
    "functionName": "Classifier.decisionTreeEnsemble",
    "arguments": {
      "treeStrings": {
        "constantValue": [
          "1) root 55 9999 9999 (147.25973519669316)\n  2) STR <= 1.677856 6 8.7500 -20.0 *\n  3) STR > 1.677856 55 60.2441 -3.885714\n    6) NIR <= 0.177952 13 30.6582 3.357143 *\n    7) NIR > 0.177952 36 18.5482 -3.3125 *\n",
          "1) root 53 9999 9999 (137.64400828644222)\n  2) SWIR1 <= 0.186536 53 65.6663 -3.928571\n    4) TCGreenness <= 0.063946 53 65.6663 -3.928571 *\n  3) SWIR1 > 0.186536 8 4.8100 -20.7 *\n    6) TCGreenness > 0.063946 39 14.8299 -1.851852 *\n",
          "1) root 49 9999 9999 (229.40906560236297)\n  2) Red <= 0.045711 27 36.3765 1.111111 *\n  3) Red > 0.045711 49 59.9243 -2.7\n    6) EVI <= 0.302120 6 40.0000 -0.666667 *\n    7) EVI > 0.302120 16 39.9136 -8.92 *\n",
          "1) root 51 9999 9999 (174.50701997203825)\n  2) SWIR1 <= 0.190012 51 60.9755 -3.714286\n    4) SWIR2 <= 0.079533 51 60.9755 -3.714286 *\n  3) SWIR1 > 0.190012 5 1.3469 -19.285714 *\n    6) SWIR2 > 0.079533 14 48.5744 -5.882353 *\n",
          "1) root 51 9999 9999 (155.8207175984602)\n  2) STR <= 1.795543 5 1.4844 -20.375 *\n  3) STR > 1.795543 51 67.7063 -3.471429\n    6) TCWetness <= -0.018154 29 25.7173 0.974359 *\n    7) TCWetness > -0.018154 17 26.2873 -5.130435 *\n",
          "1) root 51 9999 9999 (169.9831235827664)\n  2) SWIR1 <= 0.168349 51 73.5967 -3.657143\n    4) SAVI <= 0.270244 51 73.5967 -3.657143 *\n  3) SWIR1 > 0.168349 7 13.6100 -20.3 *\n    6) SAVI > 0.270244 32 12.4178 -2.733333 *\n",
          "1) root 50 9999 9999 (141.9237703477213)\n  2) SWIR1 <= 0.195392 50 70.1029 -3.8\n    4) SWIR1 <= 0.146592 50 70.1029 -3.8 *\n  3) SWIR1 > 0.195392 6 3.5556 -21.0 *\n    6) SWIR1 > 0.146592 11 14.8235 -6.0 *\n",
          "1) root 52 9999 9999 (156.71310708220403)\n  2) SWIR2 <= 0.110803 52 66.5420 -2.828571\n    4) NIR <= 0.166068 52 66.5420 -2.828571 *\n  3) SWIR2 > 0.110803 7 15.8025 -18.555556 *\n    6) NIR > 0.166068 32 22.1150 -2.276596 *\n",
          "1) root 52 9999 9999 (134.13911537253946)\n  2) SWIR1 <= 0.171396 52 50.5420 -1.828571\n    4) Green <= 0.048944 52 50.5420 -1.828571 *\n  3) SWIR1 > 0.171396 5 28.2400 -19.6 *\n    6) Green > 0.048944 16 3.6224 -4.24 *\n",
          "1) root 52 9999 9999 (174.68272458649182)\n  2) SWIR1 <= 0.168349 52 64.5692 -2.128571\n    4) EVI <= 0.252694 52 64.5692 -2.128571 *\n  3) SWIR1 > 0.168349 6 19.7284 -16.777778 *\n    6) EVI > 0.252694 33 15.6938 -2.043478 *\n"
        ]
      }
    }
  }
})

GEE:s implementation of smile randomforest

smileclassifier = ee.Classifier.smileRandomForest(200, 5, 5, 0.632, 4, 0).setOutputMode("REGRESSION").train(**{"features": features,
"classProperty": "WT",
"inputProperties": variables})

smiletrees = ee.List(smileclassifier.explain().get("trees"))

smileclassifier2 = ee.Classifier.decisionTreeEnsemble(smiletrees)
print(smileclassifier2)
ee.Classifier({
  "functionInvocationValue": {
    "functionName": "Classifier.decisionTreeEnsemble",
    "arguments": {
      "treeStrings": {
        "functionInvocationValue": {
          "functionName": "Dictionary.get",
          "arguments": {
            "dictionary": {
              "functionInvocationValue": {
                "functionName": "Classifier.explain",
                "arguments": {
                  "classifier": {
                    "functionInvocationValue": {
                      "functionName": "Classifier.train",
                      "arguments": {
                        "classProperty": {
                          "constantValue": "WT"
                        },
                        "classifier": {
                          "functionInvocationValue": {
                            "functionName": "Classifier.setOutputMode",
                            "arguments": {
                              "classifier": {
                                "functionInvocationValue": {
                                  "functionName": "Classifier.smileRandomForest",
                                  "arguments": {
                                    "bagFraction": {
                                      "constantValue": 0.632
                                    },
                                    "maxNodes": {
                                      "constantValue": 4
                                    },
                                    "minLeafPopulation": {
                                      "constantValue": 5
                                    },
                                    "numberOfTrees": {
                                      "constantValue": 200
                                    },
                                    "seed": {
                                      "constantValue": 0
                                    },
                                    "variablesPerSplit": {
                                      "constantValue": 5
                                    }
                                  }
                                }
                              },
                              "mode": {
                                "constantValue": "REGRESSION"
                              }
                            }
                          }
                        },
                        "features": {
                          "functionInvocationValue": {
                            "functionName": "Collection.loadTable",
                            "arguments": {
                              "tableId": {
                                "constantValue": "projects/ee-xx/assets/Traindata/SAT_WT_train"
                              }
                            }
                          }
                        },
                        "inputProperties": {
                          "constantValue": [
                            "Blue",
                            "Green",
                            "Red",
                            "NIR",
                            "SWIR1",
                            "SWIR2",
                            "STR",
                            "NDVI",
                            "EVI",
                            "SAVI",
                            "TCGreenness",
                            "NDWI",
                            "MNDWI",
                            "NDMI",
                            "NDMI2",
                            "TCWetness",
                            "TCAngle"
                          ]
                        }
                      }
                    }
                  }
                }
              }
            },
            "key": {
              "constantValue": "trees"
            }
          }
        }
      }
    }
  }
})

I still get the "Classifier.decisionTreeEnsemble: Error parsing line 9: expected 5, got 3" when trying to use sklearn based classifier but smile classifier is working fine. As smile RF has only limited parameter tuning options, I really would like to use locally trained regression model but currently not working.

giswqs commented

Unfortunately, regression classifier is not supported. That's a GEE limitation rather than a geemap limitation.