From a1824a1f3aaea0fd509500c31bb6869c3840a760 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Thu, 11 Jan 2024 20:39:39 +0100 Subject: [PATCH] Support refreshing service connector credentials in the Vertex step operator to support long-running jobs (#2198) * Refresh creds on Vertex step operator job polling * Fix docstring error --- .../step_operators/vertex_step_operator.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py index 4428de7478e..18ca23edce3 100644 --- a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py +++ b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py @@ -21,6 +21,7 @@ import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast +from google.api_core.exceptions import ServerError from google.cloud import aiplatform from zenml import __version__ @@ -186,7 +187,6 @@ def launch( Raises: RuntimeError: If the run fails. - ConnectionError: If the run fails due to a connection error. """ resource_settings = info.config.resource_settings if resource_settings.cpu_count or resource_settings.memory: @@ -300,27 +300,30 @@ def launch( try: response = client.get_custom_job(name=job_id) retry_count = 0 - # Handle transient connection error. - except ConnectionError as err: + # Handle transient connection errors and credential expiration by + # recreating the Python API client. + except (ConnectionError, ServerError) as err: if retry_count < CONNECTION_ERROR_RETRY_LIMIT: retry_count += 1 logger.warning( - "ConnectionError (%s) encountered when polling job: " - "%s. Trying to recreate the API client.", - err, - job_id, + f"Error encountered when polling job " + f"{job_id}: {err}\nRetrying...", ) + # This call will refresh the credentials if they expired. + credentials, project_id = self._get_authentication() # Recreate the Python API client. client = aiplatform.gapic.JobServiceClient( - client_options=client_options + credentials=credentials, client_options=client_options ) else: - logger.error( + logger.exception( "Request failed after %s retries.", CONNECTION_ERROR_RETRY_LIMIT, ) - raise - + raise RuntimeError( + f"Request failed after {CONNECTION_ERROR_RETRY_LIMIT} " + f"retries: {err}" + ) if response.state in VERTEX_JOB_STATES_FAILED: err_msg = ( "Job '{}' did not succeed. Detailed response {}.".format(