From f652bd4fd31f02d07d2506b837bf74d7a0fe1cbb Mon Sep 17 00:00:00 2001 From: tokoko Date: Thu, 21 Mar 2024 04:15:07 +0000 Subject: [PATCH] feat: refactor ibis point-in-time-join Signed-off-by: tokoko --- .../contrib/ibis_offline_store/ibis.py | 231 +++++++++--------- .../requirements/py3.10-ci-requirements.txt | 29 ++- .../requirements/py3.10-requirements.txt | 4 +- .../requirements/py3.9-ci-requirements.txt | 29 ++- .../requirements/py3.9-requirements.txt | 4 +- .../unit/infra/offline_stores/test_ibis.py | 88 +++++++ setup.py | 1 + 7 files changed, 248 insertions(+), 138 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/test_ibis.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py index 72e0d970c6..cb35cc083e 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py @@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range( return entity_df_event_timestamp_range - @staticmethod - def _get_historical_features_one( - feature_view: FeatureView, - entity_table: Table, - feature_refs: List[str], - full_feature_names: bool, - timestamp_range: Tuple, - acc_table: Table, - event_timestamp_col: str, - ) -> Table: - fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - - for old_name, new_name in feature_view.batch_source.field_mapping.items(): - if old_name in fv_table.columns: - fv_table = fv_table.rename({new_name: old_name}) - - timestamp_field = feature_view.batch_source.timestamp_field - - # TODO mutate only if tz-naive - fv_table = fv_table.mutate( - **{ - timestamp_field: fv_table[timestamp_field].cast( - dt.Timestamp(timezone="UTC") - ) - } - ) - - full_name_prefix = feature_view.projection.name_alias or feature_view.name - - feature_refs = [ - fr.split(":")[1] - for fr in feature_refs - if fr.startswith(f"{full_name_prefix}:") - ] - - timestamp_range_start_minus_ttl = ( - timestamp_range[0] - feature_view.ttl - if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0) - else timestamp_range[0] - ) - - timestamp_range_start_minus_ttl = ibis.literal( - timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - timestamp_range_end = ibis.literal( - timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - fv_table = fv_table.filter( - ibis.and_( - fv_table[timestamp_field] <= timestamp_range_end, - fv_table[timestamp_field] >= timestamp_range_start_minus_ttl, - ) - ) - - # join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns} - # predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()] - - if feature_view.projection.join_key_map: - predicates = [ - fv_table[k] == entity_table[v] - for k, v in feature_view.projection.join_key_map.items() - ] - else: - predicates = [ - fv_table[e.name] == entity_table[e.name] - for e in feature_view.entity_columns - ] - - predicates.append( - fv_table[timestamp_field] <= entity_table[event_timestamp_col] - ) - - fv_table = fv_table.inner_join( - entity_table, predicates, lname="", rname="{name}_y" - ) - - fv_table = ( - fv_table.group_by(by="entity_row_id") - .order_by(ibis.desc(fv_table[timestamp_field])) - .mutate(rn=ibis.row_number()) - ) - - fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0)) - - select_cols = ["entity_row_id"] - select_cols.extend(feature_refs) - fv_table = fv_table.select(select_cols) - - if full_feature_names: - fv_table = fv_table.rename( - {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} - ) - - acc_table = acc_table.left_join( - fv_table, - predicates=[fv_table.entity_row_id == acc_table.entity_row_id], - lname="", - rname="{name}_yyyy", - ) - - acc_table = acc_table.drop(s.endswith("_yyyy")) - - return acc_table - @staticmethod def _to_utc(entity_df: pd.DataFrame, event_timestamp_col): entity_df_event_timestamp = entity_df.loc[ @@ -228,9 +122,11 @@ def get_historical_features( entity_schema=entity_schema, ) + # TODO get range with ibis timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range( entity_df, event_timestamp_col ) + entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col) entity_table = ibis.memtable(entity_df) @@ -238,20 +134,55 @@ def get_historical_features( entity_table, feature_views, event_timestamp_col ) - res: Table = entity_table + def read_fv(feature_view, feature_refs, full_feature_names): + fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - for fv in feature_views: - res = IbisOfflineStore._get_historical_features_one( - fv, - entity_table, + for old_name, new_name in feature_view.batch_source.field_mapping.items(): + if old_name in fv_table.columns: + fv_table = fv_table.rename({new_name: old_name}) + + timestamp_field = feature_view.batch_source.timestamp_field + + # TODO mutate only if tz-naive + fv_table = fv_table.mutate( + **{ + timestamp_field: fv_table[timestamp_field].cast( + dt.Timestamp(timezone="UTC") + ) + } + ) + + full_name_prefix = feature_view.projection.name_alias or feature_view.name + + feature_refs = [ + fr.split(":")[1] + for fr in feature_refs + if fr.startswith(f"{full_name_prefix}:") + ] + + if full_feature_names: + fv_table = fv_table.rename( + {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} + ) + + feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs] + + return ( + fv_table, + feature_view.batch_source.timestamp_field, + feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}, feature_refs, - full_feature_names, - timestamp_range, - res, - event_timestamp_col, + feature_view.ttl ) - res = res.drop("entity_row_id") + res = point_in_time_join( + entity_table=entity_table, + feature_tables=[ + read_fv(feature_view, feature_refs, full_feature_names) + for feature_view in feature_views + ], + event_timestamp_col=event_timestamp_col + ) return IbisRetrievalJob( res, @@ -285,6 +216,10 @@ def pull_all_from_table_or_query( table = table.select(*fields) + # TODO get rid of this fix + if '__log_date' in table.columns: + table = table.drop('__log_date') + table = table.filter( ibis.and_( table[timestamp_field] >= ibis.literal(start_date), @@ -320,6 +255,7 @@ def write_logged_features( else: kwargs = {} + #TODO always write to directory table.to_parquet( f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs ) @@ -405,3 +341,66 @@ def persist( @property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata + + +def point_in_time_join( + entity_table: Table, + feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col = 'event_timestamp' +): + #TODO handle ttl + all_entities = [event_timestamp_col] + for feature_table, timestamp_field, join_key_map, _, _ in feature_tables: + all_entities.extend(join_key_map.values()) + + r = ibis.literal("") + + for e in set(all_entities): + r = r.concat(entity_table[e].cast("string")) # type: ignore + + entity_table = entity_table.mutate(entity_row_id=r) + + acc_table = entity_table + + for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables: + predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()] + + predicates.append( + feature_table[timestamp_field] <= entity_table[event_timestamp_col], + ) + + if ttl: + predicates.append( + feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl) + ) + + feature_table = feature_table.inner_join( + entity_table, predicates, lname="", rname="{name}_y" + ) + + feature_table = feature_table.drop(s.endswith("_y")) + + feature_table = ( + feature_table.group_by(by="entity_row_id") + .order_by(ibis.desc(feature_table[timestamp_field])) + .mutate(rn=ibis.row_number()) + ) + + feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn") + + select_cols = ["entity_row_id"] + select_cols.extend(feature_refs) + feature_table = feature_table.select(select_cols) + + acc_table = acc_table.left_join( + feature_table, + predicates=[feature_table.entity_row_id == acc_table.entity_row_id], + lname="", + rname="{name}_yyyy", + ) + + acc_table = acc_table.drop(s.endswith("_yyyy")) + + acc_table = acc_table.drop('entity_row_id') + + return acc_table \ No newline at end of file diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 8f0ef90d77..737271eee1 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -61,11 +61,11 @@ black==22.12.0 # via feast (setup.py) bleach==6.1.0 # via nbconvert -boto3==1.34.65 +boto3==1.34.67 # via # feast (setup.py) # moto -botocore==1.34.65 +botocore==1.34.67 # via # boto3 # moto @@ -82,7 +82,7 @@ cachecontrol==0.14.0 # via firebase-admin cachetools==5.3.3 # via google-auth -cassandra-driver==3.29.0 +cassandra-driver==3.29.1 # via feast (setup.py) certifi==2024.2.2 # via @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -213,7 +219,7 @@ google-api-core[grpc]==2.17.1 # google-cloud-storage google-api-python-client==2.122.0 # via firebase-admin -google-auth==2.28.2 +google-auth==2.29.0 # via # google-api-core # google-api-python-client @@ -258,7 +264,7 @@ googleapis-common-protos[grpc]==1.63.0 # google-api-core # grpc-google-iam-v1 # grpcio-status -great-expectations==0.18.11 +great-expectations==0.18.12 # via feast (setup.py) greenlet==3.0.3 # via sqlalchemy @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -331,7 +337,7 @@ importlib-metadata==6.11.0 # via # dask # feast (setup.py) -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) iniconfig==2.0.0 # via pytest @@ -459,7 +465,7 @@ moreorless==0.4.0 # via bowler moto==4.2.14 # via feast (setup.py) -msal==1.27.0 +msal==1.28.0 # via # azure-identity # msal-extensions @@ -844,8 +850,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 @@ -984,7 +995,7 @@ urllib3==1.26.18 # requests # responses # rockset -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.10-requirements.txt b/sdk/python/requirements/py3.10-requirements.txt index e17a588538..240f43b57e 100644 --- a/sdk/python/requirements/py3.10-requirements.txt +++ b/sdk/python/requirements/py3.10-requirements.txt @@ -62,7 +62,7 @@ importlib-metadata==6.11.0 # via # dask # feast (setup.py) -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) jinja2==3.1.3 # via feast (setup.py) @@ -176,7 +176,7 @@ tzdata==2024.1 # via pandas urllib3==2.2.1 # via requests -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.9-ci-requirements.txt b/sdk/python/requirements/py3.9-ci-requirements.txt index dc96554431..f2585a7978 100644 --- a/sdk/python/requirements/py3.9-ci-requirements.txt +++ b/sdk/python/requirements/py3.9-ci-requirements.txt @@ -61,11 +61,11 @@ black==22.12.0 # via feast (setup.py) bleach==6.1.0 # via nbconvert -boto3==1.34.65 +boto3==1.34.67 # via # feast (setup.py) # moto -botocore==1.34.65 +botocore==1.34.67 # via # boto3 # moto @@ -82,7 +82,7 @@ cachecontrol==0.14.0 # via firebase-admin cachetools==5.3.3 # via google-auth -cassandra-driver==3.29.0 +cassandra-driver==3.29.1 # via feast (setup.py) certifi==2024.2.2 # via @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -213,7 +219,7 @@ google-api-core[grpc]==2.17.1 # google-cloud-storage google-api-python-client==2.122.0 # via firebase-admin -google-auth==2.28.2 +google-auth==2.29.0 # via # google-api-core # google-api-python-client @@ -258,7 +264,7 @@ googleapis-common-protos[grpc]==1.63.0 # google-api-core # grpc-google-iam-v1 # grpcio-status -great-expectations==0.18.11 +great-expectations==0.18.12 # via feast (setup.py) greenlet==3.0.3 # via sqlalchemy @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -339,7 +345,7 @@ importlib-metadata==6.11.0 # nbconvert # sphinx # typeguard -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) iniconfig==2.0.0 # via pytest @@ -467,7 +473,7 @@ moreorless==0.4.0 # via bowler moto==4.2.14 # via feast (setup.py) -msal==1.27.0 +msal==1.28.0 # via # azure-identity # msal-extensions @@ -854,8 +860,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 @@ -998,7 +1009,7 @@ urllib3==1.26.18 # responses # rockset # snowflake-connector-python -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/requirements/py3.9-requirements.txt b/sdk/python/requirements/py3.9-requirements.txt index f2228ade02..43b0191ed4 100644 --- a/sdk/python/requirements/py3.9-requirements.txt +++ b/sdk/python/requirements/py3.9-requirements.txt @@ -63,7 +63,7 @@ importlib-metadata==6.11.0 # dask # feast (setup.py) # typeguard -importlib-resources==6.3.1 +importlib-resources==6.3.2 # via feast (setup.py) jinja2==3.1.3 # via feast (setup.py) @@ -178,7 +178,7 @@ tzdata==2024.1 # via pandas urllib3==2.2.1 # via requests -uvicorn[standard]==0.28.0 +uvicorn[standard]==0.29.0 # via feast (setup.py) uvloop==0.19.0 # via uvicorn diff --git a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py new file mode 100644 index 0000000000..a73d4451a5 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py @@ -0,0 +1,88 @@ +from datetime import datetime, timedelta +import ibis +import pyarrow as pa +from typing import List, Tuple, Dict +from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join +from pprint import pprint + +def pa_datetime(year, month, day): + return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC')) + +def customer_table(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 2]), + pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)]) + ], + names=['customer_id', 'event_timestamp'] + ) + +def features_table_1(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 1, 2]), + pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]), + pa.array([11, 22, 33, 22]) + ], + names=['customer_id', 'event_timestamp', 'feature1'] + ) + +def point_in_time_join_brute( + entity_table: pa.Table, + feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col = 'event_timestamp' +): + ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names] + + from operator import itemgetter + ret = entity_table.to_pydict() + batch_dict = entity_table.to_pydict() + + for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]): + for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables: + if i == 0: + ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key]) + + def check_equality(ft_dict, batch_dict, x, y): + return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]) + + ft_dict = feature_table.to_pydict() + found_matches = [ + (j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows) + if check_equality(ft_dict, batch_dict, j, i) and + ft_dict[timestamp_key][j] <= row_timestmap and + ft_dict[timestamp_key][j] >= row_timestmap - ttl + ] + + index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None + for col in ft_dict.keys(): + if col not in feature_refs: + continue + + if col not in ret: + ret[col] = [] + + if index_found is not None: + ret[col].append(ft_dict[col][index_found]) + else: + ret[col].append(None) + + return pa.Table.from_pydict(ret, schema=pa.schema(ret_fields)) + + +def test_point_in_time_join(): + expected = point_in_time_join_brute( + customer_table(), + feature_tables=[ + (features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) + ] + ) + + actual = point_in_time_join( + ibis.memtable(customer_table()), + feature_tables=[ + (ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) + ] + ).to_pyarrow() + + assert actual.equals(expected) diff --git a/setup.py b/setup.py index b32d03ed77..ca89b09bf6 100644 --- a/setup.py +++ b/setup.py @@ -211,6 +211,7 @@ + HAZELCAST_REQUIRED + IBIS_REQUIRED + GRPCIO_REQUIRED + + DUCKDB_REQUIRED ) DOCS_REQUIRED = CI_REQUIRED