Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Implement dynamo write_batch_async #4675

Merged
merged 19 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 84 additions & 13 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import itertools
import logging
from collections import OrderedDict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union

Expand All @@ -26,6 +27,7 @@
from feast.infra.online_stores.helpers import compute_entity_id
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.supported_async_methods import SupportedAsyncMethods
from feast.infra.utils.aws_utils import dynamo_write_items_async
from feast.protos.feast.core.DynamoDBTable_pb2 import (
DynamoDBTable as DynamoDBTableProto,
)
Expand Down Expand Up @@ -103,7 +105,7 @@ async def close(self):

@property
def async_supported(self) -> SupportedAsyncMethods:
return SupportedAsyncMethods(read=True)
return SupportedAsyncMethods(read=True, write=True)

def update(
self,
Expand Down Expand Up @@ -238,6 +240,42 @@ def online_write_batch(
)
self._write_batch_non_duplicates(table_instance, data, progress, config)

async def online_write_batch_async(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
"""
Writes a batch of feature rows to the online store asynchronously.

If a tz-naive timestamp is passed to this method, it is assumed to be UTC.

Args:
config: The config for the current feature store.
table: Feature view to which these feature rows correspond.
data: A list of quadruplets containing feature data. Each quadruplet contains an entity
key, a dict containing feature values, an event timestamp for the row, and the created
timestamp for the row if it exists.
progress: Function to be called once a batch of rows is written to the online store, used
to show progress.
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

table_name = _get_table_name(online_config, config, table)
items = [
_to_client_write_item(config, entity_key, features, timestamp)
for entity_key, features, timestamp, _ in _latest_data_to_write(data)
]
Comment on lines +270 to +273
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

writing the most recent item per key. since we have to use the client instead of the boto3 resource we don't get built in deduplication of keys

client = await _get_aiodynamodb_client(
online_config.region, config.online_store.max_pool_connections
)
await dynamo_write_items_async(client, table_name, items)

def online_read(
self,
config: RepoConfig,
Expand Down Expand Up @@ -419,19 +457,10 @@ def _write_batch_non_duplicates(
"""Deduplicate write batch request items on ``entity_id`` primary key."""
with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch:
for entity_key, features, timestamp, created_ts in data:
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
batch.put_item(
Item={
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}
Item=_to_resource_write_item(
config, entity_key, features, timestamp
)
)
if progress:
progress(1)
Expand Down Expand Up @@ -675,3 +704,45 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
region, endpoint_url
)
return self._dynamodb_resource


def _to_resource_write_item(config, entity_key, features, timestamp):
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
return {
"entity_id": entity_id, # PartitionKey
"event_ts": str(utils.make_tzaware(timestamp)),
"values": {
k: v.SerializeToString()
for k, v in features.items() # Serialized Features
},
}


def _to_client_write_item(config, entity_key, features, timestamp):
entity_id = compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
return {
"entity_id": {"S": entity_id}, # PartitionKey
"event_ts": {"S": str(utils.make_tzaware(timestamp))},
"values": {
"M": {
k: {"B": v.SerializeToString()}
for k, v in features.items() # Serialized Features
}
},
}


def _latest_data_to_write(
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
):
as_hashable = ((d[0].SerializeToString(), d) for d in data)
sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2]))
return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items())
64 changes: 64 additions & 0 deletions sdk/python/feast/infra/utils/aws_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import contextlib
import itertools
import os
import tempfile
import uuid
Expand All @@ -10,6 +12,7 @@
import pyarrow as pa
import pyarrow.parquet as pq
from tenacity import (
AsyncRetrying,
retry,
retry_if_exception_type,
stop_after_attempt,
Expand Down Expand Up @@ -1076,3 +1079,64 @@ def upload_arrow_table_to_athena(
# Clean up S3 temporary data
# for file_path in uploaded_files:
# s3_resource.Object(bucket, file_path).delete()


class DynamoUnprocessedWriteItems(Exception):
pass


async def dynamo_write_items_async(
dynamo_client, table_name: str, items: list[dict]
) -> None:
"""
Writes in batches to a dynamo table asynchronously. Max size of each
attempted batch is 25.
Raises DynamoUnprocessedWriteItems if not all items can be written.

Args:
dynamo_client: async dynamodb client
table_name: name of table being written to
items: list of items to be written. see boto3 docs on format of the items.
"""
DYNAMO_MAX_WRITE_BATCH_SIZE = 25

async def _do_write(items):
item_iter = iter(items)
item_batches = []
while True:
item_batch = [
item
for item in itertools.islice(item_iter, DYNAMO_MAX_WRITE_BATCH_SIZE)
]
if not item_batch:
break

item_batches.append(item_batch)

return await asyncio.gather(
*[
dynamo_client.batch_write_item(
RequestItems={table_name: item_batch},
)
for item_batch in item_batches
]
)

put_items = [{"PutRequest": {"Item": item}} for item in items]

retries = AsyncRetrying(
retry=retry_if_exception_type(DynamoUnprocessedWriteItems),
wait=wait_exponential(multiplier=1, max=4),
reraise=True,
)

async for attempt in retries:
with attempt:
response_batches = await _do_write(put_items)

put_items = []
for response in response_batches:
put_items.extend(response["UnprocessedItems"])

if put_items:
raise DynamoUnprocessedWriteItems()
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,51 @@
from tests.integration.feature_repos.universal.entities import location


@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push_features_and_read(environment, universal_data_sources):
@pytest.fixture
def store(environment, universal_data_sources):
store = environment.feature_store
_, _, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
location_fv = feature_views.pushed_locations
store.apply([location(), location_fv])
return store


def _ingest_df():
data = {
"location_id": [1],
"temperature": [4],
"event_timestamp": [pd.Timestamp(_utc_now()).round("ms")],
"created": [pd.Timestamp(_utc_now()).round("ms")],
}
df_ingest = pd.DataFrame(data)
return pd.DataFrame(data)

store.push("location_stats_push_source", df_ingest)

def assert_response(online_resp):
online_resp_dict = online_resp.to_dict()
assert online_resp_dict["location_id"] == [1]
assert online_resp_dict["temperature"] == [4]


@pytest.mark.integration
@pytest.mark.universal_online_stores
def test_push_features_and_read(store):
store.push("location_stats_push_source", _ingest_df())

online_resp = store.get_online_features(
features=["pushable_location_stats:temperature"],
entity_rows=[{"location_id": 1}],
)
online_resp_dict = online_resp.to_dict()
assert online_resp_dict["location_id"] == [1]
assert online_resp_dict["temperature"] == [4]
assert_response(online_resp)


@pytest.mark.integration
@pytest.mark.universal_online_stores(only=["dynamodb"])
async def test_push_features_and_read_async(store):
await store.push_async("location_stats_push_source", _ingest_df())

online_resp = await store.get_online_features_async(
features=["pushable_location_stats:temperature"],
entity_rows=[{"location_id": 1}],
)
assert_response(online_resp)
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime

import boto3
import pytest
Expand All @@ -10,6 +11,7 @@
DynamoDBOnlineStore,
DynamoDBOnlineStoreConfig,
DynamoDBTable,
_latest_data_to_write,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
Expand Down Expand Up @@ -358,3 +360,24 @@ def test_dynamodb_online_store_online_read_unknown_entity_end_of_batch(
# ensure the entity is not dropped
assert len(returned_items) == len(entity_keys)
assert returned_items[-1] == (None, None)


def test_batch_write_deduplication():
def to_ek_proto(val):
return EntityKeyProto(
join_keys=["customer"], entity_values=[ValueProto(string_val=val)]
)

# is out of order and has duplicate keys
data = [
(to_ek_proto("key-1"), {}, datetime(2024, 1, 1), None),
(to_ek_proto("key-2"), {}, datetime(2024, 1, 1), None),
(to_ek_proto("key-1"), {}, datetime(2024, 1, 3), None),
(to_ek_proto("key-1"), {}, datetime(2024, 1, 2), None),
(to_ek_proto("key-3"), {}, datetime(2024, 1, 2), None),
]

# assert we only keep the most recent record per key
actual = list(_latest_data_to_write(data))
expected = [data[2], data[1], data[4]]
assert expected == actual
Loading