From ac0a945299e907359f4b5ce67c0c6f104fe7bccf Mon Sep 17 00:00:00 2001 From: Andres Torres Date: Thu, 18 Jul 2024 13:45:34 -0600 Subject: [PATCH] Add support for "Key-pair" authentication to Snowflake integration (#5079) Co-authored-by: Adam Sachs --- .github/workflows/backend_checks.yml | 2 + CHANGELOG.md | 1 + noxfiles/run_infrastructure.py | 2 + noxfiles/setup_tests_nox.py | 4 + .../connection_secrets_snowflake.py | 50 ++++++++++-- .../api/service/connectors/sql_connector.py | 30 ++++++- tests/fixtures/snowflake_fixtures.py | 78 ++++++++++++++++++- .../test_connection_config_endpoints.py | 2 + .../test_connection_template_endpoints.py | 15 +++- tests/ops/integration_test_config.toml | 2 + 10 files changed, 176 insertions(+), 10 deletions(-) diff --git a/.github/workflows/backend_checks.yml b/.github/workflows/backend_checks.yml index c5a5c63270..dc961c00c8 100644 --- a/.github/workflows/backend_checks.yml +++ b/.github/workflows/backend_checks.yml @@ -361,6 +361,8 @@ jobs: SNOWFLAKE_TEST_ACCOUNT_IDENTIFIER: ${{ secrets.SNOWFLAKE_TEST_ACCOUNT_IDENTIFIER }} SNOWFLAKE_TEST_DATABASE_NAME: ${{ secrets.SNOWFLAKE_TEST_DATABASE_NAME }} SNOWFLAKE_TEST_PASSWORD: ${{ secrets.SNOWFLAKE_TEST_PASSWORD }} + SNOWFLAKE_TEST_PRIVATE_KEY: ${{ secrets.SNOWFLAKE_TEST_PRIVATE_KEY }} + SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE: ${{ secrets.SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE }} SNOWFLAKE_TEST_SCHEMA_NAME: ${{ secrets.SNOWFLAKE_TEST_SCHEMA_NAME }} SNOWFLAKE_TEST_USER_LOGIN_NAME: ${{ secrets.SNOWFLAKE_TEST_USER_LOGIN_NAME }} SNOWFLAKE_TEST_WAREHOUSE_NAME: ${{ secrets.SNOWFLAKE_TEST_WAREHOUSE_NAME }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a077e995e..b5e94ac096 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The types of changes are: - Added erasure support for Alchemer integration [#4925](https://github.com/ethyca/fides/pull/4925) - Added new columns and action buttons to discovery monitors table [#5068](https://github.com/ethyca/fides/pull/5068) - Added field to exclude databases on MonitorConfig [#5080](https://github.com/ethyca/fides/pull/5080) +- Added key pair authentication for the Snowflake integration [#5079](https://github.com/ethyca/fides/pull/5079) ### Changed - Updated the sample dataset for the Amplitude integration [#5063](https://github.com/ethyca/fides/pull/5063) diff --git a/noxfiles/run_infrastructure.py b/noxfiles/run_infrastructure.py index 38411afed3..049fe87b51 100644 --- a/noxfiles/run_infrastructure.py +++ b/noxfiles/run_infrastructure.py @@ -26,6 +26,8 @@ "SNOWFLAKE_TEST_ACCOUNT_IDENTIFIER", "SNOWFLAKE_TEST_USER_LOGIN_NAME", "SNOWFLAKE_TEST_PASSWORD", + "SNOWFLAKE_TEST_PRIVATE_KEY", + "SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE", "SNOWFLAKE_TEST_WAREHOUSE_NAME", "SNOWFLAKE_TEST_DATABASE_NAME", "SNOWFLAKE_TEST_SCHEMA_NAME", diff --git a/noxfiles/setup_tests_nox.py b/noxfiles/setup_tests_nox.py index 63c1ea0a16..11b493d6fa 100644 --- a/noxfiles/setup_tests_nox.py +++ b/noxfiles/setup_tests_nox.py @@ -174,6 +174,10 @@ def pytest_ops( "-e", "SNOWFLAKE_TEST_PASSWORD", "-e", + "SNOWFLAKE_TEST_PRIVATE_KEY", + "-e", + "SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE", + "-e", "SNOWFLAKE_TEST_WAREHOUSE_NAME", "-e", "SNOWFLAKE_TEST_DATABASE_NAME", diff --git a/src/fides/api/schemas/connection_configuration/connection_secrets_snowflake.py b/src/fides/api/schemas/connection_configuration/connection_secrets_snowflake.py index c4de601237..0564e3bc8d 100644 --- a/src/fides/api/schemas/connection_configuration/connection_secrets_snowflake.py +++ b/src/fides/api/schemas/connection_configuration/connection_secrets_snowflake.py @@ -1,6 +1,6 @@ from typing import List, Optional -from pydantic import Field +from pydantic import Field, root_validator from fides.api.schemas.base_class import NoValidationSchema from fides.api.schemas.connection_configuration.connection_secrets import ( @@ -8,6 +8,14 @@ ) +def format_private_key(raw_key: str) -> str: + # Split the key into parts and remove spaces from the key body + parts = raw_key.split("-----") + body = parts[2].replace(" ", "\n") + # Reassemble the key + return f"-----{parts[1]}-----{body}-----{parts[3]}-----" + + class SnowflakeSchema(ConnectionConfigSecretsSchema): """Schema to validate the secrets needed to connect to Snowflake""" @@ -19,9 +27,22 @@ class SnowflakeSchema(ConnectionConfigSecretsSchema): title="Username", description="The user account used to authenticate and access the database.", ) - password: str = Field( + password: Optional[str] = Field( title="Password", - description="The password used to authenticate and access the database.", + description="The password used to authenticate and access the database. You can use a password or a private key, but not both.", + default=None, + sensitive=True, + ) + private_key: Optional[str] = Field( + title="Private key", + description="The private key used to authenticate and access the database. If a `private_key_passphrase` is also provided, it is assumed to be encrypted; otherwise, it is assumed to be unencrypted.", + default=None, + sensitive=True, + ) + private_key_passphrase: Optional[str] = Field( + title="Passphrase", + description="The passphrase used for the encrypted private key.", + default=None, sensitive=True, ) warehouse_name: str = Field( @@ -37,20 +58,39 @@ class SnowflakeSchema(ConnectionConfigSecretsSchema): description="The name of the Snowflake schema within the selected database.", ) role_name: Optional[str] = Field( - None, title="Role", + default=None, description="The Snowflake role to assume for the session, if different than Username.", ) _required_components: List[str] = [ "account_identifier", "user_login_name", - "password", "warehouse_name", "database_name", "schema_name", ] + @root_validator() + def validate_private_key_and_password(cls, values: dict) -> dict: + private_key: str = values.get("private_key", "") + + if values.get("password") and private_key: + raise ValueError( + "Cannot provide both password and private key at the same time." + ) + + if not any([values.get("password"), private_key]): + raise ValueError("Must provide either a password or a private key.") + + if private_key: + try: + values["private_key"] = format_private_key(private_key) + except IndexError: + raise ValueError("Invalid private key format") + + return values + class SnowflakeDocsSchema(SnowflakeSchema, NoValidationSchema): """Snowflake Secrets Schema for API Docs""" diff --git a/src/fides/api/service/connectors/sql_connector.py b/src/fides/api/service/connectors/sql_connector.py index add91eb854..20abfa93af 100644 --- a/src/fides/api/service/connectors/sql_connector.py +++ b/src/fides/api/service/connectors/sql_connector.py @@ -1,6 +1,6 @@ import io from abc import abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union from urllib.parse import quote_plus import paramiko @@ -8,6 +8,8 @@ import pymysql import sshtunnel # type: ignore from aiohttp.client_exceptions import ClientResponseError +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization from google.cloud.sql.connector import Connector from google.oauth2 import service_account from loguru import logger @@ -202,8 +204,13 @@ def create_client(self) -> Engine: uri, hide_parameters=self.hide_parameters, echo=not self.hide_parameters, + connect_args=self.get_connect_args(), ) + def get_connect_args(self) -> Dict[str, Any]: + """Get connection arguments for the engine""" + return {} + def set_schema(self, connection: Connection) -> None: """Optionally override to set the schema for a given database that persists through the entire session""" @@ -557,6 +564,27 @@ def build_uri(self) -> str: url: str = Snowflake_URL(**kwargs) return url + def get_connect_args(self) -> Dict[str, Any]: + """Get connection arguments for the engine""" + config = self.secrets_schema(**self.configuration.secrets or {}) + connect_args: Dict[str, Union[str, bytes]] = {} + if config.private_key: + config.private_key = config.private_key.replace("\\n", "\n") + connect_args["private_key"] = config.private_key + if config.private_key_passphrase: + private_key_encoded = serialization.load_pem_private_key( + config.private_key.encode(), + password=config.private_key_passphrase.encode(), + backend=default_backend(), + ) + private_key = private_key_encoded.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + connect_args["private_key"] = private_key + return connect_args + def query_config(self, node: ExecutionNode) -> SQLQueryConfig: """Query wrapper corresponding to the input execution_node.""" return SnowflakeQueryConfig(node) diff --git a/tests/fixtures/snowflake_fixtures.py b/tests/fixtures/snowflake_fixtures.py index e18ca768d7..49829671d6 100644 --- a/tests/fixtures/snowflake_fixtures.py +++ b/tests/fixtures/snowflake_fixtures.py @@ -94,12 +94,86 @@ def snowflake_connection_config( connection_config.delete(db) -@pytest.fixture +@pytest.fixture(scope="function") +def snowflake_connection_config_with_keypair( + db: Session, + integration_config: Dict[str, str], + snowflake_connection_config_without_secrets: ConnectionConfig, +) -> Generator: + """ + Returns a Snowflake ConectionConfig with secrets attached if secrets are present + in the configuration. + """ + connection_config = snowflake_connection_config_without_secrets + + account_identifier = integration_config.get("snowflake", {}).get( + "account_identifier" + ) or os.environ.get("SNOWFLAKE_TEST_ACCOUNT_IDENTIFIER") + user_login_name = integration_config.get("snowflake", {}).get( + "user_login_name" + ) or os.environ.get("SNOWFLAKE_TEST_USER_LOGIN_NAME") + private_key = integration_config.get("snowflake", {}).get( + "private_key" + ) or os.environ.get("SNOWFLAKE_TEST_PRIVATE_KEY") + private_key_passphrase = integration_config.get("snowflake", {}).get( + "private_key_passphrase" + ) or os.environ.get("SNOWFLAKE_TEST_PRIVATE_KEY_PASSPHRASE") + warehouse_name = integration_config.get("snowflake", {}).get( + "warehouse_name" + ) or os.environ.get("SNOWFLAKE_TEST_WAREHOUSE_NAME") + database_name = integration_config.get("snowflake", {}).get( + "database_name" + ) or os.environ.get("SNOWFLAKE_TEST_DATABASE_NAME") + schema_name = integration_config.get("snowflake", {}).get( + "schema_name" + ) or os.environ.get("SNOWFLAKE_TEST_SCHEMA_NAME") + + if all( + [ + account_identifier, + user_login_name, + private_key, + private_key_passphrase, + warehouse_name, + database_name, + schema_name, + ] + ): + schema = SnowflakeSchema( + account_identifier=account_identifier, + user_login_name=user_login_name, + private_key=private_key, + private_key_passphrase=private_key_passphrase, + warehouse_name=warehouse_name, + database_name=database_name, + schema_name=schema_name, + ) + connection_config.secrets = schema.dict() + connection_config.save(db=db) + + yield connection_config + connection_config.delete(db) + + +@pytest.fixture( + params=[ + "snowflake_connection_config", + "snowflake_connection_config_with_keypair", + ] +) def snowflake_example_test_dataset_config( + request, snowflake_connection_config: ConnectionConfig, + snowflake_connection_config_with_keypair: ConnectionConfig, db: Session, example_datasets: List[Dict], ) -> Generator: + + if request.param == "snowflake_connection_config": + config: ConnectionConfig = snowflake_connection_config + elif request.param == "snowflake_connection_config_with_keypair": + config: ConnectionConfig = snowflake_connection_config_with_keypair + dataset = example_datasets[2] fides_key = dataset["fides_key"] @@ -108,7 +182,7 @@ def snowflake_example_test_dataset_config( dataset_config = DatasetConfig.create( db=db, data={ - "connection_config_id": snowflake_connection_config.id, + "connection_config_id": config.id, "fides_key": fides_key, "ctl_dataset_id": ctl_dataset.id, }, diff --git a/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py b/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py index d84a7c6a4d..3db7f58d0e 100644 --- a/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_connection_config_endpoints.py @@ -1683,6 +1683,8 @@ def test_put_connection_config_snowflake_secrets( "password": "test_password", "account_identifier": "flso2222test", "database_name": "test", + "private_key": None, + "private_key_passphrase": None, "schema_name": "schema", "warehouse_name": "warehouse", "role_name": None, diff --git a/tests/ops/api/v1/endpoints/test_connection_template_endpoints.py b/tests/ops/api/v1/endpoints/test_connection_template_endpoints.py index 3869a1eb7b..ba80e43a5b 100644 --- a/tests/ops/api/v1/endpoints/test_connection_template_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_connection_template_endpoints.py @@ -1203,8 +1203,20 @@ def test_get_connection_secret_schema_snowflake( }, "password": { "title": "Password", - "description": "The password used to authenticate and access the database.", + "description": "The password used to authenticate and access the database. You can use a password or a private key, but not both.", + "sensitive": True, + "type": "string", + }, + "private_key": { + "description": "The private key used to authenticate and access the database. If a `private_key_passphrase` is also provided, it is assumed to be encrypted; otherwise, it is assumed to be unencrypted.", + "sensitive": True, + "title": "Private key", + "type": "string", + }, + "private_key_passphrase": { + "description": "The passphrase used for the encrypted private key.", "sensitive": True, + "title": "Passphrase", "type": "string", }, "warehouse_name": { @@ -1231,7 +1243,6 @@ def test_get_connection_secret_schema_snowflake( "required": [ "account_identifier", "user_login_name", - "password", "warehouse_name", "database_name", "schema_name", diff --git a/tests/ops/integration_test_config.toml b/tests/ops/integration_test_config.toml index d92c94dc57..e934714a0e 100644 --- a/tests/ops/integration_test_config.toml +++ b/tests/ops/integration_test_config.toml @@ -46,6 +46,8 @@ db_schema="" account_identifier="" user_login_name="" password="" +private_key="" +private_key_passphrase="" warehouse_name="" database_name="" schema_name=""