Skip to content

Commit

Permalink
fix: Substrait ODFVs for online (#4064)
Browse files Browse the repository at this point in the history
* fix substrait odfvs for online, add tests

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

* fix formatting

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

* change odfv substrait test dates relative to start_date and end_date

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

* force tests rerun

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>

---------

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko authored Apr 2, 2024
1 parent d82d1ec commit 26391b0
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 21 deletions.
9 changes: 4 additions & 5 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,11 +2037,10 @@ def _augment_response_with_on_demand_transforms(

proto_values = []
for selected_feature in selected_subset:
if odfv.mode in ["python", "pandas"]:
feature_vector = transformed_features[selected_feature]
proto_values.append(
python_values_to_proto_values(feature_vector, ValueType.UNKNOWN)
)
feature_vector = transformed_features[selected_feature]
proto_values.append(
python_values_to_proto_values(feature_vector, ValueType.UNKNOWN)
)

odfv_result_names |= set(selected_subset)

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def to_arrow(
features_df = self._to_df_internal(timeout=timeout)
if self.on_demand_feature_views:
for odfv in self.on_demand_feature_views:
if odfv.mode != "pandas":
if odfv.mode not in {"pandas", "substrait"}:
raise Exception(
f'OnDemandFeatureView mode "{odfv.mode}" not supported for offline processing.'
)
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/on_demand_feature_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,9 @@ def get_transformed_features(
return self.get_transformed_features_dict(
feature_dict=features,
)
elif self.mode == "pandas" and isinstance(features, pd.DataFrame):
elif self.mode in {"pandas", "substrait"} and isinstance(
features, pd.DataFrame
):
return self.get_transformed_features_df(
df_with_features=features,
full_feature_names=full_feature_names,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_ibis_pandas_parity():
@on_demand_feature_view(
sources=[driver_stats_fv],
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
mode="pandas",
)
def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame()
Expand All @@ -84,30 +85,50 @@ def substrait_view(inputs: Table) -> Table:
[driver, driver_stats_source, driver_stats_fv, substrait_view, pandas_view]
)

store.materialize(
start_date=start_date,
end_date=end_date,
)

entity_df = pd.DataFrame.from_dict(
{
# entity's join key -> entity values
"driver_id": [1001, 1002, 1003],
# "event_timestamp" (reserved key) -> timestamps
"event_timestamp": [
datetime(2021, 4, 12, 10, 59, 42),
datetime(2021, 4, 12, 8, 12, 10),
datetime(2021, 4, 12, 16, 40, 26),
start_date + timedelta(days=4),
start_date + timedelta(days=5),
start_date + timedelta(days=6),
],
}
)

requested_features = [
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"substrait_view:conv_rate_plus_acc_substrait",
"pandas_view:conv_rate_plus_acc",
]

training_df = store.get_historical_features(
entity_df=entity_df,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"driver_hourly_stats:avg_daily_trips",
"substrait_view:conv_rate_plus_acc_substrait",
"pandas_view:conv_rate_plus_acc",
],
).to_df()
entity_df=entity_df, features=requested_features
)

assert training_df.to_df()["conv_rate_plus_acc"].equals(
training_df.to_df()["conv_rate_plus_acc_substrait"]
)

assert training_df.to_arrow()["conv_rate_plus_acc"].equals(
training_df.to_arrow()["conv_rate_plus_acc_substrait"]
)

online_response = store.get_online_features(
features=requested_features,
entity_rows=[{"driver_id": 1001}, {"driver_id": 1002}, {"driver_id": 1003}],
)

assert training_df["conv_rate_plus_acc"].equals(
training_df["conv_rate_plus_acc_substrait"]
assert (
online_response.to_dict()["conv_rate_plus_acc"]
== online_response.to_dict()["conv_rate_plus_acc_substrait"]
)

0 comments on commit 26391b0

Please sign in to comment.