-
Notifications
You must be signed in to change notification settings - Fork 998
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add online_read_async for dynamodb #4244
Changes from 13 commits
bb0a4fb
6b178bb
a19d46d
d7a982f
3db88cb
ee0b634
0217149
eeb3929
c4ba283
8c164c3
a7488e3
a1d1ada
6bd29f7
159259b
fc9e603
42e98d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
|
||
try: | ||
import boto3 | ||
from aiobotocore import session | ||
from boto3.dynamodb.types import TypeDeserializer | ||
from botocore.config import Config | ||
from botocore.exceptions import ClientError | ||
except ImportError as e: | ||
|
@@ -80,6 +82,7 @@ class DynamoDBOnlineStore(OnlineStore): | |
|
||
_dynamodb_client = None | ||
_dynamodb_resource = None | ||
_aioboto_session = None | ||
|
||
def update( | ||
self, | ||
|
@@ -223,69 +226,103 @@ def online_read( | |
""" | ||
online_config = config.online_store | ||
assert isinstance(online_config, DynamoDBOnlineStoreConfig) | ||
|
||
dynamodb_resource = self._get_dynamodb_resource( | ||
online_config.region, online_config.endpoint_url | ||
) | ||
table_instance = dynamodb_resource.Table( | ||
_get_table_name(online_config, config, table) | ||
) | ||
|
||
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] | ||
entity_ids = [ | ||
compute_entity_id( | ||
entity_key, | ||
entity_key_serialization_version=config.entity_key_serialization_version, | ||
) | ||
for entity_key in entity_keys | ||
] | ||
batch_size = online_config.batch_size | ||
entity_ids = self._to_entity_ids(config, entity_keys) | ||
entity_ids_iter = iter(entity_ids) | ||
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] | ||
|
||
while True: | ||
batch = list(itertools.islice(entity_ids_iter, batch_size)) | ||
batch_result: List[ | ||
Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]] | ||
] = [] | ||
|
||
# No more items to insert | ||
if len(batch) == 0: | ||
break | ||
batch_entity_ids = { | ||
table_instance.name: { | ||
"Keys": [{"entity_id": entity_id} for entity_id in batch], | ||
"ConsistentRead": online_config.consistent_reads, | ||
} | ||
} | ||
batch_entity_ids = self._to_resource_batch_get_payload( | ||
online_config, table_instance.name, batch | ||
) | ||
response = dynamodb_resource.batch_get_item( | ||
RequestItems=batch_entity_ids, | ||
) | ||
response = response.get("Responses") | ||
table_responses = response.get(table_instance.name) | ||
if table_responses: | ||
table_responses = self._sort_dynamodb_response( | ||
table_responses, entity_ids | ||
) | ||
entity_idx = 0 | ||
for tbl_res in table_responses: | ||
entity_id = tbl_res["entity_id"] | ||
while entity_id != batch[entity_idx]: | ||
batch_result.append((None, None)) | ||
entity_idx += 1 | ||
res = {} | ||
for feature_name, value_bin in tbl_res["values"].items(): | ||
val = ValueProto() | ||
val.ParseFromString(value_bin.value) | ||
res[feature_name] = val | ||
batch_result.append( | ||
(datetime.fromisoformat(tbl_res["event_ts"]), res) | ||
) | ||
entity_idx += 1 | ||
|
||
# Not all entities in a batch may have responses | ||
# Pad with remaining values in batch that were not found | ||
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) | ||
batch_result.extend(batch_size_nones) | ||
batch_result = self._process_batch_get_response( | ||
table_instance.name, response, entity_ids, batch | ||
) | ||
result.extend(batch_result) | ||
return result | ||
|
||
async def online_read_async( | ||
self, | ||
config: RepoConfig, | ||
table: FeatureView, | ||
entity_keys: List[EntityKeyProto], | ||
requested_features: Optional[List[str]] = None, | ||
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]: | ||
""" | ||
Reads features values for the given entity keys asynchronously. | ||
|
||
Args: | ||
config: The config for the current feature store. | ||
table: The feature view whose feature values should be read. | ||
entity_keys: The list of entity keys for which feature values should be read. | ||
requested_features: The list of features that should be read. | ||
|
||
Returns: | ||
A list of the same length as entity_keys. Each item in the list is a tuple where the first | ||
item is the event timestamp for the row, and the second item is a dict mapping feature names | ||
to values, which are returned in proto format. | ||
""" | ||
online_config = config.online_store | ||
assert isinstance(online_config, DynamoDBOnlineStoreConfig) | ||
|
||
batch_size = online_config.batch_size | ||
entity_ids = self._to_entity_ids(config, entity_keys) | ||
entity_ids_iter = iter(entity_ids) | ||
result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = [] | ||
table_name = _get_table_name(online_config, config, table) | ||
|
||
deserialize = TypeDeserializer().deserialize | ||
|
||
def to_tbl_resp(raw_client_response): | ||
return { | ||
"entity_id": deserialize(raw_client_response["entity_id"]), | ||
"event_ts": deserialize(raw_client_response["event_ts"]), | ||
"values": deserialize(raw_client_response["values"]), | ||
} | ||
|
||
async with self._get_aiodynamodb_client(online_config.region) as client: | ||
while True: | ||
batch = list(itertools.islice(entity_ids_iter, batch_size)) | ||
|
||
# No more items to insert | ||
if len(batch) == 0: | ||
break | ||
batch_entity_ids = self._to_client_batch_get_payload( | ||
online_config, table_name, batch | ||
) | ||
response = await client.batch_get_item( | ||
RequestItems=batch_entity_ids, | ||
) | ||
batch_result = self._process_batch_get_response( | ||
table_name, response, entity_ids, batch, to_tbl_response=to_tbl_resp | ||
) | ||
result.extend(batch_result) | ||
return result | ||
|
||
def _get_aioboto_session(self): | ||
if self._aioboto_session is None: | ||
self._aioboto_session = session.get_session() | ||
return self._aioboto_session | ||
|
||
def _get_aiodynamodb_client(self, region: str): | ||
return self._get_aioboto_session().create_client("dynamodb", region_name=region) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not too sure about this, but is it a good idea to recreate client object on every call? Isn't there a performance penalty? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the docs demonstrate and recommend usage within an async context manager. it is possible manually work with |
||
|
||
def _get_dynamodb_client(self, region: str, endpoint_url: Optional[str] = None): | ||
if self._dynamodb_client is None: | ||
self._dynamodb_client = _initialize_dynamodb_client(region, endpoint_url) | ||
|
@@ -298,13 +335,19 @@ def _get_dynamodb_resource(self, region: str, endpoint_url: Optional[str] = None | |
) | ||
return self._dynamodb_resource | ||
|
||
def _sort_dynamodb_response(self, responses: list, order: list) -> Any: | ||
def _sort_dynamodb_response( | ||
self, | ||
responses: list, | ||
order: list, | ||
to_tbl_response: Callable = lambda raw_dict: raw_dict, | ||
) -> Any: | ||
"""DynamoDB Batch Get Item doesn't return items in a particular order.""" | ||
# Assign an index to order | ||
order_with_index = {value: idx for idx, value in enumerate(order)} | ||
# Sort table responses by index | ||
table_responses_ordered: Any = [ | ||
(order_with_index[tbl_res["entity_id"]], tbl_res) for tbl_res in responses | ||
(order_with_index[tbl_res["entity_id"]], tbl_res) | ||
for tbl_res in map(to_tbl_response, responses) | ||
] | ||
table_responses_ordered = sorted( | ||
table_responses_ordered, key=lambda tup: tup[0] | ||
|
@@ -341,6 +384,64 @@ def _write_batch_non_duplicates( | |
if progress: | ||
progress(1) | ||
|
||
def _process_batch_get_response( | ||
self, table_name, response, entity_ids, batch, **sort_kwargs | ||
): | ||
response = response.get("Responses") | ||
table_responses = response.get(table_name) | ||
|
||
batch_result = [] | ||
if table_responses: | ||
table_responses = self._sort_dynamodb_response( | ||
table_responses, entity_ids, **sort_kwargs | ||
) | ||
entity_idx = 0 | ||
for tbl_res in table_responses: | ||
entity_id = tbl_res["entity_id"] | ||
while entity_id != batch[entity_idx]: | ||
batch_result.append((None, None)) | ||
entity_idx += 1 | ||
res = {} | ||
for feature_name, value_bin in tbl_res["values"].items(): | ||
val = ValueProto() | ||
val.ParseFromString(value_bin.value) | ||
res[feature_name] = val | ||
batch_result.append((datetime.fromisoformat(tbl_res["event_ts"]), res)) | ||
entity_idx += 1 | ||
# Not all entities in a batch may have responses | ||
# Pad with remaining values in batch that were not found | ||
batch_size_nones = ((None, None),) * (len(batch) - len(batch_result)) | ||
batch_result.extend(batch_size_nones) | ||
return batch_result | ||
|
||
@staticmethod | ||
def _to_entity_ids(config: RepoConfig, entity_keys: List[EntityKeyProto]): | ||
return [ | ||
compute_entity_id( | ||
entity_key, | ||
entity_key_serialization_version=config.entity_key_serialization_version, | ||
) | ||
for entity_key in entity_keys | ||
] | ||
|
||
@staticmethod | ||
def _to_resource_batch_get_payload(online_config, table_name, batch): | ||
return { | ||
table_name: { | ||
"Keys": [{"entity_id": entity_id} for entity_id in batch], | ||
"ConsistentRead": online_config.consistent_reads, | ||
} | ||
} | ||
|
||
@staticmethod | ||
def _to_client_batch_get_payload(online_config, table_name, batch): | ||
return { | ||
table_name: { | ||
"Keys": [{"entity_id": {"S": entity_id}} for entity_id in batch], | ||
"ConsistentRead": online_config.consistent_reads, | ||
} | ||
} | ||
|
||
|
||
def _initialize_dynamodb_client(region: str, endpoint_url: Optional[str] = None): | ||
return boto3.client( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the dynamo client returns stuff in the form
as opposed to the resource which just returns
dynamodb has a deserializer utility to do the conversion