diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py new file mode 100644 index 0000000000..4c408b0a46 --- /dev/null +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -0,0 +1,342 @@ +import logging +from abc import abstractmethod +from datetime import datetime, timedelta +from threading import Lock +from typing import List, Optional + +from feast import usage +from feast.data_source import DataSource +from feast.entity import Entity +from feast.feature_service import FeatureService +from feast.feature_view import FeatureView +from feast.infra.infra_object import Infra +from feast.infra.registry import proto_registry_utils +from feast.infra.registry.base_registry import BaseRegistry +from feast.on_demand_feature_view import OnDemandFeatureView +from feast.project_metadata import ProjectMetadata +from feast.request_feature_view import RequestFeatureView +from feast.saved_dataset import SavedDataset, ValidationReference +from feast.stream_feature_view import StreamFeatureView + +logger = logging.getLogger(__name__) + + +class CachingRegistry(BaseRegistry): + def __init__( + self, + project: str, + cache_ttl_seconds: int, + ): + self.cached_registry_proto = self.proto() + proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) + self.cached_registry_proto_created = datetime.utcnow() + self._refresh_lock = Lock() + self.cached_registry_proto_ttl = timedelta( + seconds=cache_ttl_seconds if cache_ttl_seconds is not None else 0 + ) + + @abstractmethod + def _get_data_source(self, name: str, project: str) -> DataSource: + pass + + def get_data_source( + self, name: str, project: str, allow_cache: bool = False + ) -> DataSource: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_data_source( + self.cached_registry_proto, name, project + ) + return self._get_data_source(name, project) + + @abstractmethod + def _list_data_sources(self, project: str) -> List[DataSource]: + pass + + def list_data_sources( + self, project: str, allow_cache: bool = False + ) -> List[DataSource]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_data_sources( + self.cached_registry_proto, project + ) + return self._list_data_sources(project) + + @abstractmethod + def _get_entity(self, name: str, project: str) -> Entity: + pass + + def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_entity( + self.cached_registry_proto, name, project + ) + return self._get_entity(name, project) + + @abstractmethod + def _list_entities(self, project: str) -> List[Entity]: + pass + + def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_entities( + self.cached_registry_proto, project + ) + return self._list_entities(project) + + @abstractmethod + def _get_feature_view(self, name: str, project: str) -> FeatureView: + pass + + def get_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_feature_view(name, project) + + @abstractmethod + def _list_feature_views(self, project: str) -> List[FeatureView]: + pass + + def list_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[FeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_views( + self.cached_registry_proto, project + ) + return self._list_feature_views(project) + + @abstractmethod + def _get_on_demand_feature_view( + self, name: str, project: str + ) -> OnDemandFeatureView: + pass + + def get_on_demand_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> OnDemandFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_on_demand_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_on_demand_feature_view(name, project) + + @abstractmethod + def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureView]: + pass + + def list_on_demand_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[OnDemandFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_on_demand_feature_views( + self.cached_registry_proto, project + ) + return self._list_on_demand_feature_views(project) + + @abstractmethod + def _get_request_feature_view(self, name: str, project: str) -> RequestFeatureView: + pass + + def get_request_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> RequestFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_request_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_request_feature_view(name, project) + + @abstractmethod + def _list_request_feature_views(self, project: str) -> List[RequestFeatureView]: + pass + + def list_request_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[RequestFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_request_feature_views( + self.cached_registry_proto, project + ) + return self._list_request_feature_views(project) + + @abstractmethod + def _get_stream_feature_view(self, name: str, project: str) -> StreamFeatureView: + pass + + def get_stream_feature_view( + self, name: str, project: str, allow_cache: bool = False + ) -> StreamFeatureView: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_stream_feature_view( + self.cached_registry_proto, name, project + ) + return self._get_stream_feature_view(name, project) + + @abstractmethod + def _list_stream_feature_views(self, project: str) -> List[StreamFeatureView]: + pass + + def list_stream_feature_views( + self, project: str, allow_cache: bool = False + ) -> List[StreamFeatureView]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_stream_feature_views( + self.cached_registry_proto, project + ) + return self._list_stream_feature_views(project) + + @abstractmethod + def _get_feature_service(self, name: str, project: str) -> FeatureService: + pass + + def get_feature_service( + self, name: str, project: str, allow_cache: bool = False + ) -> FeatureService: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_feature_service( + self.cached_registry_proto, name, project + ) + return self._get_feature_service(name, project) + + @abstractmethod + def _list_feature_services(self, project: str) -> List[FeatureService]: + pass + + def list_feature_services( + self, project: str, allow_cache: bool = False + ) -> List[FeatureService]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_feature_services( + self.cached_registry_proto, project + ) + return self._list_feature_services(project) + + @abstractmethod + def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: + pass + + def get_saved_dataset( + self, name: str, project: str, allow_cache: bool = False + ) -> SavedDataset: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_saved_dataset( + self.cached_registry_proto, name, project + ) + return self._get_saved_dataset(name, project) + + @abstractmethod + def _list_saved_datasets(self, project: str) -> List[SavedDataset]: + pass + + def list_saved_datasets( + self, project: str, allow_cache: bool = False + ) -> List[SavedDataset]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_saved_datasets( + self.cached_registry_proto, project + ) + return self._list_saved_datasets(project) + + @abstractmethod + def _get_validation_reference(self, name: str, project: str) -> ValidationReference: + pass + + def get_validation_reference( + self, name: str, project: str, allow_cache: bool = False + ) -> ValidationReference: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_validation_reference( + self.cached_registry_proto, name, project + ) + return self._get_validation_reference(name, project) + + @abstractmethod + def _list_validation_references(self, project: str) -> List[ValidationReference]: + pass + + def list_validation_references( + self, project: str, allow_cache: bool = False + ) -> List[ValidationReference]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_validation_references( + self.cached_registry_proto, project + ) + return self._list_validation_references(project) + + @abstractmethod + def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: + pass + + def list_project_metadata( + self, project: str, allow_cache: bool = False + ) -> List[ProjectMetadata]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_project_metadata( + self.cached_registry_proto, project + ) + return self._list_project_metadata(project) + + @abstractmethod + def _get_infra(self, project: str) -> Infra: + pass + + def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + return self._get_infra(project) + + def refresh(self, project: Optional[str] = None): + if project: + project_metadata = proto_registry_utils.get_project_metadata( + registry_proto=self.cached_registry_proto, project=project + ) + if project_metadata: + usage.set_current_project_uuid(project_metadata.project_uuid) + else: + proto_registry_utils.init_project_metadata( + self.cached_registry_proto, project + ) + self.cached_registry_proto = self.proto() + self.cached_registry_proto_created = datetime.utcnow() + + def _refresh_cached_registry_if_necessary(self): + with self._refresh_lock: + expired = ( + self.cached_registry_proto is None + or self.cached_registry_proto_created is None + ) or ( + self.cached_registry_proto_ttl.total_seconds() + > 0 # 0 ttl means infinity + and ( + datetime.utcnow() + > ( + self.cached_registry_proto_created + + self.cached_registry_proto_ttl + ) + ) + ) + + if expired: + logger.info("Registry cache expired, so refreshing") + self.refresh() diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 1e5c2a8725..597c9b8513 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -1,9 +1,8 @@ import logging import uuid -from datetime import datetime, timedelta +from datetime import datetime from enum import Enum from pathlib import Path -from threading import Lock from typing import Any, Callable, Dict, List, Optional, Set, Union from pydantic import StrictStr @@ -37,8 +36,7 @@ from feast.feature_service import FeatureService from feast.feature_view import FeatureView from feast.infra.infra_object import Infra -from feast.infra.registry import proto_registry_utils -from feast.infra.registry.base_registry import BaseRegistry +from feast.infra.registry.caching_registry import CachingRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto @@ -194,7 +192,7 @@ class SqlRegistryConfig(RegistryConfig): """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ -class SqlRegistry(BaseRegistry): +class SqlRegistry(CachingRegistry): def __init__( self, registry_config: Optional[Union[RegistryConfig, SqlRegistryConfig]], @@ -202,20 +200,14 @@ def __init__( repo_path: Optional[Path], ): assert registry_config is not None, "SqlRegistry needs a valid registry_config" + self.engine: Engine = create_engine( registry_config.path, **registry_config.sqlalchemy_config_kwargs ) metadata.create_all(self.engine) - self.cached_registry_proto = self.proto() - proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) - self.cached_registry_proto_created = datetime.utcnow() - self._refresh_lock = Lock() - self.cached_registry_proto_ttl = timedelta( - seconds=registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 + super().__init__( + project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds ) - self.project = project def teardown(self): for t in { @@ -232,49 +224,7 @@ def teardown(self): stmt = delete(t) conn.execute(stmt) - def refresh(self, project: Optional[str] = None): - if project: - project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=self.cached_registry_proto, project=project - ) - if project_metadata: - usage.set_current_project_uuid(project_metadata.project_uuid) - else: - proto_registry_utils.init_project_metadata( - self.cached_registry_proto, project - ) - self.cached_registry_proto = self.proto() - self.cached_registry_proto_created = datetime.utcnow() - - def _refresh_cached_registry_if_necessary(self): - with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( - self.cached_registry_proto_ttl.total_seconds() - > 0 # 0 ttl means infinity - and ( - datetime.utcnow() - > ( - self.cached_registry_proto_created - + self.cached_registry_proto_ttl - ) - ) - ) - - if expired: - logger.info("Registry cache expired, so refreshing") - self.refresh() - - def get_stream_feature_view( - self, name: str, project: str, allow_cache: bool = False - ): - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_stream_feature_view( - self.cached_registry_proto, name, project - ) + def _get_stream_feature_view(self, name: str, project: str): return self._get_object( table=stream_feature_views, name=name, @@ -286,14 +236,7 @@ def get_stream_feature_view( not_found_exception=FeatureViewNotFoundException, ) - def list_stream_feature_views( - self, project: str, allow_cache: bool = False - ) -> List[StreamFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_stream_feature_views( - self.cached_registry_proto, project - ) + def _list_stream_feature_views(self, project: str) -> List[StreamFeatureView]: return self._list_objects( stream_feature_views, project, @@ -311,12 +254,7 @@ def apply_entity(self, entity: Entity, project: str, commit: bool = True): proto_field_name="entity_proto", ) - def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Entity: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_entity( - self.cached_registry_proto, name, project - ) + def _get_entity(self, name: str, project: str) -> Entity: return self._get_object( table=entities, name=name, @@ -328,14 +266,7 @@ def get_entity(self, name: str, project: str, allow_cache: bool = False) -> Enti not_found_exception=EntityNotFoundException, ) - def get_feature_view( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_view( - self.cached_registry_proto, name, project - ) + def _get_feature_view(self, name: str, project: str) -> FeatureView: return self._get_object( table=feature_views, name=name, @@ -347,14 +278,9 @@ def get_feature_view( not_found_exception=FeatureViewNotFoundException, ) - def get_on_demand_feature_view( - self, name: str, project: str, allow_cache: bool = False + def _get_on_demand_feature_view( + self, name: str, project: str ) -> OnDemandFeatureView: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_on_demand_feature_view( - self.cached_registry_proto, name, project - ) return self._get_object( table=on_demand_feature_views, name=name, @@ -366,14 +292,7 @@ def get_on_demand_feature_view( not_found_exception=FeatureViewNotFoundException, ) - def get_request_feature_view( - self, name: str, project: str, allow_cache: bool = False - ): - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_request_feature_view( - self.cached_registry_proto, name, project - ) + def _get_request_feature_view(self, name: str, project: str): return self._get_object( table=request_feature_views, name=name, @@ -385,14 +304,7 @@ def get_request_feature_view( not_found_exception=FeatureViewNotFoundException, ) - def get_feature_service( - self, name: str, project: str, allow_cache: bool = False - ) -> FeatureService: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_feature_service( - self.cached_registry_proto, name, project - ) + def _get_feature_service(self, name: str, project: str) -> FeatureService: return self._get_object( table=feature_services, name=name, @@ -404,14 +316,7 @@ def get_feature_service( not_found_exception=FeatureServiceNotFoundException, ) - def get_saved_dataset( - self, name: str, project: str, allow_cache: bool = False - ) -> SavedDataset: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_saved_dataset( - self.cached_registry_proto, name, project - ) + def _get_saved_dataset(self, name: str, project: str) -> SavedDataset: return self._get_object( table=saved_datasets, name=name, @@ -423,14 +328,7 @@ def get_saved_dataset( not_found_exception=SavedDatasetNotFound, ) - def get_validation_reference( - self, name: str, project: str, allow_cache: bool = False - ) -> ValidationReference: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_validation_reference( - self.cached_registry_proto, name, project - ) + def _get_validation_reference(self, name: str, project: str) -> ValidationReference: return self._get_object( table=validation_references, name=name, @@ -442,14 +340,7 @@ def get_validation_reference( not_found_exception=ValidationReferenceNotFound, ) - def list_validation_references( - self, project: str, allow_cache: bool = False - ) -> List[ValidationReference]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_validation_references( - self.cached_registry_proto, project - ) + def _list_validation_references(self, project: str) -> List[ValidationReference]: return self._list_objects( table=validation_references, project=project, @@ -458,12 +349,7 @@ def list_validation_references( proto_field_name="validation_reference_proto", ) - def list_entities(self, project: str, allow_cache: bool = False) -> List[Entity]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_entities( - self.cached_registry_proto, project - ) + def _list_entities(self, project: str) -> List[Entity]: return self._list_objects( entities, project, EntityProto, Entity, "entity_proto" ) @@ -496,14 +382,7 @@ def delete_feature_service(self, name: str, project: str, commit: bool = True): FeatureServiceNotFoundException, ) - def get_data_source( - self, name: str, project: str, allow_cache: bool = False - ) -> DataSource: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.get_data_source( - self.cached_registry_proto, name, project - ) + def _get_data_source(self, name: str, project: str) -> DataSource: return self._get_object( table=data_sources, name=name, @@ -515,14 +394,7 @@ def get_data_source( not_found_exception=DataSourceObjectNotFoundException, ) - def list_data_sources( - self, project: str, allow_cache: bool = False - ) -> List[DataSource]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_data_sources( - self.cached_registry_proto, project - ) + def _list_data_sources(self, project: str) -> List[DataSource]: return self._list_objects( data_sources, project, DataSourceProto, DataSource, "data_source_proto" ) @@ -564,14 +436,7 @@ def delete_data_source(self, name: str, project: str, commit: bool = True): if rows.rowcount < 1: raise DataSourceObjectNotFoundException(name, project) - def list_feature_services( - self, project: str, allow_cache: bool = False - ) -> List[FeatureService]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_services( - self.cached_registry_proto, project - ) + def _list_feature_services(self, project: str) -> List[FeatureService]: return self._list_objects( feature_services, project, @@ -580,26 +445,12 @@ def list_feature_services( "feature_service_proto", ) - def list_feature_views( - self, project: str, allow_cache: bool = False - ) -> List[FeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_feature_views( - self.cached_registry_proto, project - ) + def _list_feature_views(self, project: str) -> List[FeatureView]: return self._list_objects( feature_views, project, FeatureViewProto, FeatureView, "feature_view_proto" ) - def list_saved_datasets( - self, project: str, allow_cache: bool = False - ) -> List[SavedDataset]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_saved_datasets( - self.cached_registry_proto, project - ) + def _list_saved_datasets(self, project: str) -> List[SavedDataset]: return self._list_objects( saved_datasets, project, @@ -608,14 +459,7 @@ def list_saved_datasets( "saved_dataset_proto", ) - def list_request_feature_views( - self, project: str, allow_cache: bool = False - ) -> List[RequestFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_request_feature_views( - self.cached_registry_proto, project - ) + def _list_request_feature_views(self, project: str) -> List[RequestFeatureView]: return self._list_objects( request_feature_views, project, @@ -624,14 +468,7 @@ def list_request_feature_views( "feature_view_proto", ) - def list_on_demand_feature_views( - self, project: str, allow_cache: bool = False - ) -> List[OnDemandFeatureView]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_on_demand_feature_views( - self.cached_registry_proto, project - ) + def _list_on_demand_feature_views(self, project: str) -> List[OnDemandFeatureView]: return self._list_objects( on_demand_feature_views, project, @@ -640,14 +477,7 @@ def list_on_demand_feature_views( "feature_view_proto", ) - def list_project_metadata( - self, project: str, allow_cache: bool = False - ) -> List[ProjectMetadata]: - if allow_cache: - self._refresh_cached_registry_if_necessary() - return proto_registry_utils.list_project_metadata( - self.cached_registry_proto, project - ) + def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: with self.engine.connect() as conn: stmt = select(feast_metadata).where( feast_metadata.c.project_id == project, @@ -740,7 +570,7 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True): name="infra_obj", ) - def get_infra(self, project: str, allow_cache: bool = False) -> Infra: + def _get_infra(self, project: str) -> Infra: infra_object = self._get_object( table=managed_infra, name="infra_obj",