diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index a97e81bc44..15e8357754 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -15,6 +15,7 @@ import contextlib import itertools import logging +from collections import OrderedDict from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union @@ -26,6 +27,7 @@ from feast.infra.online_stores.helpers import compute_entity_id from feast.infra.online_stores.online_store import OnlineStore from feast.infra.supported_async_methods import SupportedAsyncMethods +from feast.infra.utils.aws_utils import dynamo_write_items_async from feast.protos.feast.core.DynamoDBTable_pb2 import ( DynamoDBTable as DynamoDBTableProto, ) @@ -103,7 +105,7 @@ async def close(self): @property def async_supported(self) -> SupportedAsyncMethods: - return SupportedAsyncMethods(read=True) + return SupportedAsyncMethods(read=True, write=True) def update( self, @@ -238,6 +240,42 @@ def online_write_batch( ) self._write_batch_non_duplicates(table_instance, data, progress, config) + async def online_write_batch_async( + self, + config: RepoConfig, + table: FeatureView, + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], + progress: Optional[Callable[[int], Any]], + ) -> None: + """ + Writes a batch of feature rows to the online store asynchronously. + + If a tz-naive timestamp is passed to this method, it is assumed to be UTC. + + Args: + config: The config for the current feature store. + table: Feature view to which these feature rows correspond. + data: A list of quadruplets containing feature data. Each quadruplet contains an entity + key, a dict containing feature values, an event timestamp for the row, and the created + timestamp for the row if it exists. + progress: Function to be called once a batch of rows is written to the online store, used + to show progress. + """ + online_config = config.online_store + assert isinstance(online_config, DynamoDBOnlineStoreConfig) + + table_name = _get_table_name(online_config, config, table) + items = [ + _to_client_write_item(config, entity_key, features, timestamp) + for entity_key, features, timestamp, _ in _latest_data_to_write(data) + ] + client = await _get_aiodynamodb_client( + online_config.region, config.online_store.max_pool_connections + ) + await dynamo_write_items_async(client, table_name, items) + def online_read( self, config: RepoConfig, @@ -419,19 +457,10 @@ def _write_batch_non_duplicates( """Deduplicate write batch request items on ``entity_id`` primary key.""" with table_instance.batch_writer(overwrite_by_pkeys=["entity_id"]) as batch: for entity_key, features, timestamp, created_ts in data: - entity_id = compute_entity_id( - entity_key, - entity_key_serialization_version=config.entity_key_serialization_version, - ) batch.put_item( - Item={ - "entity_id": entity_id, # PartitionKey - "event_ts": str(utils.make_tzaware(timestamp)), - "values": { - k: v.SerializeToString() - for k, v in features.items() # Serialized Features - }, - } + Item=_to_resource_write_item( + config, entity_key, features, timestamp + ) ) if progress: progress(1) @@ -675,3 +704,45 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None region, endpoint_url ) return self._dynamodb_resource + + +def _to_resource_write_item(config, entity_key, features, timestamp): + entity_id = compute_entity_id( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + return { + "entity_id": entity_id, # PartitionKey + "event_ts": str(utils.make_tzaware(timestamp)), + "values": { + k: v.SerializeToString() + for k, v in features.items() # Serialized Features + }, + } + + +def _to_client_write_item(config, entity_key, features, timestamp): + entity_id = compute_entity_id( + entity_key, + entity_key_serialization_version=config.entity_key_serialization_version, + ) + return { + "entity_id": {"S": entity_id}, # PartitionKey + "event_ts": {"S": str(utils.make_tzaware(timestamp))}, + "values": { + "M": { + k: {"B": v.SerializeToString()} + for k, v in features.items() # Serialized Features + } + }, + } + + +def _latest_data_to_write( + data: List[ + Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]] + ], +): + as_hashable = ((d[0].SerializeToString(), d) for d in data) + sorted_data = sorted(as_hashable, key=lambda ah: (ah[0], ah[1][2])) + return (v for _, v in OrderedDict((ah[0], ah[1]) for ah in sorted_data).items()) diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index 8e1b182249..0526cf8b65 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -1,4 +1,6 @@ +import asyncio import contextlib +import itertools import os import tempfile import uuid @@ -10,6 +12,7 @@ import pyarrow as pa import pyarrow.parquet as pq from tenacity import ( + AsyncRetrying, retry, retry_if_exception_type, stop_after_attempt, @@ -1076,3 +1079,64 @@ def upload_arrow_table_to_athena( # Clean up S3 temporary data # for file_path in uploaded_files: # s3_resource.Object(bucket, file_path).delete() + + +class DynamoUnprocessedWriteItems(Exception): + pass + + +async def dynamo_write_items_async( + dynamo_client, table_name: str, items: list[dict] +) -> None: + """ + Writes in batches to a dynamo table asynchronously. Max size of each + attempted batch is 25. + Raises DynamoUnprocessedWriteItems if not all items can be written. + + Args: + dynamo_client: async dynamodb client + table_name: name of table being written to + items: list of items to be written. see boto3 docs on format of the items. + """ + DYNAMO_MAX_WRITE_BATCH_SIZE = 25 + + async def _do_write(items): + item_iter = iter(items) + item_batches = [] + while True: + item_batch = [ + item + for item in itertools.islice(item_iter, DYNAMO_MAX_WRITE_BATCH_SIZE) + ] + if not item_batch: + break + + item_batches.append(item_batch) + + return await asyncio.gather( + *[ + dynamo_client.batch_write_item( + RequestItems={table_name: item_batch}, + ) + for item_batch in item_batches + ] + ) + + put_items = [{"PutRequest": {"Item": item}} for item in items] + + retries = AsyncRetrying( + retry=retry_if_exception_type(DynamoUnprocessedWriteItems), + wait=wait_exponential(multiplier=1, max=4), + reraise=True, + ) + + async for attempt in retries: + with attempt: + response_batches = await _do_write(put_items) + + put_items = [] + for response in response_batches: + put_items.extend(response["UnprocessedItems"]) + + if put_items: + raise DynamoUnprocessedWriteItems() diff --git a/sdk/python/tests/integration/online_store/test_push_features_to_online_store.py b/sdk/python/tests/integration/online_store/test_push_features_to_online_store.py index 98fe3ab1ec..8986e21c57 100644 --- a/sdk/python/tests/integration/online_store/test_push_features_to_online_store.py +++ b/sdk/python/tests/integration/online_store/test_push_features_to_online_store.py @@ -8,29 +8,51 @@ from tests.integration.feature_repos.universal.entities import location -@pytest.mark.integration -@pytest.mark.universal_online_stores -def test_push_features_and_read(environment, universal_data_sources): +@pytest.fixture +def store(environment, universal_data_sources): store = environment.feature_store _, _, data_sources = universal_data_sources feature_views = construct_universal_feature_views(data_sources) location_fv = feature_views.pushed_locations store.apply([location(), location_fv]) + return store + +def _ingest_df(): data = { "location_id": [1], "temperature": [4], "event_timestamp": [pd.Timestamp(_utc_now()).round("ms")], "created": [pd.Timestamp(_utc_now()).round("ms")], } - df_ingest = pd.DataFrame(data) + return pd.DataFrame(data) - store.push("location_stats_push_source", df_ingest) + +def assert_response(online_resp): + online_resp_dict = online_resp.to_dict() + assert online_resp_dict["location_id"] == [1] + assert online_resp_dict["temperature"] == [4] + + +@pytest.mark.integration +@pytest.mark.universal_online_stores +def test_push_features_and_read(store): + store.push("location_stats_push_source", _ingest_df()) online_resp = store.get_online_features( features=["pushable_location_stats:temperature"], entity_rows=[{"location_id": 1}], ) - online_resp_dict = online_resp.to_dict() - assert online_resp_dict["location_id"] == [1] - assert online_resp_dict["temperature"] == [4] + assert_response(online_resp) + + +@pytest.mark.integration +@pytest.mark.universal_online_stores(only=["dynamodb"]) +async def test_push_features_and_read_async(store): + await store.push_async("location_stats_push_source", _ingest_df()) + + online_resp = await store.get_online_features_async( + features=["pushable_location_stats:temperature"], + entity_rows=[{"location_id": 1}], + ) + assert_response(online_resp) diff --git a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py index 6ff7b3c360..cb1c15ee6e 100644 --- a/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_dynamodb_online_store.py @@ -1,5 +1,6 @@ from copy import deepcopy from dataclasses import dataclass +from datetime import datetime import boto3 import pytest @@ -10,6 +11,7 @@ DynamoDBOnlineStore, DynamoDBOnlineStoreConfig, DynamoDBTable, + _latest_data_to_write, ) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto @@ -358,3 +360,24 @@ def test_dynamodb_online_store_online_read_unknown_entity_end_of_batch( # ensure the entity is not dropped assert len(returned_items) == len(entity_keys) assert returned_items[-1] == (None, None) + + +def test_batch_write_deduplication(): + def to_ek_proto(val): + return EntityKeyProto( + join_keys=["customer"], entity_values=[ValueProto(string_val=val)] + ) + + # is out of order and has duplicate keys + data = [ + (to_ek_proto("key-1"), {}, datetime(2024, 1, 1), None), + (to_ek_proto("key-2"), {}, datetime(2024, 1, 1), None), + (to_ek_proto("key-1"), {}, datetime(2024, 1, 3), None), + (to_ek_proto("key-1"), {}, datetime(2024, 1, 2), None), + (to_ek_proto("key-3"), {}, datetime(2024, 1, 2), None), + ] + + # assert we only keep the most recent record per key + actual = list(_latest_data_to_write(data)) + expected = [data[2], data[1], data[4]] + assert expected == actual