diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index f77239398e..5d0f08c2f5 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -70,6 +70,9 @@ class SnowflakeMaterializationEngineConfig(FeastConfigBaseModel): private_key: Optional[str] = None """ Snowflake private key file path""" + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + private_key_passphrase: Optional[str] = None """ Snowflake private key file passphrase""" diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 96552ff87e..ada6c99c98 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -107,6 +107,9 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): private_key: Optional[str] = None """ Snowflake private key file path""" + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + private_key_passphrase: Optional[str] = None """ Snowflake private key file passphrase""" diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py index fef804a377..6f39bdd0f6 100644 --- a/sdk/python/feast/infra/online_stores/snowflake.py +++ b/sdk/python/feast/infra/online_stores/snowflake.py @@ -53,6 +53,9 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): private_key: Optional[str] = None """ Snowflake private key file path""" + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + private_key_passphrase: Optional[str] = None """ Snowflake private key file passphrase""" diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index f2bc09e7e4..ac4f52dc06 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -96,6 +96,9 @@ class SnowflakeRegistryConfig(RegistryConfig): private_key: Optional[str] = None """ Snowflake private key file path""" + private_key_content: Optional[bytes] = None + """ Snowflake private key stored as bytes""" + private_key_passphrase: Optional[str] = None """ Snowflake private key file passphrase""" diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index dd965c4bed..b9035b40db 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -84,9 +84,11 @@ def __enter__(self): # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication - if "private_key" in kwargs: + if "private_key" in kwargs or "private_key_content" in kwargs: kwargs["private_key"] = parse_private_key_path( - kwargs["private_key"], kwargs["private_key_passphrase"] + kwargs.get("private_key_passphrase"), + kwargs.get("private_key"), + kwargs.get("private_key_content"), ) try: @@ -510,13 +512,27 @@ def chunk_helper(lst: pd.DataFrame, n: int) -> Iterator[Tuple[int, pd.DataFrame] yield int(i / n), lst[i : i + n] -def parse_private_key_path(key_path: str, private_key_passphrase: str) -> bytes: - with open(key_path, "rb") as key: +def parse_private_key_path( + private_key_passphrase: str, + key_path: Optional[str] = None, + private_key_content: Optional[bytes] = None, +) -> bytes: + """Returns snowflake pkb by parsing and reading either from key path or private_key_content as byte string.""" + if private_key_content: p_key = serialization.load_pem_private_key( - key.read(), + private_key_content, password=private_key_passphrase.encode(), backend=default_backend(), ) + elif key_path: + with open(key_path, "rb") as key: + p_key = serialization.load_pem_private_key( + key.read(), + password=private_key_passphrase.encode(), + backend=default_backend(), + ) + else: + raise ValueError("Please provide key_path or private_key_content.") pkb = p_key.private_bytes( encoding=serialization.Encoding.DER,