diff --git a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py index cb35cc083e..8787d70158 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py @@ -162,26 +162,32 @@ def read_fv(feature_view, feature_refs, full_feature_names): if full_feature_names: fv_table = fv_table.rename( - {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} + { + f"{full_name_prefix}__{feature}": feature + for feature in feature_refs + } ) - feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs] + feature_refs = [ + f"{full_name_prefix}__{feature}" for feature in feature_refs + ] return ( fv_table, feature_view.batch_source.timestamp_field, - feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}, + feature_view.projection.join_key_map + or {e.name: e.name for e in feature_view.entity_columns}, feature_refs, - feature_view.ttl + feature_view.ttl, ) res = point_in_time_join( entity_table=entity_table, - feature_tables=[ + feature_tables=[ read_fv(feature_view, feature_refs, full_feature_names) for feature_view in feature_views ], - event_timestamp_col=event_timestamp_col + event_timestamp_col=event_timestamp_col, ) return IbisRetrievalJob( @@ -217,8 +223,8 @@ def pull_all_from_table_or_query( table = table.select(*fields) # TODO get rid of this fix - if '__log_date' in table.columns: - table = table.drop('__log_date') + if "__log_date" in table.columns: + table = table.drop("__log_date") table = table.filter( ibis.and_( @@ -255,7 +261,7 @@ def write_logged_features( else: kwargs = {} - #TODO always write to directory + # TODO always write to directory table.to_parquet( f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs ) @@ -346,9 +352,9 @@ def metadata(self) -> Optional[RetrievalMetadata]: def point_in_time_join( entity_table: Table, feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]], - event_timestamp_col = 'event_timestamp' + event_timestamp_col="event_timestamp", ): - #TODO handle ttl + # TODO handle ttl all_entities = [event_timestamp_col] for feature_table, timestamp_field, join_key_map, _, _ in feature_tables: all_entities.extend(join_key_map.values()) @@ -362,8 +368,16 @@ def point_in_time_join( acc_table = entity_table - for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables: - predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()] + for ( + feature_table, + timestamp_field, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: + predicates = [ + feature_table[k] == entity_table[v] for k, v in join_key_map.items() + ] predicates.append( feature_table[timestamp_field] <= entity_table[event_timestamp_col], @@ -371,7 +385,8 @@ def point_in_time_join( if ttl: predicates.append( - feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl) + feature_table[timestamp_field] + >= entity_table[event_timestamp_col] - ibis.literal(ttl) ) feature_table = feature_table.inner_join( @@ -386,7 +401,9 @@ def point_in_time_join( .mutate(rn=ibis.row_number()) ) - feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn") + feature_table = feature_table.filter( + feature_table["rn"] == ibis.literal(0) + ).drop("rn") select_cols = ["entity_row_id"] select_cols.extend(feature_refs) @@ -401,6 +418,6 @@ def point_in_time_join( acc_table = acc_table.drop(s.endswith("_yyyy")) - acc_table = acc_table.drop('entity_row_id') + acc_table = acc_table.drop("entity_row_id") - return acc_table \ No newline at end of file + return acc_table diff --git a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py index a73d4451a5..5f105e2af7 100644 --- a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py +++ b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py @@ -1,67 +1,105 @@ from datetime import datetime, timedelta +from typing import Dict, List, Tuple + import ibis import pyarrow as pa -from typing import List, Tuple, Dict -from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join -from pprint import pprint + +from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import ( + point_in_time_join, +) + def pa_datetime(year, month, day): - return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC')) + return pa.scalar(datetime(year, month, day), type=pa.timestamp("s", tz="UTC")) + def customer_table(): return pa.Table.from_arrays( arrays=[ pa.array([1, 1, 2]), - pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)]) + pa.array( + [ + pa_datetime(2024, 1, 1), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 1), + ] + ), ], - names=['customer_id', 'event_timestamp'] + names=["customer_id", "event_timestamp"], ) + def features_table_1(): return pa.Table.from_arrays( arrays=[ pa.array([1, 1, 1, 2]), - pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]), - pa.array([11, 22, 33, 22]) - ], - names=['customer_id', 'event_timestamp', 'feature1'] + pa.array( + [ + pa_datetime(2023, 12, 31), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 3), + pa_datetime(2023, 1, 3), + ] + ), + pa.array([11, 22, 33, 22]), + ], + names=["customer_id", "event_timestamp", "feature1"], ) + def point_in_time_join_brute( entity_table: pa.Table, feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]], - event_timestamp_col = 'event_timestamp' + event_timestamp_col="event_timestamp", ): ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names] from operator import itemgetter + ret = entity_table.to_pydict() batch_dict = entity_table.to_pydict() for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]): - for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables: + for ( + feature_table, + timestamp_key, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: if i == 0: - ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key]) + ret_fields.extend( + [ + feature_table.schema.field(f) + for f in feature_table.schema.names + if f not in join_key_map.values() and f != timestamp_key + ] + ) def check_equality(ft_dict, batch_dict, x, y): - return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]) + return all( + [ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()] + ) ft_dict = feature_table.to_pydict() found_matches = [ - (j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows) - if check_equality(ft_dict, batch_dict, j, i) and - ft_dict[timestamp_key][j] <= row_timestmap and - ft_dict[timestamp_key][j] >= row_timestmap - ttl + (j, ft_dict[timestamp_key][j]) + for j in range(entity_table.num_rows) + if check_equality(ft_dict, batch_dict, j, i) + and ft_dict[timestamp_key][j] <= row_timestmap + and ft_dict[timestamp_key][j] >= row_timestmap - ttl ] - index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None + index_found = ( + max(found_matches, key=itemgetter(1))[0] if found_matches else None + ) for col in ft_dict.keys(): if col not in feature_refs: continue if col not in ret: ret[col] = [] - + if index_found is not None: ret[col].append(ft_dict[col][index_found]) else: @@ -74,15 +112,27 @@ def test_point_in_time_join(): expected = point_in_time_join_brute( customer_table(), feature_tables=[ - (features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) - ] + ( + features_table_1(), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], ) actual = point_in_time_join( ibis.memtable(customer_table()), feature_tables=[ - (ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10)) - ] + ( + ibis.memtable(features_table_1()), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], ).to_pyarrow() assert actual.equals(expected)