tensorflow/decision-forests

DistributedGradientBoostedTreesModel does not support Ranking task

JackGammack opened this issue · 1 comments

The documentation shows that you can use the ranking task for this model, but there is no warning or failure until training time. This error message is not clear that the ranking task is actually not available for this model, and I couldn't find any documentation indicating this.

Are there plans to add support for distributed ranking models? I figure there may be limitations related to examples from the same ranking_group ending up on different workers when the ndcg needs to be calculated.

Minimal example

strategy = tf.distribute.experimental.ParameterServerStrategy(...)

with strategy.scope():
        model = tfdf.keras.DistributedGradientBoostedTreesModel(
            task=tfdf.keras.Task.RANKING,
            ranking_group="group",
        )

model.fit_on_dataset_path(
        train_path=train_input_pattern,
        label_key="label",
        weight_key="sample_weight",
        dataset_format="tfrecord+tfe",
)

Error message below. Changing the task to regression makes the model train successfully.

File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/keras/core.py", line 1942, in fit_on_dataset_path
    tf_core.train_on_file_dataset(
  File "/opt/conda/lib/python3.10/site-packages/tensorflow_decision_forests/tensorflow/core.py", line 779, in train_on_file_dataset
    training_op.SimpleMLCheckStatus(process_id=process_id) == 1
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/util/tf_export.py", line 403, in wrapper
    return f(**kwargs)
  File "<string>", line 1373, in simple_ml_check_status
  File "/opt/conda/lib/python3.10/site-packages/tensorflow/python/framework/ops.py", line 5883, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.UnknownError: {{function_node __wrapped__SimpleMLCheckStatus_device_/job:chief/replica:0/task:0/device:CPU:0}} TensorFlow: INVALID_ARGUMENT: Worker #0: INVALID_ARGUMENT: Not supported task [Op:SimpleMLCheckStatus] name:

Hi Jack,

You're right, Ranking is not currently available in distributed training. I've improved the error message and the documentation about it on https://ydf.readthedocs.org.

Our team is always happy to implement missing features in TF-DF / Yggdrasil Decision Forests, but our resources are limited, and we have to prioritize, among other factors, upon impact. If you have a cool / strong use case for this feature, please contact use at decision-forests-contact@google.com.