diff --git a/sdk/python/feast/infra/online_stores/dynamodb.py b/sdk/python/feast/infra/online_stores/dynamodb.py index 86a96239bb..c161a5b955 100644 --- a/sdk/python/feast/infra/online_stores/dynamodb.py +++ b/sdk/python/feast/infra/online_stores/dynamodb.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple @@ -50,10 +51,16 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel): """Online store type selector""" region: StrictStr - """ AWS Region Name """ + """AWS Region Name""" table_name_template: StrictStr = "{project}.{table_name}" - """ DynamoDB table name template """ + """DynamoDB table name template""" + + sort_response: bool = True + """Whether or not to sort BatchGetItem response.""" + + batch_size: int = 40 + """Number of items to retrieve in a DynamoDB BatchGetItem call.""" class DynamoDBOnlineStore(OnlineStore): @@ -211,26 +218,46 @@ def online_read( online_config = config.online_store assert isinstance(online_config, DynamoDBOnlineStoreConfig) dynamodb_resource = self._get_dynamodb_resource(online_config.region) + table_instance = dynamodb_resource.Table( + _get_table_name(online_config, config, table) + ) result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] - for entity_key in entity_keys: - table_instance = dynamodb_resource.Table( - _get_table_name(online_config, config, table) - ) - entity_id = compute_entity_id(entity_key) + entity_ids = [compute_entity_id(entity_key) for entity_key in entity_keys] + batch_size = online_config.batch_size + sort_response = online_config.sort_response + entity_ids_iter = iter(entity_ids) + while True: + batch = list(itertools.islice(entity_ids_iter, batch_size)) + # No more items to insert + if len(batch) == 0: + break + batch_entity_ids = { + table_instance.name: { + "Keys": [{"entity_id": entity_id} for entity_id in batch] + } + } with tracing_span(name="remote_call"): - response = table_instance.get_item(Key={"entity_id": entity_id}) - value = response.get("Item") - - if value is not None: - res = {} - for feature_name, value_bin in value["values"].items(): - val = ValueProto() - val.ParseFromString(value_bin.value) - res[feature_name] = val - result.append((datetime.fromisoformat(value["event_ts"]), res)) + response = dynamodb_resource.batch_get_item( + RequestItems=batch_entity_ids + ) + response = response.get("Responses") + table_responses = response.get(table_instance.name) + if table_responses: + if sort_response: + table_responses = self._sort_dynamodb_response( + table_responses, entity_ids + ) + for tbl_res in table_responses: + res = {} + for feature_name, value_bin in tbl_res["values"].items(): + val = ValueProto() + val.ParseFromString(value_bin.value) + res[feature_name] = val + result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) else: - result.append((None, None)) + batch_size_nones = ((None, None),) * len(batch) + result.extend(batch_size_nones) return result def _get_dynamodb_client(self, region: str): @@ -243,6 +270,20 @@ def _get_dynamodb_resource(self, region: str): self._dynamodb_resource = _initialize_dynamodb_resource(region) return self._dynamodb_resource + def _sort_dynamodb_response(self, responses: list, order: list): + """DynamoDB Batch Get Item doesn't return items in a particular order.""" + # Assign an index to order + order_with_index = {value: idx for idx, value in enumerate(order)} + # Sort table responses by index + table_responses_ordered = [ + (order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses + ] + table_responses_ordered = sorted( + table_responses_ordered, key=lambda tup: tup[0] + ) + _, table_responses_ordered = zip(*table_responses_ordered) + return table_responses_ordered + def _initialize_dynamodb_client(region: str): return boto3.client("dynamodb", region_name=region) diff --git a/sdk/python/tests/unit/online_store/test_dynamodb_online_store.py b/sdk/python/tests/unit/online_store/test_dynamodb_online_store.py new file mode 100644 index 0000000000..0f42230ef5 --- /dev/null +++ b/sdk/python/tests/unit/online_store/test_dynamodb_online_store.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +import pytest +from moto import mock_dynamodb2 + +from feast.infra.offline_stores.file import FileOfflineStoreConfig +from feast.infra.online_stores.dynamodb import ( + DynamoDBOnlineStore, + DynamoDBOnlineStoreConfig, +) +from feast.repo_config import RepoConfig +from tests.utils.online_store_utils import ( + _create_n_customer_test_samples, + _create_test_table, + _insert_data_test_table, +) + +REGISTRY = "s3://test_registry/registry.db" +PROJECT = "test_aws" +PROVIDER = "aws" +TABLE_NAME = "dynamodb_online_store" +REGION = "us-west-2" + + +@dataclass +class MockFeatureView: + name: str + + +@pytest.fixture +def repo_config(): + return RepoConfig( + registry=REGISTRY, + project=PROJECT, + provider=PROVIDER, + online_store=DynamoDBOnlineStoreConfig(region=REGION), + offline_store=FileOfflineStoreConfig(), + ) + + +@mock_dynamodb2 +@pytest.mark.parametrize("n_samples", [5, 50, 100]) +def test_online_read(repo_config, n_samples): + """Test DynamoDBOnlineStore online_read method.""" + _create_test_table(PROJECT, f"{TABLE_NAME}_{n_samples}", REGION) + data = _create_n_customer_test_samples(n=n_samples) + _insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_{n_samples}", REGION) + + entity_keys, features = zip(*data) + dynamodb_store = DynamoDBOnlineStore() + returned_items = dynamodb_store.online_read( + config=repo_config, + table=MockFeatureView(name=f"{TABLE_NAME}_{n_samples}"), + entity_keys=entity_keys, + ) + assert len(returned_items) == len(data) + assert [item[1] for item in returned_items] == list(features) diff --git a/sdk/python/tests/utils/online_store_utils.py b/sdk/python/tests/utils/online_store_utils.py new file mode 100644 index 0000000000..ee90c2a542 --- /dev/null +++ b/sdk/python/tests/utils/online_store_utils.py @@ -0,0 +1,54 @@ +from datetime import datetime + +import boto3 + +from feast import utils +from feast.infra.online_stores.helpers import compute_entity_id +from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto +from feast.protos.feast.types.Value_pb2 import Value as ValueProto + + +def _create_n_customer_test_samples(n=10): + return [ + ( + EntityKeyProto( + join_keys=["customer"], entity_values=[ValueProto(string_val=str(i))] + ), + { + "avg_orders_day": ValueProto(float_val=1.0), + "name": ValueProto(string_val="John"), + "age": ValueProto(int64_val=3), + }, + ) + for i in range(n) + ] + + +def _create_test_table(project, tbl_name, region): + client = boto3.client("dynamodb", region_name=region) + client.create_table( + TableName=f"{project}.{tbl_name}", + KeySchema=[{"AttributeName": "entity_id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "entity_id", "AttributeType": "S"}], + BillingMode="PAY_PER_REQUEST", + ) + + +def _delete_test_table(project, tbl_name, region): + client = boto3.client("dynamodb", region_name=region) + client.delete_table(TableName=f"{project}.{tbl_name}") + + +def _insert_data_test_table(data, project, tbl_name, region): + dynamodb_resource = boto3.resource("dynamodb", region_name=region) + table_instance = dynamodb_resource.Table(f"{project}.{tbl_name}") + for entity_key, features in data: + entity_id = compute_entity_id(entity_key) + with table_instance.batch_writer() as batch: + batch.put_item( + Item={ + "entity_id": entity_id, + "event_ts": str(utils.make_tzaware(datetime.utcnow())), + "values": {k: v.SerializeToString() for k, v in features.items()}, + } + )