From d9c2a5e3a5ce530342384c749b893e98acc909c4 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 26 Jun 2024 10:01:01 +0200 Subject: [PATCH] Extract project ID from GCP service connector SA credentials (#2708) * Detect project ID mismatch in GCP service connector SA credentials * Detect project ID from service account creds * Fix docstrings --------- Co-authored-by: Safoine El Khabich <34200873+safoinme@users.noreply.github.com> Co-authored-by: Hamza Tahir --- .../gcp/google_credentials_mixin.py | 2 +- .../gcp_service_connector.py | 79 +++++++++++++++---- 2 files changed, 65 insertions(+), 16 deletions(-) diff --git a/src/zenml/integrations/gcp/google_credentials_mixin.py b/src/zenml/integrations/gcp/google_credentials_mixin.py index ebde36834d9..3bf6e25de08 100644 --- a/src/zenml/integrations/gcp/google_credentials_mixin.py +++ b/src/zenml/integrations/gcp/google_credentials_mixin.py @@ -85,7 +85,7 @@ def _get_authentication(self) -> Tuple["Credentials", str]: "trying to use the linked connector, but got " f"{type(credentials)}." ) - return credentials, connector.config.project_id + return credentials, connector.config.gcp_project_id if self.config.service_account_path: credentials, project_id = load_credentials_from_file( diff --git a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py index 87c0b1ef82c..4b592fcbbd6 100644 --- a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py +++ b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py @@ -351,24 +351,72 @@ class GCPOAuth2Token(AuthenticationConfig): class GCPBaseConfig(AuthenticationConfig): """GCP base configuration.""" + @property + def gcp_project_id(self) -> str: + """Get the GCP project ID. + + This method must be implemented by subclasses to ensure that the GCP + project ID is always available. + + Raises: + NotImplementedError: If the method is not implemented. + """ + raise NotImplementedError + + +class GCPBaseProjectIDConfig(GCPBaseConfig): + """GCP base configuration with included project ID.""" + project_id: str = Field( title="GCP Project ID where the target resource is located.", ) + @property + def gcp_project_id(self) -> str: + """Get the GCP project ID. + + Returns: + The GCP project ID. + """ + return self.project_id -class GCPUserAccountConfig(GCPBaseConfig, GCPUserAccountCredentials): + +class GCPUserAccountConfig(GCPBaseProjectIDConfig, GCPUserAccountCredentials): """GCP user account configuration.""" class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials): """GCP service account configuration.""" + _project_id: Optional[str] = None + + @property + def gcp_project_id(self) -> str: + """Get the GCP project ID. + + When a service account JSON is provided, the project ID can be extracted + from it instead of being provided explicitly. -class GCPExternalAccountConfig(GCPBaseConfig, GCPExternalAccountCredentials): + Returns: + The GCP project ID. + """ + if self._project_id is None: + self._project_id = json.loads( + self.service_account_json.get_secret_value() + )["project_id"] + # Guaranteed by the field validator + assert self._project_id is not None + + return self._project_id + + +class GCPExternalAccountConfig( + GCPBaseProjectIDConfig, GCPExternalAccountCredentials +): """GCP external account configuration.""" -class GCPOAuth2TokenConfig(GCPBaseConfig, GCPOAuth2Token): +class GCPOAuth2TokenConfig(GCPBaseProjectIDConfig, GCPOAuth2Token): """GCP OAuth 2.0 configuration.""" service_account_email: Optional[str] = Field( @@ -540,7 +588,7 @@ def _get_security_credentials( configured project has to be the same as the project of the attached service account. """, - config_class=GCPBaseConfig, + config_class=GCPBaseProjectIDConfig, ), AuthenticationMethodModel( name="GCP User Account", @@ -1006,6 +1054,7 @@ def _authenticate( # service account authentication) assert isinstance(cfg, GCPServiceAccountConfig) + credentials = ( gcp_service_account.Credentials.from_service_account_info( json.loads( @@ -1115,7 +1164,7 @@ def _parse_gcr_resource_id( # # We need to extract the project ID and registry ID from # the provided resource ID - config_project_id = self.config.project_id + config_project_id = self.config.gcp_project_id project_id: Optional[str] = None # A GCR repository URI uses one of several hostnames (gcr.io, us.gcr.io, # eu.gcr.io, asia.gcr.io etc.) and the project ID is the first part of @@ -1219,9 +1268,9 @@ def _get_default_resource_id(self, resource_type: str) -> str: authorized. """ if resource_type == GCP_RESOURCE_TYPE: - return self.config.project_id + return self.config.gcp_project_id elif resource_type == DOCKER_REGISTRY_RESOURCE_TYPE: - return f"gcr.io/{self.config.project_id}" + return f"gcr.io/{self.config.gcp_project_id}" raise RuntimeError( f"Default resource ID not supported for '{resource_type}' resource " @@ -1278,7 +1327,7 @@ def _connect_to_resource( # Create an GCS client for the bucket client = storage.Client( - project=self.config.project_id, credentials=credentials + project=self.config.gcp_project_id, credentials=credentials ) return client @@ -1384,7 +1433,7 @@ def _configure_local_client( "config", "set", "project", - self.config.project_id, + self.config.gcp_project_id, ], check=True, stderr=subprocess.STDOUT, @@ -1488,7 +1537,7 @@ def _auto_configure( ) if auth_method == GCPAuthenticationMethods.IMPLICIT: - auth_config = GCPBaseConfig( + auth_config = GCPBaseProjectIDConfig( project_id=project_id, ) elif auth_method == GCPAuthenticationMethods.OAUTH2_TOKEN: @@ -1697,7 +1746,7 @@ def _verify( if resource_type == GCS_RESOURCE_TYPE: gcs_client = storage.Client( - project=self.config.project_id, credentials=credentials + project=self.config.gcp_project_id, credentials=credentials ) if not resource_id: # List all GCS buckets @@ -1736,7 +1785,7 @@ def _verify( # List all GKE clusters try: clusters = gke_client.list_clusters( - parent=f"projects/{self.config.project_id}/locations/-" + parent=f"projects/{self.config.gcp_project_id}/locations/-" ) cluster_names = [cluster.name for cluster in clusters.clusters] except google.api_core.exceptions.GoogleAPIError as e: @@ -1810,7 +1859,7 @@ def _get_connector_client( # object auth_method: str = GCPAuthenticationMethods.OAUTH2_TOKEN config: GCPBaseConfig = GCPOAuth2TokenConfig( - project_id=self.config.project_id, + project_id=self.config.gcp_project_id, token=credentials.token, service_account_email=credentials.signer_email if hasattr(credentials, "signer_email") @@ -1884,7 +1933,7 @@ def _get_connector_client( # List all GKE clusters try: clusters = gke_client.list_clusters( - parent=f"projects/{self.config.project_id}/locations/-" + parent=f"projects/{self.config.gcp_project_id}/locations/-" ) cluster_map = { cluster.name: cluster for cluster in clusters.clusters @@ -1928,7 +1977,7 @@ def _get_connector_client( auth_method=KubernetesAuthenticationMethods.TOKEN, resource_type=resource_type, config=KubernetesTokenConfig( - cluster_name=f"gke_{self.config.project_id}_{cluster_name}", + cluster_name=f"gke_{self.config.gcp_project_id}_{cluster_name}", certificate_authority=cluster_ca_cert, server=f"https://{cluster_server}", token=bearer_token,