Skip to content

Commit

Permalink
feat: refactor ibis point-in-time-join
Browse files Browse the repository at this point in the history
Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko committed Mar 21, 2024
1 parent 6d9156b commit f652bd4
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 138 deletions.
231 changes: 115 additions & 116 deletions sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range(

return entity_df_event_timestamp_range

@staticmethod
def _get_historical_features_one(
feature_view: FeatureView,
entity_table: Table,
feature_refs: List[str],
full_feature_names: bool,
timestamp_range: Tuple,
acc_table: Table,
event_timestamp_col: str,
) -> Table:
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)

for old_name, new_name in feature_view.batch_source.field_mapping.items():
if old_name in fv_table.columns:
fv_table = fv_table.rename({new_name: old_name})

timestamp_field = feature_view.batch_source.timestamp_field

# TODO mutate only if tz-naive
fv_table = fv_table.mutate(
**{
timestamp_field: fv_table[timestamp_field].cast(
dt.Timestamp(timezone="UTC")
)
}
)

full_name_prefix = feature_view.projection.name_alias or feature_view.name

feature_refs = [
fr.split(":")[1]
for fr in feature_refs
if fr.startswith(f"{full_name_prefix}:")
]

timestamp_range_start_minus_ttl = (
timestamp_range[0] - feature_view.ttl
if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0)
else timestamp_range[0]
)

timestamp_range_start_minus_ttl = ibis.literal(
timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f")
).cast(dt.Timestamp(timezone="UTC"))

timestamp_range_end = ibis.literal(
timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f")
).cast(dt.Timestamp(timezone="UTC"))

fv_table = fv_table.filter(
ibis.and_(
fv_table[timestamp_field] <= timestamp_range_end,
fv_table[timestamp_field] >= timestamp_range_start_minus_ttl,
)
)

# join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns}
# predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()]

if feature_view.projection.join_key_map:
predicates = [
fv_table[k] == entity_table[v]
for k, v in feature_view.projection.join_key_map.items()
]
else:
predicates = [
fv_table[e.name] == entity_table[e.name]
for e in feature_view.entity_columns
]

predicates.append(
fv_table[timestamp_field] <= entity_table[event_timestamp_col]
)

fv_table = fv_table.inner_join(
entity_table, predicates, lname="", rname="{name}_y"
)

fv_table = (
fv_table.group_by(by="entity_row_id")
.order_by(ibis.desc(fv_table[timestamp_field]))
.mutate(rn=ibis.row_number())
)

fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0))

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
fv_table = fv_table.select(select_cols)

if full_feature_names:
fv_table = fv_table.rename(
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
)

acc_table = acc_table.left_join(
fv_table,
predicates=[fv_table.entity_row_id == acc_table.entity_row_id],
lname="",
rname="{name}_yyyy",
)

acc_table = acc_table.drop(s.endswith("_yyyy"))

return acc_table

@staticmethod
def _to_utc(entity_df: pd.DataFrame, event_timestamp_col):
entity_df_event_timestamp = entity_df.loc[
Expand Down Expand Up @@ -228,30 +122,67 @@ def get_historical_features(
entity_schema=entity_schema,
)

# TODO get range with ibis
timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range(
entity_df, event_timestamp_col
)

entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col)

entity_table = ibis.memtable(entity_df)
entity_table = IbisOfflineStore._generate_row_id(
entity_table, feature_views, event_timestamp_col
)

res: Table = entity_table
def read_fv(feature_view, feature_refs, full_feature_names):
fv_table: Table = ibis.read_parquet(feature_view.batch_source.name)

for fv in feature_views:
res = IbisOfflineStore._get_historical_features_one(
fv,
entity_table,
for old_name, new_name in feature_view.batch_source.field_mapping.items():
if old_name in fv_table.columns:
fv_table = fv_table.rename({new_name: old_name})

timestamp_field = feature_view.batch_source.timestamp_field

# TODO mutate only if tz-naive
fv_table = fv_table.mutate(
**{
timestamp_field: fv_table[timestamp_field].cast(
dt.Timestamp(timezone="UTC")
)
}
)

full_name_prefix = feature_view.projection.name_alias or feature_view.name

feature_refs = [
fr.split(":")[1]
for fr in feature_refs
if fr.startswith(f"{full_name_prefix}:")
]

if full_feature_names:
fv_table = fv_table.rename(
{f"{full_name_prefix}__{feature}": feature for feature in feature_refs}
)

feature_refs = [f"{full_name_prefix}__{feature}" for feature in feature_refs]

return (
fv_table,
feature_view.batch_source.timestamp_field,
feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns},
feature_refs,
full_feature_names,
timestamp_range,
res,
event_timestamp_col,
feature_view.ttl
)

res = res.drop("entity_row_id")
res = point_in_time_join(
entity_table=entity_table,
feature_tables=[
read_fv(feature_view, feature_refs, full_feature_names)
for feature_view in feature_views
],
event_timestamp_col=event_timestamp_col
)

