Skip to content

Commit

Permalink
Merge branch 'main' into fix_redshift_pk_with_hyphen
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidisido authored Jun 30, 2023
2 parents 0a7d286 + ce8a04a commit 1ead3a9
Show file tree
Hide file tree
Showing 19 changed files with 712 additions and 110 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ building/lambda/arrow
*.swp

# CDK
node_modules
*package.json
*package-lock.json
*.cdk.staging
*cdk.out
*cdk.context.json

# ruff
.ruff_cache/
2 changes: 2 additions & 0 deletions awswrangler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
athena,
catalog,
chime,
cleanrooms,
cloudwatch,
data_api,
data_quality,
Expand Down Expand Up @@ -43,6 +44,7 @@
"athena",
"catalog",
"chime",
"cleanrooms",
"cloudwatch",
"emr",
"emr_serverless",
Expand Down
12 changes: 12 additions & 0 deletions awswrangler/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from boto3.resources.base import ServiceResource
from botocore.client import BaseClient
from mypy_boto3_athena import AthenaClient
from mypy_boto3_cleanrooms import CleanRoomsServiceClient
from mypy_boto3_dynamodb import DynamoDBClient, DynamoDBServiceResource
from mypy_boto3_ec2 import EC2Client
from mypy_boto3_emr.client import EMRClient
Expand All @@ -68,6 +69,7 @@

