Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add online_read_async for dynamodb #4244

Merged
merged 16 commits into from
Jun 5, 2024
191 changes: 146 additions & 45 deletions sdk/python/feast/infra/online_stores/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

try:
import boto3
from aiobotocore import session
from boto3.dynamodb.types import TypeDeserializer
from botocore.config import Config
from botocore.exceptions import ClientError
except ImportError as e:
Expand Down Expand Up @@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore):

_dynamodb_client = None
_dynamodb_resource = None
_aioboto_session = None

def update(
self,
Expand Down Expand Up @@ -223,69 +226,103 @@ def online_read(
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

dynamodb_resource = self._get_dynamodb_resource(
online_config.region, online_config.endpoint_url
)
table_instance = dynamodb_resource.Table(
_get_table_name(online_config, config, table)
)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
entity_ids = [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]
batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))
batch_result: List[
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]
] = []

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = {
table_instance.name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}
batch_entity_ids = self._to_resource_batch_get_payload(
online_config, table_instance.name, batch
)
response = dynamodb_resource.batch_get_item(
RequestItems=batch_entity_ids,
)
response = response.get("Responses")
table_responses = response.get(table_instance.name)
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append(
(datetime.fromisoformat(tbl_res["event_ts"]), res)
)
entity_idx += 1

# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
batch_result = self._process_batch_get_response(
table_instance.name, response, entity_ids, batch
)
result.extend(batch_result)
return result

async def online_read_async(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: Optional[List[str]] = None,
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
"""
Reads features values for the given entity keys asynchronously.

Args:
config: The config for the current feature store.
table: The feature view whose feature values should be read.
entity_keys: The list of entity keys for which feature values should be read.
requested_features: The list of features that should be read.

Returns:
A list of the same length as entity_keys. Each item in the list is a tuple where the first
item is the event timestamp for the row, and the second item is a dict mapping feature names
to values, which are returned in proto format.
"""
online_config = config.online_store
assert isinstance(online_config, DynamoDBOnlineStoreConfig)

batch_size = online_config.batch_size
entity_ids = self._to_entity_ids(config, entity_keys)
entity_ids_iter = iter(entity_ids)
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []
table_name = _get_table_name(online_config, config, table)

deserialize = TypeDeserializer().deserialize

def to_tbl_resp(raw_client_response):
return {
"entity_id": deserialize(raw_client_response["entity_id"]),
"event_ts": deserialize(raw_client_response["event_ts"]),
"values": deserialize(raw_client_response["values"]),
}
Comment on lines +292 to +297
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the dynamo client returns stuff in the form

{ "field_name": {"TYPE": "value"}}

as opposed to the resource which just returns

{ "field_name": value}

dynamodb has a deserializer utility to do the conversion


async with self._get_aiodynamodb_client(online_config.region) as client:
while True:
batch = list(itertools.islice(entity_ids_iter, batch_size))

# No more items to insert
if len(batch) == 0:
break
batch_entity_ids = self._to_client_batch_get_payload(
online_config, table_name, batch
)
response = await client.batch_get_item(
RequestItems=batch_entity_ids,
)
batch_result = self._process_batch_get_response(
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp
)
result.extend(batch_result)
return result

def _get_aioboto_session(self):
if self._aioboto_session is None:
self._aioboto_session = session.get_session()
return self._aioboto_session

def _get_aiodynamodb_client(self, region: str):
return self._get_aioboto_session().create_client("dynamodb", region_name=region)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too sure about this, but is it a good idea to recreate client object on every call? Isn't there a performance penalty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the docs demonstrate and recommend usage within an async context manager. it is possible manually work with _exit_stack.enter_async_context at app startup/shutdown but wanted to avoid using protected methods and adding app complexity without first knowing it was really needed. seemed like an "as needed follow up" type of investigation.


def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None):
if self._dynamodb_client is None:
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url)
Expand All @@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None
)
return self._dynamodb_resource

def _sort_dynamodb_response(self, responses: list, order: list) -> Any:
def _sort_dynamodb_response(
self,
responses: list,
order: list,
to_tbl_response: Callable = lambda raw_dict: raw_dict,
) -> Any:
"""DynamoDB Batch Get Item doesn't return items in a particular order."""
# Assign an index to order
order_with_index = {value: idx for idx, value in enumerate(order)}
# Sort table responses by index
table_responses_ordered: Any = [
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses
(order_with_index[tbl_res["entity_id"]], tbl_res)
for tbl_res in map(to_tbl_response, responses)
]
table_responses_ordered = sorted(
table_responses_ordered, key=lambda tup: tup[0]
Expand Down Expand Up @@ -341,6 +384,64 @@ def _write_batch_non_duplicates(
if progress:
progress(1)

def _process_batch_get_response(
self, table_name, response, entity_ids, batch, **sort_kwargs
):
response = response.get("Responses")
table_responses = response.get(table_name)

batch_result = []
if table_responses:
table_responses = self._sort_dynamodb_response(
table_responses, entity_ids, **sort_kwargs
)
entity_idx = 0
for tbl_res in table_responses:
entity_id = tbl_res["entity_id"]
while entity_id != batch[entity_idx]:
batch_result.append((None, None))
entity_idx += 1
res = {}
for feature_name, value_bin in tbl_res["values"].items():
val = ValueProto()
val.ParseFromString(value_bin.value)
res[feature_name] = val
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res))
entity_idx += 1
# Not all entities in a batch may have responses
# Pad with remaining values in batch that were not found
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result))
batch_result.extend(batch_size_nones)
return batch_result

@staticmethod
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]):
return [
compute_entity_id(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
)
for entity_key in entity_keys
]

@staticmethod
def _to_resource_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": entity_id} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}

@staticmethod
def _to_client_batch_get_payload(online_config, table_name, batch):
return {
table_name: {
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch],
"ConsistentRead": online_config.consistent_reads,
}
}


def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None):
return boto3.client(
Expand Down
Loading
Loading