From 5d0cf951795cfefc790f908bf92ee0663e0bcf2e Mon Sep 17 00:00:00 2001 From: Miles Adkins Date: Thu, 18 Aug 2022 15:59:02 -0500 Subject: [PATCH] feat: Add Snowflake materialization engine Signed-off-by: Miles Adkins --- .../workflows/pr_local_integration_tests.yml | 2 +- .../reference/batch-materialization/README.md | 2 + .../batch-materialization/snowflake.md | 28 + docs/reference/offline-stores/snowflake.md | 2 +- .../driver-stats-on-snowflake.md | 12 +- .../source/feast.infra.materialization.rst | 8 + sdk/python/docs/source/feast.infra.utils.rst | 8 - .../infra/materialization/snowflake_engine.py | 492 ++++++++++++++++++ .../infra/offline_stores/offline_store.py | 6 + .../feast/infra/offline_stores/snowflake.py | 29 +- .../infra/offline_stores/snowflake_source.py | 110 +++- .../feast/infra/online_stores/snowflake.py | 45 +- .../utils/{ => snowflake}/snowflake_utils.py | 78 ++- .../utils/snowflake/snowpark/__init__.py | 0 .../snowflake_python_udfs_creation.sql | 71 +++ .../snowflake_python_udfs_deletion.sql | 17 + .../snowflake/snowpark/snowflake_udfs.py | 261 ++++++++++ sdk/python/feast/repo_config.py | 3 + .../feast/templates/snowflake/bootstrap.py | 4 +- .../feast/templates/snowflake/driver_repo.py | 58 +++ .../templates/snowflake/feature_store.yaml | 28 + sdk/python/feast/type_map.py | 85 ++- sdk/python/setup.cfg | 2 +- .../universal/data_sources/snowflake.py | 2 +- .../materialization/test_snowflake.py | 125 +++++ 25 files changed, 1376 insertions(+), 102 deletions(-) create mode 100644 docs/reference/batch-materialization/snowflake.md create mode 100644 sdk/python/feast/infra/materialization/snowflake_engine.py rename sdk/python/feast/infra/utils/{ => snowflake}/snowflake_utils.py (90%) create mode 100644 sdk/python/feast/infra/utils/snowflake/snowpark/__init__.py create mode 100644 sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_creation.sql create mode 100644 sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql create mode 100644 sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_udfs.py create mode 100644 sdk/python/feast/templates/snowflake/driver_repo.py create mode 100644 sdk/python/feast/templates/snowflake/feature_store.yaml create mode 100644 sdk/python/tests/integration/materialization/test_snowflake.py diff --git a/.github/workflows/pr_local_integration_tests.yml b/.github/workflows/pr_local_integration_tests.yml index d697f8be32..4705771911 100644 --- a/.github/workflows/pr_local_integration_tests.yml +++ b/.github/workflows/pr_local_integration_tests.yml @@ -75,4 +75,4 @@ jobs: IS_TEST: "True" FEAST_LOCAL_ONLINE_CONTAINER: "True" FEAST_IS_LOCAL_TEST: "True" - run: pytest -n 8 --cov=./ --cov-report=xml --color=yes --integration -k "not gcs_registry and not s3_registry and not test_lambda_materialization" sdk/python/tests + run: pytest -n 8 --cov=./ --cov-report=xml --color=yes --integration -k "not gcs_registry and not s3_registry and not test_lambda_materialization and not test_snowflake_materialization" sdk/python/tests diff --git a/docs/reference/batch-materialization/README.md b/docs/reference/batch-materialization/README.md index 6e8fd60611..50640bce49 100644 --- a/docs/reference/batch-materialization/README.md +++ b/docs/reference/batch-materialization/README.md @@ -2,4 +2,6 @@ Please see [Batch Materialization Engine](../../getting-started/architecture-and-components/batch-materialization-engine.md) for an explanation of batch materialization engines. +{% page-ref page="snowflake.md" %} + {% page-ref page="bytewax.md" %} diff --git a/docs/reference/batch-materialization/snowflake.md b/docs/reference/batch-materialization/snowflake.md new file mode 100644 index 0000000000..c2fa441d6d --- /dev/null +++ b/docs/reference/batch-materialization/snowflake.md @@ -0,0 +1,28 @@ +# Snowflake + +## Description + +The [Snowflake](https://trial.snowflake.com) batch materialization engine provides a highly scalable and parallel execution engine using a Snowflake Warehouse for batch materializations operations (`materialize` and `materialize-incremental`) when using a `SnowflakeSource`. + +The engine requires no additional configuration other than for you to supply Snowflake's standard login and context details. The engine leverages custom (automatically deployed for you) Python UDFs to do the proper serialization of your offline store data to your online serving tables. + +When using all three options together, `snowflake.offline`, `snowflake.engine`, and `snowflake.online`, you get the most unique experience of unlimited scale and performance + governance and data security. + +## Example + +{% code title="feature_store.yaml" %} +```yaml +... +offline_store: + type: snowflake.offline +... +batch_engine: + type: snowflake.engine + account: snowflake_deployment.us-east-1 + user: user_login + password: user_password + role: sysadmin + warehouse: demo_wh + database: FEAST +``` +{% endcode %} diff --git a/docs/reference/offline-stores/snowflake.md b/docs/reference/offline-stores/snowflake.md index b3b58fe786..e40ad7cd7a 100644 --- a/docs/reference/offline-stores/snowflake.md +++ b/docs/reference/offline-stores/snowflake.md @@ -46,7 +46,7 @@ Below is a matrix indicating which functionality is supported by `SnowflakeRetri | export to dataframe | yes | | export to arrow table | yes | | export to arrow batches | no | -| export to SQL | no | +| export to SQL | yes | | export to data lake (S3, GCS, etc.) | yes | | export to data warehouse | yes | | export as Spark dataframe | no | diff --git a/docs/tutorials/tutorials-overview/driver-stats-on-snowflake.md b/docs/tutorials/tutorials-overview/driver-stats-on-snowflake.md index 306ae2f59b..a425248b76 100644 --- a/docs/tutorials/tutorials-overview/driver-stats-on-snowflake.md +++ b/docs/tutorials/tutorials-overview/driver-stats-on-snowflake.md @@ -6,7 +6,7 @@ description: >- # Drivers stats on Snowflake In the steps below, we will set up a sample Feast project that leverages Snowflake -as an offline store. +as an offline store + materialization engine + online store. Starting with data in a Snowflake table, we will register that table to the feature store and define features associated with the columns in that table. From there, we will generate historical training data based on those feature definitions and then materialize the latest feature values into the online store. Lastly, we will retrieve the materialized feature values. @@ -46,7 +46,7 @@ The following files will automatically be created in your project folder: #### Inspect `feature_store.yaml` -Here you will see the information that you entered. This template will use Snowflake as an offline store and SQLite as the online store. The main thing to remember is by default, Snowflake objects have ALL CAPS names unless lower case was specified. +Here you will see the information that you entered. This template will use Snowflake as the offline store, materialization engine, and the online store. The main thing to remember is by default, Snowflake objects have ALL CAPS names unless lower case was specified. {% code title="feature_store.yaml" %} ```yaml @@ -61,6 +61,14 @@ offline_store: role: ROLE_NAME #case sensitive warehouse: WAREHOUSE_NAME #case sensitive database: DATABASE_NAME #case cap sensitive +batch_engine: + type: snowflake.engine + account: SNOWFLAKE_DEPLOYMENT_URL #drop .snowflakecomputing.com + user: USERNAME + password: PASSWORD + role: ROLE_NAME #case sensitive + warehouse: WAREHOUSE_NAME #case sensitive + database: DATABASE_NAME #case cap sensitive online_store: type: snowflake.online account: SNOWFLAKE_DEPLOYMENT_URL #drop .snowflakecomputing.com diff --git a/sdk/python/docs/source/feast.infra.materialization.rst b/sdk/python/docs/source/feast.infra.materialization.rst index ff3e1cf135..6e526c367c 100644 --- a/sdk/python/docs/source/feast.infra.materialization.rst +++ b/sdk/python/docs/source/feast.infra.materialization.rst @@ -28,6 +28,14 @@ feast.infra.materialization.local\_engine module :undoc-members: :show-inheritance: +feast.infra.materialization.snowflake\_engine module +---------------------------------------------------- + +.. automodule:: feast.infra.materialization.snowflake_engine + :members: + :undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/sdk/python/docs/source/feast.infra.utils.rst b/sdk/python/docs/source/feast.infra.utils.rst index ffada49797..e4116e7a17 100644 --- a/sdk/python/docs/source/feast.infra.utils.rst +++ b/sdk/python/docs/source/feast.infra.utils.rst @@ -28,14 +28,6 @@ feast.infra.utils.hbase\_utils module :undoc-members: :show-inheritance: -feast.infra.utils.snowflake\_utils module ------------------------------------------ - -.. automodule:: feast.infra.utils.snowflake_utils - :members: - :undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py new file mode 100644 index 0000000000..91840bf7e1 --- /dev/null +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -0,0 +1,492 @@ +import os +import shutil +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Callable, List, Literal, Optional, Sequence, Union + +import click +import pandas as pd +from colorama import Fore, Style +from pydantic import Field, StrictStr +from tqdm import tqdm + +import feast +from feast.batch_feature_view import BatchFeatureView +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.infra.materialization.batch_materialization_engine import ( + BatchMaterializationEngine, + MaterializationJob, + MaterializationJobStatus, + MaterializationTask, +) +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.online_stores.online_store import OnlineStore +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.utils.snowflake.snowflake_utils import ( + _run_snowflake_field_mapping, + assert_snowflake_feature_names, + get_snowflake_conn, + get_snowflake_materialization_config, + package_snowpark_zip, +) +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.repo_config import FeastConfigBaseModel, RepoConfig +from feast.stream_feature_view import StreamFeatureView +from feast.type_map import _convert_value_name_to_snowflake_udf +from feast.utils import _coerce_datetime, _get_column_names + + +class SnowflakeMaterializationEngineConfig(FeastConfigBaseModel): + """Batch Materialization Engine config for Snowflake Snowpark Python UDFs""" + + type: Literal["snowflake.engine"] = "snowflake.engine" + """ Type selector""" + + config_path: Optional[str] = ( + Path(os.environ["HOME"]) / ".snowsql/config" + ).__str__() + """ Snowflake config path -- absolute path required (Cant use ~)""" + + account: Optional[str] = None + """ Snowflake deployment identifier -- drop .snowflakecomputing.com""" + + user: Optional[str] = None + """ Snowflake user name """ + + password: Optional[str] = None + """ Snowflake password """ + + role: Optional[str] = None + """ Snowflake role name""" + + warehouse: Optional[str] = None + """ Snowflake warehouse name """ + + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + database: StrictStr + """ Snowflake database name """ + + schema_: Optional[str] = Field("PUBLIC", alias="schema") + """ Snowflake schema name """ + + class Config: + allow_population_by_field_name = True + + +@dataclass +class SnowflakeMaterializationJob(MaterializationJob): + def __init__( + self, + job_id: str, + status: MaterializationJobStatus, + error: Optional[BaseException] = None, + ) -> None: + super().__init__() + self._job_id: str = job_id + self._status: MaterializationJobStatus = status + self._error: Optional[BaseException] = error + + def status(self) -> MaterializationJobStatus: + return self._status + + def error(self) -> Optional[BaseException]: + return self._error + + def should_be_retried(self) -> bool: + return False + + def job_id(self) -> str: + return self._job_id + + def url(self) -> Optional[str]: + return None + + +class SnowflakeMaterializationEngine(BatchMaterializationEngine): + def update( + self, + project: str, + views_to_delete: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + views_to_keep: Sequence[ + Union[BatchFeatureView, StreamFeatureView, FeatureView] + ], + entities_to_delete: Sequence[Entity], + entities_to_keep: Sequence[Entity], + ): + click.echo( + f"Deploying materialization functions for {Style.BRIGHT + Fore.GREEN}{project}{Style.RESET_ALL}" + ) + click.echo() + + conn_config = get_snowflake_materialization_config(self.repo_config) + + stage_context = f'"{conn_config.database}"."{conn_config.schema_}"' + stage_path = f'{stage_context}."feast_{project}"' + with get_snowflake_conn(conn_config) as conn: + cur = conn.cursor() + + # if the stage already exists, + # assumes that the materialization functions have been deployed + cur.execute(f"SHOW STAGES IN {stage_context}") + stage_list = pd.DataFrame( + cur.fetchall(), columns=[column.name for column in cur.description] + ) + if f"feast_{project}" in stage_list["name"].tolist(): + click.echo( + f"Materialization functions for {Style.BRIGHT + Fore.GREEN}{project}{Style.RESET_ALL} already exists" + ) + click.echo() + return None + + cur.execute(f"CREATE STAGE {stage_path}") + + copy_path, zip_path = package_snowpark_zip(project) + cur.execute(f"PUT file://{zip_path} @{stage_path}") + shutil.rmtree(copy_path) + + # Execute snowflake python udf creation functions + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/snowpark/snowflake_python_udfs_creation.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + command = command.replace("STAGE_HOLDER", f"{stage_path}") + command = command.replace("PROJECT_NAME", f"{project}") + cur.execute(command) + + return None + + def teardown_infra( + self, + project: str, + fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]], + entities: Sequence[Entity], + ): + conn_config = get_snowflake_materialization_config(self.repo_config) + + stage_path = ( + f'"{conn_config.database}"."{conn_config.schema_}"."feast_{project}"' + ) + with get_snowflake_conn(conn_config) as conn: + cur = conn.cursor() + + cur.execute(f"DROP STAGE IF EXISTS {stage_path}") + + # Execute snowflake python udf deletion functions + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + command = command.replace("PROJECT_NAME", f"{project}") + cur.execute(command) + + return None + + def __init__( + self, + *, + repo_config: RepoConfig, + offline_store: OfflineStore, + online_store: OnlineStore, + **kwargs, + ): + assert ( + repo_config.offline_store.type == "snowflake.offline" + ), "To use SnowflakeMaterializationEngine, you must use Snowflake as an offline store." + + super().__init__( + repo_config=repo_config, + offline_store=offline_store, + online_store=online_store, + **kwargs, + ) + + def materialize( + self, registry, tasks: List[MaterializationTask] + ) -> List[MaterializationJob]: + return [ + self._materialize_one( + registry, + task.feature_view, + task.start_time, + task.end_time, + task.project, + task.tqdm_builder, + ) + for task in tasks + ] + + def _materialize_one( + self, + registry: BaseRegistry, + feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView], + start_date: datetime, + end_date: datetime, + project: str, + tqdm_builder: Callable[[int], tqdm], + ): + assert isinstance(feature_view, BatchFeatureView) or isinstance( + feature_view, FeatureView + ), "Snowflake can only materialize FeatureView & BatchFeatureView feature view types." + + repo_config = self.repo_config + + entities = [] + for entity_name in feature_view.entities: + entities.append(registry.get_entity(entity_name, project)) + + ( + join_key_columns, + feature_name_columns, + timestamp_field, + created_timestamp_column, + ) = _get_column_names(feature_view, entities) + + job_id = f"{feature_view.name}-{start_date}-{end_date}" + + try: + offline_job = self.offline_store.pull_latest_from_table_or_query( + config=repo_config, + data_source=feature_view.batch_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + ) + + fv_latest_values_sql = offline_job.to_sql() + + if feature_view.batch_source.field_mapping is not None: + fv_latest_mapped_values_sql = _run_snowflake_field_mapping( + fv_latest_values_sql, feature_view.batch_source.field_mapping + ) + + fv_to_proto_sql = self.generate_snowflake_materialization_query( + repo_config, + fv_latest_mapped_values_sql, + feature_view, + project, + ) + + if repo_config.online_store.type == "snowflake.online": + self.materialize_to_snowflake_online_store( + repo_config, + fv_to_proto_sql, + feature_view, + project, + ) + else: + self.materialize_to_external_online_store( + repo_config, + fv_to_proto_sql, + feature_view, + tqdm_builder, + ) + + return SnowflakeMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.SUCCEEDED + ) + except BaseException as e: + return SnowflakeMaterializationJob( + job_id=job_id, status=MaterializationJobStatus.ERROR, error=e + ) + + def generate_snowflake_materialization_query( + self, + repo_config: RepoConfig, + fv_latest_mapped_values_sql: str, + feature_view: Union[BatchFeatureView, FeatureView], + project: str, + ) -> str: + + if feature_view.batch_source.created_timestamp_column: + fv_created_str = f',"{feature_view.batch_source.created_timestamp_column}"' + else: + fv_created_str = None + + join_keys = [entity.name for entity in feature_view.entity_columns] + join_keys_type = [ + entity.dtype.to_value_type().name for entity in feature_view.entity_columns + ] + + entity_names = "ARRAY_CONSTRUCT('" + "', '".join(join_keys) + "')" + entity_data = 'ARRAY_CONSTRUCT("' + '", "'.join(join_keys) + '")' + entity_types = "ARRAY_CONSTRUCT('" + "', '".join(join_keys_type) + "')" + + """ + Generate the SQL that maps the feature given ValueType to the correct python + UDF serialization function. + """ + feature_sql_list = [] + for feature in feature_view.features: + feature_value_type_name = feature.dtype.to_value_type().name + + feature_sql = _convert_value_name_to_snowflake_udf( + feature_value_type_name, project + ) + + if feature_value_type_name == "UNIX_TIMESTAMP": + feature_sql = f'{feature_sql}(DATE_PART(EPOCH_NANOSECOND, "{feature.name}")) AS "{feature.name}"' + else: + feature_sql = f'{feature_sql}("{feature.name}") AS "{feature.name}"' + + feature_sql_list.append(feature_sql) + + features_str = ",\n".join(feature_sql_list) + + if repo_config.online_store.type == "snowflake.online": + serial_func = f"feast_{project}_serialize_entity_keys" + else: + serial_func = f"feast_{project}_entity_key_proto_to_string" + + fv_to_proto_sql = f""" + SELECT + {serial_func.upper()}({entity_names}, {entity_data}, {entity_types}) AS "entity_key", + {features_str}, + "{feature_view.batch_source.timestamp_field}" + {fv_created_str if fv_created_str else ''} + FROM ( + {fv_latest_mapped_values_sql} + ) + """ + + return fv_to_proto_sql + + def materialize_to_snowflake_online_store( + self, + repo_config: RepoConfig, + materialization_sql: str, + feature_view: Union[BatchFeatureView, FeatureView], + project: str, + ) -> None: + assert_snowflake_feature_names(feature_view) + + conn_config = get_snowflake_materialization_config(repo_config) + + online_table = f"""{repo_config.online_store.database}"."{repo_config.online_store.schema_}"."[online-transient] {project}_{feature_view.name}""" + + feature_names_str = '", "'.join( + [feature.name for feature in feature_view.features] + ) + + if feature_view.batch_source.created_timestamp_column: + fv_created_str = f',"{feature_view.batch_source.created_timestamp_column}"' + else: + fv_created_str = None + + fv_to_online = f""" + MERGE INTO "{online_table}" online_table + USING ( + SELECT + "entity_key" || TO_BINARY("feature_name", 'UTF-8') AS "entity_feature_key", + "entity_key", + "feature_name", + "feature_value" AS "value", + "{feature_view.batch_source.timestamp_field}" AS "event_ts" + {fv_created_str + ' AS "created_ts"' if fv_created_str else ''} + FROM ( + {materialization_sql} + ) + UNPIVOT("feature_value" FOR "feature_name" IN ("{feature_names_str}")) + ) AS latest_values ON online_table."entity_feature_key" = latest_values."entity_feature_key" + WHEN MATCHED THEN + UPDATE SET + online_table."entity_key" = latest_values."entity_key", + online_table."feature_name" = latest_values."feature_name", + online_table."value" = latest_values."value", + online_table."event_ts" = latest_values."event_ts" + {',online_table."created_ts" = latest_values."created_ts"' if fv_created_str else ''} + WHEN NOT MATCHED THEN + INSERT ("entity_feature_key", "entity_key", "feature_name", "value", "event_ts" {', "created_ts"' if fv_created_str else ''}) + VALUES ( + latest_values."entity_feature_key", + latest_values."entity_key", + latest_values."feature_name", + latest_values."value", + latest_values."event_ts" + {',latest_values."created_ts"' if fv_created_str else ''} + ) + """ + + with get_snowflake_conn(conn_config) as conn: + cur = conn.cursor() + cur.execute(fv_to_online) + + query_id = cur.sfqid + click.echo( + f"Snowflake Query ID: {Style.BRIGHT + Fore.GREEN}{query_id}{Style.RESET_ALL}" + ) + return None + + def materialize_to_external_online_store( + self, + repo_config: RepoConfig, + materialization_sql: str, + feature_view: Union[StreamFeatureView, FeatureView], + tqdm_builder: Callable[[int], tqdm], + ) -> None: + conn_config = get_snowflake_materialization_config(repo_config) + + feature_names = [feature.name for feature in feature_view.features] + + with get_snowflake_conn(conn_config) as conn: + cur = conn.cursor() + cur.execute(materialization_sql) + for i, df in enumerate(cur.fetch_pandas_batches()): + click.echo(f"Snowflake: Processing ResultSet Batch #{i+1}") + + entity_keys = ( + df["entity_key"].apply(EntityKeyProto.FromString).to_numpy() + ) + + for feature in feature_names: + df[feature] = df[feature].apply(ValueProto.FromString) + + features = df[feature_names].to_dict("records") + + event_timestamps = [ + _coerce_datetime(val) + for val in pd.to_datetime( + df[feature_view.batch_source.timestamp_field] + ) + ] + + if feature_view.batch_source.created_timestamp_column: + created_timestamps = [ + _coerce_datetime(val) + for val in pd.to_datetime( + df[feature_view.batch_source.created_timestamp_column] + ) + ] + else: + created_timestamps = [None] * df.shape[0] + + rows_to_write = list( + zip( + entity_keys, + features, + event_timestamps, + created_timestamps, + ) + ) + + with tqdm_builder(len(rows_to_write)) as pbar: + self.online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: pbar.update(x), + ) + return None diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 741b97e2fd..8eb391b941 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -140,6 +140,12 @@ def to_arrow( return pyarrow.Table.from_pandas(features_df) + def to_sql(self) -> str: + """ + Return RetrievalJob generated SQL statement if applicable. + """ + pass + @abstractmethod def _to_df_internal(self) -> pd.DataFrame: """ diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 98db97b179..8fde09b9e1 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -20,7 +20,7 @@ import pandas as pd import pyarrow import pyarrow as pa -from pydantic import Field +from pydantic import Field, StrictStr from pydantic.typing import Literal from pytz import utc @@ -41,7 +41,7 @@ SnowflakeSource, ) from feast.infra.registry.base_registry import BaseRegistry -from feast.infra.utils.snowflake_utils import ( +from feast.infra.utils.snowflake.snowflake_utils import ( execute_snowflake_statement, get_snowflake_conn, write_pandas, @@ -85,15 +85,15 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel): warehouse: Optional[str] = None """ Snowflake warehouse name """ - database: Optional[str] = None + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + database: StrictStr """ Snowflake database name """ - schema_: Optional[str] = Field(None, alias="schema") + schema_: Optional[str] = Field("PUBLIC", alias="schema") """ Snowflake schema name """ - authenticator: Optional[str] = None - """ Snowflake authenticator name """ - storage_integration_name: Optional[str] = None """ Storage integration name in snowflake """ @@ -120,9 +120,9 @@ def pull_latest_from_table_or_query( assert isinstance(data_source, SnowflakeSource) assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig) - from_expression = ( - data_source.get_table_query_string() - ) # returns schema.table as a string + from_expression = data_source.get_table_query_string() + if not data_source.database and data_source.table: + from_expression = f'"{config.offline_store.database}"."{config.offline_store.schema_}".{from_expression}' if join_key_columns: partition_by_join_key_string = '"' + '", "'.join(join_key_columns) + '"' @@ -148,6 +148,9 @@ def pull_latest_from_table_or_query( snowflake_conn = get_snowflake_conn(config.offline_store) + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + query = f""" SELECT {field_string} @@ -156,7 +159,7 @@ def pull_latest_from_table_or_query( SELECT {field_string}, ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS "_feast_row" FROM {from_expression} - WHERE "{timestamp_field}" BETWEEN TO_TIMESTAMP_NTZ({start_date.timestamp()}) AND TO_TIMESTAMP_NTZ({end_date.timestamp()}) + WHERE "{timestamp_field}" BETWEEN TIMESTAMP '{start_date}' AND TIMESTAMP '{end_date}' ) WHERE "_feast_row" = 1 """ @@ -181,7 +184,10 @@ def pull_all_from_table_or_query( end_date: datetime, ) -> RetrievalJob: assert isinstance(data_source, SnowflakeSource) + from_expression = data_source.get_table_query_string() + if not data_source.database and data_source.table: + from_expression = f'"{config.offline_store.database}"."{config.offline_store.schema_}".{from_expression}' field_string = ( '"' @@ -533,6 +539,7 @@ def _upload_entity_df( if isinstance(entity_df, pd.DataFrame): # Write the data from the DataFrame to the table + # Known issues with following entity data types: BINARY write_pandas( snowflake_conn, entity_df, diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index df0aef2ade..b19f35eab8 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -4,7 +4,7 @@ from feast import type_map from feast.data_source import DataSource -from feast.errors import DataSourceNoNameException +from feast.errors import DataSourceNoNameException, DataSourceNotFoundException from feast.feature_logging import LoggingDestination from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.FeatureService_pb2 import ( @@ -59,6 +59,8 @@ def __init__( """ if table is None and query is None: raise ValueError('No "table" or "query" argument provided.') + if table and query: + raise ValueError('Both "table" and "query" argument provided.') # The default Snowflake schema is named "PUBLIC". _schema = "PUBLIC" if (database and table and not schema) else schema @@ -195,7 +197,7 @@ def get_table_query_string(self) -> str: @staticmethod def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: - return type_map.snowflake_python_type_to_feast_value_type + return type_map.snowflake_type_to_feast_value_type def get_table_column_names_and_types( self, config: RepoConfig @@ -208,7 +210,7 @@ def get_table_column_names_and_types( """ from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig - from feast.infra.utils.snowflake_utils import ( + from feast.infra.utils.snowflake.snowflake_utils import ( execute_snowflake_statement, get_snowflake_conn, ) @@ -217,20 +219,100 @@ def get_table_column_names_and_types( snowflake_conn = get_snowflake_conn(config.offline_store) - if self.database and self.table: - query = f'SELECT * FROM "{self.database}"."{self.schema}"."{self.table}" LIMIT 1' - elif self.table: - query = f'SELECT * FROM "{self.table}" LIMIT 1' + query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5" + + result_cur = execute_snowflake_statement(snowflake_conn, query) + + metadata = [ + { + "column_name": column.name, + "type_code": column.type_code, + "precision": column.precision, + "scale": column.scale, + "is_nullable": column.is_nullable, + "snowflake_type": None, + } + for column in result_cur.description + ] + + for row in metadata: + if row["type_code"] == 0: + if row["scale"] == 0: + if row["precision"] <= 9: # max precision size to ensure INT32 + row["snowflake_type"] = "NUMBER32" + elif row["precision"] <= 18: # max precision size to ensure INT64 + row["snowflake_type"] = "NUMBER64" + else: + column = row["column_name"] + query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}' + + result = execute_snowflake_statement( + snowflake_conn, query + ).fetch_pandas_all() + + if ( + result.dtypes[column].name + in python_int_to_snowflake_type_map + ): + row["snowflake_type"] = python_int_to_snowflake_type_map[ + result.dtypes[column].name + ] + else: + raise NotImplementedError( + "Numbers larger than INT64 are not supported" + ) + else: + raise NotImplementedError( + "The following Snowflake Data Type is not supported: DECIMAL -- Convert to DOUBLE" + ) + elif row["type_code"] in [3, 5, 9, 10, 12]: + error = snowflake_unsupported_map[row["type_code"]] + raise NotImplementedError( + f"The following Snowflake Data Type is not supported: {error}" + ) + elif row["type_code"] in [1, 2, 4, 6, 7, 8, 11, 13]: + row["snowflake_type"] = snowflake_type_code_map[row["type_code"]] + else: + raise NotImplementedError( + f"The following Snowflake Column is not supported: {row['column_name']} (type_code: {row['type_code']})" + ) + + if not result_cur.fetch_pandas_all().empty: + return [ + (column["column_name"], column["snowflake_type"]) for column in metadata + ] else: - query = f"SELECT * FROM ({self.query}) LIMIT 1" + raise DataSourceNotFoundException( + "The following source:\n" + query + "\n ... is empty" + ) - result = execute_snowflake_statement(snowflake_conn, query).fetch_pandas_all() - if not result.empty: - metadata = result.dtypes.apply(str) - return list(zip(metadata.index, metadata)) - else: - raise ValueError("The following source:\n" + query + "\n ... is empty") +snowflake_type_code_map = { + 0: "NUMBER", + 1: "DOUBLE", + 2: "VARCHAR", + 4: "TIMESTAMP", + 6: "TIMESTAMP_LTZ", + 7: "TIMESTAMP_TZ", + 8: "TIMESTAMP_NTZ", + 11: "BINARY", + 13: "BOOLEAN", +} + +snowflake_unsupported_map = { + 3: "DATE -- Convert to TIMESTAMP", + 5: "VARIANT -- Try converting to VARCHAR", + 9: "OBJECT -- Try converting to VARCHAR", + 10: "ARRAY -- Try converting to VARCHAR", + 12: "TIME -- Try converting to VARCHAR", +} + +python_int_to_snowflake_type_map = { + "int64": "NUMBER64", + "int32": "NUMBER32", + "int16": "NUMBER32", + "int8": "NUMBER32", +} class SnowflakeOptions: diff --git a/sdk/python/feast/infra/online_stores/snowflake.py b/sdk/python/feast/infra/online_stores/snowflake.py index 88b141981f..96a9a749a9 100644 --- a/sdk/python/feast/infra/online_stores/snowflake.py +++ b/sdk/python/feast/infra/online_stores/snowflake.py @@ -7,14 +7,17 @@ import pandas as pd import pytz -from pydantic import Field +from pydantic import Field, StrictStr from pydantic.schema import Literal from feast.entity import Entity from feast.feature_view import FeatureView from feast.infra.key_encoding_utils import serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas_binary +from feast.infra.utils.snowflake.snowflake_utils import ( + get_snowflake_conn, + write_pandas_binary, +) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -47,15 +50,15 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel): warehouse: Optional[str] = None """ Snowflake warehouse name """ - database: Optional[str] = None + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + database: StrictStr """ Snowflake database name """ schema_: Optional[str] = Field("PUBLIC", alias="schema") """ Snowflake schema name """ - authenticator: Optional[str] = None - """ Snowflake authenticator name """ - class Config: allow_population_by_field_name = True @@ -91,20 +94,11 @@ def online_write_batch( if created_ts is not None: created_ts = _to_naive_utc(created_ts) - entity_key_serialization_version = ( - config.entity_key_serialization_version - if config.entity_key_serialization_version - else 2 - ) for j, (feature_name, val) in enumerate(values.items()): df.loc[j, "entity_feature_key"] = serialize_entity_key( - entity_key, - entity_key_serialization_version, + entity_key, 2 ) + bytes(feature_name, encoding="utf-8") - df.loc[j, "entity_key"] = serialize_entity_key( - entity_key, - entity_key_serialization_version, - ) + df.loc[j, "entity_key"] = serialize_entity_key(entity_key, 2) df.loc[j, "feature_name"] = feature_name df.loc[j, "value"] = val.SerializeToString() df.loc[j, "event_ts"] = timestamp @@ -118,7 +112,9 @@ def online_write_batch( # This combines both the data upload plus the overwrite in the same transaction with get_snowflake_conn(config.online_store, autocommit=False) as conn: write_pandas_binary( - conn, agg_df, f"[online-transient] {config.project}_{table.name}" + conn, + agg_df, + f"[online-transient] {config.project}_{table.name}", ) # special function for writing binary to snowflake query = f""" @@ -159,18 +155,12 @@ def online_read( result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] - entity_key_serialization_version = ( - config.entity_key_serialization_version - if config.entity_key_serialization_version - else 2 - ) - entity_fetch_str = ",".join( [ ( "TO_BINARY(" + hexlify( - serialize_entity_key(combo[0], entity_key_serialization_version) + serialize_entity_key(combo[0], 2) + bytes(combo[1], encoding="utf-8") ).__str__()[1:] + ")" @@ -197,10 +187,7 @@ def online_read( ) for entity_key in entity_keys: - entity_key_bin = serialize_entity_key( - entity_key, - entity_key_serialization_version, - ) + entity_key_bin = serialize_entity_key(entity_key, 2) res = {} res_ts = None for index, row in df[df["entity_key"] == entity_key_bin].iterrows(): diff --git a/sdk/python/feast/infra/utils/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py similarity index 90% rename from sdk/python/feast/infra/utils/snowflake_utils.py rename to sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index 6b3500b605..485808c122 100644 --- a/sdk/python/feast/infra/utils/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -1,6 +1,7 @@ import configparser import os import random +import shutil import string from logging import getLogger from pathlib import Path @@ -18,7 +19,10 @@ wait_exponential, ) +import feast from feast.errors import SnowflakeIncompleteConfig, SnowflakeQueryUnknownError +from feast.feature_view import FeatureView +from feast.repo_config import RepoConfig try: import snowflake.connector @@ -36,6 +40,16 @@ logger = getLogger(__name__) +def assert_snowflake_feature_names(feature_view: FeatureView) -> None: + for feature in feature_view.features: + assert feature.name not in [ + "entity_key", + "feature_name", + "feature_value", + ], f"Feature Name: {feature.name} is a protected name to ensure query stability" + return None + + def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCursor: cursor = conn.cursor().execute(query) if cursor is None: @@ -44,10 +58,12 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - assert config.type in ["snowflake.offline", "snowflake.online"] + assert config.type in ["snowflake.offline", "snowflake.engine", "snowflake.online"] if config.type == "snowflake.offline": config_header = "connections.feast_offline_store" + if config.type == "snowflake.engine": + config_header = "connections.feast_batch_engine" elif config.type == "snowflake.online": config_header = "connections.feast_online_store" @@ -60,19 +76,13 @@ def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: if config_reader.has_section(config_header): kwargs = dict(config_reader[config_header]) - if "schema" in kwargs: - kwargs["schema_"] = kwargs.pop("schema") - kwargs.update((k, v) for k, v in config_dict.items() if v is not None) for k, v in kwargs.items(): if k in ["role", "warehouse", "database", "schema_"]: kwargs[k] = f'"{v}"' - if "schema_" in kwargs: - kwargs["schema"] = kwargs.pop("schema_") - else: - kwargs["schema"] = '"PUBLIC"' + kwargs["schema"] = kwargs.pop("schema_") # https://docs.snowflake.com/en/user-guide/python-connector-example.html#using-key-pair-authentication-key-pair-rotation # https://docs.snowflake.com/en/user-guide/key-pair-auth.html#configuring-key-pair-authentication @@ -95,6 +105,58 @@ def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: raise SnowflakeIncompleteConfig(e) +# Determine which set of credentials to use when using snowflake materialization engine +# if snowflake.online -- this requires the online role to have access to source tables +# if else -- this requires the offline role to have access to source tables +def get_snowflake_materialization_config(repo_config: RepoConfig): + if repo_config.batch_engine.account: + conn_config = repo_config.batch_engine + elif repo_config.online_store.type == "snowflake.online": + conn_config = repo_config.online_store + else: + conn_config = repo_config.offline_store + return conn_config + + +def package_snowpark_zip(project_name) -> Tuple[str, str]: + path = os.path.dirname(feast.__file__) + copy_path = path + f"/snowflake_feast_{project_name}" + + if os.path.exists(copy_path): + shutil.rmtree(copy_path) + + copy_files = [ + "/infra/utils/snowflake/snowpark/snowflake_udfs.py", + "/infra/key_encoding_utils.py", + "/type_map.py", + "/value_type.py", + "/protos/feast/types/Value_pb2.py", + "/protos/feast/types/EntityKey_pb2.py", + ] + + package_path = copy_path + "/feast" + for feast_file in copy_files: + idx = feast_file.rfind("/") + if idx > -1: + Path(package_path + feast_file[:idx]).mkdir(parents=True, exist_ok=True) + feast_file = shutil.copy(path + feast_file, package_path + feast_file[:idx]) + else: + feast_file = shutil.copy(path + feast_file, package_path + feast_file) + + zip_path = shutil.make_archive(package_path, "zip", copy_path) + + return copy_path, zip_path + + +def _run_snowflake_field_mapping(snowflake_job_sql: str, field_mapping: dict) -> str: + snowflake_mapped_sql = snowflake_job_sql + for key in field_mapping.keys(): + snowflake_mapped_sql = snowflake_mapped_sql.replace( + f'"{key}"', f'"{key}" AS "{field_mapping[key]}"', 1 + ) + return snowflake_mapped_sql + + # TO DO -- sfc-gh-madkins # Remove dependency on write_pandas function by falling back to native snowflake python connector # Current issue is datetime[ns] types are read incorrectly in Snowflake, need to coerce to datetime[ns, UTC] diff --git a/sdk/python/feast/infra/utils/snowflake/snowpark/__init__.py b/sdk/python/feast/infra/utils/snowflake/snowpark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_creation.sql b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_creation.sql new file mode 100644 index 0000000000..3a1b6977b2 --- /dev/null +++ b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_creation.sql @@ -0,0 +1,71 @@ +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_binary_to_bytes_proto(df BINARY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_binary_to_bytes_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_varchar_to_string_proto(df VARCHAR) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_varchar_to_string_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_number_to_int32_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_number_to_int32_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_number_to_int64_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_number_to_int64_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_float_to_double_proto(df DOUBLE) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_float_to_double_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_boolean_to_bool_proto(df BOOLEAN) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_boolean_to_bool_boolean_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_snowflake_timestamp_to_unix_timestamp_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_timestamp_to_unix_timestamp_proto' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_serialize_entity_keys(names ARRAY, data ARRAY, types ARRAY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_serialize_entity_keys' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); + +CREATE FUNCTION IF NOT EXISTS feast_PROJECT_NAME_entity_key_proto_to_string(names ARRAY, data ARRAY, types ARRAY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_entity_key_proto_to_string' + IMPORTS = ('@STAGE_HOLDER/feast.zip'); diff --git a/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql new file mode 100644 index 0000000000..bf8188dbbe --- /dev/null +++ b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql @@ -0,0 +1,17 @@ +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_binary_to_bytes_proto(BINARY); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_varchar_to_string_proto(VARCHAR); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_number_to_int32_proto(NUMBER); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_number_to_int64_proto(NUMBER); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_float_to_double_proto(DOUBLE); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_boolean_to_bool_proto(BOOLEAN); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_snowflake_timestamp_to_unix_timestamp_proto(NUMBER); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_serialize_entity_keys(ARRAY, ARRAY, ARRAY); + +DROP FUNCTION IF EXISTS feast_PROJECT_NAME_entity_key_proto_to_string(ARRAY, ARRAY, ARRAY); diff --git a/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_udfs.py b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_udfs.py new file mode 100644 index 0000000000..7fde4dd3a1 --- /dev/null +++ b/sdk/python/feast/infra/utils/snowflake/snowpark/snowflake_udfs.py @@ -0,0 +1,261 @@ +from binascii import unhexlify + +import pandas +from _snowflake import vectorized + +from feast.infra.key_encoding_utils import serialize_entity_key +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto +from feast.type_map import ( + _convert_value_type_str_to_value_type, + python_values_to_proto_values, +) +from feast.value_type import ValueType + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_binary_to_bytes_proto(df BINARY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_binary_to_bytes_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.BYTES = 1 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_binary_to_bytes_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.BYTES), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_varchar_to_string_proto(df VARCHAR) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_varchar_to_string_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.STRING = 2 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_varchar_to_string_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.STRING), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_number_to_int32_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_number_to_int32_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.INT32 = 3 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_number_to_int32_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.INT32), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_number_to_int64_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_number_to_int64_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.INT64 = 4 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_number_to_int64_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.INT64), + ) + ) + return df + + +# All floating-point numbers stored as double +# https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#data-types-for-floating-point-numbers +""" +CREATE OR REPLACE FUNCTION feast_snowflake_float_to_double_proto(df DOUBLE) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_float_to_double_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.FLOAT = 5 & ValueType.DOUBLE = 6 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_float_to_double_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.DOUBLE), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_boolean_to_bool_proto(df BOOLEAN) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_boolean_to_bool_boolean_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.BOOL = 7 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_boolean_to_bool_boolean_proto(df): + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values(df[0].to_numpy(), ValueType.BOOL), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_snowflake_timestamp_to_unix_timestamp_proto(df NUMBER) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_snowflake_timestamp_to_unix_timestamp_proto' + IMPORTS = ('@feast_stage/feast.zip'); +""" +# ValueType.UNIX_TIMESTAMP = 8 +@vectorized(input=pandas.DataFrame) +def feast_snowflake_timestamp_to_unix_timestamp_proto(df): + + df = list( + map( + ValueProto.SerializeToString, + python_values_to_proto_values( + pandas.to_datetime(df[0], unit="ns").to_numpy(), + ValueType.UNIX_TIMESTAMP, + ), + ) + ) + return df + + +""" +CREATE OR REPLACE FUNCTION feast_serialize_entity_keys(names ARRAY, data ARRAY, types ARRAY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_serialize_entity_keys' + IMPORTS = ('@feast_stage/feast.zip') +""" +# converts 1 to n many entity keys to a single binary for lookups +@vectorized(input=pandas.DataFrame) +def feast_serialize_entity_keys(df): + join_keys = create_entity_dict(df[0].values[0], df[2].values[0]) + + df = pandas.DataFrame.from_dict( + dict(zip(df[1].index, df[1].values)), orient="index", columns=df[0].values[0] + ) + + proto_values_by_column = {} + for column, value_type in list(join_keys.items()): + # BINARY is converted to a hex string, we need to convert back + if value_type == ValueType.BYTES: + proto_values = python_values_to_proto_values( + list(map(unhexlify, df[column].tolist())), value_type + ) + else: + proto_values = python_values_to_proto_values( + df[column].to_numpy(), value_type + ) + + proto_values_by_column.update({column: proto_values}) + + serialized_entity_keys = [ + serialize_entity_key( + EntityKeyProto( + join_keys=join_keys, + entity_values=[proto_values_by_column[k][idx] for k in join_keys], + ), + entity_key_serialization_version=2, + ) + for idx in range(df.shape[0]) + ] + return serialized_entity_keys + + +""" +CREATE OR REPLACE FUNCTION feast_entity_key_proto_to_string(names ARRAY, data ARRAY, types ARRAY) + RETURNS BINARY + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ('protobuf', 'pandas') + HANDLER = 'feast.infra.utils.snowflake.snowpark.snowflake_udfs.feast_entity_key_proto_to_string' + IMPORTS = ('@feast_stage/feast.zip') +""" +# converts 1 to n many entity keys to a single binary for lookups +@vectorized(input=pandas.DataFrame) +def feast_entity_key_proto_to_string(df): + join_keys = create_entity_dict(df[0].values[0], df[2].values[0]) + + df = pandas.DataFrame.from_dict( + dict(zip(df[1].index, df[1].values)), orient="index", columns=df[0].values[0] + ) + + proto_values_by_column = {} + for column, value_type in list(join_keys.items()): + # BINARY is converted to a hex string, we need to convert back + if value_type == ValueType.BYTES: + proto_values = python_values_to_proto_values( + list(map(unhexlify, df[column].tolist())), value_type + ) + else: + proto_values = python_values_to_proto_values( + df[column].to_numpy(), value_type + ) + + proto_values_by_column.update({column: proto_values}) + + serialized_entity_keys = [ + EntityKeyProto( + join_keys=join_keys, + entity_values=[proto_values_by_column[k][idx] for k in join_keys], + ).SerializeToString() + for idx in range(df.shape[0]) + ] + return serialized_entity_keys + + +def create_entity_dict(names, types): + return dict( + zip( + names, + [_convert_value_type_str_to_value_type(type_str) for type_str in types], + ) + ) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 8af434b0ca..118c1ca872 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -36,6 +36,7 @@ # - first party and third party implementations can use the same class loading code path. BATCH_ENGINE_CLASS_FOR_TYPE = { "local": "feast.infra.materialization.LocalMaterializationEngine", + "snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine", "lambda": "feast.infra.materialization.lambda.lambda_engine.LambdaMaterializationEngine", "bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine", } @@ -190,6 +191,8 @@ def __init__(self, **data: Any): self._batch_engine_config = data["batch_engine"] elif "batch_engine_config" in data: self._batch_engine_config = data["batch_engine_config"] + elif self._offline_config == "snowflake.offline": + self._batch_engine_config = "snowflake.engine" else: # Defaults to using local in-process materialization engine. self._batch_engine_config = "local" diff --git a/sdk/python/feast/templates/snowflake/bootstrap.py b/sdk/python/feast/templates/snowflake/bootstrap.py index f9dda8a20a..3cad2eea1e 100644 --- a/sdk/python/feast/templates/snowflake/bootstrap.py +++ b/sdk/python/feast/templates/snowflake/bootstrap.py @@ -2,7 +2,7 @@ import snowflake.connector from feast.file_utils import replace_str_in_file -from feast.infra.utils.snowflake_utils import write_pandas +from feast.infra.utils.snowflake.snowflake_utils import write_pandas def bootstrap(): @@ -38,7 +38,7 @@ def bootstrap(): snowflake_database = click.prompt("Snowflake Database Name (Case Sensitive):") config_file = repo_path / "feature_store.yaml" - for i in range(2): + for i in range(3): replace_str_in_file( config_file, "SNOWFLAKE_DEPLOYMENT_URL", snowflake_deployment_url ) diff --git a/sdk/python/feast/templates/snowflake/driver_repo.py b/sdk/python/feast/templates/snowflake/driver_repo.py new file mode 100644 index 0000000000..5453e44795 --- /dev/null +++ b/sdk/python/feast/templates/snowflake/driver_repo.py @@ -0,0 +1,58 @@ +from datetime import timedelta + +import yaml + +from feast import BatchFeatureView, Entity, FeatureService, SnowflakeSource + +# Define an entity for the driver. Entities can be thought of as primary keys used to +# retrieve features. Entities are also used to join multiple tables/views during the +# construction of feature vectors +driver = Entity( + # Name of the entity. Must be unique within a project + name="driver", + # The join keys of an entity describe the storage level field/column on which + # features can be looked up. The join keys are also used to join feature + # tables/views when building feature vectors + join_keys=["driver_id"], +) + +# Indicates a data source from which feature values can be retrieved. Sources are queried when building training +# datasets or materializing features into an online store. +project_name = yaml.safe_load(open("feature_store.yaml"))["project"] + +driver_stats_source = SnowflakeSource( + # The Snowflake table where features can be found + database=yaml.safe_load(open("feature_store.yaml"))["offline_store"]["database"], + table=f"{project_name}_feast_driver_hourly_stats", + # The event timestamp is used for point-in-time joins and for ensuring only + # features within the TTL are returned + timestamp_field="event_timestamp", + # The (optional) created timestamp is used to ensure there are no duplicate + # feature rows in the offline store or when building training datasets + created_timestamp_column="created", +) + +# Feature views are a grouping based on how features are stored in either the +# online or offline store. +driver_stats_fv = BatchFeatureView( + # The unique name of this feature view. Two feature views in a single + # project cannot have the same name + name="driver_hourly_stats", + # The list of entities specifies the keys required for joining or looking + # up features from this feature view. The reference provided in this field + # correspond to the name of a defined entity (or entities) + entities=[driver], + # The timedelta is the maximum age that each feature value may have + # relative to its lookup time. For historical features (used in training), + # TTL is relative to each timestamp provided in the entity dataframe. + # TTL also allows for eviction of keys from online stores and limits the + # amount of historical scanning required for historical feature values + # during retrieval + ttl=timedelta(weeks=52), + # Batch sources are used to find feature values. In the case of this feature + # view we will query a source table on Snowflake for driver statistics + # features + source=driver_stats_source, +) + +driver_stats_fs = FeatureService(name="driver_activity", features=[driver_stats_fv]) diff --git a/sdk/python/feast/templates/snowflake/feature_store.yaml b/sdk/python/feast/templates/snowflake/feature_store.yaml new file mode 100644 index 0000000000..104e6394c6 --- /dev/null +++ b/sdk/python/feast/templates/snowflake/feature_store.yaml @@ -0,0 +1,28 @@ +project: my_project +registry: registry.db +provider: local +offline_store: + type: snowflake.offline + account: SNOWFLAKE_DEPLOYMENT_URL + user: SNOWFLAKE_USER + password: SNOWFLAKE_PASSWORD + role: SNOWFLAKE_ROLE + warehouse: SNOWFLAKE_WAREHOUSE + database: SNOWFLAKE_DATABASE +batch_engine: + type: snowflake.engine + account: SNOWFLAKE_DEPLOYMENT_URL + user: SNOWFLAKE_USER + password: SNOWFLAKE_PASSWORD + role: SNOWFLAKE_ROLE + warehouse: SNOWFLAKE_WAREHOUSE + database: SNOWFLAKE_DATABASE +online_store: + type: snowflake.online + account: SNOWFLAKE_DEPLOYMENT_URL + user: SNOWFLAKE_USER + password: SNOWFLAKE_PASSWORD + role: SNOWFLAKE_ROLE + warehouse: SNOWFLAKE_WAREHOUSE + database: SNOWFLAKE_DATABASE +entity_key_serialization_version: 2 diff --git a/sdk/python/feast/type_map.py b/sdk/python/feast/type_map.py index a9dc4e25da..f8292b9c0d 100644 --- a/sdk/python/feast/type_map.py +++ b/sdk/python/feast/type_map.py @@ -15,6 +15,7 @@ from collections import defaultdict from datetime import datetime, timezone from typing import ( + TYPE_CHECKING, Any, Dict, Iterator, @@ -31,7 +32,6 @@ import numpy as np import pandas as pd -import pyarrow from google.protobuf.timestamp_pb2 import Timestamp from feast.protos.feast.types.Value_pb2 import ( @@ -46,6 +46,10 @@ from feast.protos.feast.types.Value_pb2 import Value as ProtoValue from feast.value_type import ListType, ValueType +if TYPE_CHECKING: + import pyarrow + + # null timestamps get converted to -9223372036854775808 NULL_TIMESTAMP_INT_VALUE = np.datetime64("NaT").astype(int) @@ -228,6 +232,30 @@ def python_values_to_feast_value_type( return inferred_dtype +def _convert_value_type_str_to_value_type(type_str: str) -> ValueType: + type_map = { + "UNKNOWN": ValueType.UNKNOWN, + "BYTES": ValueType.BYTES, + "STRING": ValueType.STRING, + "INT32": ValueType.INT32, + "INT64": ValueType.INT64, + "DOUBLE": ValueType.DOUBLE, + "FLOAT": ValueType.FLOAT, + "BOOL": ValueType.BOOL, + "NULL": ValueType.NULL, + "UNIX_TIMESTAMP": ValueType.UNIX_TIMESTAMP, + "BYTES_LIST": ValueType.BYTES_LIST, + "STRING_LIST": ValueType.STRING_LIST, + "INT32_LIST ": ValueType.INT32_LIST, + "INT64_LIST": ValueType.INT64_LIST, + "DOUBLE_LIST": ValueType.DOUBLE_LIST, + "FLOAT_LIST": ValueType.FLOAT_LIST, + "BOOL_LIST": ValueType.BOOL_LIST, + "UNIX_TIMESTAMP_LIST": ValueType.UNIX_TIMESTAMP_LIST, + } + return type_map[type_str] + + def _type_err(item, dtype): raise TypeError(f'Value "{item}" is of type {type(item)} not of type {dtype}') @@ -525,30 +553,37 @@ def redshift_to_feast_value_type(redshift_type_as_str: str) -> ValueType: return type_map[redshift_type_as_str.lower()] -def snowflake_python_type_to_feast_value_type( - snowflake_python_type_as_str: str, -) -> ValueType: - +def snowflake_type_to_feast_value_type(snowflake_type: str) -> ValueType: type_map = { - "str": ValueType.STRING, - "float64": ValueType.DOUBLE, - "int64": ValueType.INT64, - "uint64": ValueType.INT64, - "int32": ValueType.INT32, - "uint32": ValueType.INT32, - "int16": ValueType.INT32, - "uint16": ValueType.INT32, - "uint8": ValueType.INT32, - "int8": ValueType.INT32, - "datetime64[ns]": ValueType.UNIX_TIMESTAMP, - "object": ValueType.STRING, - "bool": ValueType.BOOL, + "BINARY": ValueType.BYTES, + "VARCHAR": ValueType.STRING, + "NUMBER32": ValueType.INT32, + "NUMBER64": ValueType.INT64, + "DOUBLE": ValueType.DOUBLE, + "BOOLEAN": ValueType.BOOL, + "TIMESTAMP": ValueType.UNIX_TIMESTAMP, + "TIMESTAMP_TZ": ValueType.UNIX_TIMESTAMP, + "TIMESTAMP_LTZ": ValueType.UNIX_TIMESTAMP, + "TIMESTAMP_NTZ": ValueType.UNIX_TIMESTAMP, } - - return type_map[snowflake_python_type_as_str.lower()] + return type_map[snowflake_type] + + +def _convert_value_name_to_snowflake_udf(value_name: str, project_name: str) -> str: + name_map = { + "BYTES": f"feast_{project_name}_snowflake_binary_to_bytes_proto", + "STRING": f"feast_{project_name}_snowflake_varchar_to_string_proto", + "INT32": f"feast_{project_name}_snowflake_number_to_int32_proto", + "INT64": f"feast_{project_name}_snowflake_number_to_int64_proto", + "DOUBLE": f"feast_{project_name}_snowflake_float_to_double_proto", + "FLOAT": f"feast_{project_name}_snowflake_float_to_double_proto", + "BOOL": f"feast_{project_name}_snowflake_boolean_to_bool_proto", + "UNIX_TIMESTAMP": f"feast_{project_name}_snowflake_timestamp_to_unix_timestamp_proto", + } + return name_map[value_name].upper() -def pa_to_redshift_value_type(pa_type: pyarrow.DataType) -> str: +def pa_to_redshift_value_type(pa_type: "pyarrow.DataType") -> str: # PyArrow types: https://arrow.apache.org/docs/python/api/datatypes.html # Redshift type: https://docs.aws.amazon.com/redshift/latest/dg/c_Supported_data_types.html pa_type_as_str = str(pa_type).lower() @@ -728,7 +763,9 @@ def pg_type_to_feast_value_type(type_str: str) -> ValueType: return value -def feast_value_type_to_pa(feast_type: ValueType) -> pyarrow.DataType: +def feast_value_type_to_pa(feast_type: ValueType) -> "pyarrow.DataType": + import pyarrow + type_map = { ValueType.INT32: pyarrow.int32(), ValueType.INT64: pyarrow.int64(), @@ -814,7 +851,7 @@ def athena_to_feast_value_type(athena_type_as_str: str) -> ValueType: return type_map[athena_type_as_str.lower()] -def pa_to_athena_value_type(pa_type: pyarrow.DataType) -> str: +def pa_to_athena_value_type(pa_type: "pyarrow.DataType") -> str: # PyArrow types: https://arrow.apache.org/docs/python/api/datatypes.html # Type names from https://docs.aws.amazon.com/athena/latest/ug/data-types.html pa_type_as_str = str(pa_type).lower() @@ -824,7 +861,7 @@ def pa_to_athena_value_type(pa_type: pyarrow.DataType) -> str: if pa_type_as_str.startswith("date"): return "date" - if pa_type_as_str.startswith("decimal"): + if pa_type_as_str.startswith("python_values_to_proto_values"): return pa_type_as_str # We have to take into account how arrow types map to parquet types as well. diff --git a/sdk/python/setup.cfg b/sdk/python/setup.cfg index ebb933f69d..d934249d69 100644 --- a/sdk/python/setup.cfg +++ b/sdk/python/setup.cfg @@ -14,7 +14,7 @@ ignore = E203, E266, E501, W503, C901 max-line-length = 88 max-complexity = 20 select = B,C,E,F,W,T4 -exclude = .git,__pycache__,docs/conf.py,dist,feast/protos,feast/embedded_go/lib +exclude = .git,__pycache__,docs/conf.py,dist,feast/protos,feast/embedded_go/lib,feast/infra/utils/snowflake/snowpark/snowflake_udfs.py [mypy] files=feast,tests diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index b5fc2448d4..2af30ef40e 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -12,7 +12,7 @@ SavedDatasetSnowflakeStorage, SnowflakeLoggingDestination, ) -from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas +from feast.infra.utils.snowflake.snowflake_utils import get_snowflake_conn, write_pandas from feast.repo_config import FeastConfigBaseModel from tests.integration.feature_repos.universal.data_source_creator import ( DataSourceCreator, diff --git a/sdk/python/tests/integration/materialization/test_snowflake.py b/sdk/python/tests/integration/materialization/test_snowflake.py new file mode 100644 index 0000000000..b11652c989 --- /dev/null +++ b/sdk/python/tests/integration/materialization/test_snowflake.py @@ -0,0 +1,125 @@ +import os +from datetime import timedelta + +import pytest + +from feast.entity import Entity +from feast.feature_view import FeatureView +from tests.data.data_creator import create_basic_driver_dataset +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.repo_configuration import ( + construct_test_environment, +) +from tests.integration.feature_repos.universal.data_sources.snowflake import ( + SnowflakeDataSourceCreator, +) +from tests.utils.e2e_test_validation import validate_offline_online_store_consistency + +SNOWFLAKE_ENGINE_CONFIG = { + "type": "snowflake.engine", + "account": os.getenv("SNOWFLAKE_CI_DEPLOYMENT", ""), + "user": os.getenv("SNOWFLAKE_CI_USER", ""), + "password": os.getenv("SNOWFLAKE_CI_PASSWORD", ""), + "role": os.getenv("SNOWFLAKE_CI_ROLE", ""), + "warehouse": os.getenv("SNOWFLAKE_CI_WAREHOUSE", ""), + "database": "FEAST", + "schema": "MATERIALIZATION", +} + +SNOWFLAKE_ONLINE_CONFIG = { + "type": "snowflake.online", + "account": os.getenv("SNOWFLAKE_CI_DEPLOYMENT", ""), + "user": os.getenv("SNOWFLAKE_CI_USER", ""), + "password": os.getenv("SNOWFLAKE_CI_PASSWORD", ""), + "role": os.getenv("SNOWFLAKE_CI_ROLE", ""), + "warehouse": os.getenv("SNOWFLAKE_CI_WAREHOUSE", ""), + "database": "FEAST", + "schema": "ONLINE", +} + + +@pytest.mark.integration +def test_snowflake_materialization_consistency_internal(): + snowflake_config = IntegrationTestRepoConfig( + online_store=SNOWFLAKE_ONLINE_CONFIG, + offline_store_creator=SnowflakeDataSourceCreator, + batch_engine=SNOWFLAKE_ENGINE_CONFIG, + ) + snowflake_environment = construct_test_environment(snowflake_config, None) + + df = create_basic_driver_dataset() + ds = snowflake_environment.data_source_creator.create_data_source( + df, + snowflake_environment.feature_store.project, + field_mapping={"ts_1": "ts"}, + ) + + fs = snowflake_environment.feature_store + driver = Entity( + name="driver_id", + join_keys=["driver_id"], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(weeks=52), + source=ds, + ) + + try: + fs.apply([driver, driver_stats_fv]) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + print(f"Split datetime: {split_dt}") + + validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt) + finally: + fs.teardown() + + +@pytest.mark.integration +def test_snowflake_materialization_consistency_external(): + snowflake_config = IntegrationTestRepoConfig( + offline_store_creator=SnowflakeDataSourceCreator, + batch_engine=SNOWFLAKE_ENGINE_CONFIG, + ) + snowflake_environment = construct_test_environment(snowflake_config, None) + + df = create_basic_driver_dataset() + ds = snowflake_environment.data_source_creator.create_data_source( + df, + snowflake_environment.feature_store.project, + field_mapping={"ts_1": "ts"}, + ) + + fs = snowflake_environment.feature_store + driver = Entity( + name="driver_id", + join_keys=["driver_id"], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(weeks=52), + source=ds, + ) + + try: + fs.apply([driver, driver_stats_fv]) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + print(f"Split datetime: {split_dt}") + + validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt) + finally: + fs.teardown()