diff --git a/docs/tutorials/azure/notebooks/src/score.py b/docs/tutorials/azure/notebooks/src/score.py index 93b248240d..7def7d2d2a 100644 --- a/docs/tutorials/azure/notebooks/src/score.py +++ b/docs/tutorials/azure/notebooks/src/score.py @@ -6,9 +6,11 @@ import json import joblib from feast import FeatureStore, RepoConfig -from feast.infra.registry.registry import RegistryConfig +from feast.repo_config import RegistryConfig -from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import MsSqlServerOfflineStoreConfig +from feast.infra.offline_stores.contrib.mssql_offline_store.mssql import ( + MsSqlServerOfflineStoreConfig, +) from feast.infra.online_stores.redis import RedisOnlineStoreConfig, RedisOnlineStore @@ -73,4 +75,4 @@ def run(raw_data): y_hat = model.predict(data) return y_hat.tolist() else: - return 0.0 \ No newline at end of file + return 0.0 diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index 042a3622a9..57d04c2700 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -132,6 +132,11 @@ def __init__(self): super().__init__("Provider is not set, but is required") +class FeastRegistryNotSetError(Exception): + def __init__(self): + super().__init__("Registry is not set, but is required") + + class FeastFeatureServerTypeSetError(Exception): def __init__(self, feature_server_type: str): super().__init__( @@ -146,6 +151,13 @@ def __init__(self, feature_server_type: str): ) +class FeastRegistryTypeInvalidError(Exception): + def __init__(self, registry_type: str): + super().__init__( + f"Feature server type was set to {registry_type}, but this type is invalid" + ) + + class FeastModuleImportError(Exception): def __init__(self, module_name: str, class_name: str): super().__init__( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 55d66e185c..f8978e9e02 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -164,9 +164,15 @@ def __init__( self.repo_path, utils.get_default_yaml_file_path(self.repo_path) ) - registry_config = self.config.get_registry_config() + registry_config = self.config.registry if registry_config.registry_type == "sql": self._registry = SqlRegistry(registry_config, self.config.project, None) + elif registry_config.registry_type == "snowflake.registry": + from feast.infra.registry.snowflake import SnowflakeRegistry + + self._registry = SnowflakeRegistry( + registry_config, self.config.project, None + ) else: r = Registry(self.config.project, registry_config, repo_path=self.repo_path) r._initialize_registry(self.config.project) @@ -209,7 +215,7 @@ def refresh_registry(self): greater than 0, then once the cache becomes stale (more time than the TTL has passed), a new cache will be downloaded synchronously, which may increase latencies if the triggering method is get_online_features(). """ - registry_config = self.config.get_registry_config() + registry_config = self.config.registry registry = Registry( self.config.project, registry_config, repo_path=self.repo_path ) diff --git a/sdk/python/feast/infra/materialization/snowflake_engine.py b/sdk/python/feast/infra/materialization/snowflake_engine.py index d8fc5f5611..3b183f97e6 100644 --- a/sdk/python/feast/infra/materialization/snowflake_engine.py +++ b/sdk/python/feast/infra/materialization/snowflake_engine.py @@ -132,6 +132,10 @@ def update( # if the stage already exists, # assumes that the materialization functions have been deployed if f"feast_{project}" in stage_list["name"].tolist(): + click.echo( + f"Materialization functions for {Style.BRIGHT + Fore.GREEN}{project}{Style.RESET_ALL} already detected." + ) + click.echo() return None click.echo( diff --git a/sdk/python/feast/infra/registry/proto_registry_utils.py b/sdk/python/feast/infra/registry/proto_registry_utils.py index f43805cd9b..2a275703db 100644 --- a/sdk/python/feast/infra/registry/proto_registry_utils.py +++ b/sdk/python/feast/infra/registry/proto_registry_utils.py @@ -9,7 +9,6 @@ EntityNotFoundException, FeatureServiceNotFoundException, FeatureViewNotFoundException, - OnDemandFeatureViewNotFoundException, SavedDatasetNotFound, ValidationReferenceNotFound, ) @@ -98,7 +97,7 @@ def get_on_demand_feature_view( and on_demand_feature_view.spec.name == name ): return OnDemandFeatureView.from_proto(on_demand_feature_view) - raise OnDemandFeatureViewNotFoundException(name, project=project) + raise FeatureViewNotFoundException(name, project=project) def get_data_source( @@ -138,10 +137,6 @@ def get_validation_reference( raise ValidationReferenceNotFound(name, project=project) -def list_validation_references(registry_proto: RegistryProto): - return registry_proto.validation_references - - def list_feature_services( registry_proto: RegistryProto, project: str, allow_cache: bool = False ) -> List[FeatureService]: @@ -215,13 +210,25 @@ def list_data_sources(registry_proto: RegistryProto, project: str) -> List[DataS def list_saved_datasets( - registry_proto: RegistryProto, project: str, allow_cache: bool = False + registry_proto: RegistryProto, project: str ) -> List[SavedDataset]: - return [ - SavedDataset.from_proto(saved_dataset) - for saved_dataset in registry_proto.saved_datasets - if saved_dataset.spec.project == project - ] + saved_datasets = [] + for saved_dataset in registry_proto.saved_datasets: + if saved_dataset.project == project: + saved_datasets.append(SavedDataset.from_proto(saved_dataset)) + return saved_datasets + + +def list_validation_references( + registry_proto: RegistryProto, project: str +) -> List[ValidationReference]: + validation_references = [] + for validation_reference in registry_proto.validation_references: + if validation_reference.project == project: + validation_references.append( + ValidationReference.from_proto(validation_reference) + ) + return validation_references def list_project_metadata( diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index c6552da0c8..d2cf6a54ec 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -174,6 +174,10 @@ def __new__( from feast.infra.registry.sql import SqlRegistry return SqlRegistry(registry_config, project, repo_path) + elif registry_config and registry_config.registry_type == "snowflake.registry": + from feast.infra.registry.snowflake import SnowflakeRegistry + + return SnowflakeRegistry(registry_config, project, repo_path) else: return super(Registry, cls).__new__(cls) @@ -731,7 +735,7 @@ def list_validation_references( registry_proto = self._get_registry_proto( project=project, allow_cache=allow_cache ) - return proto_registry_utils.list_validation_references(registry_proto) + return proto_registry_utils.list_validation_references(registry_proto, project) def delete_validation_reference(self, name: str, project: str, commit: bool = True): registry_proto = self._prepare_registry_for_changes(project) diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py new file mode 100644 index 0000000000..07709db696 --- /dev/null +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -0,0 +1,1096 @@ +import os +import uuid +from binascii import hexlify +from datetime import datetime, timedelta +from enum import Enum +from threading import Lock +from typing import Any, Callable, List, Optional, Set, Union + +from pydantic import Field, StrictStr +from pydantic.schema import Literal + +import feast +from feast import usage +from feast.base_feature_view import BaseFeatureView +from feast.data_source import DataSource +from feast.entity import Entity +from feast.errors import ( + DataSourceObjectNotFoundException, + EntityNotFoundException, + FeatureServiceNotFoundException, + FeatureViewNotFoundException, + SavedDatasetNotFound, + ValidationReferenceNotFound, +) +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.utils.snowflake.snowflake_utils import ( + execute_snowflake_statement, + get_snowflake_conn, +) +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.project_metadata import ProjectMetadata +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto +from feast.protos.feast.core.FeatureService_pb2 import ( + FeatureService as FeatureServiceProto, +) +from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto +from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto +from feast.protos.feast.core.OnDemandFeatureView_pb2 import ( + OnDemandFeatureView as OnDemandFeatureViewProto, +) +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.protos.feast.core.RequestFeatureView_pb2 import ( + RequestFeatureView as RequestFeatureViewProto, +) +from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto +from feast.protos.feast.core.StreamFeatureView_pb2 import ( + StreamFeatureView as StreamFeatureViewProto, +) +from feast.protos.feast.core.ValidationProfile_pb2 import ( + ValidationReference as ValidationReferenceProto, +) +from feast.repo_config import RegistryConfig +from feast.request_feature_view import RequestFeatureView +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView + + +class FeastMetadataKeys(Enum): + LAST_UPDATED_TIMESTAMP = "last_updated_timestamp" + PROJECT_UUID = "project_uuid" + + +class SnowflakeRegistryConfig(RegistryConfig): + """Registry config for Snowflake""" + + registry_type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + type: Literal["snowflake.registry"] = "snowflake.registry" + """ Registry type selector """ + + config_path: Optional[str] = os.path.expanduser("~/.snowsql/config") + """ Snowflake config path -- absolute path required (Cant use ~) """ + + account: Optional[str] = None + """ Snowflake deployment identifier -- drop .snowflakecomputing.com """ + + user: Optional[str] = None + """ Snowflake user name """ + + password: Optional[str] = None + """ Snowflake password """ + + role: Optional[str] = None + """ Snowflake role name """ + + warehouse: Optional[str] = None + """ Snowflake warehouse name """ + + authenticator: Optional[str] = None + """ Snowflake authenticator name """ + + database: StrictStr + """ Snowflake database name """ + + schema_: Optional[str] = Field("PUBLIC", alias="schema") + """ Snowflake schema name """ + + class Config: + allow_population_by_field_name = True + + +class SnowflakeRegistry(BaseRegistry): + def __init__( + self, + registry_config, + project: str, + repo_path, + ): + assert registry_config is not None and isinstance( + registry_config, SnowflakeRegistryConfig + ), "SnowflakeRegistry needs a valid registry_config, a path does not work" + + self.registry_config = registry_config + self.registry_path = ( + f'"{self.registry_config.database}"."{self.registry_config.schema_}"' + ) + + with get_snowflake_conn(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + self.cached_registry_proto = self.proto() + proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) + self.cached_registry_proto_created = datetime.utcnow() + self._refresh_lock = Lock() + self.cached_registry_proto_ttl = timedelta( + seconds=registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) + self.project = project + + def refresh(self, project: Optional[str] = None): + if project: + project_metadata = proto_registry_utils.get_project_metadata( + registry_proto=self.cached_registry_proto, project=project + ) + if project_metadata: + usage.set_current_project_uuid(project_metadata.project_uuid) + else: + proto_registry_utils.init_project_metadata( + self.cached_registry_proto, project + ) + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = datetime.utcnow() + + def _refresh_cached_registry_if_necessary(self): + with self._refresh_lock: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + datetime.utcnow() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if expired: + self.refresh() + + def teardown(self): + with get_snowflake_conn(self.registry_config) as conn: + sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql" + with open(sql_function_file, "r") as file: + sqlFile = file.read() + + sqlCommands = sqlFile.split(";") + for command in sqlCommands: + query = command.replace("REGISTRY_PATH", f"{self.registry_path}") + execute_snowflake_statement(conn, query) + + # apply operations + def apply_data_source( + self, data_source: DataSource, project: str, commit: bool = True + ): + return self._apply_object( + "DATA_SOURCES", + project, + "DATA_SOURCE_NAME", + data_source, + "DATA_SOURCE_PROTO", + ) + + def apply_entity(self, entity: Entity, project: str, commit: bool = True): + return self._apply_object( + "ENTITIES", project, "ENTITY_NAME", entity, "ENTITY_PROTO" + ) + + def apply_feature_service( + self, feature_service: FeatureService, project: str, commit: bool = True + ): + return self._apply_object( + "FEATURE_SERVICES", + project, + "FEATURE_SERVICE_NAME", + feature_service, + "FEATURE_SERVICE_PROTO", + ) + + def apply_feature_view( + self, feature_view: BaseFeatureView, project: str, commit: bool = True + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + return self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + feature_view, + f"{fv_column_name}_PROTO", + ) + + def apply_saved_dataset( + self, + saved_dataset: SavedDataset, + project: str, + commit: bool = True, + ): + return self._apply_object( + "SAVED_DATASETS", + project, + "SAVED_DATASET_NAME", + saved_dataset, + "SAVED_DATASET_PROTO", + ) + + def apply_validation_reference( + self, + validation_reference: ValidationReference, + project: str, + commit: bool = True, + ): + return self._apply_object( + "VALIDATION_REFERENCES", + project, + "VALIDATION_REFERENCE_NAME", + validation_reference, + "VALIDATION_REFERENCE_PROTO", + ) + + def update_infra(self, infra: Infra, project: str, commit: bool = True): + self._apply_object( + "MANAGED_INFRA", + project, + "INFRA_NAME", + infra, + "INFRA_PROTO", + name="infra_obj", + ) + + def _apply_object( + self, + table: str, + project: str, + id_field_name: str, + obj: Any, + proto_field_name: str, + name: Optional[str] = None, + ): + self._maybe_init_project_metadata(project) + + name = name or (obj.name if hasattr(obj, "name") else None) + assert name, f"name needs to be provided for {obj}" + + update_datetime = datetime.utcnow() + if hasattr(obj, "last_updated_timestamp"): + obj.last_updated_timestamp = update_datetime + + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + proto = hexlify(obj.to_proto().SerializeToString()).__str__()[1:] + query = f""" + UPDATE {self.registry_path}."{table}" + SET + {proto_field_name} = TO_BINARY({proto}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + {id_field_name.lower()} = '{name}' + """ + execute_snowflake_statement(conn, query) + + else: + obj_proto = obj.to_proto() + + if hasattr(obj_proto, "meta") and hasattr( + obj_proto.meta, "created_timestamp" + ): + obj_proto.meta.created_timestamp.FromDatetime(update_datetime) + + proto = hexlify(obj_proto.SerializeToString()).__str__()[1:] + if table == "FEATURE_VIEWS": + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '', '') + """ + elif "_FEATURE_VIEWS" in table: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto}), '') + """ + else: + query = f""" + INSERT INTO {self.registry_path}."{table}" + VALUES + ('{name}', '{project}', CURRENT_TIMESTAMP(), TO_BINARY({proto})) + """ + execute_snowflake_statement(conn, query) + + self._set_last_updated_metadata(update_datetime, project) + + # delete operations + def delete_data_source(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "DATA_SOURCES", + name, + project, + "DATA_SOURCE_NAME", + DataSourceObjectNotFoundException, + ) + + def delete_entity(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "ENTITIES", name, project, "ENTITY_NAME", EntityNotFoundException + ) + + def delete_feature_service(self, name: str, project: str, commit: bool = True): + return self._delete_object( + "FEATURE_SERVICES", + name, + project, + "FEATURE_SERVICE_NAME", + FeatureServiceNotFoundException, + ) + + # can you have featureviews with the same name + def delete_feature_view(self, name: str, project: str, commit: bool = True): + deleted_count = 0 + for table in { + "FEATURE_VIEWS", + "REQUEST_FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + }: + deleted_count += self._delete_object( + table, name, project, "FEATURE_VIEW_NAME", None + ) + if deleted_count == 0: + raise FeatureViewNotFoundException(name, project) + + def delete_saved_dataset(self, name: str, project: str, allow_cache: bool = False): + self._delete_object( + "SAVED_DATASETS", + name, + project, + "SAVED_DATASET_NAME", + SavedDatasetNotFound, + ) + + def delete_validation_reference(self, name: str, project: str, commit: bool = True): + self._delete_object( + "VALIDATION_REFERENCES", + name, + project, + "VALIDATION_REFERENCE_NAME", + ValidationReferenceNotFound, + ) + + def _delete_object( + self, + table: str, + name: str, + project: str, + id_field_name: str, + not_found_exception: Optional[Callable], + ): + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + """ + cursor = execute_snowflake_statement(conn, query) + + if cursor.rowcount < 1 and not_found_exception: + raise not_found_exception(name, project) + self._set_last_updated_metadata(datetime.utcnow(), project) + + return cursor.rowcount + + # get operations + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_data_source( + self.cached_registry_proto, name, project + ) + return self._get_object( + "DATA_SOURCES", + name, + project, + DataSourceProto, + DataSource, + "DATA_SOURCE_NAME", + "DATA_SOURCE_PROTO", + DataSourceObjectNotFoundException, + ) + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_entity( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ENTITIES", + name, + project, + EntityProto, + Entity, + "ENTITY_NAME", + "ENTITY_PROTO", + EntityNotFoundException, + ) + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_service( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_SERVICES", + name, + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_NAME", + "FEATURE_SERVICE_PROTO", + FeatureServiceNotFoundException, + ) + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "FEATURE_VIEWS", + name, + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_NAME", + "FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + infra_object = self._get_object( + "MANAGED_INFRA", + "infra_obj", + project, + InfraProto, + Infra, + "INFRA_NAME", + "INFRA_PROTO", + None, + ) + infra_object = infra_object or InfraProto() + return Infra.from_proto(infra_object) + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_on_demand_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "ON_DEMAND_FEATURE_VIEWS", + name, + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_NAME", + "ON_DEMAND_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_request_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> RequestFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_request_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "REQUEST_FEATURE_VIEWS", + name, + project, + RequestFeatureViewProto, + RequestFeatureView, + "REQUEST_FEATURE_VIEW_NAME", + "REQUEST_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_saved_dataset( + self.cached_registry_proto, name, project + ) + return self._get_object( + "SAVED_DATASETS", + name, + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_NAME", + "SAVED_DATASET_PROTO", + SavedDatasetNotFound, + ) + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ): + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_stream_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_object( + "STREAM_FEATURE_VIEWS", + name, + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_NAME", + "STREAM_FEATURE_VIEW_PROTO", + FeatureViewNotFoundException, + ) + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_validation_reference( + self.cached_registry_proto, name, project + ) + return self._get_object( + "VALIDATION_REFERENCES", + name, + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_NAME", + "VALIDATION_REFERENCE_PROTO", + ValidationReferenceNotFound, + ) + + def _get_object( + self, + table: str, + name: str, + project: str, + proto_class: Any, + python_class: Any, + id_field_name: str, + proto_field_name: str, + not_found_exception: Optional[Callable], + ): + self._maybe_init_project_metadata(project) + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + AND {id_field_name.lower()} = '{name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + _proto = proto_class.FromString(df.squeeze()) + return python_class.from_proto(_proto) + elif not_found_exception: + raise not_found_exception(name, project) + else: + return None + + # list operations + def list_data_sources( + self, project: str, allow_cache: bool = False + ) -> List[DataSource]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_data_sources( + self.cached_registry_proto, project + ) + return self._list_objects( + "DATA_SOURCES", project, DataSourceProto, DataSource, "DATA_SOURCE_PROTO" + ) + + def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_entities( + self.cached_registry_proto, project + ) + return self._list_objects( + "ENTITIES", project, EntityProto, Entity, "ENTITY_PROTO" + ) + + def list_feature_services( + self, project: str, allow_cache: bool = False + ) -> List[FeatureService]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_services( + self.cached_registry_proto, project + ) + return self._list_objects( + "FEATURE_SERVICES", + project, + FeatureServiceProto, + FeatureService, + "FEATURE_SERVICE_PROTO", + ) + + def list_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[FeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_views( + self.cached_registry_proto, project + ) + return self._list_objects( + "FEATURE_VIEWS", + project, + FeatureViewProto, + FeatureView, + "FEATURE_VIEW_PROTO", + ) + + def list_on_demand_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[OnDemandFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_on_demand_feature_views( + self.cached_registry_proto, project + ) + return self._list_objects( + "ON_DEMAND_FEATURE_VIEWS", + project, + OnDemandFeatureViewProto, + OnDemandFeatureView, + "ON_DEMAND_FEATURE_VIEW_PROTO", + ) + + def list_request_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[RequestFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_request_feature_views( + self.cached_registry_proto, project + ) + return self._list_objects( + "REQUEST_FEATURE_VIEWS", + project, + RequestFeatureViewProto, + RequestFeatureView, + "REQUEST_FEATURE_VIEW_PROTO", + ) + + def list_saved_datasets( + self, project: str, allow_cache: bool = False + ) -> List[SavedDataset]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_saved_datasets( + self.cached_registry_proto, project + ) + return self._list_objects( + "SAVED_DATASETS", + project, + SavedDatasetProto, + SavedDataset, + "SAVED_DATASET_PROTO", + ) + + def list_stream_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[StreamFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_stream_feature_views( + self.cached_registry_proto, project + ) + return self._list_objects( + "STREAM_FEATURE_VIEWS", + project, + StreamFeatureViewProto, + StreamFeatureView, + "STREAM_FEATURE_VIEW_PROTO", + ) + + def list_validation_references( + self, project: str, allow_cache: bool = False + ) -> List[ValidationReference]: + return self._list_objects( + "VALIDATION_REFERENCES", + project, + ValidationReferenceProto, + ValidationReference, + "VALIDATION_REFERENCE_PROTO", + ) + + def _list_objects( + self, + table: str, + project: str, + proto_class: Any, + python_class: Any, + proto_field_name: str, + ): + self._maybe_init_project_metadata(project) + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + {proto_field_name} + FROM + {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + return [ + python_class.from_proto( + proto_class.FromString(row[1][proto_field_name]) + ) + for row in df.iterrows() + ] + return [] + + def apply_materialization( + self, + feature_view: FeatureView, + project: str, + start_date: datetime, + end_date: datetime, + commit: bool = True, + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1] + python_class, proto_class = self._infer_fv_classes(feature_view) + + if python_class in {RequestFeatureView, OnDemandFeatureView}: + raise ValueError( + f"Cannot apply materialization for feature {feature_view.name} of type {python_class}" + ) + fv: Union[FeatureView, StreamFeatureView] = self._get_object( + fv_table_str, + feature_view.name, + project, + proto_class, + python_class, + f"{fv_column_name}_NAME", + f"{fv_column_name}_PROTO", + FeatureViewNotFoundException, + ) + fv.materialization_intervals.append((start_date, end_date)) + self._apply_object( + fv_table_str, + project, + f"{fv_column_name}_NAME", + fv, + f"{fv_column_name}_PROTO", + ) + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_project_metadata( + self.cached_registry_proto, project + ) + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + metadata_key, + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + project_metadata = ProjectMetadata(project_name=project) + for row in df.iterrows(): + if row[1]["METADATA_KEY"] == FeastMetadataKeys.PROJECT_UUID.value: + project_metadata.project_uuid = row[1]["METADATA_VALUE"] + break + # TODO(adchia): Add other project metadata in a structured way + return [project_metadata] + return [] + + def apply_user_metadata( + self, + project: str, + feature_view: BaseFeatureView, + metadata_bytes: Optional[bytes], + ): + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."{fv_table_str}" + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + if metadata_bytes: + metadata_hex = hexlify(metadata_bytes).__str__()[1:] + else: + metadata_hex = "''" + query = f""" + UPDATE {self.registry_path}."{fv_table_str}" + SET + user_metadata = TO_BINARY({metadata_hex}), + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND {fv_column_name}_name = '{feature_view.name}' + """ + execute_snowflake_statement(conn, query) + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def get_user_metadata( + self, project: str, feature_view: BaseFeatureView + ) -> Optional[bytes]: + fv_table_str = self._infer_fv_table(feature_view) + fv_column_name = fv_table_str[:-1].lower() + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + user_metadata + FROM + {self.registry_path}."{fv_table_str}" + WHERE + {fv_column_name}_name = '{feature_view.name}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + return df.squeeze() + else: + raise FeatureViewNotFoundException(feature_view.name, project=project) + + def proto(self) -> RegistryProto: + r = RegistryProto() + last_updated_timestamps = [] + projects = self._get_all_projects() + for project in projects: + for lister, registry_proto_field in [ + (self.list_entities, r.entities), + (self.list_feature_views, r.feature_views), + (self.list_data_sources, r.data_sources), + (self.list_on_demand_feature_views, r.on_demand_feature_views), + (self.list_request_feature_views, r.request_feature_views), + (self.list_stream_feature_views, r.stream_feature_views), + (self.list_feature_services, r.feature_services), + (self.list_saved_datasets, r.saved_datasets), + (self.list_validation_references, r.validation_references), + (self.list_project_metadata, r.project_metadata), + ]: + objs: List[Any] = lister(project) # type: ignore + if objs: + obj_protos = [obj.to_proto() for obj in objs] + for obj_proto in obj_protos: + if "spec" in obj_proto.DESCRIPTOR.fields_by_name: + obj_proto.spec.project = project + else: + obj_proto.project = project + registry_proto_field.extend(obj_protos) + + # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, + # the registry proto only has a single infra field, which we're currently setting as the "last" project. + r.infra.CopyFrom(self.get_infra(project).to_proto()) + last_updated_timestamps.append(self._get_last_updated_metadata(project)) + + if last_updated_timestamps: + r.last_updated.FromDatetime(max(last_updated_timestamps)) + + return r + + def _get_all_projects(self) -> Set[str]: + projects = set() + + base_tables = [ + "DATA_SOURCES", + "ENTITIES", + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "REQUEST_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + ] + + with get_snowflake_conn(self.registry_config) as conn: + for table in base_tables: + query = ( + f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"' + ) + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + projects.add(row[1]["PROJECT_ID"]) + + return projects + + def _get_last_updated_metadata(self, project: str): + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if df.empty: + return None + + return datetime.utcfromtimestamp(int(df.squeeze())) + + def _infer_fv_classes(self, feature_view): + if isinstance(feature_view, StreamFeatureView): + python_class, proto_class = StreamFeatureView, StreamFeatureViewProto + elif isinstance(feature_view, FeatureView): + python_class, proto_class = FeatureView, FeatureViewProto + elif isinstance(feature_view, OnDemandFeatureView): + python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto + elif isinstance(feature_view, RequestFeatureView): + python_class, proto_class = RequestFeatureView, RequestFeatureViewProto + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return python_class, proto_class + + def _infer_fv_table(self, feature_view) -> str: + if isinstance(feature_view, StreamFeatureView): + table = "STREAM_FEATURE_VIEWS" + elif isinstance(feature_view, FeatureView): + table = "FEATURE_VIEWS" + elif isinstance(feature_view, OnDemandFeatureView): + table = "ON_DEMAND_FEATURE_VIEWS" + elif isinstance(feature_view, RequestFeatureView): + table = "REQUEST_FEATURE_VIEWS" + else: + raise ValueError(f"Unexpected feature view type: {type(feature_view)}") + return table + + def _maybe_init_project_metadata(self, project): + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + metadata_value + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.PROJECT_UUID.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + if not df.empty: + usage.set_current_project_uuid(df.squeeze()) + else: + new_project_uuid = f"{uuid.uuid4()}" + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.PROJECT_UUID.value}', '{new_project_uuid}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + usage.set_current_project_uuid(new_project_uuid) + + def _set_last_updated_metadata(self, last_updated: datetime, project: str): + with get_snowflake_conn(self.registry_config) as conn: + query = f""" + SELECT + project_id + FROM + {self.registry_path}."FEAST_METADATA" + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + LIMIT 1 + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + update_time = int(last_updated.timestamp()) + if not df.empty: + query = f""" + UPDATE {self.registry_path}."FEAST_METADATA" + SET + project_id = '{project}', + metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', + metadata_value = '{update_time}', + last_updated_timestamp = CURRENT_TIMESTAMP() + WHERE + project_id = '{project}' + AND metadata_key = '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}' + """ + execute_snowflake_statement(conn, query) + + else: + query = f""" + INSERT INTO {self.registry_path}."FEAST_METADATA" + VALUES + ('{project}', '{FeastMetadataKeys.LAST_UPDATED_TIMESTAMP.value}', '{update_time}', CURRENT_TIMESTAMP()) + """ + execute_snowflake_statement(conn, query) + + def commit(self): + pass diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index de21e3c056..628b6d1e65 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -429,7 +429,7 @@ def list_validation_references( if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_validation_references( - self.cached_registry_proto + self.cached_registry_proto, project ) return self._list_objects( table=validation_references, diff --git a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql new file mode 100644 index 0000000000..4b53d6bb3f --- /dev/null +++ b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql @@ -0,0 +1,92 @@ +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."DATA_SOURCES" ( + data_source_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + data_source_proto BINARY NOT NULL, + PRIMARY KEY (data_source_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."ENTITIES" ( + entity_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + entity_proto BINARY NOT NULL, + PRIMARY KEY (entity_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."FEAST_METADATA" ( + project_id VARCHAR, + metadata_key VARCHAR, + metadata_value VARCHAR NOT NULL, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + PRIMARY KEY (project_id, metadata_key) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."FEATURE_SERVICES" ( + feature_service_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + feature_service_proto BINARY NOT NULL, + PRIMARY KEY (feature_service_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."FEATURE_VIEWS" ( + feature_view_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + feature_view_proto BINARY NOT NULL, + materialized_intervals BINARY, + user_metadata BINARY, + PRIMARY KEY (feature_view_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."MANAGED_INFRA" ( + infra_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + infra_proto BINARY NOT NULL, + PRIMARY KEY (infra_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."ON_DEMAND_FEATURE_VIEWS" ( + on_demand_feature_view_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + on_demand_feature_view_proto BINARY NOT NULL, + user_metadata BINARY, + PRIMARY KEY (on_demand_feature_view_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."REQUEST_FEATURE_VIEWS" ( + request_feature_view_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + request_feature_view_proto BINARY NOT NULL, + user_metadata BINARY, + PRIMARY KEY (request_feature_view_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."SAVED_DATASETS" ( + saved_dataset_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + saved_dataset_proto BINARY NOT NULL, + PRIMARY KEY (saved_dataset_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."STREAM_FEATURE_VIEWS" ( + stream_feature_view_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + stream_feature_view_proto BINARY NOT NULL, + user_metadata BINARY, + PRIMARY KEY (stream_feature_view_name, project_id) +); + +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."VALIDATION_REFERENCES" ( + validation_reference_name VARCHAR, + project_id VARCHAR, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + validation_reference_proto BINARY NOT NULL, + PRIMARY KEY (validation_reference_name, project_id) +) diff --git a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql new file mode 100644 index 0000000000..7f5c1991ea --- /dev/null +++ b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_deletion.sql @@ -0,0 +1,21 @@ +DROP TABLE IF EXISTS REGISTRY_PATH."DATA_SOURCES"; + +DROP TABLE IF EXISTS REGISTRY_PATH."ENTITIES"; + +DROP TABLE IF EXISTS REGISTRY_PATH."FEAST_METADATA"; + +DROP TABLE IF EXISTS REGISTRY_PATH."FEATURE_SERVICES"; + +DROP TABLE IF EXISTS REGISTRY_PATH."FEATURE_VIEWS"; + +DROP TABLE IF EXISTS REGISTRY_PATH."MANAGED_INFRA"; + +DROP TABLE IF EXISTS REGISTRY_PATH."ON_DEMAND_FEATURE_VIEWS"; + +DROP TABLE IF EXISTS REGISTRY_PATH."REQUEST_FEATURE_VIEWS"; + +DROP TABLE IF EXISTS REGISTRY_PATH."SAVED_DATASETS"; + +DROP TABLE IF EXISTS REGISTRY_PATH."STREAM_FEATURE_VIEWS"; + +DROP TABLE IF EXISTS REGISTRY_PATH."VALIDATION_REFERENCES" diff --git a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py index a5d2b05d45..8023980eac 100644 --- a/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py +++ b/sdk/python/feast/infra/utils/snowflake/snowflake_utils.py @@ -58,9 +58,16 @@ def execute_snowflake_statement(conn: SnowflakeConnection, query) -> SnowflakeCu def get_snowflake_conn(config, autocommit=True) -> SnowflakeConnection: - assert config.type in ["snowflake.offline", "snowflake.engine", "snowflake.online"] + assert config.type in [ + "snowflake.registry", + "snowflake.offline", + "snowflake.engine", + "snowflake.online", + ] - if config.type == "snowflake.offline": + if config.type == "snowflake.registry": + config_header = "connections.feast_registry" + elif config.type == "snowflake.offline": config_header = "connections.feast_offline_store" if config.type == "snowflake.engine": config_header = "connections.feast_batch_engine" diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 28847294b3..5430d557bb 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -15,7 +15,7 @@ validator, ) from pydantic.error_wrappers import ErrorWrapper -from pydantic.typing import Dict, Optional, Union +from pydantic.typing import Dict, Optional from feast.errors import ( FeastFeatureServerTypeInvalidError, @@ -23,6 +23,8 @@ FeastOfflineStoreInvalidName, FeastOnlineStoreInvalidName, FeastProviderNotSetError, + FeastRegistryNotSetError, + FeastRegistryTypeInvalidError, ) from feast.importer import import_class from feast.usage import log_exceptions @@ -34,6 +36,12 @@ # These dict exists so that: # - existing values for the online store type in featurestore.yaml files continue to work in a backwards compatible way # - first party and third party implementations can use the same class loading code path. +REGISTRY_CLASS_FOR_TYPE = { + "file": "feast.infra.registry.registry.Registry", + "sql": "feast.infra.registry.sql.SqlRegistry", + "snowflake.registry": "feast.infra.registry.snowflake.SnowflakeRegistry", +} + BATCH_ENGINE_CLASS_FOR_TYPE = { "local": "feast.infra.materialization.local_engine.LocalMaterializationEngine", "snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine", @@ -101,14 +109,15 @@ class RegistryConfig(FeastBaseModel): """Metadata Store Configuration. Configuration that relates to reading from and writing to the Feast registry.""" registry_type: StrictStr = "file" - """ str: Provider name or a class name that implements RegistryStore. - If specified, registry_store_type should be redundant.""" + """ str: Provider name or a class name that implements Registry.""" registry_store_type: Optional[StrictStr] """ str: Provider name or a class name that implements RegistryStore. """ - path: StrictStr - """ str: Path to metadata store. Can be a local path, or remote object storage path, e.g. a GCS URI """ + path: StrictStr = "" + """ str: Path to metadata store. + If registry_type is 'file', then an be a local path, or remote object storage path, e.g. a GCS URI + If registry_type is 'sql', then this is a database URL as expected by SQLAlchemy """ cache_ttl_seconds: StrictInt = 600 """int: The cache TTL is the amount of time registry state will be cached in memory. If this TTL is exceeded then @@ -123,9 +132,6 @@ class RegistryConfig(FeastBaseModel): class RepoConfig(FeastBaseModel): """Repo config. Typically loaded from `feature_store.yaml`""" - registry: Union[StrictStr, RegistryConfig] = "data/registry.db" - """ str: Path to metadata store. Can be a local path, or remote object storage path, e.g. a GCS URI """ - project: StrictStr """ str: Feast project id. This can be any alphanumeric string up to 16 characters. You can have multiple independent feature repositories deployed to the same cloud @@ -135,6 +141,14 @@ class RepoConfig(FeastBaseModel): provider: StrictStr """ str: local or gcp or aws """ + _registry_config: Any = Field(alias="registry", default="data/registry.db") + """ Configures the registry. + Can be: + 1. str: a path to a file based registry (a local path, or remote object storage path, e.g. a GCS URI) + 2. RegistryConfig: A fully specified file based registry or SQL based registry + 3. SnowflakeRegistryConfig: Using a Snowflake table to store the registry + """ + _online_config: Any = Field(alias="online_store") """ OnlineStoreConfig: Online store configuration (optional depending on provider) """ @@ -175,6 +189,11 @@ class RepoConfig(FeastBaseModel): def __init__(self, **data: Any): super().__init__(**data) + self._registry = None + if "registry" not in data: + raise FeastRegistryNotSetError() + self._registry_config = data["registry"] + self._offline_store = None if "offline_store" in data: self._offline_config = data["offline_store"] @@ -223,11 +242,25 @@ def __init__(self, **data: Any): RuntimeWarning, ) - def get_registry_config(self): - if isinstance(self.registry, str): - return RegistryConfig(path=self.registry) - else: - return self.registry + @property + def registry(self): + if not self._registry: + if isinstance(self._registry_config, Dict): + if "registry_type" in self._registry_config: + self._registry = get_registry_config_from_type( + self._registry_config["registry_type"] + )(**self._registry_config) + else: + # This may be a custom registry store, which does not need a 'registry_type' + self._registry = RegistryConfig(**self._registry_config) + elif isinstance(self._registry_config, str): + # User passed in just a path to file registry + self._registry = get_registry_config_from_type("file")( + path=self._registry_config + ) + elif self._registry_config: + self._registry = self._registry_config + return self._registry @property def offline_store(self): @@ -457,6 +490,16 @@ def get_data_source_class_from_type(data_source_type: str): return import_class(module_name, config_class_name, "DataSource") +def get_registry_config_from_type(registry_type: str): + # We do not support custom registry's right now + if registry_type not in REGISTRY_CLASS_FOR_TYPE: + raise FeastRegistryTypeInvalidError(registry_type) + registry_type = REGISTRY_CLASS_FOR_TYPE[registry_type] + module_name, registry_class_type = registry_type.rsplit(".", 1) + config_class_name = f"{registry_class_type}Config" + return import_class(module_name, config_class_name, config_class_name) + + def get_batch_engine_config_from_type(batch_engine_type: str): if batch_engine_type in BATCH_ENGINE_CLASS_FOR_TYPE: batch_engine_type = BATCH_ENGINE_CLASS_FOR_TYPE[batch_engine_type] diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index 03162e7507..a66edc86cd 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -347,7 +347,7 @@ def teardown(repo_config: RepoConfig, repo_path: Path): @log_exceptions_and_usage def registry_dump(repo_config: RepoConfig, repo_path: Path) -> str: """For debugging only: output contents of the metadata registry""" - registry_config = repo_config.get_registry_config() + registry_config = repo_config.registry project = repo_config.project registry = Registry(project, registry_config=registry_config, repo_path=repo_path) registry_dict = registry.to_dict(project=project) diff --git a/sdk/python/tests/integration/registration/test_inference.py b/sdk/python/tests/integration/registration/test_inference.py index 17bb09933e..9f490d7f4e 100644 --- a/sdk/python/tests/integration/registration/test_inference.py +++ b/sdk/python/tests/integration/registration/test_inference.py @@ -20,7 +20,10 @@ def test_update_file_data_source_with_inferred_event_timestamp_col(simple_datase update_data_sources_with_inferred_event_timestamp_col( data_sources, RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + registry="test.pb", + entity_key_serialization_version=2, ), ) actual_event_timestamp_cols = [ @@ -35,7 +38,10 @@ def test_update_file_data_source_with_inferred_event_timestamp_col(simple_datase update_data_sources_with_inferred_event_timestamp_col( [file_source], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + registry="test.pb", + entity_key_serialization_version=2, ), ) @@ -53,7 +59,10 @@ def test_update_data_sources_with_inferred_event_timestamp_col(universal_data_so update_data_sources_with_inferred_event_timestamp_col( data_sources_copy.values(), RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + registry="test.pb", + entity_key_serialization_version=2, ), ) actual_event_timestamp_cols = [ diff --git a/sdk/python/tests/unit/cli/test_cli.py b/sdk/python/tests/unit/cli/test_cli.py index 25a1dfed34..d15e1d1616 100644 --- a/sdk/python/tests/unit/cli/test_cli.py +++ b/sdk/python/tests/unit/cli/test_cli.py @@ -122,6 +122,7 @@ def setup_third_party_provider_repo(provider_name: str): type: sqlite offline_store: type: file + entity_key_serialization_version: 2 """ ) ) @@ -159,6 +160,7 @@ def setup_third_party_registry_store_repo( type: sqlite offline_store: type: file + entity_key_serialization_version: 2 """ ) ) diff --git a/sdk/python/tests/unit/infra/test_inference_unit_tests.py b/sdk/python/tests/unit/infra/test_inference_unit_tests.py index 46a131e1b5..a108d397bd 100644 --- a/sdk/python/tests/unit/infra/test_inference_unit_tests.py +++ b/sdk/python/tests/unit/infra/test_inference_unit_tests.py @@ -194,7 +194,10 @@ def test_feature_view_inference_respects_basic_inference(): [feature_view_1], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) assert len(feature_view_1.schema) == 2 @@ -209,7 +212,10 @@ def test_feature_view_inference_respects_basic_inference(): [feature_view_2], [entity1, entity2], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) assert len(feature_view_2.schema) == 3 @@ -240,7 +246,10 @@ def test_feature_view_inference_on_entity_value_types(): [feature_view_1], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) @@ -310,7 +319,10 @@ def test_feature_view_inference_on_entity_columns(simple_dataset_1): [feature_view_1], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) @@ -345,7 +357,10 @@ def test_feature_view_inference_on_feature_columns(simple_dataset_1): [feature_view_1], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) @@ -397,7 +412,10 @@ def test_update_feature_services_with_inferred_features(simple_dataset_1): [feature_view_1, feature_view_2], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) feature_service.infer_features( @@ -454,7 +472,10 @@ def test_update_feature_services_with_specified_features(simple_dataset_1): [feature_view_1, feature_view_2], [entity1], RepoConfig( - provider="local", project="test", entity_key_serialization_version=2 + provider="local", + project="test", + entity_key_serialization_version=2, + registry="dummy_registry.pb", ), ) assert len(feature_view_1.features) == 1 diff --git a/sdk/python/tests/unit/online_store/test_online_retrieval.py b/sdk/python/tests/unit/online_store/test_online_retrieval.py index 6f96e7b5d9..926c7226fc 100644 --- a/sdk/python/tests/unit/online_store/test_online_retrieval.py +++ b/sdk/python/tests/unit/online_store/test_online_retrieval.py @@ -137,7 +137,7 @@ def test_online() -> None: fs_fast_ttl = FeatureStore( config=RepoConfig( registry=RegistryConfig( - path=store.config.registry, cache_ttl_seconds=cache_ttl + path=store.config.registry.path, cache_ttl_seconds=cache_ttl ), online_store=store.config.online_store, project=store.project, @@ -161,7 +161,7 @@ def test_online() -> None: assert result["trips"] == [7] # Rename the registry.db so that it cant be used for refreshes - os.rename(store.config.registry, store.config.registry + "_fake") + os.rename(store.config.registry.path, store.config.registry.path + "_fake") # Wait for registry to expire time.sleep(cache_ttl) @@ -180,7 +180,7 @@ def test_online() -> None: ).to_dict() # Restore registry.db so that we can see if it actually reloads registry - os.rename(store.config.registry + "_fake", store.config.registry) + os.rename(store.config.registry.path + "_fake", store.config.registry.path) # Test if registry is actually reloaded and whether results return result = fs_fast_ttl.get_online_features( @@ -200,7 +200,7 @@ def test_online() -> None: fs_infinite_ttl = FeatureStore( config=RepoConfig( registry=RegistryConfig( - path=store.config.registry, cache_ttl_seconds=0 + path=store.config.registry.path, cache_ttl_seconds=0 ), online_store=store.config.online_store, project=store.project, @@ -227,7 +227,7 @@ def test_online() -> None: time.sleep(2) # Rename the registry.db so that it cant be used for refreshes - os.rename(store.config.registry, store.config.registry + "_fake") + os.rename(store.config.registry.path, store.config.registry.path + "_fake") # TTL is infinite so this method should use registry cache result = fs_infinite_ttl.get_online_features( @@ -248,7 +248,7 @@ def test_online() -> None: fs_infinite_ttl.refresh_registry() # Restore registry.db so that teardown works - os.rename(store.config.registry + "_fake", store.config.registry) + os.rename(store.config.registry.path + "_fake", store.config.registry.path) def test_online_to_df():