return IbisRetrievalJob(
res,
Expand Down Expand Up @@ -285,6 +216,10 @@ def pull_all_from_table_or_query(

table = table.select(*fields)

# TODO get rid of this fix
if '__log_date' in table.columns:
table = table.drop('__log_date')

table = table.filter(
ibis.and_(
table[timestamp_field] >= ibis.literal(start_date),
Expand Down Expand Up @@ -320,6 +255,7 @@ def write_logged_features(
else:
kwargs = {}

#TODO always write to directory
table.to_parquet(
f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs
)
Expand Down Expand Up @@ -405,3 +341,66 @@ def persist(
@property
def metadata(self) -> Optional[RetrievalMetadata]:
return self._metadata


def point_in_time_join(
entity_table: Table,
feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]],
event_timestamp_col = 'event_timestamp'
):
#TODO handle ttl
all_entities = [event_timestamp_col]
for feature_table, timestamp_field, join_key_map, _, _ in feature_tables:
all_entities.extend(join_key_map.values())

r = ibis.literal("")

for e in set(all_entities):
r = r.concat(entity_table[e].cast("string")) # type: ignore

entity_table = entity_table.mutate(entity_row_id=r)

acc_table = entity_table

for feature_table, timestamp_field, join_key_map, feature_refs, ttl in feature_tables:
predicates = [feature_table[k] == entity_table[v] for k, v in join_key_map.items()]

predicates.append(
feature_table[timestamp_field] <= entity_table[event_timestamp_col],
)

if ttl:
predicates.append(
feature_table[timestamp_field] >= entity_table[event_timestamp_col] - ibis.literal(ttl)
)

feature_table = feature_table.inner_join(
entity_table, predicates, lname="", rname="{name}_y"
)

feature_table = feature_table.drop(s.endswith("_y"))

feature_table = (
feature_table.group_by(by="entity_row_id")
.order_by(ibis.desc(feature_table[timestamp_field]))
.mutate(rn=ibis.row_number())
)

feature_table = feature_table.filter(feature_table["rn"] == ibis.literal(0)).drop("rn")

select_cols = ["entity_row_id"]
select_cols.extend(feature_refs)
feature_table = feature_table.select(select_cols)

acc_table = acc_table.left_join(
feature_table,
predicates=[feature_table.entity_row_id == acc_table.entity_row_id],
lname="",
rname="{name}_yyyy",
)

acc_table = acc_table.drop(s.endswith("_yyyy"))

acc_table = acc_table.drop('entity_row_id')

return acc_table
29 changes: 20 additions & 9 deletions sdk/python/requirements/py3.10-ci-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ black==22.12.0
# via feast (setup.py)
bleach==6.1.0
# via nbconvert
boto3==1.34.65
boto3==1.34.67
# via
# feast (setup.py)
# moto
botocore==1.34.65
botocore==1.34.67
# via
# boto3
# moto
Expand All @@ -82,7 +82,7 @@ cachecontrol==0.14.0
# via firebase-admin
cachetools==5.3.3
# via google-auth
cassandra-driver==3.29.0
cassandra-driver==3.29.1
# via feast (setup.py)
certifi==2024.2.2
# via
Expand Down Expand Up @@ -164,6 +164,12 @@ docker==7.0.0
# testcontainers
docutils==0.19
# via sphinx
duckdb==0.10.1
# via
# duckdb-engine
# ibis-framework
duckdb-engine==0.11.2
# via ibis-framework
entrypoints==0.4
# via altair
exceptiongroup==1.2.0
Expand Down Expand Up @@ -213,7 +219,7 @@ google-api-core[grpc]==2.17.1
# google-cloud-storage
google-api-python-client==2.122.0
# via firebase-admin
google-auth==2.28.2
google-auth==2.29.0
# via
# google-api-core
# google-api-python-client
Expand Down Expand Up @@ -258,7 +264,7 @@ googleapis-common-protos[grpc]==1.63.0
# google-api-core
# grpc-google-iam-v1
# grpcio-status
great-expectations==0.18.11
great-expectations==0.18.12
# via feast (setup.py)
greenlet==3.0.3
# via sqlalchemy
Expand Down Expand Up @@ -310,7 +316,7 @@ httpx==0.27.0
# via
# feast (setup.py)
# jupyterlab
ibis-framework==8.0.0
ibis-framework[duckdb]==8.0.0
# via
# feast (setup.py)
# ibis-substrait
Expand All @@ -331,7 +337,7 @@ importlib-metadata==6.11.0
# via
# dask
# feast (setup.py)
importlib-resources==6.3.1
importlib-resources==6.3.2
# via feast (setup.py)
iniconfig==2.0.0
# via pytest
Expand Down Expand Up @@ -459,7 +465,7 @@ moreorless==0.4.0
# via bowler
moto==4.2.14
# via feast (setup.py)
msal==1.27.0
msal==1.28.0
# via
# azure-identity
# msal-extensions
Expand Down Expand Up @@ -844,8 +850,13 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
sqlalchemy[mypy]==1.4.52
# via
# duckdb-engine
# feast (setup.py)
# ibis-framework
# sqlalchemy
# sqlalchemy-views
sqlalchemy-views==0.3.2
# via ibis-framework
sqlalchemy2-stubs==0.0.2a38
# via sqlalchemy
sqlglot==20.11.0
Expand Down Expand Up @@ -984,7 +995,7 @@ urllib3==1.26.18
# requests
# responses
# rockset
uvicorn[standard]==0.28.0
uvicorn[standard]==0.29.0
# via feast (setup.py)
uvloop==0.19.0
# via uvicorn
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/requirements/py3.10-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ importlib-metadata==6.11.0
# via
# dask
# feast (setup.py)
importlib-resources==6.3.1
importlib-resources==6.3.2
# via feast (setup.py)
jinja2==3.1.3
# via feast (setup.py)
Expand Down Expand Up @@ -176,7 +176,7 @@ tzdata==2024.1
# via pandas
urllib3==2.2.1
# via requests
uvicorn[standard]==0.28.0
uvicorn[standard]==0.29.0
# via feast (setup.py)
uvloop==0.19.0
# via uvicorn
Expand Down
Loading

0 comments on commit f652bd4

Please sign in to comment.