Skip to content

Commit

Permalink
Add coverage for athena.read_sql_* w/o results. #299
Browse files Browse the repository at this point in the history
  • Loading branch information
igorborgest committed Jul 16, 2020
1 parent 32c7b9f commit 820693d
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
38 changes: 23 additions & 15 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Union

import boto3 # type: ignore
import botocore.exceptions # type: ignore
import pandas as pd # type: ignore

from awswrangler import _utils, catalog, exceptions, s3
Expand Down Expand Up @@ -233,8 +234,8 @@ def _fetch_csv_result(
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
_chunksize: Optional[int] = chunksize if isinstance(chunksize, int) else None
_logger.debug("_chunksize: %s", _chunksize)
if query_metadata.output_location is None:
return pd.DataFrame() if chunksize is None is False else _utils.empty_generator()
if query_metadata.output_location is None or query_metadata.output_location.endswith(".csv") is False:
return pd.DataFrame() if _chunksize is None else _utils.empty_generator()
path: str = query_metadata.output_location
s3.wait_objects_exist(paths=[path], use_threads=False, boto3_session=boto3_session)
_logger.debug("Start CSV reading from %s", path)
Expand All @@ -253,7 +254,7 @@ def _fetch_csv_result(
)
_logger.debug("Start type casting...")
_logger.debug(type(ret))
if chunksize is None:
if _chunksize is None:
df = _fix_csv_types(df=ret, parse_dates=query_metadata.parse_dates, binaries=query_metadata.binaries)
if keep_files is False:
s3.delete_objects(path=[path, f"{path}.metadata"], use_threads=use_threads, boto3_session=boto3_session)
Expand All @@ -274,7 +275,7 @@ def _resolve_query_with_cache( # pylint: disable=too-many-return-statements
session: Optional[boto3.Session],
):
"""Fetch cached data and return it as a pandas DataFrame (or list of DataFrames)."""
_logger.debug("cache_info: %s", cache_info)
_logger.debug("cache_info:\n%s", cache_info)
query_metadata: _QueryMetadata = _get_query_metadata(
query_execution_id=cache_info.query_execution_id,
boto3_session=session,
Expand Down Expand Up @@ -328,16 +329,23 @@ def _resolve_query_without_cache_ctas(
f"{sql}"
)
_logger.debug("sql: %s", sql)
query_id: str = _start_query_execution(
sql=sql,
wg_config=wg_config,
database=database,
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
boto3_session=boto3_session,
)
try:
query_id: str = _start_query_execution(
sql=sql,
wg_config=wg_config,
database=database,
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
boto3_session=boto3_session,
)
except botocore.exceptions.ClientError as ex:
error: Dict[str, Any] = ex.response['Error']
if error['Code'] == 'InvalidRequestException' and "extraneous input" in error['Message']:
raise exceptions.InvalidCtasApproachQuery("Is not possible to wrap this query into a CTAS statement. "
"Please use ctas_approach=False.")
raise ex
_logger.debug("query_id: %s", query_id)
try:
query_metadata: _QueryMetadata = _get_query_metadata(
Expand Down Expand Up @@ -596,7 +604,7 @@ def read_sql_query(
max_cache_seconds=max_cache_seconds,
max_cache_query_inspections=max_cache_query_inspections,
)
_logger.debug("cache_info: %s", cache_info)
_logger.debug("cache_info:\n%s", cache_info)
if cache_info.has_valid_cache is True:
_logger.debug("Valid cache found. Retrieving...")
try:
Expand Down
4 changes: 3 additions & 1 deletion awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _get_query_metadata(
if "Statistics" in _query_execution_payload:
if "DataManifestLocation" in _query_execution_payload["Statistics"]:
manifest_location = _query_execution_payload["Statistics"]["DataManifestLocation"]
return _QueryMetadata(
query_metadata: _QueryMetadata = _QueryMetadata(
execution_id=query_execution_id,
dtype=dtype,
parse_timestamps=parse_timestamps,
Expand All @@ -174,6 +174,8 @@ def _get_query_metadata(
output_location=output_location,
manifest_location=manifest_location,
)
_logger.debug("query_metadata:\n%s", query_metadata)
return query_metadata


def get_query_columns_types(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, str]:
Expand Down
4 changes: 4 additions & 0 deletions awswrangler/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,7 @@ class InvalidRedshiftPrimaryKeys(Exception):

class InvalidSchemaConvergence(Exception):
"""InvalidSchemaMerge exception."""


class InvalidCtasApproachQuery(Exception):
"""InvalidCtasApproachQuery exception."""
24 changes: 24 additions & 0 deletions tests/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,27 @@ def test_catalog_columns(path, glue_table, glue_database):
ensure_data_types_csv(df2)

assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True


def test_read_sql_query_wo_results(path, glue_database, glue_table):
wr.catalog.create_parquet_table(
database=glue_database,
table=glue_table,
path=path,
columns_types={"c0": "int"}
)
sql = f"ALTER TABLE {glue_database}.{glue_table} SET LOCATION '{path}dir/'"
df = wr.athena.read_sql_query(sql, database=glue_database, ctas_approach=False)
assert df.empty


def test_read_sql_query_wo_results_ctas(path, glue_database, glue_table):
wr.catalog.create_parquet_table(
database=glue_database,
table=glue_table,
path=path,
columns_types={"c0": "int"}
)
sql = f"ALTER TABLE {glue_database}.{glue_table} SET LOCATION '{path}dir/'"
with pytest.raises(wr.exceptions.InvalidCtasApproachQuery):
wr.athena.read_sql_query(sql, database=glue_database, ctas_approach=True)

0 comments on commit 820693d

Please sign in to comment.