diff --git a/protos/feast/core/Transformation.proto b/protos/feast/core/Transformation.proto index 36f1e691fe..5cb53e690f 100644 --- a/protos/feast/core/Transformation.proto +++ b/protos/feast/core/Transformation.proto @@ -29,4 +29,5 @@ message FeatureTransformationV2 { message SubstraitTransformationV2 { bytes substrait_plan = 1; + bytes ibis_function = 2; } diff --git a/sdk/python/feast/infra/offline_stores/ibis.py b/sdk/python/feast/infra/offline_stores/ibis.py index f9c6b2d20b..6e0729d6d1 100644 --- a/sdk/python/feast/infra/offline_stores/ibis.py +++ b/sdk/python/feast/infra/offline_stores/ibis.py @@ -193,9 +193,15 @@ def read_fv( event_timestamp_col=event_timestamp_col, ) + odfvs = OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry) + + substrait_odfvs = [fv for fv in odfvs if fv.mode == "substrait"] + for odfv in substrait_odfvs: + res = odfv.transform_ibis(res, full_feature_names) + return IbisRetrievalJob( res, - OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry), + [fv for fv in odfvs if fv.mode != "substrait"], full_feature_names, metadata=RetrievalMetadata( features=feature_refs, diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index cfb322fb2d..b532fa651a 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -392,6 +392,53 @@ def get_request_data_schema(self) -> Dict[str, ValueType]: def _get_projected_feature_name(self, feature: str) -> str: return f"{self.projection.name_to_use()}__{feature}" + def transform_ibis( + self, + ibis_table, + full_feature_names: bool = False, + ): + from ibis.expr.types import Table + + if not isinstance(ibis_table, Table): + raise TypeError("transform_ibis only accepts ibis.expr.types.Table") + + assert type(self.feature_transformation) == SubstraitTransformation + + columns_to_cleanup = [] + for source_fv_projection in self.source_feature_view_projections.values(): + for feature in source_fv_projection.features: + full_feature_ref = f"{source_fv_projection.name}__{feature.name}" + if full_feature_ref in ibis_table.columns: + # Make sure the partial feature name is always present + ibis_table = ibis_table.mutate( + **{feature.name: ibis_table[full_feature_ref]} + ) + columns_to_cleanup.append(feature.name) + elif feature.name in ibis_table.columns: + ibis_table = ibis_table.mutate( + **{full_feature_ref: ibis_table[feature.name]} + ) + columns_to_cleanup.append(full_feature_ref) + + transformed_table = self.feature_transformation.transform_ibis(ibis_table) + + transformed_table = transformed_table.drop(*columns_to_cleanup) + + rename_columns: Dict[str, str] = {} + for feature in self.features: + short_name = feature.name + long_name = self._get_projected_feature_name(feature.name) + if short_name in transformed_table.columns and full_feature_names: + rename_columns[short_name] = long_name + elif not full_feature_names: + rename_columns[long_name] = short_name + + for rename_from, rename_to in rename_columns.items(): + if rename_from in transformed_table.columns: + transformed_table = transformed_table.rename(**{rename_to: rename_from}) + + return transformed_table + def transform_arrow( self, pa_table: pyarrow.Table, @@ -419,7 +466,7 @@ def transform_arrow( columns_to_cleanup.append(full_feature_ref) df_with_transformed_features: pyarrow.Table = ( - self.feature_transformation.transform_arrow(pa_table) + self.feature_transformation.transform_arrow(pa_table, self.features) ) # Work out whether the correct columns names are used. @@ -438,7 +485,7 @@ def transform_arrow( # Cleanup extra columns used for transformation for col in columns_to_cleanup: if col in df_with_transformed_features.column_names: - df_with_transformed_features = df_with_transformed_features.dtop(col) + df_with_transformed_features = df_with_transformed_features.drop(col) return df_with_transformed_features.rename_columns( [ rename_columns.get(c, c) @@ -487,7 +534,9 @@ def get_transformed_features_df( rename_columns[long_name] = short_name # Cleanup extra columns used for transformation - df_with_features.drop(columns=columns_to_cleanup, inplace=True) + df_with_transformed_features = df_with_transformed_features[ + [f.name for f in self.features] + ] return df_with_transformed_features.rename(columns=rename_columns) def get_transformed_features_dict( diff --git a/sdk/python/feast/transformation/pandas_transformation.py b/sdk/python/feast/transformation/pandas_transformation.py index 28f3c22b9f..7e706810cb 100644 --- a/sdk/python/feast/transformation/pandas_transformation.py +++ b/sdk/python/feast/transformation/pandas_transformation.py @@ -27,7 +27,9 @@ def __init__(self, udf: FunctionType, udf_string: str = ""): self.udf = udf self.udf_string = udf_string - def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table: + def transform_arrow( + self, pa_table: pyarrow.Table, features: List[Field] + ) -> pyarrow.Table: if not isinstance(pa_table, pyarrow.Table): raise TypeError( f"pa_table should be type pyarrow.Table but got {type(pa_table).__name__}" diff --git a/sdk/python/feast/transformation/python_transformation.py b/sdk/python/feast/transformation/python_transformation.py index 1245fc52ed..ec950a24f3 100644 --- a/sdk/python/feast/transformation/python_transformation.py +++ b/sdk/python/feast/transformation/python_transformation.py @@ -25,7 +25,9 @@ def __init__(self, udf: FunctionType, udf_string: str = ""): self.udf = udf self.udf_string = udf_string - def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table: + def transform_arrow( + self, pa_table: pyarrow.Table, features: List[Field] + ) -> pyarrow.Table: raise Exception( 'OnDemandFeatureView mode "python" not supported for offline processing.' ) diff --git a/sdk/python/feast/transformation/substrait_transformation.py b/sdk/python/feast/transformation/substrait_transformation.py index a816f8118a..48a87b6207 100644 --- a/sdk/python/feast/transformation/substrait_transformation.py +++ b/sdk/python/feast/transformation/substrait_transformation.py @@ -1,5 +1,7 @@ +from types import FunctionType from typing import Any, Dict, List +import dill import pandas as pd import pyarrow import pyarrow.substrait as substrait # type: ignore # noqa @@ -16,14 +18,16 @@ class SubstraitTransformation: - def __init__(self, substrait_plan: bytes): + def __init__(self, substrait_plan: bytes, ibis_function: FunctionType): """ Creates an SubstraitTransformation object. Args: substrait_plan: The user-provided substrait plan. + ibis_function: The user-provided ibis function. """ self.substrait_plan = substrait_plan + self.ibis_function = ibis_function def transform(self, df: pd.DataFrame) -> pd.DataFrame: def table_provider(names, schema: pyarrow.Schema): @@ -34,13 +38,22 @@ def table_provider(names, schema: pyarrow.Schema): ).read_all() return table.to_pandas() - def transform_arrow(self, pa_table: pyarrow.Table) -> pyarrow.Table: + def transform_ibis(self, table): + return self.ibis_function(table) + + def transform_arrow( + self, pa_table: pyarrow.Table, features: List[Field] = [] + ) -> pyarrow.Table: def table_provider(names, schema: pyarrow.Schema): return pa_table.select(schema.names) table: pyarrow.Table = pyarrow.substrait.run_query( self.substrait_plan, table_provider=table_provider ).read_all() + + if features: + table = table.select([f.name for f in features]) + return table def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]: @@ -55,6 +68,7 @@ def infer_features(self, random_input: Dict[str, List[Any]]) -> List[Field]: ), ) for f, dt in zip(output_df.columns, output_df.dtypes) + if f not in random_input ] def __eq__(self, other): @@ -66,10 +80,17 @@ def __eq__(self, other): if not super().__eq__(other): return False - return self.substrait_plan == other.substrait_plan + return ( + self.substrait_plan == other.substrait_plan + and self.ibis_function.__code__.co_code + == other.ibis_function.__code__.co_code + ) def to_proto(self) -> SubstraitTransformationProto: - return SubstraitTransformationProto(substrait_plan=self.substrait_plan) + return SubstraitTransformationProto( + substrait_plan=self.substrait_plan, + ibis_function=dill.dumps(self.ibis_function, recurse=True), + ) @classmethod def from_proto( @@ -77,7 +98,8 @@ def from_proto( substrait_transformation_proto: SubstraitTransformationProto, ): return SubstraitTransformation( - substrait_plan=substrait_transformation_proto.substrait_plan + substrait_plan=substrait_transformation_proto.substrait_plan, + ibis_function=dill.loads(substrait_transformation_proto.ibis_function), ) @classmethod @@ -91,7 +113,7 @@ def from_ibis(cls, user_function, sources): input_fields = [] for s in sources: - fields = s.projection.features if isinstance(s, FeatureView) else s.features + fields = s.projection.features if isinstance(s, FeatureView) else s.schema input_fields.extend( [ @@ -108,5 +130,6 @@ def from_ibis(cls, user_function, sources): expr = user_function(ibis.table(input_fields, "t")) return SubstraitTransformation( - substrait_plan=compiler.compile(expr).SerializeToString() + substrait_plan=compiler.compile(expr).SerializeToString(), + ibis_function=user_function, ) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 6eb5204161..98fc696e75 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -343,17 +343,23 @@ def values(self): def construct_universal_feature_views( data_sources: UniversalDataSources, with_odfv: bool = True, + use_substrait_odfv: bool = False, ) -> UniversalFeatureViews: driver_hourly_stats = create_driver_hourly_stats_feature_view(data_sources.driver) driver_hourly_stats_base_feature_view = ( create_driver_hourly_stats_batch_feature_view(data_sources.driver) ) + return UniversalFeatureViews( customer=create_customer_daily_profile_feature_view(data_sources.customer), global_fv=create_global_stats_feature_view(data_sources.global_ds), driver=driver_hourly_stats, driver_odfv=conv_rate_plus_100_feature_view( - [driver_hourly_stats_base_feature_view, create_conv_rate_request_source()] + [ + driver_hourly_stats_base_feature_view[["conv_rate"]], + create_conv_rate_request_source(), + ], + use_substrait_odfv=use_substrait_odfv, ) if with_odfv else None, diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index 48f6e27b8a..2a0a9d1bd0 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +from ibis.expr.types.relations import Table from feast import ( BatchFeatureView, @@ -15,7 +16,7 @@ ) from feast.data_source import DataSource, RequestSource from feast.feature_view_projection import FeatureViewProjection -from feast.on_demand_feature_view import PandasTransformation +from feast.on_demand_feature_view import PandasTransformation, SubstraitTransformation from feast.types import Array, FeastType, Float32, Float64, Int32, Int64 from tests.integration.feature_repos.universal.entities import ( customer, @@ -56,10 +57,22 @@ def conv_rate_plus_100(features_df: pd.DataFrame) -> pd.DataFrame: return df +def conv_rate_plus_100_ibis(features_table: Table) -> Table: + return features_table.mutate( + conv_rate_plus_100=features_table["conv_rate"] + 100, + conv_rate_plus_val_to_add=features_table["conv_rate"] + + features_table["val_to_add"], + conv_rate_plus_100_rounded=(features_table["conv_rate"] + 100) + .round(digits=0) + .cast("int32"), + ) + + def conv_rate_plus_100_feature_view( sources: List[Union[FeatureView, RequestSource, FeatureViewProjection]], infer_features: bool = False, features: Optional[List[Field]] = None, + use_substrait_odfv: bool = False, ) -> OnDemandFeatureView: # Test that positional arguments and Features still work for ODFVs. _features = features or [ @@ -73,7 +86,10 @@ def conv_rate_plus_100_feature_view( sources=sources, feature_transformation=PandasTransformation( udf=conv_rate_plus_100, udf_string="raw udf source" - ), + ) + if not use_substrait_odfv + else SubstraitTransformation.from_ibis(conv_rate_plus_100_ibis, sources), + mode="pandas" if not use_substrait_odfv else "substrait", ) diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 7e106b3e2a..958b829a60 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -41,12 +41,19 @@ @pytest.mark.integration @pytest.mark.universal_offline_stores @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: f"full:{v}") -def test_historical_features(environment, universal_data_sources, full_feature_names): +@pytest.mark.parametrize( + "use_substrait_odfv", [True, False], ids=lambda v: f"substrait:{v}" +) +def test_historical_features( + environment, universal_data_sources, full_feature_names, use_substrait_odfv +): store = environment.feature_store (entities, datasets, data_sources) = universal_data_sources - feature_views = construct_universal_feature_views(data_sources) + feature_views = construct_universal_feature_views( + data_sources, use_substrait_odfv=use_substrait_odfv + ) entity_df_with_request_data = datasets.entity_df.copy(deep=True) entity_df_with_request_data["val_to_add"] = [ diff --git a/sdk/python/tests/unit/test_substrait_transformation.py b/sdk/python/tests/unit/test_substrait_transformation.py index 28ab68c70b..351651cfda 100644 --- a/sdk/python/tests/unit/test_substrait_transformation.py +++ b/sdk/python/tests/unit/test_substrait_transformation.py @@ -75,10 +75,8 @@ def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: mode="substrait", ) def substrait_view(inputs: Table) -> Table: - return inputs.select( - (inputs["conv_rate"] + inputs["acc_rate"]).name( - "conv_rate_plus_acc_substrait" - ) + return inputs.mutate( + conv_rate_plus_acc_substrait=inputs["conv_rate"] + inputs["acc_rate"] ) store.apply(