ServiceName = Literal[
"athena",
"cleanrooms",
"dynamodb",
"ec2",
"emr",
Expand Down Expand Up @@ -286,6 +288,16 @@ def client(
...


@overload
def client(
service_name: 'Literal["cleanrooms"]',
session: Optional[boto3.Session] = None,
botocore_config: Optional[Config] = None,
verify: Optional[Union[str, bool]] = None,
) -> "CleanRoomsServiceClient":
...


@overload
def client(
service_name: 'Literal["lakeformation"]',
Expand Down
9 changes: 9 additions & 0 deletions awswrangler/cleanrooms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Amazon Clean Rooms Module."""

from awswrangler.cleanrooms._read import read_sql_query
from awswrangler.cleanrooms._utils import wait_query

__all__ = [
"read_sql_query",
"wait_query",
]
128 changes: 128 additions & 0 deletions awswrangler/cleanrooms/_read.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Amazon Clean Rooms Module hosting read_* functions."""

import logging
from typing import Any, Dict, Iterator, Optional, Union

import boto3

import awswrangler.pandas as pd
from awswrangler import _utils, s3
from awswrangler._sql_formatter import _process_sql_params
from awswrangler.cleanrooms._utils import wait_query

_logger: logging.Logger = logging.getLogger(__name__)


def _delete_after_iterate(
dfs: Iterator[pd.DataFrame], keep_files: bool, kwargs: Dict[str, Any]
) -> Iterator[pd.DataFrame]:
for df in dfs:
yield df
if keep_files is False:
s3.delete_objects(**kwargs)


def read_sql_query(
sql: str,
membership_id: str,
output_bucket: str,
output_prefix: str,
keep_files: bool = True,
params: Optional[Dict[str, Any]] = None,
chunksize: Optional[Union[int, bool]] = None,
use_threads: Union[bool, int] = True,
boto3_session: Optional[boto3.Session] = None,
pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[Iterator[pd.DataFrame], pd.DataFrame]:
"""Execute Clean Rooms Protected SQL query and return the results as a Pandas DataFrame.
Parameters
----------
sql : str
SQL query
membership_id : str
Membership ID
output_bucket : str
S3 output bucket name
output_prefix : str
S3 output prefix
keep_files : bool, optional
Whether files in S3 output bucket/prefix are retained. 'True' by default
params : Dict[str, any], optional
Dict of parameters used for constructing the SQL query. Only named parameters are supported.
The dict must be in the form {'name': 'value'} and the SQL query must contain
`:name`. Note that for varchar columns and similar, you must surround the value in single quotes
chunksize : Union[int, bool], optional
If passed, the data is split into an iterable of DataFrames (Memory friendly).
If `True` an iterable of DataFrames is returned without guarantee of chunksize.
If an `INTEGER` is passed, an iterable of DataFrames is returned with maximum rows
equal to the received INTEGER
use_threads : Union[bool, int], optional
True to enable concurrent requests, False to disable multiple threads.
If enabled os.cpu_count() is used as the maximum number of threads.
If integer is provided, specified number is used
boto3_session : boto3.Session, optional
Boto3 Session. If None, the default boto3 session is used
pyarrow_additional_kwargs : Optional[Dict[str, Any]]
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
e.g. pyarrow_additional_kwargs={'split_blocks': True}
Returns
-------
Union[Iterator[pd.DataFrame], pd.DataFrame]
Pandas DataFrame or Generator of Pandas DataFrames if chunksize is provided.
Examples
--------
>>> import awswrangler as wr
>>> df = wr.cleanrooms.read_sql_query(
>>> sql='SELECT DISTINCT...',
>>> membership_id='membership-id',
>>> output_bucket='output-bucket',
>>> output_prefix='output-prefix',
>>> )
"""
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)

query_id: str = client_cleanrooms.start_protected_query(
type="SQL",
membershipIdentifier=membership_id,
sqlParameters={"queryString": _process_sql_params(sql, params, engine_type="partiql")},
resultConfiguration={
"outputConfiguration": {
"s3": {
"bucket": output_bucket,
"keyPrefix": output_prefix,
"resultFormat": "PARQUET",
}
}
},
)["protectedQuery"]["id"]

_logger.debug("query_id: %s", query_id)
path: str = wait_query(membership_id=membership_id, query_id=query_id)["protectedQuery"]["result"]["output"]["s3"][
"location"
]

_logger.debug("path: %s", path)
chunked: Union[bool, int] = False if chunksize is None else chunksize
ret = s3.read_parquet(
path=path,
use_threads=use_threads,
chunked=chunked,
boto3_session=boto3_session,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
)

_logger.debug("type(ret): %s", type(ret))
kwargs: Dict[str, Any] = {
"path": path,
"use_threads": use_threads,
"boto3_session": boto3_session,
}
if chunked is False:
if keep_files is False:
s3.delete_objects(**kwargs)
return ret
return _delete_after_iterate(ret, keep_files, kwargs)
60 changes: 60 additions & 0 deletions awswrangler/cleanrooms/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Utilities Module for Amazon Clean Rooms."""
import logging
import time
from typing import TYPE_CHECKING, List, Optional

import boto3

from awswrangler import _utils, exceptions

if TYPE_CHECKING:
from mypy_boto3_cleanrooms.type_defs import GetProtectedQueryOutputTypeDef

_QUERY_FINAL_STATES: List[str] = ["CANCELLED", "FAILED", "SUCCESS", "TIMED_OUT"]
_QUERY_WAIT_POLLING_DELAY: float = 2 # SECONDS

_logger: logging.Logger = logging.getLogger(__name__)


def wait_query(
membership_id: str, query_id: str, boto3_session: Optional[boto3.Session] = None
) -> "GetProtectedQueryOutputTypeDef":
"""Wait for the Clean Rooms protected query to end.
Parameters
----------
membership_id : str
Membership ID
query_id : str
Protected query execution ID
boto3_session : boto3.Session, optional
Boto3 Session. If None, the default boto3 session is used
Returns
-------
Dict[str, Any]
Dictionary with the get_protected_query response.
Raises
------
exceptions.QueryFailed
Raises exception with error message if protected query is cancelled, times out or fails.
Examples
--------
>>> import awswrangler as wr
>>> res = wr.cleanrooms.wait_query(membership_id='membership-id', query_id='query-id')
"""
client_cleanrooms = _utils.client(service_name="cleanrooms", session=boto3_session)
state = "SUBMITTED"

while state not in _QUERY_FINAL_STATES:
time.sleep(_QUERY_WAIT_POLLING_DELAY)
response = client_cleanrooms.get_protected_query(
membershipIdentifier=membership_id, protectedQueryIdentifier=query_id
)
state = response["protectedQuery"].get("status") # type: ignore[assignment]

_logger.debug("state: %s", state)
if state != "SUCCESS":
raise exceptions.QueryFailed(response["protectedQuery"].get("Error"))
return response
12 changes: 12 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ API Reference
* `Amazon Neptune`_
* `DynamoDB`_
* `Amazon Timestream`_
* `AWS Clean Rooms`_
* `Amazon EMR`_
* `Amazon EMR Serverless`_
* `Amazon CloudWatch Logs`_
Expand Down Expand Up @@ -351,6 +352,17 @@ Amazon Timestream
unload_to_files
unload

AWS Clean Rooms
-----------------

.. currentmodule:: awswrangler.cleanrooms

.. autosummary::
:toctree: stubs

read_sql_query
wait_query

Amazon EMR
----------

Expand Down
21 changes: 19 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ wheel = "^0.38.1"

# Lint
black = "^23.1.0"
boto3-stubs = {version = "1.26.151", extras = ["athena", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
boto3-stubs = {version = "^1.26.151", extras = ["athena", "cleanrooms", "chime", "cloudwatch", "dynamodb", "ec2", "emr", "emr-serverless", "glue", "kms", "lakeformation", "logs", "neptune", "opensearch", "opensearchserverless", "quicksight", "rds", "rds-data", "redshift", "redshift-data", "s3", "secretsmanager", "ssm", "sts", "timestream-query", "timestream-write"]}
doc8 = "^1.0"
mypy = "^1.0"
pylint = "^2.17"
Expand Down
7 changes: 7 additions & 0 deletions test_infra/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from aws_cdk import App, Environment
from stacks.base_stack import BaseStack
from stacks.cleanrooms_stack import CleanRoomsStack
from stacks.databases_stack import DatabasesStack
from stacks.glueray_stack import GlueRayStack
from stacks.opensearch_stack import OpenSearchStack
Expand Down Expand Up @@ -42,4 +43,10 @@
**env,
)

CleanRoomsStack(
app,
"aws-sdk-pandas-cleanrooms",
**env,
)

app.synth()
Loading

0 comments on commit 1ead3a9

Please sign in to comment.