diff --git a/sdk/python/feast/feast_object.py b/sdk/python/feast/feast_object.py index 0ac0446f5f..38109f5d8c 100644 --- a/sdk/python/feast/feast_object.py +++ b/sdk/python/feast/feast_object.py @@ -1,5 +1,6 @@ from typing import Union +from .batch_feature_view import BatchFeatureView from .data_source import DataSource from .entity import Entity from .feature_service import FeatureService @@ -16,12 +17,15 @@ ) from .request_feature_view import RequestFeatureView from .saved_dataset import ValidationReference +from .stream_feature_view import StreamFeatureView # Convenience type representing all Feast objects FeastObject = Union[ FeatureView, OnDemandFeatureView, RequestFeatureView, + BatchFeatureView, + StreamFeatureView, Entity, FeatureService, DataSource, diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index ea13c3a8db..a3d9de6c26 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -775,6 +775,7 @@ def apply( FeatureView, OnDemandFeatureView, RequestFeatureView, + BatchFeatureView, StreamFeatureView, FeatureService, ValidationReference, @@ -834,9 +835,9 @@ def apply( ob for ob in objects if ( - isinstance(ob, FeatureView) + # BFVs are not handled separately from FVs right now. + (isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView)) and not isinstance(ob, StreamFeatureView) - and not isinstance(ob, BatchFeatureView) ) ] sfvs_to_update = [ob for ob in objects if isinstance(ob, StreamFeatureView)] @@ -919,13 +920,18 @@ def apply( validation_references, project=self.project, commit=False ) + entities_to_delete = [] + views_to_delete = [] + sfvs_to_delete = [] if not partial: # Delete all registry objects that should not exist. entities_to_delete = [ ob for ob in objects_to_delete if isinstance(ob, Entity) ] views_to_delete = [ - ob for ob in objects_to_delete if isinstance(ob, FeatureView) + ob + for ob in objects_to_delete + if isinstance(ob, FeatureView) or isinstance(ob, BatchFeatureView) ] request_views_to_delete = [ ob for ob in objects_to_delete if isinstance(ob, RequestFeatureView) @@ -979,10 +985,13 @@ def apply( validation_references.name, project=self.project, commit=False ) + tables_to_delete: List[FeatureView] = views_to_delete + sfvs_to_delete if not partial else [] # type: ignore + tables_to_keep: List[FeatureView] = views_to_update + sfvs_to_update # type: ignore + self._get_provider().update_infra( project=self.project, - tables_to_delete=views_to_delete + sfvs_to_delete if not partial else [], - tables_to_keep=views_to_update + sfvs_to_update, + tables_to_delete=tables_to_delete, + tables_to_keep=tables_to_keep, entities_to_delete=entities_to_delete if not partial else [], entities_to_keep=entities_to_update, partial=partial, diff --git a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py index b2da58c4c0..2cced75eb2 100644 --- a/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py +++ b/sdk/python/tests/unit/local_feast_tests/test_local_feature_store.py @@ -4,6 +4,7 @@ import pytest from pytest_lazyfixture import lazy_fixture +from feast import BatchFeatureView from feast.aggregation import Aggregation from feast.data_format import AvroFormat, ParquetFormat from feast.data_source import KafkaSource @@ -78,14 +79,29 @@ def test_apply_feature_view(test_feature_store): ttl=timedelta(minutes=5), ) + bfv = BatchFeatureView( + name="batch_feature_view", + schema=[ + Field(name="fs1_my_feature_1", dtype=Int64), + Field(name="fs1_my_feature_2", dtype=String), + Field(name="fs1_my_feature_3", dtype=Array(String)), + Field(name="fs1_my_feature_4", dtype=Array(Bytes)), + Field(name="entity_id", dtype=Int64), + ], + entities=[entity], + tags={"team": "matchmaking"}, + source=batch_source, + ttl=timedelta(minutes=5), + ) + # Register Feature View - test_feature_store.apply([entity, fv1]) + test_feature_store.apply([entity, fv1, bfv]) feature_views = test_feature_store.list_feature_views() # List Feature Views assert ( - len(feature_views) == 1 + len(feature_views) == 2 and feature_views[0].name == "my_feature_view_1" and feature_views[0].features[0].name == "fs1_my_feature_1" and feature_views[0].features[0].dtype == Int64