diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 9131bd3ed8..53dbe03368 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -1,7 +1,7 @@ import time from dataclasses import asdict, dataclass from datetime import datetime, timedelta -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union import pandas import pyarrow @@ -130,9 +130,9 @@ class FeatureViewQueryContext: table_ref: str event_timestamp_column: str created_timestamp_column: str - field_mapping: Dict[str, str] query: str table_subquery: str + entity_selections: List[str] def _upload_entity_df_into_bigquery(project, entity_df) -> str: @@ -178,9 +178,17 @@ def get_feature_view_query_context( query_context = [] for feature_view, features in feature_views_to_feature_map.items(): join_keys = [] + entity_selections = [] + reverse_field_mapping = { + v: k for k, v in feature_view.input.field_mapping.items() + } for entity_name in feature_view.entities: entity = registry.get_entity(entity_name, project) join_keys.append(entity.join_key) + join_key_column = reverse_field_mapping.get( + entity.join_key, entity.join_key + ) + entity_selections.append(f"{join_key_column} AS {entity.join_key}") if isinstance(feature_view.ttl, timedelta): ttl_seconds = int(feature_view.ttl.total_seconds()) @@ -189,18 +197,25 @@ def get_feature_view_query_context( assert isinstance(feature_view.input, BigQuerySource) + event_timestamp_column = feature_view.input.event_timestamp_column + created_timestamp_column = feature_view.input.created_timestamp_column + context = FeatureViewQueryContext( name=feature_view.name, ttl=ttl_seconds, entities=join_keys, features=features, table_ref=feature_view.input.table_ref, - event_timestamp_column=feature_view.input.event_timestamp_column, - created_timestamp_column=feature_view.input.created_timestamp_column, + event_timestamp_column=reverse_field_mapping.get( + event_timestamp_column, event_timestamp_column + ), + created_timestamp_column=reverse_field_mapping.get( + created_timestamp_column, created_timestamp_column + ), # TODO: Make created column optional and not hardcoded - field_mapping=feature_view.input.field_mapping, query=feature_view.input.query, table_subquery=feature_view.input.get_table_query_string(), + entity_selections=entity_selections, ) query_context.append(context) return query_context @@ -267,7 +282,7 @@ def build_point_in_time_query( {{ featureview.event_timestamp_column }} as event_timestamp, {{ featureview.event_timestamp_column }} as {{ featureview.name }}_feature_timestamp, {{ featureview.created_timestamp_column }} as created_timestamp, - {{ featureview.entities | join(', ')}}, + {{ featureview.entity_selections | join(', ')}}, false AS is_entity_table FROM {{ featureview.table_subquery }} WHERE {{ featureview.event_timestamp_column }} <= '{{ max_timestamp }}' {% if featureview.ttl == 0 %}{% else %}AND {{ featureview.event_timestamp_column }} >= Timestamp_sub(TIMESTAMP '{{ min_timestamp }}', interval {{ featureview.ttl }} second){% endif %} @@ -308,7 +323,7 @@ def build_point_in_time_query( SELECT {{ featureview.event_timestamp_column }} as {{ featureview.name }}_feature_timestamp, {{ featureview.created_timestamp_column }} as created_timestamp, - {{ featureview.entities | join(', ')}}, + {{ featureview.entity_selections | join(', ')}}, {% for feature in featureview.features %} {{ feature }} as {{ featureview.name }}__{{ feature }}{% if loop.last %}{% else %}, {% endif %} {% endfor %} diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index ffe67e804d..b5e0201827 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -11,6 +11,7 @@ from feast.infra.provider import ( ENTITY_DF_EVENT_TIMESTAMP_COL, _get_requested_feature_views_to_features_dict, + _run_field_mapping, ) from feast.registry import Registry from feast.repo_config import RepoConfig @@ -55,6 +56,10 @@ def get_historical_features( # Create lazy function that is only called from the RetrievalJob object def evaluate_historical_retrieval(): + # Make sure all event timestamp fields are tz-aware. We default tz-naive fields to UTC + entity_df[ENTITY_DF_EVENT_TIMESTAMP_COL] = entity_df[ + ENTITY_DF_EVENT_TIMESTAMP_COL + ].apply(lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc)) # Sort entity dataframe prior to join, and create a copy to prevent modifying the original entity_df_with_features = entity_df.sort_values( ENTITY_DF_EVENT_TIMESTAMP_COL @@ -65,10 +70,29 @@ def evaluate_historical_retrieval(): event_timestamp_column = feature_view.input.event_timestamp_column created_timestamp_column = feature_view.input.created_timestamp_column - # Read dataframe to join to entity dataframe - df_to_join = pd.read_parquet(feature_view.input.path).sort_values( + # Read offline parquet data in pyarrow format + table = pyarrow.parquet.read_table(feature_view.input.path) + + # Rename columns by the field mapping dictionary if it exists + if feature_view.input.field_mapping is not None: + table = _run_field_mapping(table, feature_view.input.field_mapping) + + # Convert pyarrow table to pandas dataframe + df_to_join = table.to_pandas() + + # Make sure all timestamp fields are tz-aware. We default tz-naive fields to UTC + df_to_join[event_timestamp_column] = df_to_join[ event_timestamp_column - ) + ].apply(lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc)) + if created_timestamp_column: + df_to_join[created_timestamp_column] = df_to_join[ + created_timestamp_column + ].apply( + lambda x: x if x.tz is not None else x.replace(tzinfo=pytz.utc) + ) + + # Sort dataframe by the event timestamp column + df_to_join = df_to_join.sort_values(event_timestamp_column) # Build a list of all the features we should select from this source feature_names = [] diff --git a/sdk/python/tests/test_materialize.py b/sdk/python/tests/test_offline_online_store_consistency.py similarity index 73% rename from sdk/python/tests/test_materialize.py rename to sdk/python/tests/test_offline_online_store_consistency.py index cc9eadbdb7..102e9f4fed 100644 --- a/sdk/python/tests/test_materialize.py +++ b/sdk/python/tests/test_offline_online_store_consistency.py @@ -47,7 +47,7 @@ def create_dataset() -> pd.DataFrame: def get_feature_view(data_source: Union[FileSource, BigQuerySource]) -> FeatureView: return FeatureView( name="test_bq_correctness", - entities=["driver_id"], + entities=["driver"], features=[Feature("value", ValueType.FLOAT)], ttl=timedelta(days=5), input=data_source, @@ -83,20 +83,20 @@ def prep_bq_fs_and_fv( event_timestamp_column="ts", created_timestamp_column="created_ts", date_partition_column="", - field_mapping={"ts_1": "ts", "id": "driver_ident"}, + field_mapping={"ts_1": "ts", "id": "driver_id"}, ) fv = get_feature_view(bigquery_source) e = Entity( - name="driver_id", + name="driver", description="id for driver", - join_key="driver_ident", + join_key="driver_id", value_type=ValueType.INT32, ) with tempfile.TemporaryDirectory() as repo_dir_name: config = RepoConfig( registry=str(Path(repo_dir_name) / "registry.db"), - project=f"test_bq_correctness_{uuid.uuid4()}", + project=f"test_bq_correctness_{str(uuid.uuid4()).replace('-', '')}", provider="gcp", ) fs = FeatureStore(config=config) @@ -121,7 +121,10 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: ) fv = get_feature_view(file_source) e = Entity( - name="driver_id", description="id for driver", value_type=ValueType.INT32 + name="driver", + description="id for driver", + join_key="driver_id", + value_type=ValueType.INT32, ) with tempfile.TemporaryDirectory() as repo_dir_name, tempfile.TemporaryDirectory() as data_dir_name: config = RepoConfig( @@ -138,7 +141,34 @@ def prep_local_fs_and_fv() -> Iterator[Tuple[FeatureStore, FeatureView]]: yield fs, fv -def run_materialization_test(fs: FeatureStore, fv: FeatureView) -> None: +# Checks that both offline & online store values are as expected +def check_offline_and_online_features( + fs: FeatureStore, + fv: FeatureView, + driver_id: int, + event_timestamp: datetime, + expected_value: float, +) -> None: + # Check online store + response_dict = fs.get_online_features( + [f"{fv.name}:value"], [{"driver": driver_id}] + ).to_dict() + assert abs(response_dict[f"{fv.name}__value"][0] - expected_value) < 1e-6 + + # Check offline store + df = fs.get_historical_features( + entity_df=pd.DataFrame.from_dict( + {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} + ), + feature_refs=[f"{fv.name}:value"], + ).to_df() + + assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6 + + +def run_offline_online_store_consistency_test( + fs: FeatureStore, fv: FeatureView +) -> None: now = datetime.utcnow() # Run materialize() # use both tz-naive & tz-aware timestamps to test that they're both correctly handled @@ -147,38 +177,33 @@ def run_materialization_test(fs: FeatureStore, fv: FeatureView) -> None: fs.materialize(feature_views=[fv.name], start_date=start_date, end_date=end_date) # check result of materialize() - response_dict = fs.get_online_features( - [f"{fv.name}:value"], [{"driver_id": 1}] - ).to_dict() - assert abs(response_dict[f"{fv.name}__value"][0] - 0.3) < 1e-6 + check_offline_and_online_features( + fs=fs, fv=fv, driver_id=1, event_timestamp=end_date, expected_value=0.3 + ) # check prior value for materialize_incremental() - response_dict = fs.get_online_features( - [f"{fv.name}:value"], [{"driver_id": 3}] - ).to_dict() - assert abs(response_dict[f"{fv.name}__value"][0] - 4) < 1e-6 + check_offline_and_online_features( + fs=fs, fv=fv, driver_id=3, event_timestamp=end_date, expected_value=4 + ) # run materialize_incremental() - fs.materialize_incremental( - feature_views=[fv.name], end_date=now - timedelta(seconds=0), - ) + fs.materialize_incremental(feature_views=[fv.name], end_date=now) # check result of materialize_incremental() - response_dict = fs.get_online_features( - [f"{fv.name}:value"], [{"driver_id": 3}] - ).to_dict() - assert abs(response_dict[f"{fv.name}__value"][0] - 5) < 1e-6 + check_offline_and_online_features( + fs=fs, fv=fv, driver_id=3, event_timestamp=now, expected_value=5 + ) @pytest.mark.integration @pytest.mark.parametrize( "bq_source_type", ["query", "table"], ) -def test_bq_materialization(bq_source_type: str): +def test_bq_offline_online_store_consistency(bq_source_type: str): with prep_bq_fs_and_fv(bq_source_type) as (fs, fv): - run_materialization_test(fs, fv) + run_offline_online_store_consistency_test(fs, fv) -def test_local_materialization(): +def test_local_offline_online_store_consistency(): with prep_local_fs_and_fv() as (fs, fv): - run_materialization_test(fs, fv) + run_offline_online_store_consistency_test(fs, fv)