Skip to content

Commit

Permalink
feat: Support client-side parameter resolution in athena.create_ctas_…
Browse files Browse the repository at this point in the history
…table
  • Loading branch information
LeonLuttenberger committed May 1, 2024
1 parent eb9386f commit 7ad96bd
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
3 changes: 2 additions & 1 deletion awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,8 @@ def _resolve_query_without_cache_ctas(
wait=True,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
boto3_session=boto3_session,
execution_params=execution_params,
params=execution_params,
paramstyle="qmark",
)
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"])
Expand Down
27 changes: 27 additions & 0 deletions awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ def create_ctas_table(
wait: bool = False,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
execution_params: list[str] | None = None,
params: dict[str, Any] | list[str] | None = None,
paramstyle: Literal["qmark", "named"] = "named",
boto3_session: boto3.Session | None = None,
) -> dict[str, str | _QueryMetadata]:
"""Create a new table populated with the results of a SELECT query.
Expand Down Expand Up @@ -701,6 +703,17 @@ def create_ctas_table(
Whether to wait for the query to finish and return a dictionary with the Query metadata.
athena_query_wait_polling_delay: float, default: 0.25 seconds
Interval in seconds for how often the function will check if the Athena query has completed.
execution_params: List[str], optional [DEPRECATED]
A list of values for the parameters that are used in the SQL query.
This parameter is on a deprecation path.
Use ``params`` and `paramstyle`` instead.
params: Dict[str, Any] | List[str], optional
Dictionary or list of parameters to pass to execute method.
The syntax used to pass parameters depends on the configuration of ``paramstyle``.
paramstyle: str, optional
The syntax style to use for the parameters.
Supported values are ``named`` and ``qmark``.
The default is ``named``.
boto3_session: boto3.Session, optional
Boto3 Session. The default boto3 session is used if boto3_session is None.
Expand Down Expand Up @@ -752,6 +765,20 @@ def create_ctas_table(
if ctas_database is None:
raise exceptions.InvalidArgumentCombination("Either ctas_database or database must be defined.")

# Substitute execution_params with params
if execution_params:
if params:
raise exceptions.InvalidArgumentCombination("`execution_params` and `params` are mutually exclusive.")

params = execution_params
paramstyle = "qmark"
raise DeprecationWarning(
'`execution_params` is being deprecated. Use `params` and `paramstyle="qmark"` instead.'
)

# Substitute query parameters if applicable
sql, execution_params = _apply_formatter(sql, params, paramstyle)

fully_qualified_name = f'"{ctas_database}"."{ctas_table}"'

wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,94 @@ def test_athena_create_ctas(path, glue_table, glue_table2, glue_database, glue_c
ensure_athena_ctas_table(ctas_query_info=ctas_query_info, boto3_session=boto3_session)


def test_athena_create_ctas_with_named_params(path, glue_table, glue_database, glue_ctas_database):
wr.s3.to_parquet(
df=get_df_list(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
)

wr.athena.create_ctas_table(
sql=f"SELECT * FROM {glue_table} WHERE par1 = :par1",
database=glue_database,
ctas_database=glue_ctas_database,
params={"par1": "b"},
paramstyle="named",
wait=True,
)


def test_athena_create_ctas_with_qmark_params(path, glue_table, glue_database, glue_ctas_database):
wr.s3.to_parquet(
df=get_df_list(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
)

wr.athena.create_ctas_table(
sql=f"SELECT * FROM {glue_table} WHERE par1 = ?",
database=glue_database,
ctas_database=glue_ctas_database,
params=["b"],
paramstyle="qmark",
wait=True,
)


def test_athena_create_ctas_with_execution_params_deprecation_warning(
path, glue_table, glue_database, glue_ctas_database
):
wr.s3.to_parquet(
df=get_df_list(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
)

with pytest.raises(DeprecationWarning):
wr.athena.create_ctas_table(
sql=f"SELECT * FROM {glue_table} WHERE par1 = ?",
database=glue_database,
ctas_database=glue_ctas_database,
execution_params=["b"],
wait=True,
)


def test_athena_create_ctas_with_params_and_execution_params_error(path, glue_table, glue_database, glue_ctas_database):
wr.s3.to_parquet(
df=get_df_list(),
path=path,
index=False,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
)

with pytest.raises(wr.exceptions.InvalidArgumentCombination):
wr.athena.create_ctas_table(
sql=f"SELECT * FROM {glue_table} WHERE par1 = ?",
database=glue_database,
ctas_database=glue_ctas_database,
execution_params=["b"],
params=["b"],
paramstyle="qmark",
wait=True,
)


def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1):
wr.s3.to_parquet(
df=get_df(),
Expand Down

0 comments on commit 7ad96bd

Please sign in to comment.