Skip to content

Commit

Permalink
Add entity column validations when getting historical features from b…
Browse files Browse the repository at this point in the history
…igquery (#1614)

* Add entity column validations when getting historical features from bigquery

Signed-off-by: Achal Shah <achals@gmail.com>

* make format

Signed-off-by: Achal Shah <achals@gmail.com>

* Remove wrong file

Signed-off-by: Achal Shah <achals@gmail.com>

* Add tests

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Jun 7, 2021
1 parent e712782 commit 17231d0
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
8 changes: 8 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,11 @@ def __init__(self, offline_store_name: str, data_source_name: str):
super().__init__(
f"Offline Store '{offline_store_name}' does not support data source '{data_source_name}'"
)


class FeastEntityDFMissingColumnsError(Exception):
def __init__(self, expected, missing):
super().__init__(
f"The entity dataframe you have provided must contain columns {expected}, "
f"but {missing} were missing."
)
53 changes: 52 additions & 1 deletion sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import time
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from typing import List, Optional, Union
from typing import List, Optional, Set, Union

import pandas
import pyarrow
from jinja2 import BaseLoader, Environment

from feast import errors
from feast.data_source import BigQuerySource, DataSource
from feast.errors import FeastProviderLoginError
from feast.feature_view import FeatureView
Expand Down Expand Up @@ -87,13 +88,18 @@ def get_historical_features(

client = _get_bigquery_client()

expected_join_keys = _get_join_keys(project, feature_views, registry)

if type(entity_df) is str:
entity_df_job = client.query(entity_df)
entity_df_result = entity_df_job.result() # also starts job

entity_df_event_timestamp_col = _infer_event_timestamp_from_bigquery_query(
entity_df_result
)
_assert_expected_columns_in_bigquery(
expected_join_keys, entity_df_event_timestamp_col, entity_df_result
)

entity_df_sql_table = f"`{entity_df_job.destination.project}.{entity_df_job.destination.dataset_id}.{entity_df_job.destination.table_id}`"
elif isinstance(entity_df, pandas.DataFrame):
Expand All @@ -103,6 +109,10 @@ def get_historical_features(

assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

_assert_expected_columns_in_dataframe(
expected_join_keys, entity_df_event_timestamp_col, entity_df
)

table_id = _upload_entity_df_into_bigquery(
config.project, config.offline_store.dataset, entity_df, client
)
Expand Down Expand Up @@ -132,6 +142,47 @@ def get_historical_features(
return job


def _assert_expected_columns_in_dataframe(
join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df: pandas.DataFrame
):
entity_df_columns = set(entity_df.columns.values)
expected_columns = join_keys.copy()
expected_columns.add(entity_df_event_timestamp_col)

missing_keys = expected_columns - entity_df_columns

if len(missing_keys) != 0:
raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys)


def _assert_expected_columns_in_bigquery(
join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df_result
):
entity_columns = set()
for schema_field in entity_df_result.schema:
entity_columns.add(schema_field.name)

expected_columns = join_keys.copy()
expected_columns.add(entity_df_event_timestamp_col)

missing_keys = expected_columns - entity_columns

if len(missing_keys) != 0:
raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys)


def _get_join_keys(
project: str, feature_views: List[FeatureView], registry: Registry
) -> Set[str]:
join_keys = set()
for feature_view in feature_views:
entities = feature_view.entities
for entity_name in entities:
entity = registry.get_entity(entity_name, project)
join_keys.add(entity.join_key)
return join_keys


def _infer_event_timestamp_from_bigquery_query(entity_df_result) -> str:
if any(
schema_field.name == DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
Expand Down
43 changes: 42 additions & 1 deletion sdk/python/tests/test_historical_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pytz import utc

import feast.driver_test_data as driver_data
from feast import utils
from feast import errors, utils
from feast.data_source import BigQuerySource, FileSource
from feast.entity import Entity
from feast.feature import Feature
Expand Down Expand Up @@ -454,6 +454,30 @@ def test_historical_features_from_bigquery_sources(
check_dtype=False,
)

timestamp_column = (
"e_ts"
if infer_event_timestamp_col
else DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
)

entity_df_query_with_invalid_join_key = (
f"select order_id, driver_id, customer_id as customer, "
f"order_is_success, {timestamp_column}, FROM {gcp_project}.{table_id}"
)
# Rename the join key; this should now raise an error.
assertpy.assert_that(store.get_historical_features).raises(
errors.FeastEntityDFMissingColumnsError
).when_called_with(
entity_df=entity_df_query_with_invalid_join_key,
feature_refs=[
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
],
)

job_from_df = store.get_historical_features(
entity_df=orders_df,
feature_refs=[
Expand All @@ -465,6 +489,23 @@ def test_historical_features_from_bigquery_sources(
],
)

# Rename the join key; this should now raise an error.
orders_df_with_invalid_join_key = orders_df.rename(
{"customer_id": "customer"}, axis="columns"
)
assertpy.assert_that(store.get_historical_features).raises(
errors.FeastEntityDFMissingColumnsError
).when_called_with(
entity_df=orders_df_with_invalid_join_key,
feature_refs=[
"driver_stats:conv_rate",
"driver_stats:avg_daily_trips",
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
],
)

# Make sure that custom dataset name is being used from the offline_store config
if provider_type == "gcp_custom_offline_config":
assertpy.assert_that(job_from_df.query).contains("foo.entity_df")
Expand Down

0 comments on commit 17231d0

Please sign in to comment.