Skip to content

Commit

Permalink
fix formatting, linting
Browse files Browse the repository at this point in the history
Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko committed Mar 21, 2024
1 parent f652bd4 commit fcfe305
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -162,26 +162,32 @@ def read_fv(feature_view, feature_refs, full_feature_names):

if full_feature_names:
fv_table = fv_table.rename(
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
{
f"{full_name_prefix}__{feature}": feature
for feature in feature_refs
}
)

feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs]
feature_refs = [
f"{full_name_prefix}__{feature}" for feature in feature_refs
]

return (
fv_table,
feature_view.batch_source.timestamp_field,
feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns},
feature_view.projection.join_key_map
or {e.name: e.name for e in feature_view.entity_columns},
feature_refs,
feature_view.ttl
feature_view.ttl,
)

res = point_in_time_join(
entity_table=entity_table,
feature_tables=[
feature_tables=[
read_fv(feature_view, feature_refs, full_feature_names)
for feature_view in feature_views
],
event_timestamp_col=event_timestamp_col
event_timestamp_col=event_timestamp_col,
)

return IbisRetrievalJob(
Expand Down Expand Up @@ -217,8 +223,8 @@ def pull_all_from_table_or_query(
table = table.select(*fields)

# TODO get rid of this fix
if '__log_date' in table.columns:
table = table.drop('__log_date')
if "__log_date" in table.columns:
table = table.drop("__log_date")

table = table.filter(
ibis.and_(
Expand Down Expand Up @@ -255,7 +261,7 @@ def write_logged_features(
else:
kwargs = {}

#TODO always write to directory
# TODO always write to directory
table.to_parquet(
f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs
)
Expand Down Expand Up @@ -346,9 +352,9 @@ def metadata(self) -> Optional[RetrievalMetadata]:
def point_in_time_join(
entity_table: Table,
feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col = 'event_timestamp'
event_timestamp_col="event_timestamp",
):
#TODO handle ttl
# TODO handle ttl
all_entities = [event_timestamp_col]
for feature_table, timestamp_field, join_key_map, _, _ in feature_tables:
all_entities.extend(join_key_map.values())
Expand All @@ -362,16 +368,25 @@ def point_in_time_join(

acc_table = entity_table

for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables:
predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()]
for (
feature_table,
timestamp_field,
join_key_map,
feature_refs,
ttl,
) in feature_tables:
predicates = [
feature_table[k] == entity_table[v] for k, v in join_key_map.items()
]

predicates.append(
feature_table[timestamp_field] <= entity_table[event_timestamp_col],
)

if ttl:
predicates.append(
feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl)
feature_table[timestamp_field]
>= entity_table[event_timestamp_col] - ibis.literal(ttl)
)

feature_table = feature_table.inner_join(
Expand All @@ -386,7 +401,9 @@ def point_in_time_join(
.mutate(rn=ibis.row_number())
)

feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn")
feature_table = feature_table.filter(
feature_table["rn"] == ibis.literal(0)
).drop("rn")

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
Expand All @@ -401,6 +418,6 @@ def point_in_time_join(

acc_table = acc_table.drop(s.endswith("_yyyy"))

acc_table = acc_table.drop('entity_row_id')
acc_table = acc_table.drop("entity_row_id")

return acc_table
return acc_table
98 changes: 74 additions & 24 deletions sdk/python/tests/unit/infra/offline_stores/test_ibis.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,105 @@
from datetime import datetime, timedelta
from typing import Dict, List, Tuple

import ibis
import pyarrow as pa
from typing import List, Tuple, Dict
from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import point_in_time_join
from pprint import pprint

from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import (
point_in_time_join,
)


def pa_datetime(year, month, day):
return pa.scalar(datetime(year, month, day), type=pa.timestamp('s', tz='UTC'))
return pa.scalar(datetime(year, month, day), type=pa.timestamp("s", tz="UTC"))


def customer_table():
return pa.Table.from_arrays(
arrays=[
pa.array([1, 1, 2]),
pa.array([pa_datetime(2024, 1, 1),pa_datetime(2024, 1, 2),pa_datetime(2024, 1, 1)])
pa.array(
[
pa_datetime(2024, 1, 1),
pa_datetime(2024, 1, 2),
pa_datetime(2024, 1, 1),
]
),
],
names=['customer_id', 'event_timestamp']
names=["customer_id", "event_timestamp"],
)


def features_table_1():
return pa.Table.from_arrays(
arrays=[
pa.array([1, 1, 1, 2]),
pa.array([pa_datetime(2023, 12, 31), pa_datetime(2024, 1, 2), pa_datetime(2024, 1, 3), pa_datetime(2023, 1, 3)]),
pa.array([11, 22, 33, 22])
],
names=['customer_id', 'event_timestamp', 'feature1']
pa.array(
[
pa_datetime(2023, 12, 31),
pa_datetime(2024, 1, 2),
pa_datetime(2024, 1, 3),
pa_datetime(2023, 1, 3),
]
),
pa.array([11, 22, 33, 22]),
],
names=["customer_id", "event_timestamp", "feature1"],
)


def point_in_time_join_brute(
entity_table: pa.Table,
feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col = 'event_timestamp'
event_timestamp_col="event_timestamp",
):
ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names]

from operator import itemgetter

ret = entity_table.to_pydict()
batch_dict = entity_table.to_pydict()

for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]):
for feature_table, timestamp_key, join_key_map, feature_refs, ttl in feature_tables:
for (
feature_table,
timestamp_key,
join_key_map,
feature_refs,
ttl,
) in feature_tables:
if i == 0:
ret_fields.extend([feature_table.schema.field(f) for f in feature_table.schema.names if f not in join_key_map.values() and f != timestamp_key])
ret_fields.extend(
[
feature_table.schema.field(f)
for f in feature_table.schema.names
if f not in join_key_map.values() and f != timestamp_key
]
)

def check_equality(ft_dict, batch_dict, x, y):
return all([ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()])
return all(
[ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()]
)

ft_dict = feature_table.to_pydict()
found_matches = [
(j, ft_dict[timestamp_key][j]) for j in range(entity_table.num_rows)
if check_equality(ft_dict, batch_dict, j, i) and
ft_dict[timestamp_key][j] <= row_timestmap and
ft_dict[timestamp_key][j] >= row_timestmap - ttl
(j, ft_dict[timestamp_key][j])
for j in range(entity_table.num_rows)
if check_equality(ft_dict, batch_dict, j, i)
and ft_dict[timestamp_key][j] <= row_timestmap
and ft_dict[timestamp_key][j] >= row_timestmap - ttl
]

index_found = max(found_matches, key=itemgetter(1))[0] if found_matches else None
index_found = (
max(found_matches, key=itemgetter(1))[0] if found_matches else None
)
for col in ft_dict.keys():
if col not in feature_refs:
continue

if col not in ret:
ret[col] = []

if index_found is not None:
ret[col].append(ft_dict[col][index_found])
else:
Expand All @@ -74,15 +112,27 @@ def test_point_in_time_join():
expected = point_in_time_join_brute(
customer_table(),
feature_tables=[
(features_table_1(), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10))
]
(
features_table_1(),
"event_timestamp",
{"customer_id": "customer_id"},
["feature1"],
timedelta(days=10),
)
],
)

actual = point_in_time_join(
ibis.memtable(customer_table()),
feature_tables=[
(ibis.memtable(features_table_1()), 'event_timestamp', {'customer_id': 'customer_id'}, ['feature1'], timedelta(days=10))
]
(
ibis.memtable(features_table_1()),
"event_timestamp",
{"customer_id": "customer_id"},
["feature1"],
timedelta(days=10),
)
],
).to_pyarrow()

assert actual.equals(expected)

0 comments on commit fcfe305

Please sign in to comment.