Skip to content

Commit

Permalink
Support refreshing service connector credentials in the Vertex step o…
Browse files Browse the repository at this point in the history
…perator to support long-running jobs (#2198)

* Refresh creds on Vertex step operator job polling

* Fix docstring error
  • Loading branch information
stefannica authored Jan 11, 2024
1 parent 250e65c commit a1824a1
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions src/zenml/integrations/gcp/step_operators/vertex_step_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast

from google.api_core.exceptions import ServerError
from google.cloud import aiplatform

from zenml import __version__
Expand Down Expand Up @@ -186,7 +187,6 @@ def launch(
Raises:
RuntimeError: If the run fails.
ConnectionError: If the run fails due to a connection error.
"""
resource_settings = info.config.resource_settings
if resource_settings.cpu_count or resource_settings.memory:
Expand Down Expand Up @@ -300,27 +300,30 @@ def launch(
try:
response = client.get_custom_job(name=job_id)
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
# Handle transient connection errors and credential expiration by
# recreating the Python API client.
except (ConnectionError, ServerError) as err:
if retry_count < CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logger.warning(
"ConnectionError (%s) encountered when polling job: "
"%s. Trying to recreate the API client.",
err,
job_id,
f"Error encountered when polling job "
f"{job_id}: {err}\nRetrying...",
)
# This call will refresh the credentials if they expired.
credentials, project_id = self._get_authentication()
# Recreate the Python API client.
client = aiplatform.gapic.JobServiceClient(
client_options=client_options
credentials=credentials, client_options=client_options
)
else:
logger.error(
logger.exception(
"Request failed after %s retries.",
CONNECTION_ERROR_RETRY_LIMIT,
)
raise

raise RuntimeError(
f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} "
f"retries: {err}"
)
if response.state in VERTEX_JOB_STATES_FAILED:
err_msg = (
"Job '{}' did not succeed. Detailed response {}.".format(
Expand Down

0 comments on commit a1824a1

Please sign in to comment.