Skip to content

Commit

Permalink
feat: Add support for DynamoDB online_read in batches (feast-dev#2371)
Browse files Browse the repository at this point in the history
* feat: dynamodb onlin read in batches

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* run linters and format

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* feat: batch_size parameter

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* docs: typo in batch_size description

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* trailing white space

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* fix: batch_size is last argument

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* test: dynamodb online store online_read in batches

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* test: mock dynamodb behavior

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* feat: batch_size value must be less than 40

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* feat: batch_size defaults to 40

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* feat: sort dynamodb responses

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* resolve merge conflicts

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* test online response proto with redshift:dynamodb

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* feat: consistency in batch_size process

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* fix: return batch_size times None

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* remove debug code

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* typo in docstring

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

* batch_size in onlineconfigstore

Signed-off-by: Miguel Trejo <armando.trejo.marrufo@gmail.com>

Co-authored-by: Danny Chiao <danny@tecton.ai>
  • Loading branch information
TremaMiguel and adchia authored Mar 23, 2022
1 parent 45db6dc commit 702ec49
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 18 deletions.
77 changes: 59 additions & 18 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
Expand Down Expand Up @@ -50,10 +51,16 @@ class DynamoDBOnlineStoreConfig(FeastConfigBaseModel):
"""Online store type selector"""

region: StrictStr
""" AWS Region Name """
"""AWS Region Name"""

table_name_template: StrictStr = "{project}.{table_name}"
""" DynamoDB table name template """
"""DynamoDB table name template"""

sort_response: bool = True
"""Whether or not to sort BatchGetItem response."""

batch_size: int = 40
"""Number of items to retrieve in a DynamoDB BatchGetItem call."""


class DynamoDBOnlineStore(OnlineStore):
Expand Down Expand Up @@ -211,26 +218,46 @@ def online_read(
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)
dynamodb_resource = self._get_dynamodb_resource(online_config.region)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
for entity_key in entity_keys:
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)
entity_id = compute_entity_id(entity_key)
entity_ids = [compute_entity_id(entity_key) for entity_key in entity_keys]
batch_size = online_config.batch_size
sort_response = online_config.sort_response
entity_ids_iter = iter(entity_ids)
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_instance.name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch]
}
}
with tracing_span(name="remote_call"):
response = table_instance.get_item(Key={"entity_id": entity_id})
value = response.get("Item")

if value is not None:
res = {}
for feature_name, value_bin in value["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
result.append((datetime.fromisoformat(value["event_ts"]), res))
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids
)
response = response.get("Responses")
table_responses = response.get(table_instance.name)
if table_responses:
if sort_response:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
for tbl_res in table_responses:
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
else:
result.append((None, None))
batch_size_nones = ((None, None),) * len(batch)
result.extend(batch_size_nones)
return result

def _get_dynamodb_client(self, region: str):
Expand All @@ -243,6 +270,20 @@ def _get_dynamodb_resource(self, region: str):
self._dynamodb_resource = _initialize_dynamodb_resource(region)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list):
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered = [
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
)
_, table_responses_ordered = zip(*table_responses_ordered)
return table_responses_ordered


def _initialize_dynamodb_client(region: str):
return boto3.client("dynamodb", region_name=region)
Expand Down
57 changes: 57 additions & 0 deletions sdk/python/tests/unit/online_store/test_dynamodb_online_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from dataclasses import dataclass

import pytest
from moto import mock_dynamodb2

from feast.infra.offline_stores.file import FileOfflineStoreConfig
from feast.infra.online_stores.dynamodb import (
DynamoDBOnlineStore,
DynamoDBOnlineStoreConfig,
)
from feast.repo_config import RepoConfig
from tests.utils.online_store_utils import (
_create_n_customer_test_samples,
_create_test_table,
_insert_data_test_table,
)

REGISTRY = "s3://test_registry/registry.db"
PROJECT = "test_aws"
PROVIDER = "aws"
TABLE_NAME = "dynamodb_online_store"
REGION = "us-west-2"


@dataclass
class MockFeatureView:
name: str


@pytest.fixture
def repo_config():
return RepoConfig(
registry=REGISTRY,
project=PROJECT,
provider=PROVIDER,
online_store=DynamoDBOnlineStoreConfig(region=REGION),
offline_store=FileOfflineStoreConfig(),
)


@mock_dynamodb2
@pytest.mark.parametrize("n_samples", [5, 50, 100])
def test_online_read(repo_config, n_samples):
"""Test DynamoDBOnlineStore online_read method."""
_create_test_table(PROJECT, f"{TABLE_NAME}_{n_samples}", REGION)
data = _create_n_customer_test_samples(n=n_samples)
_insert_data_test_table(data, PROJECT, f"{TABLE_NAME}_{n_samples}", REGION)

entity_keys, features = zip(*data)
dynamodb_store = DynamoDBOnlineStore()
returned_items = dynamodb_store.online_read(
config=repo_config,
table=MockFeatureView(name=f"{TABLE_NAME}_{n_samples}"),
entity_keys=entity_keys,
)
assert len(returned_items) == len(data)
assert [item[1] for item in returned_items] == list(features)
54 changes: 54 additions & 0 deletions sdk/python/tests/utils/online_store_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from datetime import datetime

import boto3

from feast import utils
from feast.infra.online_stores.helpers import compute_entity_id
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto


def _create_n_customer_test_samples(n=10):
return [
(
EntityKeyProto(
join_keys=["customer"], entity_values=[ValueProto(string_val=str(i))]
),
{
"avg_orders_day": ValueProto(float_val=1.0),
"name": ValueProto(string_val="John"),
"age": ValueProto(int64_val=3),
},
)
for i in range(n)
]


def _create_test_table(project, tbl_name, region):
client = boto3.client("dynamodb", region_name=region)
client.create_table(
TableName=f"{project}.{tbl_name}",
KeySchema=[{"AttributeName": "entity_id", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "entity_id", "AttributeType": "S"}],
BillingMode="PAY_PER_REQUEST",
)


def _delete_test_table(project, tbl_name, region):
client = boto3.client("dynamodb", region_name=region)
client.delete_table(TableName=f"{project}.{tbl_name}")


def _insert_data_test_table(data, project, tbl_name, region):
dynamodb_resource = boto3.resource("dynamodb", region_name=region)
table_instance = dynamodb_resource.Table(f"{project}.{tbl_name}")
for entity_key, features in data:
entity_id = compute_entity_id(entity_key)
with table_instance.batch_writer() as batch:
batch.put_item(
Item={
"entity_id": entity_id,
"event_ts": str(utils.make_tzaware(datetime.utcnow())),
"values": {k: v.SerializeToString() for k, v in features.items()},
}
)

0 comments on commit 702ec49

Please sign in to comment.