Skip to content

Commit

Permalink
fix: enforce use_threads=False when Limit is supplied
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidisido committed Jun 30, 2023
1 parent ce8a04a commit bf718c6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 9 deletions.
37 changes: 31 additions & 6 deletions awswrangler/dynamodb/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def _read_scan_chunked(

deserializer = boto3.dynamodb.types.TypeDeserializer()
next_token = "init_token" # Dummy token
total_items = 0

kwargs = dict(kwargs)
if segment is not None:
Expand All @@ -208,8 +209,12 @@ def _read_scan_chunked(
{k: v["B"] if list(v.keys())[0] == "B" else deserializer.deserialize(v) for k, v in d.items()}
for d in response.get("Items", [])
]
total_items += len(items)
yield _utils.list_to_arrow_table(mapping=items) if as_dataframe else items

if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
break

next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
if next_token:
kwargs["ExclusiveStartKey"] = next_token
Expand Down Expand Up @@ -237,14 +242,22 @@ def _read_query_chunked(
table_name: str, boto3_session: Optional[boto3.Session] = None, **kwargs: Any
) -> Iterator[_ItemsListType]:
table = get_table(table_name=table_name, boto3_session=boto3_session)
response = table.query(**kwargs)
yield response.get("Items", [])
next_token = "init_token" # Dummy token
total_items = 0

# Handle pagination
while "LastEvaluatedKey" in response:
kwargs["ExclusiveStartKey"] = response["LastEvaluatedKey"]
while next_token:
response = table.query(**kwargs)
yield response.get("Items", [])
items = response.get("Items", [])
total_items += len(items)
yield items

if ("Limit" in kwargs) and (total_items >= kwargs["Limit"]):
break

next_token = response.get("LastEvaluatedKey", None) # type: ignore[assignment]
if next_token:
kwargs["ExclusiveStartKey"] = next_token


@_handle_reserved_keyword_error
Expand Down Expand Up @@ -352,9 +365,10 @@ def _read_items(
boto3_session: Optional[boto3.Session] = None,
**kwargs: Any,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame], _ItemsListType, Iterator[_ItemsListType]]:
# Extract 'Keys' and 'IndexName' from provided kwargs: if needed, will be reinserted later on
# Extract 'Keys', 'IndexName' and 'Limit' from provided kwargs: if needed, will be reinserted later on
keys = kwargs.pop("Keys", None)
index = kwargs.pop("IndexName", None)
limit = kwargs.pop("Limit", None)

# Conditionally define optimal reading strategy
use_get_item = (keys is not None) and (len(keys) == 1)
Expand All @@ -372,6 +386,11 @@ def _read_items(
items = _read_batch_items(table_name, chunked, boto3_session, **kwargs)

else:
if limit:
kwargs["Limit"] = limit
_logger.debug("`max_items_evaluated` argument detected, setting use_threads to False")
use_threads = False

if index:
kwargs["IndexName"] = index

Expand Down Expand Up @@ -438,6 +457,11 @@ def read_items( # pylint: disable=too-many-branches
of the table or index.
See: https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Scan.html#Scan.ParallelScan
Note
----
If `max_items_evaluated` is specified, then `use_threads=False` is enforced. This is because
it's not possible to limit the number of items in a Query/Scan operation across threads.
Parameters
----------
table_name : str
Expand Down Expand Up @@ -466,6 +490,7 @@ def read_items( # pylint: disable=too-many-branches
If True, allow full table scan without any filtering. Defaults to False.
max_items_evaluated : int, optional
Limit the number of items evaluated in case of query or scan operations. Defaults to None (all matching items).
When set, `use_threads` is enforced to False.
dtype_backend: str, optional
Which dtype_backend to use, e.g. whether a DataFrame should have NumPy arrays,
nullable dtypes are used for all dtypes that have a nullable implementation when
Expand Down
50 changes: 47 additions & 3 deletions tests/unit/test_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,14 @@ def test_read_items_simple(params: Dict[str, Any], dynamodb_table: str, use_thre

df2 = wr.dynamodb.read_items(
table_name=dynamodb_table,
max_items_evaluated=5,
max_items_evaluated=2,
pyarrow_additional_kwargs={"types_mapper": None},
use_threads=use_threads,
chunked=chunked,
)
if chunked:
df2 = pd.concat(df2)
assert df2.shape == df.shape
assert df2.shape == (2, len(df.columns))
assert df2.dtypes.to_list() == df.dtypes.to_list()

df3 = wr.dynamodb.read_items(
Expand Down Expand Up @@ -377,11 +377,12 @@ def test_read_items_index(params: Dict[str, Any], dynamodb_table: str, use_threa
table_name=dynamodb_table,
key_condition_expression=Key("Category").eq("Suspense"),
index_name="CategoryIndex",
max_items_evaluated=1,
chunked=chunked,
)
if chunked:
df2 = pd.concat(df2)
assert df2.shape == df.shape
assert df2.shape == (1, len(df.columns))

df3 = wr.dynamodb.read_items(
table_name=dynamodb_table, allow_full_scan=True, index_name="CategoryIndex", use_threads=1, chunked=chunked
Expand Down Expand Up @@ -456,3 +457,46 @@ def test_read_items_expression(params: Dict[str, Any], dynamodb_table: str, use_
expression_attribute_values={":v": "Eido"},
)
assert df6.shape == (1, len(df.columns))


@pytest.mark.parametrize(
"params",
[
{
"KeySchema": [{"AttributeName": "id", "KeyType": "HASH"}],
"AttributeDefinitions": [{"AttributeName": "id", "AttributeType": "N"}],
}
],
)
@pytest.mark.parametrize("max_items_evaluated", [1, 3, 5])
@pytest.mark.parametrize("chunked", [False, True])
def test_read_items_limited(
params: Dict[str, Any], dynamodb_table: str, max_items_evaluated: int, chunked: bool
) -> None:
df = pd.DataFrame(
{
"id": [1, 2, 3, 4],
"word": ["this", "is", "a", "test"],
"char_count": [4, 2, 1, 4],
}
)
wr.dynamodb.put_df(df=df, table_name=dynamodb_table)

df2 = wr.dynamodb.read_items(
table_name=dynamodb_table,
filter_expression=Attr("id").eq(1),
max_items_evaluated=max_items_evaluated,
chunked=chunked,
)
if chunked:
df2 = pd.concat(df2)
assert df2.shape == (1, len(df.columns))

df3 = wr.dynamodb.read_items(
table_name=dynamodb_table,
max_items_evaluated=max_items_evaluated,
chunked=chunked,
)
if chunked:
df3 = pd.concat(df3)
assert df3.shape == (min(max_items_evaluated, len(df)), len(df.columns))

0 comments on commit bf718c6

Please sign in to comment.