Skip to content

Commit

Permalink
perf: Implement dynamo write_batch_async (#4675)
Browse files Browse the repository at this point in the history
* rebase

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* offline store init doesnt make sense

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* dont init or close

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* update test to handle event loop for dynamo case

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* use run util complete

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* fix: spelling sigh

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* run integration test as async since that is default for read

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* add pytest async to ci reqs

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* be safe w cleanup in test fixture

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* be safe w cleanup in test fixture

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* update pytest ini

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* not in a finally

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* remove close

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* test client is a lifespan aware context manager

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* add async writer for dynamo

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* fix dynamo client put item format

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* clarify documentation

Signed-off-by: Rob Howley <howley.robert@gmail.com>

* add deduplication to async dynamo write

Signed-off-by: Rob Howley <howley.robert@gmail.com>

---------

Signed-off-by: Rob Howley <howley.robert@gmail.com>
  • Loading branch information
robhowley authored Oct 24, 2024
1 parent d95ed18 commit ba4404c
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 21 deletions.
97 changes: 84 additions & 13 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import itertools
import logging
from collections import OrderedDict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

Expand All @@ -26,6 +27,7 @@
from feast.infra.online_stores.helpers import compute_entity_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.supported_async_methods import SupportedAsyncMethods
from feast.infra.utils.aws_utils import dynamo_write_items_async
from feast.protos.feast.core.DynamoDBTable_pb2 import (
DynamoDBTable as DynamoDBTableProto,
)
Expand Down Expand Up @@ -103,7 +105,7 @@ async def close(self):

@property
def async_supported(self) -> SupportedAsyncMethods:
return SupportedAsyncMethods(read=True)
return SupportedAsyncMethods(read=True, write=True)

def update(
self,
Expand Down Expand Up @@ -238,6 +240,42 @@ def online_write_batch(
)
self._write_batch_non_duplicates(table_instance, data, progress, config)

async def online_write_batch_async(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
"""
Writes a batch of feature rows to the online store asynchronously.
If a tz-naive timestamp is passed to this method, it is assumed to be UTC.
Args:
config: The config for the current feature store.
table: Feature view to which these feature rows correspond.
data: A list of quadruplets containing feature data. Each quadruplet contains an entity
key, a dict containing feature values, an event timestamp for the row, and the created
timestamp for the row if it exists.
progress: Function to be called once a batch of rows is written to the online store, used
to show progress.
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

table_name = _get_table_name(online_config, config, table)
items = [
_to_client_write_item(config, entity_key, features, timestamp)
for entity_key, features, timestamp, _ in _latest_data_to_write(data)
]
client = await _get_aiodynamodb_client(
online_config.region, config.online_store.max_pool_connections
)
await dynamo_write_items_async(client, table_name, items)

def online_read(
self,
config: RepoConfig,
Expand Down Expand Up @@ -419,19 +457,10 @@ def _write_batch_non_duplicates(
"""Deduplicate write batch request items on ``entity_id`` primary key."""
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
batch.put_item(
Item={
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}
Item=_to_resource_write_item(
config, entity_key, features, timestamp
)
)
if progress:
progress(1)
Expand Down Expand Up @@ -675,3 +704,45 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
region, endpoint_url
)
return self._dynamodb_resource


def _to_resource_write_item(config, entity_key, features, timestamp):
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
return {
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}


def _to_client_write_item(config, entity_key, features, timestamp):
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
return {
"entity_id": {"S": entity_id}, # PartitionKey
"event_ts": {"S": str(utils.make_tzaware(timestamp))},
"values": {
"M": {
k: {"B": v.SerializeToString()}
for k, v in features.items() # Serialized Features
}
},
}


def _latest_data_to_write(
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
):
as_hashable = ((d[0].SerializeToString(), d) for d in data)
sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2]))
return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items())
64 changes: 64 additions & 0 deletions sdk/python/feast/infra/utils/aws_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import contextlib
import itertools
import os
import tempfile
import uuid
Expand All @@ -10,6 +12,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
from tenacity import (
AsyncRetrying,
retry,
retry_if_exception_type,
stop_after_attempt,
Expand Down Expand Up @@ -1076,3 +1079,64 @@ def upload_arrow_table_to_athena(
# Clean up S3 temporary data
# for file_path in uploaded_files:
# s3_resource.Object(bucket, file_path).delete()


class DynamoUnprocessedWriteItems(Exception):
pass


async def dynamo_write_items_async(
dynamo_client, table_name: str, items: list[dict]
) -> None:
"""
Writes in batches to a dynamo table asynchronously. Max size of each
attempted batch is 25.
Raises DynamoUnprocessedWriteItems if not all items can be written.
Args:
dynamo_client: async dynamodb client
table_name: name of table being written to
items: list of items to be written. see boto3 docs on format of the items.
"""
DYNAMO_MAX_WRITE_BATCH_SIZE = 25

async def _do_write(items):
item_iter = iter(items)
item_batches = []
while True:
item_batch = [
item
for item in itertools.islice(item_iter, DYNAMO_MAX_WRITE_BATCH_SIZE)
]
if not item_batch:
break

item_batches.append(item_batch)

return await asyncio.gather(
*[
dynamo_client.batch_write_item(
RequestItems={table_name: item_batch},
)
for item_batch in item_batches
]
)

put_items = [{"PutRequest": {"Item": item}} for item in items]

retries = AsyncRetrying(
retry=retry_if_exception_type(DynamoUnprocessedWriteItems),
wait=wait_exponential(multiplier=1, max=4),
reraise=True,
)

async for attempt in retries:
with attempt:
response_batches = await _do_write(put_items)

put_items = []
for response in response_batches:
put_items.extend(response["UnprocessedItems"])

if put_items:
raise DynamoUnprocessedWriteItems()
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,51 @@
from tests.integration.feature_repos.universal.entities import location


@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push_features_and_read(environment, universal_data_sources):
@pytest.fixture
def store(environment, universal_data_sources):
store = environment.feature_store
_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
location_fv = feature_views.pushed_locations
store.apply([location(), location_fv])
return store


def _ingest_df():
data = {
"location_id": [1],
"temperature": [4],
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
"created": [pd.Timestamp(_utc_now()).round("ms")],
}
df_ingest = pd.DataFrame(data)
return pd.DataFrame(data)

store.push("location_stats_push_source", df_ingest)

def assert_response(online_resp):
online_resp_dict = online_resp.to_dict()
assert online_resp_dict["location_id"] == [1]
assert online_resp_dict["temperature"] == [4]


@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push_features_and_read(store):
store.push("location_stats_push_source", _ingest_df())

online_resp = store.get_online_features(
features=["pushable_location_stats:temperature"],
entity_rows=[{"location_id": 1}],
)
online_resp_dict = online_resp.to_dict()
assert online_resp_dict["location_id"] == [1]
assert online_resp_dict["temperature"] == [4]
assert_response(online_resp)


@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["dynamodb"])
async def test_push_features_and_read_async(store):
await store.push_async("location_stats_push_source", _ingest_df())

online_resp = await store.get_online_features_async(
features=["pushable_location_stats:temperature"],
entity_rows=[{"location_id": 1}],
)
assert_response(online_resp)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime

import boto3
import pytest
Expand All @@ -10,6 +11,7 @@
DynamoDBOnlineStore,
DynamoDBOnlineStoreConfig,
DynamoDBTable,
_latest_data_to_write,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -358,3 +360,24 @@ def test_dynamodb_online_store_online_read_unknown_entity_end_of_batch(
# ensure the entity is not dropped
assert len(returned_items) == len(entity_keys)
assert returned_items[-1] == (None, None)


def test_batch_write_deduplication():
def to_ek_proto(val):
return EntityKeyProto(
join_keys=["customer"], entity_values=[ValueProto(string_val=val)]
)

# is out of order and has duplicate keys
data = [
(to_ek_proto("key-1"), {}, datetime(2024, 1, 1), None),
(to_ek_proto("key-2"), {}, datetime(2024, 1, 1), None),
(to_ek_proto("key-1"), {}, datetime(2024, 1, 3), None),
(to_ek_proto("key-1"), {}, datetime(2024, 1, 2), None),
(to_ek_proto("key-3"), {}, datetime(2024, 1, 2), None),
]

# assert we only keep the most recent record per key
actual = list(_latest_data_to_write(data))
expected = [data[2], data[1], data[4]]
assert expected == actual

0 comments on commit ba4404c

Please sign in to comment.