Skip to content

Commit

Permalink
eventtime check before writing features, use pipelines, ttl (#1961)
Browse files Browse the repository at this point in the history
* eventtime check before writing features, use pipelines, ttl

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* redis write optimizations and event time check

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* small fixes and test

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* formatting

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* typing fix

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* formatting, comments, test

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* formatting

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* test fixes for online store write order

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* comment on test

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>

* remove commented out tests for now

Signed-off-by: Vitaly Sergeyev <vsergeyev@better.com>
  • Loading branch information
Vitaly Sergeyev authored Nov 1, 2021
1 parent 3e42fb3 commit 600d38e
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 74 deletions.
143 changes: 99 additions & 44 deletions sdk/python/feast/infra/online_stores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
import logging
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
ByteString,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Union,
)

from google.protobuf.timestamp_pb2 import Timestamp
from pydantic import StrictStr
Expand All @@ -36,7 +46,6 @@

raise FeastExtrasDependencyImportError("redis", str(e))

EX_SECONDS = 253402300799
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -166,26 +175,53 @@ def online_write_batch(
client = self._get_client(online_store_config)
project = config.project

entity_hset = {}
feature_view = table.name

ex = Timestamp()
ex.seconds = EX_SECONDS
ex_str = ex.SerializeToString()
for entity_key, values, timestamp, created_ts in data:
redis_key_bin = _redis_key(project, entity_key)
ts = Timestamp()
ts.seconds = int(utils.make_tzaware(timestamp).timestamp())
entity_hset[f"_ts:{feature_view}"] = ts.SerializeToString()
entity_hset[f"_ex:{feature_view}"] = ex_str

for feature_name, val in values.items():
f_key = _mmh3(f"{feature_view}:{feature_name}")
entity_hset[f_key] = val.SerializeToString()

client.hset(redis_key_bin, mapping=entity_hset)
ts_key = f"_ts:{feature_view}"
keys = []
# redis pipelining optimization: send multiple commands to redis server without waiting for every reply
with client.pipeline() as pipe:
# check if a previous record under the key bin exists
# TODO: investigate if check and set is a better approach rather than pulling all entity ts and then setting
# it may be significantly slower but avoids potential (rare) race conditions
for entity_key, _, _, _ in data:
redis_key_bin = _redis_key(project, entity_key)
keys.append(redis_key_bin)
pipe.hmget(redis_key_bin, ts_key)
prev_event_timestamps = pipe.execute()
# flattening the list of lists. `hmget` does the lookup assuming a list of keys in the key bin
prev_event_timestamps = [i[0] for i in prev_event_timestamps]

for redis_key_bin, prev_event_time, (_, values, timestamp, _) in zip(
keys, prev_event_timestamps, data
):
event_time_seconds = int(utils.make_tzaware(timestamp).timestamp())

# ignore if event_timestamp is before the event features that are currently in the feature store
if prev_event_time:
prev_ts = Timestamp()
prev_ts.ParseFromString(prev_event_time)
if prev_ts.seconds and event_time_seconds <= prev_ts.seconds:
# TODO: somehow signal that it's not overwriting the current record?
if progress:
progress(1)
continue

ts = Timestamp()
ts.seconds = event_time_seconds
entity_hset = dict()
entity_hset[ts_key] = ts.SerializeToString()

for feature_name, val in values.items():
f_key = _mmh3(f"{feature_view}:{feature_name}")
entity_hset[f_key] = val.SerializeToString()

pipe.hset(redis_key_bin, mapping=entity_hset)
# TODO: support expiring the entity / features in Redis
# otherwise entity features remain in redis until cleaned up in separate process
# client.expire redis_key_bin based a ttl setting
results = pipe.execute()
if progress:
progress(1)
progress(len(results))

def online_read(
self,
Expand All @@ -206,30 +242,49 @@ def online_read(
if not requested_features:
requested_features = [f.name for f in table.features]

hset_keys = [_mmh3(f"{feature_view}:{k}") for k in requested_features]

ts_key = f"_ts:{feature_view}"
hset_keys.append(ts_key)
requested_features.append(ts_key)

keys = []
for entity_key in entity_keys:
redis_key_bin = _redis_key(project, entity_key)
hset_keys = [_mmh3(f"{feature_view}:{k}") for k in requested_features]
ts_key = f"_ts:{feature_view}"
hset_keys.append(ts_key)
values = client.hmget(redis_key_bin, hset_keys)
requested_features.append(ts_key)
res_val = dict(zip(requested_features, values))

res_ts = Timestamp()
ts_val = res_val.pop(ts_key)
if ts_val:
res_ts.ParseFromString(ts_val)

res = {}
for feature_name, val_bin in res_val.items():
val = ValueProto()
if val_bin:
val.ParseFromString(val_bin)
res[feature_name] = val

if not res:
result.append((None, None))
else:
timestamp = datetime.fromtimestamp(res_ts.seconds)
result.append((timestamp, res))
keys.append(redis_key_bin)
with client.pipeline() as pipe:
for redis_key_bin in keys:
pipe.hmget(redis_key_bin, hset_keys)
redis_values = pipe.execute()
for values in redis_values:
features = self._get_features_for_entity(
values, feature_view, requested_features
)
result.append(features)
return result

def _get_features_for_entity(
self,
values: List[ByteString],
feature_view: str,
requested_features: List[str],
) -> Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]:
res_val = dict(zip(requested_features, values))

res_ts = Timestamp()
ts_val = res_val.pop(f"_ts:{feature_view}")
if ts_val:
res_ts.ParseFromString(ts_val)

res = {}
for feature_name, val_bin in res_val.items():
val = ValueProto()
if val_bin:
val.ParseFromString(val_bin)
res[feature_name] = val

if not res:
return None, None
else:
timestamp = datetime.fromtimestamp(res_ts.seconds)
return timestamp, res
12 changes: 12 additions & 0 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
from _pytest.nodes import Item

from tests.data.data_creator import create_dataset
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.repo_configuration import (
FULL_REPO_CONFIGS,
REDIS_CONFIG,
Environment,
construct_test_environment,
construct_universal_data_sources,
Expand Down Expand Up @@ -138,6 +142,14 @@ def environment(request):
yield e


@pytest.fixture()
def local_redis_environment():
with construct_test_environment(
IntegrationTestRepoConfig(online_store=REDIS_CONFIG)
) as e:
yield e


@pytest.fixture(scope="session")
def universal_data_sources(environment):
entities = construct_universal_entities()
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/integration/e2e/test_universal_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@pytest.mark.parametrize("infer_features", [True, False])
def test_e2e_consistency(environment, e2e_data_sources, infer_features):
fs = environment.feature_store
fs.config.project = fs.config.project + str(infer_features)
df, data_source = e2e_data_sources
fv = driver_feature_view(data_source=data_source, infer_features=infer_features)

Expand Down
104 changes: 103 additions & 1 deletion sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import pytest

from feast import FeatureService
from feast import Entity, Feature, FeatureService, FeatureView, ValueType
from feast.errors import (
FeatureNameCollisionError,
RequestDataNotFoundInEntityRowsException,
Expand All @@ -23,6 +23,108 @@
from tests.integration.feature_repos.universal.feature_views import (
create_driver_hourly_stats_feature_view,
)
from tests.utils.data_source_utils import prep_file_source


# TODO: make this work with all universal (all online store types)
@pytest.mark.integration
def test_write_to_online_store_event_check(local_redis_environment):
fs = local_redis_environment.feature_store

# write same data points 3 with different timestamps
now = pd.Timestamp(datetime.datetime.utcnow()).round("ms")
hour_ago = pd.Timestamp(datetime.datetime.utcnow() - timedelta(hours=1)).round("ms")
latest = pd.Timestamp(datetime.datetime.utcnow() + timedelta(seconds=1)).round("ms")

data = {
"id": [123, 567, 890],
"string_col": ["OLD_FEATURE", "LATEST_VALUE2", "LATEST_VALUE3"],
"ts_1": [hour_ago, now, now],
}
dataframe_source = pd.DataFrame(data)
with prep_file_source(
df=dataframe_source, event_timestamp_column="ts_1"
) as file_source:
e = Entity(name="id", value_type=ValueType.STRING)

# Create Feature View
fv1 = FeatureView(
name="feature_view_123",
features=[Feature(name="string_col", dtype=ValueType.STRING)],
entities=["id"],
batch_source=file_source,
ttl=timedelta(minutes=5),
)
# Register Feature View and Entity
fs.apply([fv1, e])

# data to ingest into Online Store (recent)
data = {
"id": [123],
"string_col": ["hi_123"],
"ts_1": [now],
}
df_data = pd.DataFrame(data)

# directly ingest data into the Online Store
fs.write_to_online_store("feature_view_123", df_data)

df = fs.get_online_features(
features=["feature_view_123:string_col"], entity_rows=[{"id": 123}]
).to_df()
assert df["string_col"].iloc[0] == "hi_123"

# data to ingest into Online Store (1 hour delayed data)
# should now overwrite features for id=123 because it's less recent data
data = {
"id": [123, 567, 890],
"string_col": ["bye_321", "hello_123", "greetings_321"],
"ts_1": [hour_ago, hour_ago, hour_ago],
}
df_data = pd.DataFrame(data)

# directly ingest data into the Online Store
fs.write_to_online_store("feature_view_123", df_data)

df = fs.get_online_features(
features=["feature_view_123:string_col"],
entity_rows=[{"id": 123}, {"id": 567}, {"id": 890}],
).to_df()
assert df["string_col"].iloc[0] == "hi_123"
assert df["string_col"].iloc[1] == "hello_123"
assert df["string_col"].iloc[2] == "greetings_321"

# should overwrite string_col for id=123 because it's most recent based on event_timestamp
data = {
"id": [123],
"string_col": ["LATEST_VALUE"],
"ts_1": [latest],
}
df_data = pd.DataFrame(data)

fs.write_to_online_store("feature_view_123", df_data)

df = fs.get_online_features(
features=["feature_view_123:string_col"],
entity_rows=[{"id": 123}, {"id": 567}, {"id": 890}],
).to_df()
assert df["string_col"].iloc[0] == "LATEST_VALUE"
assert df["string_col"].iloc[1] == "hello_123"
assert df["string_col"].iloc[2] == "greetings_321"

# writes to online store via datasource (dataframe_source) materialization
fs.materialize(
start_date=datetime.datetime.now() - timedelta(hours=12),
end_date=datetime.datetime.utcnow(),
)

df = fs.get_online_features(
features=["feature_view_123:string_col"],
entity_rows=[{"id": 123}, {"id": 567}, {"id": 890}],
).to_df()
assert df["string_col"].iloc[0] == "LATEST_VALUE"
assert df["string_col"].iloc[1] == "LATEST_VALUE2"
assert df["string_col"].iloc[2] == "LATEST_VALUE3"


@pytest.mark.integration
Expand Down
29 changes: 0 additions & 29 deletions sdk/python/tests/utils/online_read_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read):
event_ts=time_1, created_ts=time_1, write=(1.1, "3.1"), expect_read=(1.1, "3.1")
)

# Note: This behavior has changed for performance. We should test that older
# value can't overwrite over a newer value once we add the respective flag
""" Values with an older event_ts should overwrite newer ones """
time_2 = datetime.utcnow()
_driver_rw_test(
event_ts=time_1 - timedelta(hours=1),
created_ts=time_2,
write=(-1000, "OLD"),
expect_read=(-1000, "OLD"),
)

""" Values with an new event_ts should overwrite older ones """
time_3 = datetime.utcnow()
_driver_rw_test(
Expand All @@ -86,21 +75,3 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read):
write=(1123, "NEWER"),
expect_read=(1123, "NEWER"),
)

# Note: This behavior has changed for performance. We should test that older
# value can't overwrite over a newer value once we add the respective flag
""" created_ts is used as a tie breaker, using older created_ts here, but we still overwrite """
_driver_rw_test(
event_ts=time_1 + timedelta(hours=1),
created_ts=time_3 - timedelta(hours=1),
write=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
expect_read=(54321, "I HAVE AN OLDER created_ts SO I LOSE"),
)

""" created_ts is used as a tie breaker, using newer created_ts here so we should overwrite """
_driver_rw_test(
event_ts=time_1 + timedelta(hours=1),
created_ts=time_3 + timedelta(hours=1),
write=(96864, "I HAVE A NEWER created_ts SO I WIN"),
expect_read=(96864, "I HAVE A NEWER created_ts SO I WIN"),
)

0 comments on commit 600d38e

Please sign in to comment.