tensorflow/tfx

TFX component never completes even though Vertex AI custom job succeeds / fails

clee421 opened this issue · 8 comments

If the bug is related to a specific library below, please raise an issue in the
respective repo directly:

TensorFlow Data Validation Repo

TensorFlow Model Analysis Repo

TensorFlow Transform Repo

TensorFlow Serving Repo

System information

  • Have I specified the code to reproduce the issue (Yes, No): Yes
  • Environment in which the code is executed (e.g., Local(Linux/MacOS/Windows),
    Interactive Notebook, Google Cloud, etc): GCP GKE Pod
  • TensorFlow version: 2.13.0
  • TFX Version: 1.14.0
  • Python version: 3.8
  • Python dependencies (from pip freeze output): N/A (I can provide the dependencies if it's deemed applicaple)

Describe the current behavior
I have a pipeline which wraps runner.start_cloud_training and will run a custom job on vertex which will succeed or fail. The TFX component will continue to hang and not complete regardless of the custom job completion

Describe the expected behavior
I would expect the TFX component to complete when the vertex custom job completes.

Standalone code to reproduce the issue

I've debugged this the best I could and here is my finding.

I believe this line here:

while client.get_job_state(response) not in client.JOB_STATES_COMPLETED:
    # ...

Doesn't ever complete because the return value of client.get_job_state(response) is not an enum.

Here is the script I used to test and validate my hypothesis

from google.cloud.aiplatform_v1beta1.services.job_service import JobServiceClient
from google.cloud.aiplatform_v1.types.job_state import JobState

_VERTEX_ENDPOINT_SUFFIX = '-aiplatform.googleapis.com'

_VERTEX_JOB_STATE_SUCCEEDED = JobState.JOB_STATE_SUCCEEDED
_VERTEX_JOB_STATE_FAILED = JobState.JOB_STATE_FAILED
_VERTEX_JOB_STATE_CANCELLED = JobState.JOB_STATE_CANCELLED
JOB_STATES_COMPLETED = (_VERTEX_JOB_STATE_SUCCEEDED, _VERTEX_JOB_STATE_FAILED,
                          _VERTEX_JOB_STATE_CANCELLED)

"""
The client used by TFX is the VertexJobClient
https://github.com/tensorflow/tfx/blob/master/tfx/extensions/google_cloud_ai_platform/training_clients.py#L288

which underneath uses the JobServiceClient
https://github.com/tensorflow/tfx/blob/master/tfx/extensions/google_cloud_ai_platform/training_clients.py#L312


The polling in TFX
https://github.com/tensorflow/tfx/blob/master/tfx/extensions/google_cloud_ai_platform/training_clients.py#L312

checks if
`while client.get_job_state(response) not in client.JOB_STATES_COMPLETED`

which does not seem to be correct since replicating below I get
"4 is not in list (<JobState.JOB_STATE_SUCCEEDED: 4>, <JobState.JOB_STATE_FAILED: 5>, <JobState.JOB_STATE_CANCELLED: 7>)"

The completed state is found here
https://github.com/tensorflow/tfx/blob/master/tfx/extensions/google_cloud_ai_platform/training_clients.py#L291
"""
def main():
    vertex_region = "my-region"
    client = JobServiceClient(
        client_options=dict(api_endpoint=vertex_region + _VERTEX_ENDPOINT_SUFFIX))
    
    project = "my-project"
    location = vertex_region
    custom_job = "my-id"

    custom_job = client.get_custom_job(name=f"projects/{project}/locations/{location}/customJobs/{custom_job}")
    print(custom_job)

    if custom_job.state not in JOB_STATES_COMPLETED:
        print(f"{custom_job.state} is not in list {JOB_STATES_COMPLETED}")
    else:
        print(f"{custom_job.state} is in list {JOB_STATES_COMPLETED}")

main()

When running the snippet above I have the output

4 is not in list (<JobState.JOB_STATE_SUCCEEDED: 4>, <JobState.JOB_STATE_FAILED: 5>, <JobState.JOB_STATE_CANCELLED: 7>)

Converting the number to an enum by JobState(custom_job.state) fixes the problem.

I hope this help, I would be more than happen to provide more information!

Providing a bare minimum test case or step(s) to reproduce the problem will
greatly help us to debug the issue. If possible, please share a link to
Colab/Jupyter/any notebook.

Name of your Organization (Optional)

Other info / logs

Include any logs or source code that would be helpful to diagnose the problem.
If including tracebacks, please include the full traceback. Large logs and files
should be attached.

Following up here. I was able to copy the runner.py file over and make the modifications for my pipeline. The TFX component no longer hangs.

@clee421, As per your previous comment if this issue is resolved for you, Requesting you to close this issue. Thank you!

@singhniraj08 Well the bug is still there. I'm copying your TFX file over as a workaround, I still would need a fix.

Hi, @clee421.

Thanks for investigating and giving the details. It helped a lot to understand your problem.

While looking at your example code, I found out that there is a version incompatibility between JobServiceClient and JobState. Specifically, JobServiceClient uses v1beta1 while JobState uses v1. Because there is another JobState in v1beta1, you have to import from google.cloud.aiplatform_v1beta1.types.job_state import JobState if you want to use v1beta client.

From my own experiment, I got the desired result like 4 is in list (<JobState.JOB_STATE_SUCCEEDED: 4>, <JobState.JOB_STATE_FAILED: 5>, <JobState.JOB_STATE_CANCELLED: 7>) after changing the type into that of v1beta1.

Additionally, tfx.extensions.google_cloud_ai_platform.Trainer which internally uses runner.start_cloud_training also works well as expected.

I have a pipeline which wraps runner.start_cloud_training and will run a custom job on vertex which will succeed or fail. The TFX component will continue to hang and not complete regardless of the custom job completion

I don't have a clear idea how you wrapped runner.start_cloud_training, but if you mixed up the client and type of different versions, please match them first.

If this doesn't work, please let me know.

Specifically, JobServiceClient uses v1beta1 while JobState uses v1. Because there is another JobState in v1beta1, you have to import from google.cloud.aiplatform_v1beta1.types.job_state import JobState if you want to use v1beta client.

This poses a problem for us then. We specifically monkey patch JobServiceClient because the v1beta has the persistent resource id which v1 does not.

I see. Please try with v1beta.types.job_state, and let us know if it works.

@briron Thanks for the suggestion. I was able to test in an isolated environment with v1beta1.JobState that there is no bug.

Are you satisfied with the resolution of your issue?
Yes
No