diff --git a/docs/getting-started/quickstart.md b/docs/getting-started/quickstart.md index ffc01c9d6e..7169989e7e 100644 --- a/docs/getting-started/quickstart.md +++ b/docs/getting-started/quickstart.md @@ -103,12 +103,17 @@ from feast import ( FeatureView, Field, FileSource, + Project, PushSource, RequestSource, ) from feast.on_demand_feature_view import on_demand_feature_view from feast.types import Float32, Float64, Int64 +# Define a project for the feature repo +project = Project(name="my_project", description="A project for driver statistics") + + # Define an entity for the driver. You can think of an entity as a primary key used to # fetch features. driver = Entity(name="driver", join_keys=["driver_id"]) diff --git a/protos/feast/core/Permission.proto b/protos/feast/core/Permission.proto index 57958d3d81..400f70a11b 100644 --- a/protos/feast/core/Permission.proto +++ b/protos/feast/core/Permission.proto @@ -45,6 +45,7 @@ message PermissionSpec { VALIDATION_REFERENCE = 7; SAVED_DATASET = 8; PERMISSION = 9; + PROJECT = 10; } repeated Type types = 3; diff --git a/protos/feast/core/Project.proto b/protos/feast/core/Project.proto new file mode 100644 index 0000000000..08e8b38f23 --- /dev/null +++ b/protos/feast/core/Project.proto @@ -0,0 +1,52 @@ +// +// * Copyright 2020 The Feast Authors +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * https://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// + +syntax = "proto3"; + +package feast.core; +option java_package = "feast.proto.core"; +option java_outer_classname = "ProjectProto"; +option go_package = "github.com/feast-dev/feast/go/protos/feast/core"; + +import "google/protobuf/timestamp.proto"; + +message Project { + // User-specified specifications of this entity. + ProjectSpec spec = 1; + // System-populated metadata for this entity. + ProjectMeta meta = 2; +} + +message ProjectSpec { + // Name of the Project + string name = 1; + + // Description of the Project + string description = 2; + + // User defined metadata + map tags = 3; + + // Owner of the Project + string owner = 4; +} + +message ProjectMeta { + // Time when the Project is created + google.protobuf.Timestamp created_timestamp = 1; + // Time when the Project is last updated with registry changes (Apply stage) + google.protobuf.Timestamp last_updated_timestamp = 2; +} diff --git a/protos/feast/core/Registry.proto b/protos/feast/core/Registry.proto index b4f1ffb0a3..45ecd2c173 100644 --- a/protos/feast/core/Registry.proto +++ b/protos/feast/core/Registry.proto @@ -33,8 +33,9 @@ import "feast/core/SavedDataset.proto"; import "feast/core/ValidationProfile.proto"; import "google/protobuf/timestamp.proto"; import "feast/core/Permission.proto"; +import "feast/core/Project.proto"; -// Next id: 17 +// Next id: 18 message Registry { repeated Entity entities = 1; repeated FeatureTable feature_tables = 2; @@ -47,12 +48,13 @@ message Registry { repeated ValidationReference validation_references = 13; Infra infra = 10; // Tracking metadata of Feast by project - repeated ProjectMetadata project_metadata = 15; + repeated ProjectMetadata project_metadata = 15 [deprecated = true]; string registry_schema_version = 3; // to support migrations; incremented when schema is changed string version_id = 4; // version id, random string generated on each update of the data; now used only for debugging purposes google.protobuf.Timestamp last_updated = 5; repeated Permission permissions = 16; + repeated Project projects = 17; } message ProjectMetadata { diff --git a/protos/feast/registry/RegistryServer.proto b/protos/feast/registry/RegistryServer.proto index 928354077b..3ad64b5b34 100644 --- a/protos/feast/registry/RegistryServer.proto +++ b/protos/feast/registry/RegistryServer.proto @@ -15,6 +15,7 @@ import "feast/core/SavedDataset.proto"; import "feast/core/ValidationProfile.proto"; import "feast/core/InfraObject.proto"; import "feast/core/Permission.proto"; +import "feast/core/Project.proto"; service RegistryServer{ // Entity RPCs @@ -67,6 +68,12 @@ service RegistryServer{ rpc ListPermissions (ListPermissionsRequest) returns (ListPermissionsResponse) {} rpc DeletePermission (DeletePermissionRequest) returns (google.protobuf.Empty) {} + // Project RPCs + rpc ApplyProject (ApplyProjectRequest) returns (google.protobuf.Empty) {} + rpc GetProject (GetProjectRequest) returns (feast.core.Project) {} + rpc ListProjects (ListProjectsRequest) returns (ListProjectsResponse) {} + rpc DeleteProject (DeleteProjectRequest) returns (google.protobuf.Empty) {} + rpc ApplyMaterialization (ApplyMaterializationRequest) returns (google.protobuf.Empty) {} rpc ListProjectMetadata (ListProjectMetadataRequest) returns (ListProjectMetadataResponse) {} rpc UpdateInfra (UpdateInfraRequest) returns (google.protobuf.Empty) {} @@ -356,3 +363,29 @@ message DeletePermissionRequest { string project = 2; bool commit = 3; } + +// Projects + +message ApplyProjectRequest { + feast.core.Project project = 1; + bool commit = 2; +} + +message GetProjectRequest { + string name = 1; + bool allow_cache = 2; +} + +message ListProjectsRequest { + bool allow_cache = 1; + map tags = 2; +} + +message ListProjectsResponse { + repeated feast.core.Project projects = 1; +} + +message DeleteProjectRequest { + string name = 1; + bool commit = 2; +} diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index 52734bc71e..71122b7047 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -18,6 +18,7 @@ from .feature_view import FeatureView from .field import Field from .on_demand_feature_view import OnDemandFeatureView +from .project import Project from .repo_config import RepoConfig from .stream_feature_view import StreamFeatureView from .value_type import ValueType @@ -49,4 +50,5 @@ "PushSource", "RequestSource", "AthenaSource", + "Project", ] diff --git a/sdk/python/feast/cli.py b/sdk/python/feast/cli.py index ec90b31151..499788101e 100644 --- a/sdk/python/feast/cli.py +++ b/sdk/python/feast/cli.py @@ -254,6 +254,79 @@ def data_source_list(ctx: click.Context, tags: list[str]): print(tabulate(table, headers=["NAME", "CLASS"], tablefmt="plain")) +@cli.group(name="projects") +def projects_cmd(): + """ + Access projects + """ + pass + + +@projects_cmd.command("describe") +@click.argument("name", type=click.STRING) +@click.pass_context +def project_describe(ctx: click.Context, name: str): + """ + Describe a project + """ + store = create_feature_store(ctx) + + try: + project = store.get_project(name) + except FeastObjectNotFoundException as e: + print(e) + exit(1) + + print( + yaml.dump( + yaml.safe_load(str(project)), default_flow_style=False, sort_keys=False + ) + ) + + +@projects_cmd.command("current_project") +@click.pass_context +def project_current(ctx: click.Context): + """ + Returns the current project configured with FeatureStore object + """ + store = create_feature_store(ctx) + + try: + project = store.get_project(name=None) + except FeastObjectNotFoundException as e: + print(e) + exit(1) + + print( + yaml.dump( + yaml.safe_load(str(project)), default_flow_style=False, sort_keys=False + ) + ) + + +@projects_cmd.command(name="list") +@tagsOption +@click.pass_context +def project_list(ctx: click.Context, tags: list[str]): + """ + List all projects + """ + store = create_feature_store(ctx) + table = [] + tags_filter = utils.tags_list_to_dict(tags) + for project in store.list_projects(tags=tags_filter): + table.append([project.name, project.description, project.tags, project.owner]) + + from tabulate import tabulate + + print( + tabulate( + table, headers=["NAME", "DESCRIPTION", "TAGS", "OWNER"], tablefmt="plain" + ) + ) + + @cli.group(name="entities") def entities_cmd(): """ diff --git a/sdk/python/feast/diff/registry_diff.py b/sdk/python/feast/diff/registry_diff.py index 6235025adc..272c4590d8 100644 --- a/sdk/python/feast/diff/registry_diff.py +++ b/sdk/python/feast/diff/registry_diff.py @@ -11,6 +11,7 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.infra.registry.registry import FEAST_OBJECT_TYPES, FeastObjectType from feast.permissions.permission import Permission +from feast.project import Project from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto from feast.protos.feast.core.FeatureService_pb2 import ( @@ -371,6 +372,11 @@ def apply_diff_to_registry( TransitionType.CREATE, TransitionType.UPDATE, ]: + if feast_object_diff.feast_object_type == FeastObjectType.PROJECT: + registry.apply_project( + cast(Project, feast_object_diff.new_feast_object), + commit=False, + ) if feast_object_diff.feast_object_type == FeastObjectType.DATA_SOURCE: registry.apply_data_source( cast(DataSource, feast_object_diff.new_feast_object), diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index fd5955fd98..4dbb220c1e 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -480,6 +480,16 @@ def __init__(self, name, project=None): super().__init__(f"Permission {name} does not exist") +class ProjectNotFoundException(FeastError): + def __init__(self, project): + super().__init__(f"Project {project} does not exist in registry") + + +class ProjectObjectNotFoundException(FeastObjectNotFoundException): + def __init__(self, name, project=None): + super().__init__(f"Project {name} does not exist") + + class ZeroRowsQueryResult(FeastError): def __init__(self, query: str): super().__init__(f"This query returned zero rows:\n{query}") diff --git a/sdk/python/feast/feast_object.py b/sdk/python/feast/feast_object.py index dfe29b7128..63fa1e913b 100644 --- a/sdk/python/feast/feast_object.py +++ b/sdk/python/feast/feast_object.py @@ -1,5 +1,8 @@ from typing import Union, get_args +from feast.project import Project +from feast.protos.feast.core.Project_pb2 import ProjectSpec + from .batch_feature_view import BatchFeatureView from .data_source import DataSource from .entity import Entity @@ -23,6 +26,7 @@ # Convenience type representing all Feast objects FeastObject = Union[ + Project, FeatureView, OnDemandFeatureView, BatchFeatureView, @@ -36,6 +40,7 @@ ] FeastObjectSpecProto = Union[ + ProjectSpec, FeatureViewSpec, OnDemandFeatureViewSpec, StreamFeatureViewSpec, diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index a03706e56f..27b6eade5b 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -60,11 +60,7 @@ ) from feast.feast_object import FeastObject from feast.feature_service import FeatureService -from feast.feature_view import ( - DUMMY_ENTITY, - DUMMY_ENTITY_NAME, - FeatureView, -) +from feast.feature_view import DUMMY_ENTITY, DUMMY_ENTITY_NAME, FeatureView from feast.inference import ( update_data_sources_with_inferred_event_timestamp_col, update_feature_views_with_inferred_features_and_entities, @@ -77,6 +73,7 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.online_response import OnlineResponse from feast.permissions.permission import Permission +from feast.project import Project from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto from feast.protos.feast.serving.ServingService_pb2 import ( FieldStatus, @@ -162,14 +159,12 @@ def __init__( registry_config, self.config.project, None, self.config.auth_config ) else: - r = Registry( + self._registry = Registry( self.config.project, registry_config, repo_path=self.repo_path, auth_config=self.config.auth_config, ) - r._initialize_registry(self.config.project) - self._registry = r self._provider = get_provider(self.config) @@ -205,16 +200,8 @@ def refresh_registry(self): greater than 0, then once the cache becomes stale (more time than the TTL has passed), a new cache will be downloaded synchronously, which may increase latencies if the triggering method is get_online_features(). """ - registry_config = self.config.registry - registry = Registry( - self.config.project, - registry_config, - repo_path=self.repo_path, - auth_config=self.config.auth_config, - ) - registry.refresh(self.config.project) - self._registry = registry + self._registry.refresh(self.project) def list_entities( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None @@ -740,6 +727,7 @@ def plan( ... source=driver_hourly_stats, ... ) >>> registry_diff, infra_diff, new_infra = fs.plan(RepoContents( + ... projects=[Project(name="project")], ... data_sources=[driver_hourly_stats], ... feature_views=[driver_hourly_stats_view], ... on_demand_feature_views=list(), @@ -802,6 +790,7 @@ def _apply_diffs( def apply( self, objects: Union[ + Project, DataSource, Entity, FeatureView, @@ -862,6 +851,9 @@ def apply( objects_to_delete = [] # Separate all objects into entities, feature services, and different feature view types. + projects_to_update = [ob for ob in objects if isinstance(ob, Project)] + if len(projects_to_update) > 1: + raise ValueError("Only one project can be applied at a time.") entities_to_update = [ob for ob in objects if isinstance(ob, Entity)] views_to_update = [ ob @@ -924,6 +916,8 @@ def apply( ) # Add all objects to the registry and update the provider's infrastructure. + for project in projects_to_update: + self._registry.apply_project(project, commit=False) for ds in data_sources_to_update: self._registry.apply_data_source(ds, project=self.project, commit=False) for view in itertools.chain(views_to_update, odfvs_to_update, sfvs_to_update): @@ -1990,6 +1984,36 @@ def get_permission(self, name: str) -> Permission: """ return self._registry.get_permission(name, self.project) + def list_projects( + self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None + ) -> List[Project]: + """ + Retrieves the list of projects from the registry. + + Args: + allow_cache: Whether to allow returning projects from a cached registry. + tags: Filter by tags. + + Returns: + A list of projects. + """ + return self._registry.list_projects(allow_cache=allow_cache, tags=tags) + + def get_project(self, name: Optional[str]) -> Project: + """ + Retrieves a project from the registry. + + Args: + name: Name of the project. + + Returns: + The specified project. + + Raises: + ProjectObjectNotFoundException: The project could not be found. + """ + return self._registry.get_project(name or self.project) + def list_saved_datasets( self, allow_cache: bool = False, tags: Optional[dict[str, str]] = None ) -> List[SavedDataset]: diff --git a/sdk/python/feast/feature_view.py b/sdk/python/feast/feature_view.py index 1a85a4b90c..dd01078e20 100644 --- a/sdk/python/feast/feature_view.py +++ b/sdk/python/feast/feature_view.py @@ -423,7 +423,7 @@ def from_proto(cls, feature_view_proto: FeatureViewProto): if len(feature_view.entities) != len(feature_view.entity_columns): warnings.warn( - f"There are some mismatches in your feature view's registered entities. Please check if you have applied your entities correctly." + f"There are some mismatches in your feature view: {feature_view.name} registered entities. Please check if you have applied your entities correctly." f"Entities: {feature_view.entities} vs Entity Columns: {feature_view.entity_columns}" ) diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 5f65d8da8b..8a7e299516 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -109,7 +109,7 @@ def online_read( result_tuples.append((event_ts, feature_values_dict)) return result_tuples else: - error_msg = f"Unable to retrieve the online store data using feature server API. Error_code={response.status_code}, error_message={response.reason}" + error_msg = f"Unable to retrieve the online store data using feature server API. Error_code={response.status_code}, error_message={response.text}" logger.error(error_msg) raise RuntimeError(error_msg) diff --git a/sdk/python/feast/infra/registry/base_registry.py b/sdk/python/feast/infra/registry/base_registry.py index 33adb6b7c9..f5040d9752 100644 --- a/sdk/python/feast/infra/registry/base_registry.py +++ b/sdk/python/feast/infra/registry/base_registry.py @@ -29,6 +29,7 @@ from feast.infra.infra_object import Infra from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto from feast.protos.feast.core.FeatureService_pb2 import ( @@ -39,6 +40,7 @@ OnDemandFeatureView as OnDemandFeatureViewProto, ) from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto from feast.protos.feast.core.StreamFeatureView_pb2 import ( @@ -663,6 +665,71 @@ def list_permissions( """ raise NotImplementedError + @abstractmethod + def apply_project( + self, + project: Project, + commit: bool = True, + ): + """ + Registers a project with Feast + + Args: + project: A project that will be registered + commit: Whether to immediately commit to the registry + """ + raise NotImplementedError + + @abstractmethod + def delete_project( + self, + name: str, + commit: bool = True, + ): + """ + Deletes a project or raises an ProjectNotFoundException exception if not found. + + Args: + project: Feast project name that needs to be deleted + commit: Whether the change should be persisted immediately + """ + raise NotImplementedError + + @abstractmethod + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + """ + Retrieves a project. + + Args: + name: Feast project name + allow_cache: Whether to allow returning this permission from a cached registry + + Returns: + Returns either the specified project, or raises ProjectObjectNotFoundException exception if none is found + """ + raise NotImplementedError + + @abstractmethod + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + """ + Retrieve a list of projects from the registry + + Args: + allow_cache: Whether to allow returning permissions from a cached registry + + Returns: + List of project + """ + raise NotImplementedError + @abstractmethod def proto(self) -> RegistryProto: """ @@ -814,4 +881,6 @@ def deserialize_registry_values(serialized_proto, feast_obj_type) -> Any: return FeatureServiceProto.FromString(serialized_proto) if feast_obj_type == Permission: return PermissionProto.FromString(serialized_proto) + if feast_obj_type == Project: + return ProjectProto.FromString(serialized_proto) return None diff --git a/sdk/python/feast/infra/registry/caching_registry.py b/sdk/python/feast/infra/registry/caching_registry.py index 611d67de96..c04a62552b 100644 --- a/sdk/python/feast/infra/registry/caching_registry.py +++ b/sdk/python/feast/infra/registry/caching_registry.py @@ -1,6 +1,7 @@ import atexit import logging import threading +import warnings from abc import abstractmethod from datetime import timedelta from threading import Lock @@ -15,6 +16,7 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.saved_dataset import SavedDataset, ValidationReference from feast.stream_feature_view import StreamFeatureView @@ -26,7 +28,6 @@ class CachingRegistry(BaseRegistry): def __init__(self, project: str, cache_ttl_seconds: int, cache_mode: str): self.cached_registry_proto = self.proto() - proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) self.cached_registry_proto_created = _utc_now() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( @@ -308,6 +309,10 @@ def _list_project_metadata(self, project: str) -> List[ProjectMetadata]: def list_project_metadata( self, project: str, allow_cache: bool = False ) -> List[ProjectMetadata]: + warnings.warn( + "list_project_metadata is deprecated and will be removed in a future version. Use list_projects() and get_project() methods instead.", + DeprecationWarning, + ) if allow_cache: self._refresh_cached_registry_if_necessary() return proto_registry_utils.list_project_metadata( @@ -355,15 +360,35 @@ def list_permissions( ) return self._list_permissions(project, tags) + @abstractmethod + def _get_project(self, name: str) -> Project: + pass + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_project(self.cached_registry_proto, name) + return self._get_project(name) + + @abstractmethod + def _list_projects(self, tags: Optional[dict[str, str]]) -> List[Project]: + pass + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_projects(self.cached_registry_proto, tags) + return self._list_projects(tags) + def refresh(self, project: Optional[str] = None): - if project: - project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=self.cached_registry_proto, project=project - ) - if not project_metadata: - proto_registry_utils.init_project_metadata( - self.cached_registry_proto, project - ) self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() @@ -395,7 +420,7 @@ def _start_thread_async_refresh(self, cache_ttl_seconds): self.registry_refresh_thread = threading.Timer( cache_ttl_seconds, self._start_thread_async_refresh, [cache_ttl_seconds] ) - self.registry_refresh_thread.setDaemon(True) + self.registry_refresh_thread.daemon = True self.registry_refresh_thread.start() def _exit_handler(self): diff --git a/sdk/python/feast/infra/registry/proto_registry_utils.py b/sdk/python/feast/infra/registry/proto_registry_utils.py index f67808aab5..b0413fd77e 100644 --- a/sdk/python/feast/infra/registry/proto_registry_utils.py +++ b/sdk/python/feast/infra/registry/proto_registry_utils.py @@ -1,4 +1,3 @@ -import uuid from functools import wraps from typing import List, Optional @@ -11,6 +10,7 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionObjectNotFoundException, + ProjectObjectNotFoundException, SavedDatasetNotFound, ValidationReferenceNotFound, ) @@ -18,6 +18,7 @@ from feast.feature_view import FeatureView from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.Registry_pb2 import ProjectMetadata as ProjectMetadataProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto @@ -69,13 +70,6 @@ def wrapper( return wrapper -def init_project_metadata(cached_registry_proto: RegistryProto, project: str): - new_project_uuid = f"{uuid.uuid4()}" - cached_registry_proto.project_metadata.append( - ProjectMetadata(project_name=project, project_uuid=new_project_uuid).to_proto() - ) - - def get_project_metadata( registry_proto: Optional[RegistryProto], project: str ) -> Optional[ProjectMetadataProto]: @@ -316,3 +310,21 @@ def get_permission( ): return Permission.from_proto(permission_proto) raise PermissionObjectNotFoundException(name=name, project=project) + + +def list_projects( + registry_proto: RegistryProto, + tags: Optional[dict[str, str]], +) -> List[Project]: + projects = [] + for project_proto in registry_proto.projects: + if utils.has_all_tags(project_proto.spec.tags, tags): + projects.append(Project.from_proto(project_proto)) + return projects + + +def get_project(registry_proto: RegistryProto, name: str) -> Project: + for projects_proto in registry_proto.projects: + if projects_proto.spec.name == name: + return Project.from_proto(projects_proto) + raise ProjectObjectNotFoundException(name=name) diff --git a/sdk/python/feast/infra/registry/registry.py b/sdk/python/feast/infra/registry/registry.py index 366f3aacaa..634d6fa7ac 100644 --- a/sdk/python/feast/infra/registry/registry.py +++ b/sdk/python/feast/infra/registry/registry.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from enum import Enum from pathlib import Path from threading import Lock @@ -32,6 +32,8 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, ValidationReferenceNotFound, ) from feast.feature_service import FeatureService @@ -44,6 +46,7 @@ from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.auth_model import AuthConfig, NoAuthConfig from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.repo_config import RegistryConfig @@ -70,6 +73,7 @@ class FeastObjectType(Enum): + PROJECT = "project" DATA_SOURCE = "data source" ENTITY = "entity" FEATURE_VIEW = "feature view" @@ -83,6 +87,11 @@ def get_objects_from_registry( registry: "BaseRegistry", project: str ) -> Dict["FeastObjectType", List[Any]]: return { + FeastObjectType.PROJECT: [ + project_obj + for project_obj in registry.list_projects() + if project_obj.name == project + ], FeastObjectType.DATA_SOURCE: registry.list_data_sources(project=project), FeastObjectType.ENTITY: registry.list_entities(project=project), FeastObjectType.FEATURE_VIEW: registry.list_feature_views(project=project), @@ -103,6 +112,7 @@ def get_objects_from_repo_contents( repo_contents: RepoContents, ) -> Dict["FeastObjectType", List[Any]]: return { + FeastObjectType.PROJECT: repo_contents.projects, FeastObjectType.DATA_SOURCE: repo_contents.data_sources, FeastObjectType.ENTITY: repo_contents.entities, FeastObjectType.FEATURE_VIEW: repo_contents.feature_views, @@ -157,34 +167,10 @@ def get_user_metadata( # The cached_registry_proto object is used for both reads and writes. In particular, # all write operations refresh the cache and modify it in memory; the write must # then be persisted to the underlying RegistryStore with a call to commit(). - cached_registry_proto: Optional[RegistryProto] = None - cached_registry_proto_created: Optional[datetime] = None + cached_registry_proto: RegistryProto + cached_registry_proto_created: datetime cached_registry_proto_ttl: timedelta - def __new__( - cls, - project: str, - registry_config: Optional[RegistryConfig], - repo_path: Optional[Path], - auth_config: AuthConfig = NoAuthConfig(), - ): - # We override __new__ so that we can inspect registry_config and create a SqlRegistry without callers - # needing to make any changes. - if registry_config and registry_config.registry_type == "sql": - from feast.infra.registry.sql import SqlRegistry - - return SqlRegistry(registry_config, project, repo_path) - elif registry_config and registry_config.registry_type == "snowflake.registry": - from feast.infra.registry.snowflake import SnowflakeRegistry - - return SnowflakeRegistry(registry_config, project, repo_path) - elif registry_config and registry_config.registry_type == "remote": - from feast.infra.registry.remote import RemoteRegistry - - return RemoteRegistry(registry_config, project, repo_path, auth_config) - else: - return super(Registry, cls).__new__(cls) - def __init__( self, project: str, @@ -204,6 +190,17 @@ def __init__( self._refresh_lock = Lock() self._auth_config = auth_config + registry_proto = RegistryProto() + registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + + self.purge_feast_metadata = ( + registry_config.purge_feast_metadata + if registry_config is not None + else False + ) + if registry_config: registry_store_type = registry_config.registry_store_type registry_path = registry_config.path @@ -214,11 +211,52 @@ def __init__( self._registry_store = cls(registry_config, repo_path) self.cached_registry_proto_ttl = timedelta( - seconds=registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) ) + try: + registry_proto = self._registry_store.get_registry_proto() + self.cached_registry_proto = registry_proto + self.cached_registry_proto_created = _utc_now() + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + except FileNotFoundError: + logger.info("Registry file not found. Creating new registry.") + finally: + self.commit() + + def _sync_feast_metadata_to_projects_table(self): + """ + Sync feast_metadata to projects table + """ + feast_metadata_projects = [] + projects_set = [] + # List of project in project_metadata + for project_metadata in self.cached_registry_proto.project_metadata: + project = ProjectMetadata.from_proto(project_metadata) + feast_metadata_projects.append(project.project_name) + if len(feast_metadata_projects) > 0: + # List of project in projects + for project_metadata in self.cached_registry_proto.projects: + project = Project.from_proto(project_metadata) + projects_set.append(project.name) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + # Sync feast_metadata to projects table + for project_name in projects_to_sync: + project = Project(name=project_name) + self.cached_registry_proto.projects.append(project.to_proto()) + + if self.purge_feast_metadata: + self.cached_registry_proto.project_metadata = [] + def clone(self) -> "Registry": new_registry = Registry("project", None, None, self._auth_config) new_registry.cached_registry_proto_ttl = timedelta(seconds=0) @@ -231,16 +269,6 @@ def clone(self) -> "Registry": new_registry._registry_store = NoopRegistryStore() return new_registry - def _initialize_registry(self, project: str): - """Explicitly initializes the registry with an empty proto if it doesn't exist.""" - try: - self._get_registry_proto(project=project) - except FileNotFoundError: - registry_proto = RegistryProto() - registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION - proto_registry_utils.init_project_metadata(registry_proto, project) - self._registry_store.update_registry_proto(registry_proto) - def update_infra(self, infra: Infra, project: str, commit: bool = True): self._prepare_registry_for_changes(project) assert self.cached_registry_proto @@ -320,7 +348,7 @@ def apply_data_source( data_source_proto.data_source_class_type = ( f"{data_source.__class__.__module__}.{data_source.__class__.__name__}" ) - registry.data_sources.append(data_source_proto) + self.cached_registry_proto.data_sources.append(data_source_proto) if commit: self.commit() @@ -363,7 +391,7 @@ def apply_feature_service( feature_service_proto = feature_service.to_proto() feature_service_proto.spec.project = project del registry.feature_services[idx] - registry.feature_services.append(feature_service_proto) + self.cached_registry_proto.feature_services.append(feature_service_proto) if commit: self.commit() @@ -773,15 +801,16 @@ def list_validation_references( ) def delete_validation_reference(self, name: str, project: str, commit: bool = True): - registry_proto = self._prepare_registry_for_changes(project) + self._prepare_registry_for_changes(project) + assert self.cached_registry_proto for idx, existing_validation_reference in enumerate( - registry_proto.validation_references + self.cached_registry_proto.validation_references ): if ( existing_validation_reference.name == name and existing_validation_reference.project == project ): - del registry_proto.validation_references[idx] + del self.cached_registry_proto.validation_references[idx] if commit: self.commit() return @@ -811,37 +840,36 @@ def teardown(self): def proto(self) -> RegistryProto: return self.cached_registry_proto or RegistryProto() - def _prepare_registry_for_changes(self, project: str): + def _prepare_registry_for_changes(self, project_name: str): """Prepares the Registry for changes by refreshing the cache if necessary.""" + + assert self.cached_registry_proto is not None + try: - self._get_registry_proto(project=project, allow_cache=True) - if ( - proto_registry_utils.get_project_metadata( - self.cached_registry_proto, project - ) - is None - ): - # Project metadata not initialized yet. Try pulling without cache - self._get_registry_proto(project=project, allow_cache=False) - except FileNotFoundError: - registry_proto = RegistryProto() - registry_proto.registry_schema_version = REGISTRY_SCHEMA_VERSION + # Check if the project exists in the registry cache + self.get_project(name=project_name, allow_cache=True) + return self.cached_registry_proto + except ProjectObjectNotFoundException: + # If the project does not exist in cache, refresh cache from store + registry_proto = self._registry_store.get_registry_proto() self.cached_registry_proto = registry_proto self.cached_registry_proto_created = _utc_now() - # Initialize project metadata if needed - assert self.cached_registry_proto - if ( - proto_registry_utils.get_project_metadata( - self.cached_registry_proto, project - ) - is None - ): - proto_registry_utils.init_project_metadata( - self.cached_registry_proto, project - ) + try: + # Check if the project exists in the registry cache after refresh from store + self.get_project(name=project_name) + except ProjectObjectNotFoundException: + # If the project still does not exist, create it + project_proto = Project(name=project_name).to_proto() + self.cached_registry_proto.projects.append(project_proto) + if not self.purge_feast_metadata: + project_metadata_proto = ProjectMetadata( + project_name=project_name + ).to_proto() + self.cached_registry_proto.project_metadata.append( + project_metadata_proto + ) self.commit() - return self.cached_registry_proto def _get_registry_proto( @@ -856,10 +884,7 @@ def _get_registry_proto( Returns: Returns a RegistryProto object which represents the state of the registry """ with self._refresh_lock: - expired = ( - self.cached_registry_proto is None - or self.cached_registry_proto_created is None - ) or ( + expired = (self.cached_registry_proto_created is None) or ( self.cached_registry_proto_ttl.total_seconds() > 0 # 0 ttl means infinity and ( @@ -871,33 +896,12 @@ def _get_registry_proto( ) ) - if project: - old_project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=self.cached_registry_proto, project=project - ) - - if allow_cache and not expired and old_project_metadata is not None: - assert isinstance(self.cached_registry_proto, RegistryProto) - return self.cached_registry_proto - elif allow_cache and not expired: - assert isinstance(self.cached_registry_proto, RegistryProto) + if allow_cache and not expired: return self.cached_registry_proto - logger.info("Registry cache expired, so refreshing") registry_proto = self._registry_store.get_registry_proto() self.cached_registry_proto = registry_proto self.cached_registry_proto_created = _utc_now() - - if not project: - return registry_proto - - project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=registry_proto, project=project - ) - if not project_metadata: - proto_registry_utils.init_project_metadata(registry_proto, project) - self.commit() - return registry_proto def _check_conflicting_feature_view_names(self, feature_view: BaseFeatureView): @@ -960,7 +964,7 @@ def apply_permission( permission_proto = permission.to_proto() permission_proto.spec.project = project - registry.permissions.append(permission_proto) + self.cached_registry_proto.permissions.append(permission_proto) if commit: self.commit() @@ -978,3 +982,91 @@ def delete_permission(self, name: str, project: str, commit: bool = True): self.commit() return raise PermissionNotFoundException(name, project) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + registry = self.cached_registry_proto + + for idx, existing_project_proto in enumerate(registry.projects): + if existing_project_proto.spec.name == project.name: + project.created_timestamp = ( + existing_project_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + del registry.projects[idx] + + project_proto = project.to_proto() + self.cached_registry_proto.projects.append(project_proto) + if commit: + self.commit() + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + registry_proto = self._get_registry_proto(project=name, allow_cache=allow_cache) + return proto_registry_utils.get_project(registry_proto, name) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + registry_proto = self._get_registry_proto(project=None, allow_cache=allow_cache) + return proto_registry_utils.list_projects( + registry_proto=registry_proto, tags=tags + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + assert self.cached_registry_proto + + for idx, project_proto in enumerate(self.cached_registry_proto.projects): + if project_proto.spec.name == name: + list_validation_references = self.list_validation_references(name) + for validation_reference in list_validation_references: + self.delete_validation_reference(validation_reference.name, name) + + list_saved_datasets = self.list_saved_datasets(name) + for saved_dataset in list_saved_datasets: + self.delete_saved_dataset(saved_dataset.name, name) + + list_feature_services = self.list_feature_services(name) + for feature_service in list_feature_services: + self.delete_feature_service(feature_service.name, name) + + list_on_demand_feature_views = self.list_on_demand_feature_views(name) + for on_demand_feature_view in list_on_demand_feature_views: + self.delete_feature_view(on_demand_feature_view.name, name) + + list_stream_feature_views = self.list_stream_feature_views(name) + for stream_feature_view in list_stream_feature_views: + self.delete_feature_view(stream_feature_view.name, name) + + list_feature_views = self.list_feature_views(name) + for feature_view in list_feature_views: + self.delete_feature_view(feature_view.name, name) + + list_data_sources = self.list_data_sources(name) + for data_source in list_data_sources: + self.delete_data_source(data_source.name, name) + + list_entities = self.list_entities(name) + for entity in list_entities: + self.delete_entity(entity.name, name) + list_permissions = self.list_permissions(name) + for permission in list_permissions: + self.delete_permission(permission.name, name) + del self.cached_registry_proto.projects[idx] + if commit: + self.commit() + return + raise ProjectNotFoundException(name) diff --git a/sdk/python/feast/infra/registry/remote.py b/sdk/python/feast/infra/registry/remote.py index 618628bc07..ba25ef7dbe 100644 --- a/sdk/python/feast/infra/registry/remote.py +++ b/sdk/python/feast/infra/registry/remote.py @@ -16,14 +16,12 @@ from feast.infra.registry.base_registry import BaseRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.auth.auth_type import AuthType -from feast.permissions.auth_model import ( - AuthConfig, - NoAuthConfig, -) +from feast.permissions.auth_model import AuthConfig, NoAuthConfig from feast.permissions.client.grpc_client_auth_interceptor import ( GrpcClientAuthHeaderInterceptor, ) from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc @@ -50,11 +48,18 @@ def __init__( auth_config: AuthConfig = NoAuthConfig(), ): self.auth_config = auth_config - channel = grpc.insecure_channel(registry_config.path) + self.channel = grpc.insecure_channel(registry_config.path) if self.auth_config.type != AuthType.NONE.value: auth_header_interceptor = GrpcClientAuthHeaderInterceptor(auth_config) - channel = grpc.intercept_channel(channel, auth_header_interceptor) - self.stub = RegistryServer_pb2_grpc.RegistryServerStub(channel) + self.channel = grpc.intercept_channel(self.channel, auth_header_interceptor) + self.stub = RegistryServer_pb2_grpc.RegistryServerStub(self.channel) + + def close(self): + if self.channel: + self.channel.close() + + def __del__(self): + self.close() def apply_entity(self, entity: Entity, project: str, commit: bool = True): request = RegistryServer_pb2.ApplyEntityRequest( @@ -173,15 +178,17 @@ def apply_feature_view( arg_name = "on_demand_feature_view" request = RegistryServer_pb2.ApplyFeatureViewRequest( - feature_view=feature_view.to_proto() - if arg_name == "feature_view" - else None, - stream_feature_view=feature_view.to_proto() - if arg_name == "stream_feature_view" - else None, - on_demand_feature_view=feature_view.to_proto() - if arg_name == "on_demand_feature_view" - else None, + feature_view=( + feature_view.to_proto() if arg_name == "feature_view" else None + ), + stream_feature_view=( + feature_view.to_proto() if arg_name == "stream_feature_view" else None + ), + on_demand_feature_view=( + feature_view.to_proto() + if arg_name == "on_demand_feature_view" + else None + ), project=project, commit=commit, ) @@ -450,6 +457,49 @@ def list_permissions( Permission.from_proto(permission) for permission in response.permissions ] + def apply_project( + self, + project: Project, + commit: bool = True, + ): + project_proto = project.to_proto() + + request = RegistryServer_pb2.ApplyProjectRequest( + project=project_proto, commit=commit + ) + self.stub.ApplyProject(request) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + request = RegistryServer_pb2.DeleteProjectRequest(name=name, commit=commit) + self.stub.DeleteProject(request) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + request = RegistryServer_pb2.GetProjectRequest( + name=name, allow_cache=allow_cache + ) + response = self.stub.GetProject(request) + + return Project.from_proto(response) + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + request = RegistryServer_pb2.ListProjectsRequest( + allow_cache=allow_cache, tags=tags + ) + response = self.stub.ListProjects(request) + return [Project.from_proto(project) for project in response.projects] + def proto(self) -> RegistryProto: return self.stub.Proto(Empty()) diff --git a/sdk/python/feast/infra/registry/snowflake.py b/sdk/python/feast/infra/registry/snowflake.py index 801b90afe3..accfa42e12 100644 --- a/sdk/python/feast/infra/registry/snowflake.py +++ b/sdk/python/feast/infra/registry/snowflake.py @@ -5,7 +5,7 @@ from datetime import datetime, timedelta, timezone from enum import Enum from threading import Lock -from typing import Any, Callable, List, Literal, Optional, Set, Union +from typing import Any, Callable, List, Literal, Optional, Union from pydantic import ConfigDict, Field, StrictStr @@ -19,6 +19,8 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, SavedDatasetNotFound, ValidationReferenceNotFound, ) @@ -33,6 +35,7 @@ ) from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto @@ -45,6 +48,7 @@ OnDemandFeatureView as OnDemandFeatureViewProto, ) from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto from feast.protos.feast.core.StreamFeatureView_pb2 import ( @@ -138,26 +142,57 @@ def __init__( query = command.replace("REGISTRY_PATH", f"{self.registry_path}") execute_snowflake_statement(conn, query) + self.purge_feast_metadata = registry_config.purge_feast_metadata + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + self.cached_registry_proto = self.proto() - proto_registry_utils.init_project_metadata(self.cached_registry_proto, project) self.cached_registry_proto_created = _utc_now() self._refresh_lock = Lock() self.cached_registry_proto_ttl = timedelta( - seconds=registry_config.cache_ttl_seconds - if registry_config.cache_ttl_seconds is not None - else 0 + seconds=( + registry_config.cache_ttl_seconds + if registry_config.cache_ttl_seconds is not None + else 0 + ) ) self.project = project - def refresh(self, project: Optional[str] = None): - if project: - project_metadata = proto_registry_utils.get_project_metadata( - registry_proto=self.cached_registry_proto, project=project + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: set = [] + projects_set: set = [] + + with GetSnowflakeConnection(self.registry_config) as conn: + query = ( + f'SELECT DISTINCT project_id FROM {self.registry_path}."FEAST_METADATA"' ) - if not project_metadata: - proto_registry_utils.init_project_metadata( - self.cached_registry_proto, project - ) + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + feast_metadata_projects.add(row[1]["PROJECT_ID"]) + + if len(feast_metadata_projects) > 0: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f'SELECT project_id FROM {self.registry_path}."PROJECTS"' + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + + for row in df.iterrows(): + projects_set.add(row[1]["PROJECT_ID"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project(Project(name=project_name), commit=True) + + if self.purge_feast_metadata: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + DELETE FROM {self.registry_path}."FEAST_METADATA" + """ + execute_snowflake_statement(conn, query) + + def refresh(self, project: Optional[str] = None): self.cached_registry_proto = self.proto() self.cached_registry_proto_created = _utc_now() @@ -271,6 +306,17 @@ def update_infra(self, infra: Infra, project: str, commit: bool = True): name="infra_obj", ) + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + def _apply_object( self, table: str, @@ -280,7 +326,11 @@ def _apply_object( proto_field_name: str, name: Optional[str] = None, ): - self._maybe_init_project_metadata(project) + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" @@ -343,7 +393,13 @@ def _apply_object( """ execute_snowflake_statement(conn, query) - self._set_last_updated_metadata(update_datetime, project) + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) def apply_permission( self, permission: Permission, project: str, commit: bool = True @@ -620,7 +676,6 @@ def _get_object( proto_field_name: str, not_found_exception: Optional[Callable], ): - self._maybe_init_project_metadata(project) with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT @@ -821,7 +876,6 @@ def _list_objects( proto_field_name: str, tags: Optional[dict[str, str]] = None, ): - self._maybe_init_project_metadata(project) with GetSnowflakeConnection(self.registry_config) as conn: query = f""" SELECT @@ -992,8 +1046,27 @@ def get_user_metadata( def proto(self) -> RegistryProto: r = RegistryProto() last_updated_timestamps = [] - projects = self._get_all_projects() - for project in projects: + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + for lister, registry_proto_field in [ (self.list_entities, r.entities), (self.list_feature_views, r.feature_views), @@ -1003,53 +1076,31 @@ def proto(self) -> RegistryProto: (self.list_feature_services, r.feature_services), (self.list_saved_datasets, r.saved_datasets), (self.list_validation_references, r.validation_references), - (self.list_project_metadata, r.project_metadata), (self.list_permissions, r.permissions), ]: - objs: List[Any] = lister(project) # type: ignore + objs: List[Any] = lister(project_name, allow_cache) # type: ignore if objs: obj_protos = [obj.to_proto() for obj in objs] for obj_proto in obj_protos: if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project + obj_proto.spec.project = project_name else: - obj_proto.project = project + obj_proto.project = project_name registry_proto_field.extend(obj_protos) # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project).to_proto()) - last_updated_timestamps.append(self._get_last_updated_metadata(project)) + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + for project in projects_list: + process_project(project) if last_updated_timestamps: r.last_updated.FromDatetime(max(last_updated_timestamps)) return r - def _get_all_projects(self) -> Set[str]: - projects = set() - - base_tables = [ - "DATA_SOURCES", - "ENTITIES", - "FEATURE_VIEWS", - "ON_DEMAND_FEATURE_VIEWS", - "STREAM_FEATURE_VIEWS", - "PERMISSIONS", - ] - - with GetSnowflakeConnection(self.registry_config) as conn: - for table in base_tables: - query = ( - f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"' - ) - df = execute_snowflake_statement(conn, query).fetch_pandas_all() - - for row in df.iterrows(): - projects.add(row[1]["PROJECT_ID"]) - - return projects - def _get_last_updated_metadata(self, project: str): with GetSnowflakeConnection(self.registry_config) as conn: query = f""" @@ -1153,3 +1204,98 @@ def _set_last_updated_metadata(self, last_updated: datetime, project: str): def commit(self): pass + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + "PROJECTS", project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with GetSnowflakeConnection(self.registry_config) as conn: + for table in { + "MANAGED_INFRA", + "SAVED_DATASETS", + "VALIDATION_REFERENCES", + "FEATURE_SERVICES", + "FEATURE_VIEWS", + "ON_DEMAND_FEATURE_VIEWS", + "STREAM_FEATURE_VIEWS", + "DATA_SOURCES", + "ENTITIES", + "PERMISSIONS", + "PROJECTS", + }: + query = f""" + DELETE FROM {self.registry_path}."{table}" + WHERE + project_id = '{project}' + """ + execute_snowflake_statement(conn, query) + return + + raise ProjectNotFoundException(name) + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table="PROJECTS", + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def get_project( + self, + name: str, + allow_cache: bool = False, + ) -> Project: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.get_project(self.cached_registry_proto, name) + return self._get_project(name) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with GetSnowflakeConnection(self.registry_config) as conn: + query = f""" + SELECT project_proto FROM {self.registry_path}."PROJECTS" + """ + df = execute_snowflake_statement(conn, query).fetch_pandas_all() + if not df.empty: + objects = [] + for row in df.iterrows(): + obj = Project.from_proto( + ProjectProto.FromString(row[1]["project_proto"]) + ) + if has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def list_projects( + self, + allow_cache: bool = False, + tags: Optional[dict[str, str]] = None, + ) -> List[Project]: + if allow_cache: + self._refresh_cached_registry_if_necessary() + return proto_registry_utils.list_projects(self.cached_registry_proto, tags) + return self._list_projects(tags) diff --git a/sdk/python/feast/infra/registry/sql.py b/sdk/python/feast/infra/registry/sql.py index 90c6e82e7d..2b4a58266c 100644 --- a/sdk/python/feast/infra/registry/sql.py +++ b/sdk/python/feast/infra/registry/sql.py @@ -1,11 +1,12 @@ import logging import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union -from pydantic import StrictStr +from pydantic import StrictInt, StrictStr from sqlalchemy import ( # type: ignore BigInteger, Column, @@ -31,6 +32,8 @@ FeatureServiceNotFoundException, FeatureViewNotFoundException, PermissionNotFoundException, + ProjectNotFoundException, + ProjectObjectNotFoundException, SavedDatasetNotFound, ValidationReferenceNotFound, ) @@ -40,6 +43,7 @@ from feast.infra.registry.caching_registry import CachingRegistry from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.project_metadata import ProjectMetadata from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.protos.feast.core.Entity_pb2 import Entity as EntityProto @@ -52,6 +56,7 @@ OnDemandFeatureView as OnDemandFeatureViewProto, ) from feast.protos.feast.core.Permission_pb2 import Permission as PermissionProto +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto from feast.protos.feast.core.StreamFeatureView_pb2 import ( @@ -67,11 +72,21 @@ metadata = MetaData() + +projects = Table( + "projects", + metadata, + Column("project_id", String(255), primary_key=True), + Column("project_name", String(255), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), + Column("project_proto", LargeBinary, nullable=False), +) + entities = Table( "entities", metadata, - Column("entity_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("entity_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("entity_proto", LargeBinary, nullable=False), ) @@ -80,7 +95,7 @@ "data_sources", metadata, Column("data_source_name", String(255), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("data_source_proto", LargeBinary, nullable=False), ) @@ -88,8 +103,8 @@ feature_views = Table( "feature_views", metadata, - Column("feature_view_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("materialized_intervals", LargeBinary, nullable=True), Column("feature_view_proto", LargeBinary, nullable=False), @@ -99,8 +114,8 @@ stream_feature_views = Table( "stream_feature_views", metadata, - Column("feature_view_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), Column("user_metadata", LargeBinary, nullable=True), @@ -109,8 +124,8 @@ on_demand_feature_views = Table( "on_demand_feature_views", metadata, - Column("feature_view_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("feature_view_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_view_proto", LargeBinary, nullable=False), Column("user_metadata", LargeBinary, nullable=True), @@ -119,8 +134,8 @@ feature_services = Table( "feature_services", metadata, - Column("feature_service_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("feature_service_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("feature_service_proto", LargeBinary, nullable=False), ) @@ -128,8 +143,8 @@ saved_datasets = Table( "saved_datasets", metadata, - Column("saved_dataset_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("saved_dataset_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("saved_dataset_proto", LargeBinary, nullable=False), ) @@ -137,8 +152,8 @@ validation_references = Table( "validation_references", metadata, - Column("validation_reference_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("validation_reference_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("validation_reference_proto", LargeBinary, nullable=False), ) @@ -146,8 +161,8 @@ managed_infra = Table( "managed_infra", metadata, - Column("infra_name", String(50), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("infra_name", String(255), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("infra_proto", LargeBinary, nullable=False), ) @@ -156,7 +171,7 @@ "permissions", metadata, Column("permission_name", String(255), primary_key=True), - Column("project_id", String(50), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("last_updated_timestamp", BigInteger, nullable=False), Column("permission_proto", LargeBinary, nullable=False), ) @@ -170,7 +185,7 @@ class FeastMetadataKeys(Enum): feast_metadata = Table( "feast_metadata", metadata, - Column("project_id", String(50), primary_key=True), + Column("project_id", String(255), primary_key=True), Column("metadata_key", String(50), primary_key=True), Column("metadata_value", String(50), nullable=False), Column("last_updated_timestamp", BigInteger, nullable=False), @@ -190,26 +205,75 @@ class SqlRegistryConfig(RegistryConfig): sqlalchemy_config_kwargs: Dict[str, Any] = {"echo": False} """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ + cache_mode: StrictStr = "sync" + """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + + thread_pool_executor_worker_count: StrictInt = 0 + """ int: Number of worker threads to use for asynchronous caching in SQL Registry. If set to 0, it doesn't use ThreadPoolExecutor. """ + class SqlRegistry(CachingRegistry): def __init__( self, - registry_config: Optional[Union[RegistryConfig, SqlRegistryConfig]], + registry_config, project: str, repo_path: Optional[Path], ): - assert registry_config is not None, "SqlRegistry needs a valid registry_config" + assert registry_config is not None and isinstance( + registry_config, SqlRegistryConfig + ), "SqlRegistry needs a valid registry_config" self.engine: Engine = create_engine( registry_config.path, **registry_config.sqlalchemy_config_kwargs ) + self.thread_pool_executor_worker_count = ( + registry_config.thread_pool_executor_worker_count + ) metadata.create_all(self.engine) + self.purge_feast_metadata = registry_config.purge_feast_metadata + # Sync feast_metadata to projects table + # when purge_feast_metadata is set to True, Delete data from + # feast_metadata table and list_project_metadata will not return any data + self._sync_feast_metadata_to_projects_table() + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) super().__init__( project=project, cache_ttl_seconds=registry_config.cache_ttl_seconds, cache_mode=registry_config.cache_mode, ) + def _sync_feast_metadata_to_projects_table(self): + feast_metadata_projects: set = [] + projects_set: set = [] + with self.engine.begin() as conn: + stmt = select(feast_metadata).where( + feast_metadata.c.metadata_key == FeastMetadataKeys.PROJECT_UUID.value + ) + rows = conn.execute(stmt).all() + for row in rows: + feast_metadata_projects.append(row._mapping["project_id"]) + + if len(feast_metadata_projects) > 0: + with self.engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + for row in rows: + projects_set.append(row._mapping["project_id"]) + + # Find object in feast_metadata_projects but not in projects + projects_to_sync = set(feast_metadata_projects) - set(projects_set) + for project_name in projects_to_sync: + self.apply_project(Project(name=project_name), commit=True) + + if self.purge_feast_metadata: + with self.engine.begin() as conn: + for project_name in feast_metadata_projects: + stmt = delete(feast_metadata).where( + feast_metadata.c.project_id == project_name + ) + conn.execute(stmt) + def teardown(self): for t in { entities, @@ -673,8 +737,27 @@ def get_user_metadata( def proto(self) -> RegistryProto: r = RegistryProto() last_updated_timestamps = [] - projects = self._get_all_projects() - for project in projects: + + def process_project(project: Project): + nonlocal r, last_updated_timestamps + project_name = project.name + last_updated_timestamp = project.last_updated_timestamp + + try: + cached_project = self.get_project(project_name, True) + except ProjectObjectNotFoundException: + cached_project = None + + allow_cache = False + + if cached_project is not None: + allow_cache = ( + last_updated_timestamp <= cached_project.last_updated_timestamp + ) + + r.projects.extend([project.to_proto()]) + last_updated_timestamps.append(last_updated_timestamp) + for lister, registry_proto_field in [ (self.list_entities, r.entities), (self.list_feature_views, r.feature_views), @@ -684,23 +767,31 @@ def proto(self) -> RegistryProto: (self.list_feature_services, r.feature_services), (self.list_saved_datasets, r.saved_datasets), (self.list_validation_references, r.validation_references), - (self.list_project_metadata, r.project_metadata), (self.list_permissions, r.permissions), ]: - objs: List[Any] = lister(project) # type: ignore + objs: List[Any] = lister(project_name, allow_cache) # type: ignore if objs: obj_protos = [obj.to_proto() for obj in objs] for obj_proto in obj_protos: if "spec" in obj_proto.DESCRIPTOR.fields_by_name: - obj_proto.spec.project = project + obj_proto.spec.project = project_name else: - obj_proto.project = project + obj_proto.project = project_name registry_proto_field.extend(obj_protos) # This is suuuper jank. Because of https://github.com/feast-dev/feast/issues/2783, # the registry proto only has a single infra field, which we're currently setting as the "last" project. - r.infra.CopyFrom(self.get_infra(project).to_proto()) - last_updated_timestamps.append(self._get_last_updated_metadata(project)) + r.infra.CopyFrom(self.get_infra(project_name).to_proto()) + + projects_list = self.list_projects(allow_cache=False) + if self.thread_pool_executor_worker_count == 0: + for project in projects_list: + process_project(project) + else: + with ThreadPoolExecutor( + max_workers=self.thread_pool_executor_worker_count + ) as executor: + executor.map(process_project, projects_list) if last_updated_timestamps: r.last_updated.FromDatetime(max(last_updated_timestamps)) @@ -711,6 +802,17 @@ def commit(self): # This method is a no-op since we're always writing values eagerly to the db. pass + def _initialize_project_if_not_exists(self, project_name: str): + try: + self.get_project(project_name, allow_cache=True) + return + except ProjectObjectNotFoundException: + try: + self.get_project(project_name, allow_cache=False) + return + except ProjectObjectNotFoundException: + self.apply_project(Project(name=project_name), commit=True) + def _apply_object( self, table: Table, @@ -720,8 +822,11 @@ def _apply_object( proto_field_name: str, name: Optional[str] = None, ): - self._maybe_init_project_metadata(project) - + if not self.purge_feast_metadata: + self._maybe_init_project_metadata(project) + # Initialize project is necessary because FeatureStore object can apply objects individually without "feast apply" cli option + if not isinstance(obj, Project): + self._initialize_project_if_not_exists(project_name=project) name = name or (obj.name if hasattr(obj, "name") else None) assert name, f"name needs to be provided for {obj}" @@ -742,12 +847,15 @@ def _apply_object( "feature_view_proto", "feature_service_proto", "permission_proto", + "project_proto", ]: deserialized_proto = self.deserialize_registry_values( row._mapping[proto_field_name], type(obj) ) obj.created_timestamp = ( - deserialized_proto.meta.created_timestamp.ToDatetime() + deserialized_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) ) if isinstance(obj, (FeatureView, StreamFeatureView)): obj.update_materialization_intervals( @@ -789,7 +897,12 @@ def _apply_object( ) conn.execute(insert_stmt) - self._set_last_updated_metadata(update_datetime, project) + if not isinstance(obj, Project): + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(update_datetime, project) def _maybe_init_project_metadata(self, project): # Initialize project metadata if needed @@ -827,7 +940,11 @@ def _delete_object( rows = conn.execute(stmt) if rows.rowcount < 1 and not_found_exception: raise not_found_exception(name, project) - self._set_last_updated_metadata(_utc_now(), project) + self.apply_project( + self.get_project(name=project, allow_cache=False), commit=True + ) + if not self.purge_feast_metadata: + self._set_last_updated_metadata(_utc_now(), project) return rows.rowcount @@ -842,8 +959,6 @@ def _get_object( proto_field_name: str, not_found_exception: Optional[Callable], ): - self._maybe_init_project_metadata(project) - with self.engine.begin() as conn: stmt = select(table).where( getattr(table.c, id_field_name) == name, table.c.project_id == project @@ -866,7 +981,6 @@ def _list_objects( proto_field_name: str, tags: Optional[dict[str, str]] = None, ): - self._maybe_init_project_metadata(project) with self.engine.begin() as conn: stmt = select(table).where(table.c.project_id == project) rows = conn.execute(stmt).all() @@ -929,24 +1043,6 @@ def _get_last_updated_metadata(self, project: str): return datetime.fromtimestamp(update_time, tz=timezone.utc) - def _get_all_projects(self) -> Set[str]: - projects = set() - with self.engine.begin() as conn: - for table in { - entities, - data_sources, - feature_views, - on_demand_feature_views, - stream_feature_views, - permissions, - }: - stmt = select(table) - rows = conn.execute(stmt).all() - for row in rows: - projects.add(row._mapping["project_id"]) - - return projects - def _get_permission(self, name: str, project: str) -> Permission: return self._get_object( table=permissions, @@ -987,3 +1083,72 @@ def delete_permission(self, name: str, project: str, commit: bool = True): rows = conn.execute(stmt) if rows.rowcount < 1: raise PermissionNotFoundException(name, project) + + def _list_projects( + self, + tags: Optional[dict[str, str]], + ) -> List[Project]: + with self.engine.begin() as conn: + stmt = select(projects) + rows = conn.execute(stmt).all() + if rows: + objects = [] + for row in rows: + obj = Project.from_proto( + ProjectProto.FromString(row._mapping["project_proto"]) + ) + if utils.has_all_tags(obj.tags, tags): + objects.append(obj) + return objects + return [] + + def _get_project( + self, + name: str, + ) -> Project: + return self._get_object( + table=projects, + name=name, + project=name, + proto_class=ProjectProto, + python_class=Project, + id_field_name="project_name", + proto_field_name="project_proto", + not_found_exception=ProjectObjectNotFoundException, + ) + + def apply_project( + self, + project: Project, + commit: bool = True, + ): + return self._apply_object( + projects, project.name, "project_name", project, "project_proto" + ) + + def delete_project( + self, + name: str, + commit: bool = True, + ): + project = self.get_project(name, allow_cache=False) + if project: + with self.engine.begin() as conn: + for t in { + managed_infra, + saved_datasets, + validation_references, + feature_services, + feature_views, + on_demand_feature_views, + stream_feature_views, + data_sources, + entities, + permissions, + projects, + }: + stmt = delete(t).where(t.c.project_id == name) + conn.execute(stmt) + return + + raise ProjectNotFoundException(name) diff --git a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql index 021d175b4e..fc13332e4b 100644 --- a/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql +++ b/sdk/python/feast/infra/utils/snowflake/registry/snowflake_table_creation.sql @@ -1,3 +1,11 @@ +CREATE TABLE IF NOT EXISTS REGISTRY_PATH."PROJECTS" ( + project_id VARCHAR, + project_name VARCHAR NOT NULL, + last_updated_timestamp TIMESTAMP_LTZ NOT NULL, + project_proto BINARY NOT NULL, + PRIMARY KEY (project_id) +); + CREATE TABLE IF NOT EXISTS REGISTRY_PATH."DATA_SOURCES" ( data_source_name VARCHAR, project_id VARCHAR, diff --git a/sdk/python/feast/permissions/permission.py b/sdk/python/feast/permissions/permission.py index 1117a3ee82..9046abbfa9 100644 --- a/sdk/python/feast/permissions/permission.py +++ b/sdk/python/feast/permissions/permission.py @@ -256,6 +256,7 @@ def get_type_class_from_permission_type(permission_type: str): _PERMISSION_TYPES = { + "PROJECT": "feast.project.Project", "FEATURE_VIEW": "feast.feature_view.FeatureView", "ON_DEMAND_FEATURE_VIEW": "feast.on_demand_feature_view.OnDemandFeatureView", "BATCH_FEATURE_VIEW": "feast.batch_feature_view.BatchFeatureView", diff --git a/sdk/python/feast/permissions/security_manager.py b/sdk/python/feast/permissions/security_manager.py index c00a3d8853..cb8cafd5b9 100644 --- a/sdk/python/feast/permissions/security_manager.py +++ b/sdk/python/feast/permissions/security_manager.py @@ -10,6 +10,7 @@ from feast.permissions.enforcer import enforce_policy from feast.permissions.permission import Permission from feast.permissions.user import User +from feast.project import Project logger = logging.getLogger(__name__) @@ -88,7 +89,9 @@ def assert_permissions( def assert_permissions_to_update( resource: FeastObject, - getter: Callable[[str, str, bool], FeastObject], + getter: Union[ + Callable[[str, str, bool], FeastObject], Callable[[str, bool], FeastObject] + ], project: str, allow_cache: bool = True, ) -> FeastObject: @@ -117,11 +120,17 @@ def assert_permissions_to_update( actions = [AuthzedAction.DESCRIBE, AuthzedAction.UPDATE] try: - existing_resource = getter( - name=resource.name, - project=project, - allow_cache=allow_cache, - ) # type: ignore[call-arg] + if isinstance(resource, Project): + existing_resource = getter( + name=resource.name, + allow_cache=allow_cache, + ) # type: ignore[call-arg] + else: + existing_resource = getter( + name=resource.name, + project=project, + allow_cache=allow_cache, + ) # type: ignore[call-arg] assert_permissions(resource=existing_resource, actions=actions) except FeastObjectNotFoundException: actions = [AuthzedAction.CREATE] diff --git a/sdk/python/feast/project.py b/sdk/python/feast/project.py new file mode 100644 index 0000000000..d9ec45dcc9 --- /dev/null +++ b/sdk/python/feast/project.py @@ -0,0 +1,175 @@ +# Copyright 2019 The Feast Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from datetime import datetime, timezone +from typing import Dict, Optional + +from google.protobuf.json_format import MessageToJson +from typeguard import typechecked + +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Project_pb2 import ProjectMeta as ProjectMetaProto +from feast.protos.feast.core.Project_pb2 import ProjectSpec as ProjectSpecProto +from feast.utils import _utc_now + + +@typechecked +class Project: + """ + Project is a collection of Feast Objects. Projects provide complete isolation of + feature stores at the infrastructure level. + + Attributes: + name: The unique name of the project. + description: A human-readable description. + tags: A dictionary of key-value pairs to store arbitrary metadata. + owner: The owner of the project, typically the email of the primary maintainer. + created_timestamp: The time when the entity was created. + last_updated_timestamp: The time when the entity was last updated. + """ + + name: str + description: str + tags: Dict[str, str] + owner: str + created_timestamp: datetime + last_updated_timestamp: datetime + + def __init__( + self, + *, + name: str, + description: str = "", + tags: Optional[Dict[str, str]] = None, + owner: str = "", + created_timestamp: Optional[datetime] = None, + last_updated_timestamp: Optional[datetime] = None, + ): + """ + Creates Project object. + + Args: + name: The unique name of the project. + description (optional): A human-readable description. + tags (optional): A dictionary of key-value pairs to store arbitrary metadata. + owner (optional): The owner of the project, typically the email of the primary maintainer. + created_timestamp (optional): The time when the project was created. Defaults to + last_updated_timestamp (optional): The time when the project was last updated. + + Raises: + ValueError: Parameters are specified incorrectly. + """ + self.name = name + self.description = description + self.tags = tags if tags is not None else {} + self.owner = owner + updated_time = _utc_now() + self.created_timestamp = created_timestamp or updated_time + self.last_updated_timestamp = last_updated_timestamp or updated_time + + def __hash__(self) -> int: + return hash((self.name)) + + def __eq__(self, other): + if not isinstance(other, Project): + raise TypeError("Comparisons should only involve Project class objects.") + + if ( + self.name != other.name + or self.description != other.description + or self.tags != other.tags + or self.owner != other.owner + or self.created_timestamp != other.created_timestamp + or self.last_updated_timestamp != other.last_updated_timestamp + ): + return False + + return True + + def __str__(self): + return str(MessageToJson(self.to_proto())) + + def __lt__(self, other): + return self.name < other.name + + def is_valid(self): + """ + Validates the state of this project locally. + + Raises: + ValueError: The project does not have a name or does not have a type. + """ + if not self.name: + raise ValueError("The project does not have a name.") + + from feast.repo_operations import is_valid_name + + if not is_valid_name(self.name): + raise ValueError( + f"Project name, {self.name}, should only have " + f"alphanumerical values and underscores but not start with an underscore." + ) + + @classmethod + def from_proto(cls, project_proto: ProjectProto): + """ + Creates a project from a protobuf representation of an project. + + Args: + entity_proto: A protobuf representation of an project. + + Returns: + An Entity object based on the entity protobuf. + """ + project = cls( + name=project_proto.spec.name, + description=project_proto.spec.description, + tags=dict(project_proto.spec.tags), + owner=project_proto.spec.owner, + ) + if project_proto.meta.HasField("created_timestamp"): + project.created_timestamp = ( + project_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + if project_proto.meta.HasField("last_updated_timestamp"): + project.last_updated_timestamp = ( + project_proto.meta.last_updated_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ) + ) + + return project + + def to_proto(self) -> ProjectProto: + """ + Converts an project object to its protobuf representation. + + Returns: + An ProjectProto protobuf. + """ + meta = ProjectMetaProto() + if self.created_timestamp: + meta.created_timestamp.FromDatetime(self.created_timestamp) + if self.last_updated_timestamp: + meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp) + + spec = ProjectSpecProto( + name=self.name, + description=self.description, + tags=self.tags, + owner=self.owner, + ) + + return ProjectProto(spec=spec, meta=meta) diff --git a/sdk/python/feast/registry_server.py b/sdk/python/feast/registry_server.py index 40475aa580..2661f25882 100644 --- a/sdk/python/feast/registry_server.py +++ b/sdk/python/feast/registry_server.py @@ -32,6 +32,7 @@ init_security_manager, str_to_auth_manager_type, ) +from feast.project import Project from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc from feast.saved_dataset import SavedDataset, ValidationReference from feast.stream_feature_view import StreamFeatureView @@ -624,6 +625,58 @@ def DeletePermission( ) return Empty() + def ApplyProject(self, request: RegistryServer_pb2.ApplyProjectRequest, context): + project = cast( + Project, + assert_permissions_to_update( + resource=Project.from_proto(request.project), + getter=self.proxied_registry.get_project, + project=Project.from_proto(request.project).name, + ), + ) + self.proxied_registry.apply_project( + project=project, + commit=request.commit, + ) + return Empty() + + def GetProject(self, request: RegistryServer_pb2.GetProjectRequest, context): + project = self.proxied_registry.get_project( + name=request.name, allow_cache=request.allow_cache + ) + assert_permissions( + resource=project, + actions=[AuthzedAction.DESCRIBE], + ) + return project.to_proto() + + def ListProjects(self, request: RegistryServer_pb2.ListProjectsRequest, context): + return RegistryServer_pb2.ListProjectsResponse( + projects=[ + project.to_proto() + for project in permitted_resources( + resources=cast( + list[FeastObject], + self.proxied_registry.list_projects( + allow_cache=request.allow_cache + ), + ), + actions=AuthzedAction.DESCRIBE, + ) + ] + ) + + def DeleteProject(self, request: RegistryServer_pb2.DeleteProjectRequest, context): + assert_permissions( + resource=self.proxied_registry.get_project( + name=request.name, + ), + actions=[AuthzedAction.DELETE], + ) + + self.proxied_registry.delete_project(name=request.name, commit=request.commit) + return Empty() + def Commit(self, request, context): self.proxied_registry.commit() return Empty() diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 52372f2987..bf0bde6fcb 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -9,6 +9,7 @@ BaseModel, ConfigDict, Field, + StrictBool, StrictInt, StrictStr, ValidationError, @@ -132,11 +133,10 @@ class RegistryConfig(FeastBaseModel): s3_additional_kwargs: Optional[Dict[str, str]] = None """ Dict[str, str]: Extra arguments to pass to boto3 when writing the registry file to S3. """ - sqlalchemy_config_kwargs: Dict[str, Any] = {} - """ Dict[str, Any]: Extra arguments to pass to SQLAlchemy.create_engine. """ - - cache_mode: StrictStr = "sync" - """ str: Cache mode type, Possible options are sync and thread(asynchronous caching using threading library)""" + purge_feast_metadata: StrictBool = False + """ bool: Stops using feast_metadata table and delete data from feast_metadata table. + Once this is set to True, it cannot be reverted back to False. Reverting back to False will + only reset the project but not all the projects""" @field_validator("path") def validate_path(cls, path: str, values: ValidationInfo) -> str: diff --git a/sdk/python/feast/repo_contents.py b/sdk/python/feast/repo_contents.py index 9893d5be4e..d65f6ac7bb 100644 --- a/sdk/python/feast/repo_contents.py +++ b/sdk/python/feast/repo_contents.py @@ -19,6 +19,7 @@ from feast.feature_view import FeatureView from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto from feast.stream_feature_view import StreamFeatureView @@ -28,6 +29,7 @@ class RepoContents(NamedTuple): Represents the objects in a Feast feature repo. """ + projects: List[Project] data_sources: List[DataSource] feature_views: List[FeatureView] on_demand_feature_views: List[OnDemandFeatureView] @@ -38,6 +40,7 @@ class RepoContents(NamedTuple): def to_registry_proto(self) -> RegistryProto: registry_proto = RegistryProto() + registry_proto.projects.extend([e.to_proto() for e in self.projects]) registry_proto.data_sources.extend([e.to_proto() for e in self.data_sources]) registry_proto.entities.extend([e.to_proto() for e in self.entities]) registry_proto.feature_views.extend( diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index cb27568957..6629768375 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -1,6 +1,7 @@ import base64 import importlib import json +import logging import os import random import re @@ -24,14 +25,18 @@ from feast.feature_store import FeatureStore from feast.feature_view import DUMMY_ENTITY, FeatureView from feast.file_utils import replace_str_in_file +from feast.infra.registry.base_registry import BaseRegistry from feast.infra.registry.registry import FEAST_OBJECT_TYPES, FeastObjectType, Registry from feast.names import adjectives, animals from feast.on_demand_feature_view import OnDemandFeatureView from feast.permissions.permission import Permission +from feast.project import Project from feast.repo_config import RepoConfig from feast.repo_contents import RepoContents from feast.stream_feature_view import StreamFeatureView +logger = logging.getLogger(__name__) + def py_path_to_module(path: Path) -> str: return ( @@ -115,6 +120,7 @@ def parse_repo(repo_root: Path) -> RepoContents: not result in duplicates, but defining two equal objects will. """ res = RepoContents( + projects=[], data_sources=[], entities=[], feature_views=[], @@ -207,6 +213,8 @@ def parse_repo(repo_root: Path) -> RepoContents: (obj is p) for p in res.permissions ): res.permissions.append(obj) + elif isinstance(obj, Project) and not any((obj is p) for p in res.projects): + res.projects.append(obj) res.entities.append(DUMMY_ENTITY) return res @@ -214,33 +222,57 @@ def parse_repo(repo_root: Path) -> RepoContents: def plan(repo_config: RepoConfig, repo_path: Path, skip_source_validation: bool): os.chdir(repo_path) - project, registry, repo, store = _prepare_registry_and_repo(repo_config, repo_path) - - if not skip_source_validation: - provider = store._get_provider() - data_sources = [t.batch_source for t in repo.feature_views] - # Make sure the data source used by this feature view is supported by Feast - for data_source in data_sources: - provider.validate_data_source(store.config, data_source) + repo = _get_repo_contents(repo_path, repo_config.project) + for project in repo.projects: + repo_config.project = project.name + store, registry = _get_store_and_registry(repo_config) + # TODO: When we support multiple projects in a single repo, we should filter repo contents by project + if not skip_source_validation: + provider = store._get_provider() + data_sources = [t.batch_source for t in repo.feature_views] + # Make sure the data source used by this feature view is supported by Feast + for data_source in data_sources: + provider.validate_data_source(store.config, data_source) + + registry_diff, infra_diff, _ = store.plan(repo) + click.echo(registry_diff.to_string()) + click.echo(infra_diff.to_string()) - registry_diff, infra_diff, _ = store.plan(repo) - click.echo(registry_diff.to_string()) - click.echo(infra_diff.to_string()) +def _get_repo_contents(repo_path, project_name: Optional[str] = None): + sys.dont_write_bytecode = True + repo = parse_repo(repo_path) -def _prepare_registry_and_repo(repo_config, repo_path): - store = FeatureStore(config=repo_config) - project = store.project - if not is_valid_name(project): + if len(repo.projects) < 1: + if project_name: + print( + f"No project found in the repository. Using project name {project_name} defined in feature_store.yaml" + ) + repo.projects.append(Project(name=project_name)) + else: + print( + "No project found in the repository. Either define Project in repository or define a project in feature_store.yaml" + ) + sys.exit(1) + elif len(repo.projects) == 1: + if repo.projects[0].name != project_name: + print( + "Project object name should match with the project name defined in feature_store.yaml" + ) + sys.exit(1) + else: print( - f"{project} is not valid. Project name should only have " - f"alphanumerical values and underscores but not start with an underscore." + "Multiple projects found in the repository. Currently no support for multiple projects" ) sys.exit(1) + + return repo + + +def _get_store_and_registry(repo_config): + store = FeatureStore(config=repo_config) registry = store.registry - sys.dont_write_bytecode = True - repo = parse_repo(repo_path) - return project, registry, repo, store + return store, registry def extract_objects_for_apply_delete(project, registry, repo): @@ -289,8 +321,8 @@ def extract_objects_for_apply_delete(project, registry, repo): def apply_total_with_repo_instance( store: FeatureStore, - project: str, - registry: Registry, + project_name: str, + registry: BaseRegistry, repo: RepoContents, skip_source_validation: bool, ): @@ -307,7 +339,7 @@ def apply_total_with_repo_instance( all_to_delete, views_to_keep, views_to_delete, - ) = extract_objects_for_apply_delete(project, registry, repo) + ) = extract_objects_for_apply_delete(project_name, registry, repo) if store._should_use_plan(): registry_diff, infra_diff, new_infra = store.plan(repo) @@ -357,10 +389,21 @@ def create_feature_store( def apply_total(repo_config: RepoConfig, repo_path: Path, skip_source_validation: bool): os.chdir(repo_path) - project, registry, repo, store = _prepare_registry_and_repo(repo_config, repo_path) - apply_total_with_repo_instance( - store, project, registry, repo, skip_source_validation - ) + repo = _get_repo_contents(repo_path, repo_config.project) + for project in repo.projects: + repo_config.project = project.name + store, registry = _get_store_and_registry(repo_config) + if not is_valid_name(project.name): + print( + f"{project.name} is not valid. Project name should only have " + f"alphanumerical values and underscores but not start with an underscore." + ) + sys.exit(1) + # TODO: When we support multiple projects in a single repo, we should filter repo contents by project. Currently there is no way to associate Feast objects to project. + print(f"Applying changes for project {project.name}") + apply_total_with_repo_instance( + store, project.name, registry, repo, skip_source_validation + ) def teardown(repo_config: RepoConfig, repo_path: Optional[str]): diff --git a/sdk/python/feast/templates/local/bootstrap.py b/sdk/python/feast/templates/local/bootstrap.py index ee2847c19c..e2c1efdbc4 100644 --- a/sdk/python/feast/templates/local/bootstrap.py +++ b/sdk/python/feast/templates/local/bootstrap.py @@ -10,6 +10,7 @@ def bootstrap(): from feast.driver_test_data import create_driver_hourly_stats_df repo_path = pathlib.Path(__file__).parent.absolute() / "feature_repo" + project_name = pathlib.Path(__file__).parent.absolute().name data_path = repo_path / "data" data_path.mkdir(exist_ok=True) @@ -23,6 +24,7 @@ def bootstrap(): driver_df.to_parquet(path=str(driver_stats_path), allow_truncated_timestamps=True) example_py_file = repo_path / "example_repo.py" + replace_str_in_file(example_py_file, "%PROJECT_NAME%", str(project_name)) replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path)) replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path)) diff --git a/sdk/python/feast/templates/local/feature_repo/example_repo.py b/sdk/python/feast/templates/local/feature_repo/example_repo.py index debe9d45e9..e2fd0a891c 100644 --- a/sdk/python/feast/templates/local/feature_repo/example_repo.py +++ b/sdk/python/feast/templates/local/feature_repo/example_repo.py @@ -10,6 +10,7 @@ FeatureView, Field, FileSource, + Project, PushSource, RequestSource, ) @@ -18,6 +19,9 @@ from feast.on_demand_feature_view import on_demand_feature_view from feast.types import Float32, Float64, Int64 +# Define a project for the feature repo +project = Project(name="%PROJECT_NAME%", description="A project for driver statistics") + # Define an entity for the driver. You can think of an entity as a primary key used to # fetch features. driver = Entity(name="driver", join_keys=["driver_id"]) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 5e70da074c..a9bb9ba9c4 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -35,8 +35,8 @@ create_basic_driver_dataset, # noqa: E402 create_document_dataset, ) -from tests.integration.feature_repos.integration_test_repo_config import ( - IntegrationTestRepoConfig, # noqa: E402 +from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402 + IntegrationTestRepoConfig, ) from tests.integration.feature_repos.repo_configuration import ( # noqa: E402 AVAILABLE_OFFLINE_STORES, @@ -48,8 +48,8 @@ construct_universal_feature_views, construct_universal_test_data, ) -from tests.integration.feature_repos.universal.data_sources.file import ( - FileDataSourceCreator, # noqa: E402 +from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402 + FileDataSourceCreator, ) from tests.integration.feature_repos.universal.entities import ( # noqa: E402 customer, @@ -451,15 +451,20 @@ def is_integration_test(all_markers_from_module): @pytest.fixture( scope="module", params=[ - dedent(""" + dedent( + """ auth: type: no_auth - """), - dedent(""" + """ + ), + dedent( + """ auth: type: kubernetes - """), - dedent(""" + """ + ), + dedent( + """ auth: type: oidc client_id: feast-integration-client @@ -467,7 +472,8 @@ def is_integration_test(all_markers_from_module): username: reader_writer password: password auth_discovery_url: KEYCLOAK_URL_PLACE_HOLDER/realms/master/.well-known/openid-configuration - """), + """ + ), ], ) def auth_config(request, is_integration_test): diff --git a/sdk/python/tests/example_repos/example_feature_repo_with_project_1.py b/sdk/python/tests/example_repos/example_feature_repo_with_project_1.py new file mode 100644 index 0000000000..ad04d7ae66 --- /dev/null +++ b/sdk/python/tests/example_repos/example_feature_repo_with_project_1.py @@ -0,0 +1,151 @@ +from datetime import timedelta + +import pandas as pd + +from feast import Entity, FeatureService, FeatureView, Field, FileSource, PushSource +from feast.on_demand_feature_view import on_demand_feature_view +from feast.project import Project +from feast.types import Array, Float32, Int64, String +from tests.integration.feature_repos.universal.feature_views import TAGS + +# Note that file source paths are not validated, so there doesn't actually need to be any data +# at the paths for these file sources. Since these paths are effectively fake, this example +# feature repo should not be used for historical retrieval. +project = Project( + name="test_universal_cli_with_project_4567", + description="test_universal_cli_with_project_4567 description", + tags={"application": "integration"}, + owner="test@test.com", +) + +driver_locations_source = FileSource( + path="data/driver_locations.parquet", + timestamp_field="event_timestamp", + created_timestamp_column="created_timestamp", +) + +customer_profile_source = FileSource( + name="customer_profile_source", + path="data/customer_profiles.parquet", + timestamp_field="event_timestamp", +) + +customer_driver_combined_source = FileSource( + path="data/customer_driver_combined.parquet", + timestamp_field="event_timestamp", +) + +driver_locations_push_source = PushSource( + name="driver_locations_push", + batch_source=driver_locations_source, +) + +rag_documents_source = FileSource( + name="rag_documents_source", + path="data/rag_documents.parquet", + timestamp_field="event_timestamp", +) + +driver = Entity( + name="driver", # The name is derived from this argument, not object name. + join_keys=["driver_id"], + description="driver id", + tags=TAGS, +) + +customer = Entity( + name="customer", # The name is derived from this argument, not object name. + join_keys=["customer_id"], + tags=TAGS, +) + +item = Entity( + name="item_id", # The name is derived from this argument, not object name. + join_keys=["item_id"], +) + +driver_locations = FeatureView( + name="driver_locations", + entities=[driver], + ttl=timedelta(days=1), + schema=[ + Field(name="lat", dtype=Float32), + Field(name="lon", dtype=String), + Field(name="driver_id", dtype=Int64), + ], + online=True, + source=driver_locations_source, + tags={}, +) + +pushed_driver_locations = FeatureView( + name="pushed_driver_locations", + entities=[driver], + ttl=timedelta(days=1), + schema=[ + Field(name="driver_lat", dtype=Float32), + Field(name="driver_long", dtype=String), + Field(name="driver_id", dtype=Int64), + ], + online=True, + source=driver_locations_push_source, + tags={}, +) + +customer_profile = FeatureView( + name="customer_profile", + entities=[customer], + ttl=timedelta(days=1), + schema=[ + Field(name="avg_orders_day", dtype=Float32), + Field(name="name", dtype=String), + Field(name="age", dtype=Int64), + Field(name="customer_id", dtype=String), + ], + online=True, + source=customer_profile_source, + tags={}, +) + +customer_driver_combined = FeatureView( + name="customer_driver_combined", + entities=[customer, driver], + ttl=timedelta(days=1), + schema=[ + Field(name="trips", dtype=Int64), + Field(name="driver_id", dtype=Int64), + Field(name="customer_id", dtype=String), + ], + online=True, + source=customer_driver_combined_source, + tags={}, +) + +document_embeddings = FeatureView( + name="document_embeddings", + entities=[item], + schema=[ + Field(name="Embeddings", dtype=Array(Float32)), + Field(name="item_id", dtype=String), + ], + source=rag_documents_source, + ttl=timedelta(hours=24), +) + + +@on_demand_feature_view( + sources=[customer_profile], + schema=[Field(name="on_demand_age", dtype=Int64)], + mode="pandas", +) +def customer_profile_pandas_odfv(inputs: pd.DataFrame) -> pd.DataFrame: + outputs = pd.DataFrame() + outputs["on_demand_age"] = inputs["age"] + 1 + return outputs + + +all_drivers_feature_service = FeatureService( + name="driver_locations_service", + features=[driver_locations], + tags=TAGS, +) diff --git a/sdk/python/tests/integration/online_store/test_remote_online_store.py b/sdk/python/tests/integration/online_store/test_remote_online_store.py index f74fb14a86..d8c92077db 100644 --- a/sdk/python/tests/integration/online_store/test_remote_online_store.py +++ b/sdk/python/tests/integration/online_store/test_remote_online_store.py @@ -187,9 +187,6 @@ def _create_remote_client_feature_store( auth_config=auth_config, ) - result = runner.run(["--chdir", repo_path, "apply"], cwd=temp_dir) - assert result.returncode == 0 - return FeatureStore(repo_path=repo_path) diff --git a/sdk/python/tests/integration/registration/test_universal_cli.py b/sdk/python/tests/integration/registration/test_universal_cli.py index 5c238da24d..735f71407f 100644 --- a/sdk/python/tests/integration/registration/test_universal_cli.py +++ b/sdk/python/tests/integration/registration/test_universal_cli.py @@ -52,7 +52,9 @@ def test_universal_cli(): for key, value in registry_dict.items() } - # entity & feature view list commands should succeed + # project, entity & feature view list commands should succeed + result = runner.run(["projects", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run(["entities", "list"], cwd=repo_path) assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run(["feature-views", "list"], cwd=repo_path) @@ -71,6 +73,10 @@ def test_universal_cli(): assertpy.assert_that(result.returncode).is_equal_to(0) # entity & feature view describe commands should succeed when objects exist + result = runner.run(["projects", "describe", project], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "current_project"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run(["entities", "describe", "driver"], cwd=repo_path) assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run( @@ -89,8 +95,132 @@ def test_universal_cli(): ) assertpy.assert_that(result.returncode).is_equal_to(0) assertpy.assert_that(fs.list_data_sources()).is_length(5) + assertpy.assert_that(fs.list_projects()).is_length(1) # entity & feature view describe commands should fail when objects don't exist + result = runner.run(["projects", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + result = runner.run(["entities", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + result = runner.run(["feature-views", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + result = runner.run(["feature-services", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + result = runner.run(["data-sources", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + result = runner.run(["permissions", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) + + # Doing another apply should be a no op, and should not cause errors + result = runner.run(["apply"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + basic_rw_test( + FeatureStore(repo_path=str(repo_path), config=None), + view_name="driver_locations", + ) + + # Confirm that registry contents have not changed. + registry_dict = fs.registry.to_dict(project=project) + assertpy.assert_that(registry_specs).is_equal_to( + { + key: [fco["spec"] if "spec" in fco else fco for fco in value] + for key, value in registry_dict.items() + } + ) + + result = runner.run(["teardown"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + finally: + runner.run(["teardown"], cwd=repo_path) + + +@pytest.mark.integration +def test_universal_cli_with_project(): + project = "test_universal_cli_with_project_4567" + runner = CliRunner() + + with tempfile.TemporaryDirectory() as repo_dir_name: + try: + repo_path = Path(repo_dir_name) + feature_store_yaml = make_feature_store_yaml( + project, + repo_path, + FileDataSourceCreator("project"), + "local", + {"type": "sqlite"}, + ) + + repo_config = repo_path / "feature_store.yaml" + + repo_config.write_text(dedent(feature_store_yaml)) + + repo_example = repo_path / "example.py" + repo_example.write_text( + get_example_repo("example_feature_repo_with_project_1.py") + ) + result = runner.run(["apply"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + + # Store registry contents, to be compared later. + fs = FeatureStore(repo_path=str(repo_path)) + registry_dict = fs.registry.to_dict(project=project) + # Save only the specs, not the metadata. + registry_specs = { + key: [fco["spec"] if "spec" in fco else fco for fco in value] + for key, value in registry_dict.items() + } + + # entity & feature view list commands should succeed + result = runner.run(["projects", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["entities", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["feature-views", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["feature-services", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["data-sources", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["permissions", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + + # entity & feature view describe commands should succeed when objects exist + result = runner.run(["projects", "describe", project], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "current_project"], cwd=repo_path) + print(result.returncode) + print("result: ", result) + print("result.stdout: ", result.stdout) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["entities", "describe", "driver"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run( + ["feature-views", "describe", "driver_locations"], cwd=repo_path + ) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run( + ["feature-services", "describe", "driver_locations_service"], + cwd=repo_path, + ) + assertpy.assert_that(result.returncode).is_equal_to(0) + assertpy.assert_that(fs.list_feature_views()).is_length(5) + result = runner.run( + ["data-sources", "describe", "customer_profile_source"], + cwd=repo_path, + ) + assertpy.assert_that(result.returncode).is_equal_to(0) + assertpy.assert_that(fs.list_data_sources()).is_length(5) + + projects_list = fs.list_projects() + assertpy.assert_that(projects_list).is_length(1) + assertpy.assert_that(projects_list[0].name).is_equal_to(project) + assertpy.assert_that(projects_list[0].description).is_equal_to( + "test_universal_cli_with_project_4567 description" + ) + + # entity & feature view describe commands should fail when objects don't exist + result = runner.run(["projects", "describe", "foo"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(1) result = runner.run(["entities", "describe", "foo"], cwd=repo_path) assertpy.assert_that(result.returncode).is_equal_to(1) result = runner.run(["feature-views", "describe", "foo"], cwd=repo_path) @@ -161,6 +291,12 @@ def test_odfv_apply() -> None: assertpy.assert_that(result.returncode).is_equal_to(0) # entity & feature view list commands should succeed + result = runner.run(["projects", "describe", project], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "current_project"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run(["entities", "list"], cwd=repo_path) assertpy.assert_that(result.returncode).is_equal_to(0) result = runner.run(["on-demand-feature-views", "list"], cwd=repo_path) @@ -192,7 +328,14 @@ def test_nullable_online_store(test_nullable_online_store) -> None: repo_example = repo_path / "example.py" repo_example.write_text(get_example_repo("empty_feature_repo.py")) + result = runner.run(["apply"], cwd=repo_path) assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "describe", project], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "current_project"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) + result = runner.run(["projects", "list"], cwd=repo_path) + assertpy.assert_that(result.returncode).is_equal_to(0) finally: runner.run(["teardown"], cwd=repo_path) diff --git a/sdk/python/tests/integration/registration/test_universal_registry.py b/sdk/python/tests/integration/registration/test_universal_registry.py index c528cee4a8..20f1f5ef0a 100644 --- a/sdk/python/tests/integration/registration/test_universal_registry.py +++ b/sdk/python/tests/integration/registration/test_universal_registry.py @@ -38,11 +38,12 @@ from feast.infra.online_stores.sqlite import SqliteTable from feast.infra.registry.registry import Registry from feast.infra.registry.remote import RemoteRegistry, RemoteRegistryConfig -from feast.infra.registry.sql import SqlRegistry +from feast.infra.registry.sql import SqlRegistry, SqlRegistryConfig from feast.on_demand_feature_view import on_demand_feature_view from feast.permissions.action import AuthzedAction from feast.permissions.permission import Permission from feast.permissions.policy import RoleBasedPolicy +from feast.project import Project from feast.protos.feast.registry import RegistryServer_pb2, RegistryServer_pb2_grpc from feast.registry_server import RegistryServer from feast.repo_config import RegistryConfig @@ -91,7 +92,7 @@ def s3_registry() -> Registry: return Registry("project", registry_config, None) -@pytest.fixture(scope="session") +@pytest.fixture(scope="function") def minio_registry() -> Registry: bucket_name = "test-bucket" @@ -158,7 +159,7 @@ def pg_registry_async(): container.start() - registry_config = _given_registry_config_for_pg_sql(container, 2, "thread") + registry_config = _given_registry_config_for_pg_sql(container, 2, "thread", 3) yield SqlRegistry(registry_config, "project", None) @@ -166,7 +167,11 @@ def pg_registry_async(): def _given_registry_config_for_pg_sql( - container, cache_ttl_seconds=2, cache_mode="sync" + container, + cache_ttl_seconds=2, + cache_mode="sync", + thread_pool_executor_worker_count=0, + purge_feast_metadata=False, ): log_string_to_wait_for = "database system is ready to accept connections" waited = wait_for_logs( @@ -179,7 +184,7 @@ def _given_registry_config_for_pg_sql( container_port = container.get_exposed_port(5432) container_host = container.get_container_host_ip() - return RegistryConfig( + return SqlRegistryConfig( registry_type="sql", cache_ttl_seconds=cache_ttl_seconds, cache_mode=cache_mode, @@ -187,6 +192,8 @@ def _given_registry_config_for_pg_sql( # to understand that we are using psycopg3. path=f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{container_host}:{container_port}/{POSTGRES_DB}", sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True}, + thread_pool_executor_worker_count=thread_pool_executor_worker_count, + purge_feast_metadata=purge_feast_metadata, ) @@ -207,14 +214,20 @@ def mysql_registry_async(): container = MySqlContainer("mysql:latest") container.start() - registry_config = _given_registry_config_for_mysql(container, 2, "thread") + registry_config = _given_registry_config_for_mysql(container, 2, "thread", 3) yield SqlRegistry(registry_config, "project", None) container.stop() -def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode="sync"): +def _given_registry_config_for_mysql( + container, + cache_ttl_seconds=2, + cache_mode="sync", + thread_pool_executor_worker_count=0, + purge_feast_metadata=False, +): import sqlalchemy engine = sqlalchemy.create_engine( @@ -222,18 +235,20 @@ def _given_registry_config_for_mysql(container, cache_ttl_seconds=2, cache_mode= ) engine.connect() - return RegistryConfig( + return SqlRegistryConfig( registry_type="sql", path=container.get_connection_url(), cache_ttl_seconds=cache_ttl_seconds, cache_mode=cache_mode, sqlalchemy_config_kwargs={"echo": False, "pool_pre_ping": True}, + thread_pool_executor_worker_count=thread_pool_executor_worker_count, + purge_feast_metadata=purge_feast_metadata, ) @pytest.fixture(scope="session") def sqlite_registry(): - registry_config = RegistryConfig( + registry_config = SqlRegistryConfig( registry_type="sql", path="sqlite://", ) @@ -250,7 +265,11 @@ def __init__(self, service, servicer): ) def unary_unary( - self, method: str, request_serializer=None, response_deserializer=None + self, + method: str, + request_serializer=None, + response_deserializer=None, + _registered_method=None, ): method_name = method.split("/")[-1] method_descriptor = self.service.methods_by_name[method_name] @@ -347,9 +366,11 @@ def test_apply_entity_success(test_registry): project_uuid = project_metadata[0].project_uuid assert len(project_metadata[0].project_uuid) == 36 assert_project_uuid(project, project_uuid, test_registry) + assert_project(project, test_registry) entities = test_registry.list_entities(project, tags=entity.tags) assert_project_uuid(project, project_uuid, test_registry) + assert_project(project, test_registry) entity = entities[0] assert ( @@ -386,11 +407,12 @@ def test_apply_entity_success(test_registry): updated_entity.created_timestamp is not None and updated_entity.created_timestamp == entity.created_timestamp ) - test_registry.delete_entity("driver_car_id", project) assert_project_uuid(project, project_uuid, test_registry) + assert_project(project, test_registry) entities = test_registry.list_entities(project) assert_project_uuid(project, project_uuid, test_registry) + assert_project(project, test_registry) assert len(entities) == 0 test_registry.teardown() @@ -402,6 +424,14 @@ def assert_project_uuid(project, project_uuid, test_registry): assert project_metadata[0].project_uuid == project_uuid +def assert_project(project_name, test_registry, allow_cache=False): + project_obj = test_registry.list_projects(allow_cache=allow_cache) + assert len(project_obj) == 1 + assert project_obj[0].name == "project" + project_obj = test_registry.get_project(name=project_name, allow_cache=allow_cache) + assert project_obj.name == "project" + + @pytest.mark.integration @pytest.mark.parametrize( "test_registry", @@ -725,9 +755,10 @@ def simple_udf(x: int): project = "project" # Register Feature Views - test_registry.apply_feature_view(odfv1, project) - test_registry.apply_feature_view(fv1, project) - test_registry.apply_feature_view(sfv, project) + test_registry.apply_feature_view(odfv1, project, False) + test_registry.apply_feature_view(fv1, project, False) + test_registry.apply_feature_view(sfv, project, False) + test_registry.commit() # Modify odfv by changing a single feature dtype @on_demand_feature_view( @@ -1283,6 +1314,10 @@ def test_commit(): project_uuid = project_metadata.project_uuid assert len(project_uuid) == 36 validate_project_uuid(project_uuid, test_registry) + assert len(test_registry.cached_registry_proto.projects) == 1 + project_obj = test_registry.cached_registry_proto.projects[0] + assert project == Project.from_proto(project_obj).name + assert_project(project, test_registry, True) # Retrieving the entity should still succeed entities = test_registry.list_entities(project, allow_cache=True, tags=entity.tags) @@ -1295,6 +1330,7 @@ def test_commit(): and entity.tags["team"] == "matchmaking" ) validate_project_uuid(project_uuid, test_registry) + assert_project(project, test_registry, True) entity = test_registry.get_entity("driver_car_id", project, allow_cache=True) assert ( @@ -1304,6 +1340,7 @@ def test_commit(): and entity.tags["team"] == "matchmaking" ) validate_project_uuid(project_uuid, test_registry) + assert_project(project, test_registry, True) # Create new registry that points to the same store registry_with_same_store = Registry("project", registry_config, None) @@ -1312,6 +1349,7 @@ def test_commit(): entities = registry_with_same_store.list_entities(project) assert len(entities) == 0 validate_project_uuid(project_uuid, registry_with_same_store) + assert_project(project, test_registry, True) # commit from the original registry test_registry.commit() @@ -1330,6 +1368,7 @@ def test_commit(): and entity.tags["team"] == "matchmaking" ) validate_project_uuid(project_uuid, registry_with_same_store) + assert_project(project, test_registry) entity = test_registry.get_entity("driver_car_id", project) assert ( @@ -1371,6 +1410,7 @@ def test_apply_permission_success(test_registry): project_uuid = project_metadata[0].project_uuid assert len(project_metadata[0].project_uuid) == 36 assert_project_uuid(project, project_uuid, test_registry) + assert_project(project, test_registry) permissions = test_registry.list_permissions(project) assert_project_uuid(project, project_uuid, test_registry) @@ -1483,5 +1523,194 @@ def test_apply_permission_success(test_registry): permissions = test_registry.list_permissions(project) assert_project_uuid(project, project_uuid, test_registry) assert len(permissions) == 0 + assert_project(project, test_registry) + + test_registry.teardown() + + +@pytest.mark.integration +@pytest.mark.parametrize("test_registry", all_fixtures) +def test_apply_project_success(test_registry): + project = Project( + name="project", + description="Project description", + tags={"team": "project team"}, + owner="owner@mail.com", + ) + + # Register Project + test_registry.apply_project(project) + assert_project(project.name, test_registry, False) + + projects_list = test_registry.list_projects(tags=project.tags) + + assert_project(projects_list[0].name, test_registry) + + project_get = test_registry.get_project("project") + assert ( + project_get.name == project.name + and project_get.description == project.description + and project_get.tags == project.tags + and project_get.owner == project.owner + ) + + # Update project + updated_project = Project( + name=project.name, + description="New Project Description", + tags={"team": "matchmaking", "app": "feast"}, + ) + test_registry.apply_project(updated_project) + + updated_project_get = test_registry.get_project(project.name) + + # The created_timestamp for the entity should be set to the created_timestamp value stored from the previous apply + assert ( + updated_project_get.created_timestamp is not None + and updated_project_get.created_timestamp == project_get.created_timestamp + ) + + assert ( + updated_project_get.created_timestamp + < updated_project_get.last_updated_timestamp + ) + + entity = Entity( + name="driver_car_id", + description="Car driver id", + tags={"team": "matchmaking"}, + ) + + test_registry.apply_entity(entity, project.name) + entities = test_registry.list_entities(project.name) + assert len(entities) == 1 + + test_registry.delete_project(project.name, commit=False) + + test_registry.commit() + + entities = test_registry.list_entities(project.name, False) + assert len(entities) == 0 + projects_list = test_registry.list_projects() + assert len(projects_list) == 0 + + test_registry.refresh(project.name) + + test_registry.teardown() + + +@pytest.fixture +def local_registry_purge_feast_metadata() -> Registry: + fd, registry_path = mkstemp() + registry_config = RegistryConfig( + path=registry_path, cache_ttl_seconds=600, purge_feast_metadata=True + ) + return Registry("project", registry_config, None) + + +@pytest.fixture(scope="function") +def pg_registry_purge_feast_metadata(): + container = ( + DockerContainer("postgres:latest") + .with_exposed_ports(5432) + .with_env("POSTGRES_USER", POSTGRES_USER) + .with_env("POSTGRES_PASSWORD", POSTGRES_PASSWORD) + .with_env("POSTGRES_DB", POSTGRES_DB) + ) + + container.start() + + registry_config = _given_registry_config_for_pg_sql(container, 2, "thread", 3, True) + + yield SqlRegistry(registry_config, "project", None) + + container.stop() + + +@pytest.fixture(scope="function") +def mysql_registry_purge_feast_metadata(): + container = MySqlContainer("mysql:latest") + container.start() + + registry_config = _given_registry_config_for_mysql(container, 2, "thread", 3, True) + + yield SqlRegistry(registry_config, "project", None) + + container.stop() + + +purge_feast_metadata_fixtures = [ + lazy_fixture("local_registry_purge_feast_metadata"), + pytest.param( + lazy_fixture("pg_registry_purge_feast_metadata"), + marks=pytest.mark.xdist_group(name="pg_registry_purge_feast_metadata"), + ), + pytest.param( + lazy_fixture("mysql_registry_purge_feast_metadata"), + marks=pytest.mark.xdist_group(name="mysql_registry_purge_feast_metadata"), + ), +] + + +@pytest.mark.integration +@pytest.mark.parametrize("test_registry", purge_feast_metadata_fixtures) +def test_apply_entity_success_with_purge_feast_metadata(test_registry): + entity = Entity( + name="driver_car_id", + description="Car driver id", + tags={"team": "matchmaking"}, + ) + + project = "project" + + # Register Entity + test_registry.apply_entity(entity, project) + project_metadata = test_registry.list_project_metadata(project=project) + assert len(project_metadata) == 0 + assert_project(project, test_registry) + + entities = test_registry.list_entities(project, tags=entity.tags) + assert_project(project, test_registry) + + entity = entities[0] + assert ( + len(entities) == 1 + and entity.name == "driver_car_id" + and entity.description == "Car driver id" + and "team" in entity.tags + and entity.tags["team"] == "matchmaking" + ) + + entity = test_registry.get_entity("driver_car_id", project) + assert ( + entity.name == "driver_car_id" + and entity.description == "Car driver id" + and "team" in entity.tags + and entity.tags["team"] == "matchmaking" + ) + + # After the first apply, the created_timestamp should be the same as the last_update_timestamp. + assert entity.created_timestamp == entity.last_updated_timestamp + + # Update entity + updated_entity = Entity( + name="driver_car_id", + description="Car driver Id", + tags={"team": "matchmaking"}, + ) + test_registry.apply_entity(updated_entity, project) + + updated_entity = test_registry.get_entity("driver_car_id", project) + + # The created_timestamp for the entity should be set to the created_timestamp value stored from the previous apply + assert ( + updated_entity.created_timestamp is not None + and updated_entity.created_timestamp == entity.created_timestamp + ) + test_registry.delete_entity("driver_car_id", project) + assert_project(project, test_registry) + entities = test_registry.list_entities(project) + assert_project(project, test_registry) + assert len(entities) == 0 test_registry.teardown() diff --git a/sdk/python/tests/unit/permissions/auth/conftest.py b/sdk/python/tests/unit/permissions/auth/conftest.py index ea6e2e4311..5a29f8ec78 100644 --- a/sdk/python/tests/unit/permissions/auth/conftest.py +++ b/sdk/python/tests/unit/permissions/auth/conftest.py @@ -8,6 +8,7 @@ read_fv_perm, read_odfv_perm, read_permissions_perm, + read_projects_perm, read_sfv_perm, ) from tests.unit.permissions.auth.test_token_parser import _CLIENT_ID @@ -90,6 +91,7 @@ def oidc_config() -> OidcAuthConfig: read_fv_perm, read_odfv_perm, read_sfv_perm, + read_projects_perm, ], ], ) diff --git a/sdk/python/tests/unit/permissions/auth/server/test_auth_registry_server.py b/sdk/python/tests/unit/permissions/auth/server/test_auth_registry_server.py index 9e9bc1473e..c72b1aa1e2 100644 --- a/sdk/python/tests/unit/permissions/auth/server/test_auth_registry_server.py +++ b/sdk/python/tests/unit/permissions/auth/server/test_auth_registry_server.py @@ -5,9 +5,7 @@ import pytest import yaml -from feast import ( - FeatureStore, -) +from feast import FeatureStore from feast.errors import ( EntityNotFoundException, FeastPermissionError, @@ -23,6 +21,7 @@ read_fv_perm, read_odfv_perm, read_permissions_perm, + read_projects_perm, read_sfv_perm, ) from tests.utils.auth_permissions_util import get_remote_registry_store @@ -50,7 +49,11 @@ def start_registry_server( assertpy.assert_that(server_port).is_not_equal_to(0) print(f"Starting Registry at {server_port}") - server = start_server(feature_store, server_port, wait_for_termination=False) + server = start_server( + feature_store, + server_port, + wait_for_termination=False, + ) print("Waiting server availability") wait_retry_backoff( lambda: (None, check_port_open("localhost", server_port)), @@ -179,6 +182,7 @@ def _test_list_permissions( read_fv_perm, read_odfv_perm, read_sfv_perm, + read_projects_perm, ], permissions, ): @@ -191,6 +195,7 @@ def _test_list_permissions( read_fv_perm, read_odfv_perm, read_sfv_perm, + read_projects_perm, ] ) ) diff --git a/sdk/python/tests/unit/permissions/auth/server/test_utils.py b/sdk/python/tests/unit/permissions/auth/server/test_utils.py index 5d781919a0..32b4fd8f98 100644 --- a/sdk/python/tests/unit/permissions/auth/server/test_utils.py +++ b/sdk/python/tests/unit/permissions/auth/server/test_utils.py @@ -6,6 +6,7 @@ from feast.permissions.permission import Permission from feast.permissions.policy import RoleBasedPolicy from feast.permissions.server.utils import AuthManagerType, str_to_auth_manager_type +from feast.project import Project read_permissions_perm = Permission( name="read_permissions_perm", @@ -14,6 +15,13 @@ actions=[AuthzedAction.DESCRIBE], ) +read_projects_perm = Permission( + name="read_projects_perm", + types=Project, + policy=RoleBasedPolicy(roles=["reader"]), + actions=[AuthzedAction.DESCRIBE], +) + read_entities_perm = Permission( name="read_entities_perm", types=Entity, diff --git a/sdk/python/tests/unit/test_on_demand_feature_view.py b/sdk/python/tests/unit/test_on_demand_feature_view.py index d9cc5dee50..6073891aba 100644 --- a/sdk/python/tests/unit/test_on_demand_feature_view.py +++ b/sdk/python/tests/unit/test_on_demand_feature_view.py @@ -251,11 +251,9 @@ def test_from_proto_backwards_compatible_udf(): proto.spec.feature_transformation.user_defined_function.body_text ) - # And now we're going to null the feature_transformation proto object before reserializing the entire proto - # proto.spec.user_defined_function.body_text = on_demand_feature_view.transformation.udf_string - proto.spec.feature_transformation.user_defined_function.name = "" - proto.spec.feature_transformation.user_defined_function.body = b"" - proto.spec.feature_transformation.user_defined_function.body_text = "" + # For objects that are already registered, feature_transformation and mode is not set + proto.spec.feature_transformation.Clear() + proto.spec.ClearField("mode") # And now we expect the to get the same object back under feature_transformation reserialized_proto = OnDemandFeatureView.from_proto(proto) diff --git a/sdk/python/tests/unit/test_project.py b/sdk/python/tests/unit/test_project.py new file mode 100644 index 0000000000..f15aef2972 --- /dev/null +++ b/sdk/python/tests/unit/test_project.py @@ -0,0 +1,122 @@ +import unittest +from datetime import datetime, timezone + +from feast.project import Project +from feast.protos.feast.core.Project_pb2 import Project as ProjectProto +from feast.protos.feast.core.Project_pb2 import ProjectMeta as ProjectMetaProto +from feast.protos.feast.core.Project_pb2 import ProjectSpec as ProjectSpecProto + + +class TestProject(unittest.TestCase): + def setUp(self): + self.project_name = "test_project" + self.description = "Test project description" + self.tags = {"env": "test"} + self.owner = "test_owner" + self.created_timestamp = datetime.now(tz=timezone.utc) + self.last_updated_timestamp = datetime.now(tz=timezone.utc) + + def test_initialization(self): + project = Project( + name=self.project_name, + description=self.description, + tags=self.tags, + owner=self.owner, + created_timestamp=self.created_timestamp, + last_updated_timestamp=self.last_updated_timestamp, + ) + self.assertEqual(project.name, self.project_name) + self.assertEqual(project.description, self.description) + self.assertEqual(project.tags, self.tags) + self.assertEqual(project.owner, self.owner) + self.assertEqual(project.created_timestamp, self.created_timestamp) + self.assertEqual(project.last_updated_timestamp, self.last_updated_timestamp) + + def test_equality(self): + project1 = Project(name=self.project_name) + project2 = Project(name=self.project_name) + project3 = Project(name="different_project") + self.assertTrue( + project1.name == project2.name + and project1.description == project2.description + and project1.tags == project2.tags + and project1.owner == project2.owner + ) + self.assertFalse( + project1.name == project3.name + and project1.description == project3.description + and project1.tags == project3.tags + and project1.owner == project3.owner + ) + + def test_is_valid(self): + project = Project(name=self.project_name) + project.is_valid() + with self.assertRaises(ValueError): + invalid_project = Project(name="") + invalid_project.is_valid() + + def test_from_proto(self): + meta = ProjectMetaProto() + meta.created_timestamp.FromDatetime(self.created_timestamp) + meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp) + project_proto = ProjectProto( + spec=ProjectSpecProto( + name=self.project_name, + description=self.description, + tags=self.tags, + owner=self.owner, + ), + meta=meta, + ) + project = Project.from_proto(project_proto) + self.assertEqual(project.name, self.project_name) + self.assertEqual(project.description, self.description) + self.assertEqual(project.tags, self.tags) + self.assertEqual(project.owner, self.owner) + self.assertEqual(project.created_timestamp, self.created_timestamp) + self.assertEqual(project.last_updated_timestamp, self.last_updated_timestamp) + + def test_to_proto(self): + project = Project( + name=self.project_name, + description=self.description, + tags=self.tags, + owner=self.owner, + created_timestamp=self.created_timestamp, + last_updated_timestamp=self.last_updated_timestamp, + ) + project_proto = project.to_proto() + self.assertEqual(project_proto.spec.name, self.project_name) + self.assertEqual(project_proto.spec.description, self.description) + self.assertEqual(project_proto.spec.tags, self.tags) + self.assertEqual(project_proto.spec.owner, self.owner) + self.assertEqual( + project_proto.meta.created_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ), + self.created_timestamp, + ) + self.assertEqual( + project_proto.meta.last_updated_timestamp.ToDatetime().replace( + tzinfo=timezone.utc + ), + self.last_updated_timestamp, + ) + + def test_to_proto_and_back(self): + project = Project( + name=self.project_name, + description=self.description, + tags=self.tags, + owner=self.owner, + created_timestamp=self.created_timestamp, + last_updated_timestamp=self.last_updated_timestamp, + ) + project_proto = project.to_proto() + project_from_proto = Project.from_proto(project_proto) + self.assertEqual(project, project_from_proto) + + +if __name__ == "__main__": + unittest.main()