diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..597be1d05 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,225 @@ +from datetime import datetime + +import boto3 +import pytest + +import awswrangler as wr + +from ._utils import extract_cloudformation_outputs, get_time_str_with_random_suffix, path_generator + + +@pytest.fixture(scope="session") +def cloudformation_outputs(): + return extract_cloudformation_outputs() + + +@pytest.fixture(scope="session") +def region(cloudformation_outputs): + return cloudformation_outputs["Region"] + + +@pytest.fixture(scope="session") +def bucket(cloudformation_outputs): + return cloudformation_outputs["BucketName"] + + +@pytest.fixture(scope="session") +def glue_database(cloudformation_outputs): + return cloudformation_outputs["GlueDatabaseName"] + + +@pytest.fixture(scope="session") +def kms_key(cloudformation_outputs): + return cloudformation_outputs["KmsKeyArn"] + + +@pytest.fixture(scope="session") +def kms_key_id(kms_key): + return kms_key.split("/", 1)[1] + + +@pytest.fixture(scope="session") +def loggroup(cloudformation_outputs): + loggroup_name = cloudformation_outputs["LogGroupName"] + logstream_name = cloudformation_outputs["LogStream"] + client = boto3.client("logs") + response = client.describe_log_streams(logGroupName=loggroup_name, logStreamNamePrefix=logstream_name) + token = response["logStreams"][0].get("uploadSequenceToken") + events = [] + for i in range(5): + events.append({"timestamp": int(1000 * datetime.now().timestamp()), "message": str(i)}) + args = {"logGroupName": loggroup_name, "logStreamName": logstream_name, "logEvents": events} + if token: + args["sequenceToken"] = token + try: + client.put_log_events(**args) + except client.exceptions.DataAlreadyAcceptedException: + pass # Concurrency + while True: + results = wr.cloudwatch.run_query(log_group_names=[loggroup_name], query="fields @timestamp | limit 5") + if len(results) >= 5: + break + yield loggroup_name + + +@pytest.fixture(scope="session") +def workgroup0(bucket): + wkg_name = "aws_data_wrangler_0" + client = boto3.client("athena") + wkgs = client.list_work_groups() + wkgs = [x["Name"] for x in wkgs["WorkGroups"]] + if wkg_name not in wkgs: + client.create_work_group( + Name=wkg_name, + Configuration={ + "ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup0/"}, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False, + }, + Description="AWS Data Wrangler Test WorkGroup Number 0", + ) + return wkg_name + + +@pytest.fixture(scope="session") +def workgroup1(bucket): + wkg_name = "aws_data_wrangler_1" + client = boto3.client("athena") + wkgs = client.list_work_groups() + wkgs = [x["Name"] for x in wkgs["WorkGroups"]] + if wkg_name not in wkgs: + client.create_work_group( + Name=wkg_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": f"s3://{bucket}/athena_workgroup1/", + "EncryptionConfiguration": {"EncryptionOption": "SSE_S3"}, + }, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False, + }, + Description="AWS Data Wrangler Test WorkGroup Number 1", + ) + return wkg_name + + +@pytest.fixture(scope="session") +def workgroup2(bucket, kms_key): + wkg_name = "aws_data_wrangler_2" + client = boto3.client("athena") + wkgs = client.list_work_groups() + wkgs = [x["Name"] for x in wkgs["WorkGroups"]] + if wkg_name not in wkgs: + client.create_work_group( + Name=wkg_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": f"s3://{bucket}/athena_workgroup2/", + "EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key}, + }, + "EnforceWorkGroupConfiguration": False, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False, + }, + Description="AWS Data Wrangler Test WorkGroup Number 2", + ) + return wkg_name + + +@pytest.fixture(scope="session") +def workgroup3(bucket, kms_key): + wkg_name = "aws_data_wrangler_3" + client = boto3.client("athena") + wkgs = client.list_work_groups() + wkgs = [x["Name"] for x in wkgs["WorkGroups"]] + if wkg_name not in wkgs: + client.create_work_group( + Name=wkg_name, + Configuration={ + "ResultConfiguration": { + "OutputLocation": f"s3://{bucket}/athena_workgroup3/", + "EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key}, + }, + "EnforceWorkGroupConfiguration": True, + "PublishCloudWatchMetricsEnabled": True, + "BytesScannedCutoffPerQuery": 100_000_000, + "RequesterPaysEnabled": False, + }, + Description="AWS Data Wrangler Test WorkGroup Number 3", + ) + return wkg_name + + +@pytest.fixture(scope="session") +def databases_parameters(cloudformation_outputs): + parameters = dict(postgresql={}, mysql={}, redshift={}) + parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] + parameters["postgresql"]["port"] = 3306 + parameters["postgresql"]["schema"] = "public" + parameters["postgresql"]["database"] = "postgres" + parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] + parameters["mysql"]["port"] = 3306 + parameters["mysql"]["schema"] = "test" + parameters["mysql"]["database"] = "test" + parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] + parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] + parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] + parameters["redshift"]["schema"] = "public" + parameters["redshift"]["database"] = "test" + parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] + parameters["password"] = cloudformation_outputs["DatabasesPassword"] + parameters["user"] = "test" + return parameters + + +@pytest.fixture(scope="session") +def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_database): + region = cloudformation_outputs.get("Region") + sql = f""" + CREATE EXTERNAL SCHEMA IF NOT EXISTS aws_data_wrangler_external FROM data catalog + DATABASE '{glue_database}' + IAM_ROLE '{databases_parameters["redshift"]["role"]}' + REGION '{region}'; + """ + engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") + with engine.connect() as con: + con.execute(sql) + return "aws_data_wrangler_external" + + +@pytest.fixture(scope="function") +def glue_table(glue_database): + name = f"tbl_{get_time_str_with_random_suffix()}" + print(f"Table name: {name}") + wr.catalog.delete_table_if_exists(database=glue_database, table=name) + yield name + wr.catalog.delete_table_if_exists(database=glue_database, table=name) + + +@pytest.fixture(scope="function") +def glue_table2(glue_database): + name = f"tbl_{get_time_str_with_random_suffix()}" + print(f"Table name: {name}") + wr.catalog.delete_table_if_exists(database=glue_database, table=name) + yield name + wr.catalog.delete_table_if_exists(database=glue_database, table=name) + + +@pytest.fixture(scope="function") +def path(bucket): + yield from path_generator(bucket) + + +@pytest.fixture(scope="function") +def path2(bucket): + yield from path_generator(bucket) + + +@pytest.fixture(scope="function") +def path3(bucket): + yield from path_generator(bucket) diff --git a/tests/test_athena.py b/tests/test_athena.py index 64dce523b..8ca132800 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -14,7 +14,6 @@ ensure_data_types, ensure_data_types_category, ensure_data_types_csv, - extract_cloudformation_outputs, get_df, get_df_cast, get_df_category, @@ -22,7 +21,6 @@ get_df_list, get_query_long, get_time_str_with_random_suffix, - path_generator, ts, ) @@ -31,153 +29,7 @@ logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="module") -def bucket(cloudformation_outputs): - yield cloudformation_outputs["BucketName"] - - -@pytest.fixture(scope="module") -def database(cloudformation_outputs): - yield cloudformation_outputs["GlueDatabaseName"] - - -@pytest.fixture(scope="module") -def kms_key(cloudformation_outputs): - yield cloudformation_outputs["KmsKeyArn"] - - -@pytest.fixture(scope="module") -def workgroup0(bucket): - wkg_name = "aws_data_wrangler_0" - client = boto3.client("athena") - wkgs = client.list_work_groups() - wkgs = [x["Name"] for x in wkgs["WorkGroups"]] - if wkg_name not in wkgs: - client.create_work_group( - Name=wkg_name, - Configuration={ - "ResultConfiguration": {"OutputLocation": f"s3://{bucket}/athena_workgroup0/"}, - "EnforceWorkGroupConfiguration": True, - "PublishCloudWatchMetricsEnabled": True, - "BytesScannedCutoffPerQuery": 100_000_000, - "RequesterPaysEnabled": False, - }, - Description="AWS Data Wrangler Test WorkGroup Number 0", - ) - yield wkg_name - - -@pytest.fixture(scope="module") -def workgroup1(bucket): - wkg_name = "aws_data_wrangler_1" - client = boto3.client("athena") - wkgs = client.list_work_groups() - wkgs = [x["Name"] for x in wkgs["WorkGroups"]] - if wkg_name not in wkgs: - client.create_work_group( - Name=wkg_name, - Configuration={ - "ResultConfiguration": { - "OutputLocation": f"s3://{bucket}/athena_workgroup1/", - "EncryptionConfiguration": {"EncryptionOption": "SSE_S3"}, - }, - "EnforceWorkGroupConfiguration": True, - "PublishCloudWatchMetricsEnabled": True, - "BytesScannedCutoffPerQuery": 100_000_000, - "RequesterPaysEnabled": False, - }, - Description="AWS Data Wrangler Test WorkGroup Number 1", - ) - yield wkg_name - - -@pytest.fixture(scope="module") -def workgroup2(bucket, kms_key): - wkg_name = "aws_data_wrangler_2" - client = boto3.client("athena") - wkgs = client.list_work_groups() - wkgs = [x["Name"] for x in wkgs["WorkGroups"]] - if wkg_name not in wkgs: - client.create_work_group( - Name=wkg_name, - Configuration={ - "ResultConfiguration": { - "OutputLocation": f"s3://{bucket}/athena_workgroup2/", - "EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key}, - }, - "EnforceWorkGroupConfiguration": False, - "PublishCloudWatchMetricsEnabled": True, - "BytesScannedCutoffPerQuery": 100_000_000, - "RequesterPaysEnabled": False, - }, - Description="AWS Data Wrangler Test WorkGroup Number 2", - ) - yield wkg_name - - -@pytest.fixture(scope="module") -def workgroup3(bucket, kms_key): - wkg_name = "aws_data_wrangler_3" - client = boto3.client("athena") - wkgs = client.list_work_groups() - wkgs = [x["Name"] for x in wkgs["WorkGroups"]] - if wkg_name not in wkgs: - client.create_work_group( - Name=wkg_name, - Configuration={ - "ResultConfiguration": { - "OutputLocation": f"s3://{bucket}/athena_workgroup3/", - "EncryptionConfiguration": {"EncryptionOption": "SSE_KMS", "KmsKey": kms_key}, - }, - "EnforceWorkGroupConfiguration": True, - "PublishCloudWatchMetricsEnabled": True, - "BytesScannedCutoffPerQuery": 100_000_000, - "RequesterPaysEnabled": False, - }, - Description="AWS Data Wrangler Test WorkGroup Number 3", - ) - yield wkg_name - - -@pytest.fixture(scope="function") -def table(database): - name = f"tbl_{get_time_str_with_random_suffix()}" - print(f"Table name: {name}") - wr.catalog.delete_table_if_exists(database=database, table=name) - yield name - wr.catalog.delete_table_if_exists(database=database, table=name) - - -@pytest.fixture(scope="function") -def table2(database): - name = f"tbl_{get_time_str_with_random_suffix()}" - print(f"Table name: {name}") - wr.catalog.delete_table_if_exists(database=database, table=name) - yield name - wr.catalog.delete_table_if_exists(database=database, table=name) - - -@pytest.fixture(scope="function") -def path(bucket): - yield from path_generator(bucket) - - -@pytest.fixture(scope="function") -def path2(bucket): - yield from path_generator(bucket) - - -@pytest.fixture(scope="function") -def path3(bucket): - yield from path_generator(bucket) - - -def test_to_parquet_modes(database, table, path): +def test_to_parquet_modes(glue_database, glue_table, path): # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -186,22 +38,22 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c0" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "0" @@ -212,22 +64,22 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c1": "1"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -238,23 +90,23 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, columns_comments={"c1": "1"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df.columns) == len(df2.columns) assert len(df.index) * 2 == len(df2.index) assert df.c1.sum() + 1 == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -265,23 +117,23 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, columns_comments={"c1": "1", "c2": "2"}, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 2 assert len(df2.index) == 9 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "9" - assert wr.catalog.get_table_description(database, table) == "c1+c2" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" assert comments["c2"] == "2" @@ -293,23 +145,23 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1+c2+c3", parameters={"num_cols": "3", "num_rows": "10"}, columns_comments={"c1": "1!", "c2": "2!", "c3": "3"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 3 assert len(df2.index) == 10 assert df2.c1.sum() == 4 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "10" - assert wr.catalog.get_table_description(database, table) == "c1+c2+c3" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2+c3" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1!" assert comments["c2"] == "2!" @@ -322,23 +174,23 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "2"}, columns_comments={"c0": "zero", "c1": "one"}, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "2" - assert wr.catalog.get_table_description(database, table) == "c0+c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -350,24 +202,24 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="overwrite_partitions", - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1"], description="c0+c1", parameters={"num_cols": "2", "num_rows": "3"}, columns_comments={"c0": "zero", "c1": "one"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 2 assert len(df2.index) == 3 assert df2.c1.sum() == 3 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "3" - assert wr.catalog.get_table_description(database, table) == "c0+c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -379,31 +231,31 @@ def test_to_parquet_modes(database, table, path): path=path, dataset=True, mode="overwrite_partitions", - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1"], description="c0+c1+c2", parameters={"num_cols": "3", "num_rows": "4"}, columns_comments={"c0": "zero", "c1": "one", "c2": "two"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 3 assert len(df2.index) == 4 assert df2.c1.sum() == 6 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "4" - assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" -def test_store_parquet_metadata_modes(database, table, path): +def test_store_parquet_metadata_modes(glue_database, glue_table, path): # Round 1 - Warm up df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") @@ -413,21 +265,21 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c0" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "0" @@ -439,21 +291,21 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c1": "1"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -465,22 +317,22 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index) * 2)}, columns_comments={"c1": "1"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df.columns) == len(df2.columns) assert len(df.index) * 2 == len(df2.index) assert df.c1.sum() + 1 == df2.c1.sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == str(len(df2.columns)) assert parameters["num_rows"] == str(len(df2.index)) - assert wr.catalog.get_table_description(database, table) == "c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" @@ -493,22 +345,22 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c1+c2", parameters={"num_cols": "2", "num_rows": "9"}, columns_comments={"c1": "1", "c2": "2"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 2 assert len(df2.index) == 9 assert df2.c1.sum() == 4 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "9" - assert wr.catalog.get_table_description(database, table) == "c1+c2" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c1+c2" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c1"] == "1" assert comments["c2"] == "2" @@ -521,21 +373,21 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0+c1", parameters={"num_cols": "2", "num_rows": "2"}, columns_comments={"c0": "zero", "c1": "one"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c1.sum() == df2.c1.astype(int).sum() - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "2" - assert wr.catalog.get_table_description(database, table) == "c0+c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -550,22 +402,22 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0+c1", parameters={"num_cols": "2", "num_rows": "3"}, columns_comments={"c0": "zero", "c1": "one"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 2 assert len(df2.index) == 3 assert df2.c1.astype(int).sum() == 3 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "2" assert parameters["num_rows"] == "3" - assert wr.catalog.get_table_description(database, table) == "c0+c1" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" @@ -580,29 +432,29 @@ def test_store_parquet_metadata_modes(database, table, path): path=path, dataset=True, mode="append", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0+c1+c2", parameters={"num_cols": "3", "num_rows": "4"}, columns_comments={"c0": "zero", "c1": "one", "c2": "two"}, ) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.columns) == 3 assert len(df2.index) == 4 assert df2.c1.astype(int).sum() == 6 - parameters = wr.catalog.get_table_parameters(database, table) + parameters = wr.catalog.get_table_parameters(glue_database, glue_table) assert len(parameters) >= 5 assert parameters["num_cols"] == "3" assert parameters["num_rows"] == "4" - assert wr.catalog.get_table_description(database, table) == "c0+c1+c2" - comments = wr.catalog.get_columns_comments(database, table) + assert wr.catalog.get_table_description(glue_database, glue_table) == "c0+c1+c2" + comments = wr.catalog.get_columns_comments(glue_database, glue_table) assert len(comments) == len(df.columns) assert comments["c0"] == "zero" assert comments["c1"] == "one" assert comments["c2"] == "two" -def test_athena_ctas(path, path2, path3, table, table2, database, kms_key): +def test_athena_ctas(path, path2, path3, glue_table, glue_table2, glue_database, kms_key): df = get_df_list() columns_types, partitions_types = wr.catalog.extract_athena_types(df=df, partition_cols=["par0", "par1"]) assert len(columns_types) == 17 @@ -616,20 +468,20 @@ def test_athena_ctas(path, path2, path3, table, table2, database, kms_key): use_threads=True, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["par0", "par1"], )["paths"] wr.s3.wait_objects_exist(paths=paths) dirs = wr.s3.list_directories(path=path) for d in dirs: assert d.startswith(f"{path}par0=") - df = wr.s3.read_parquet_table(table=table, database=database) + df = wr.s3.read_parquet_table(table=glue_table, database=glue_database) assert len(df.index) == 3 ensure_data_types(df=df, has_list=True) df = wr.athena.read_sql_table( - table=table, - database=database, + table=glue_table, + database=glue_database, ctas_approach=True, encryption="SSE_KMS", kms_key=kms_key, @@ -638,20 +490,20 @@ def test_athena_ctas(path, path2, path3, table, table2, database, kms_key): ) assert len(df.index) == 3 ensure_data_types(df=df, has_list=True) - final_destination = f"{path3}{table2}/" + final_destination = f"{path3}{glue_table2}/" # keep_files=False wr.s3.delete_objects(path=path3) dfs = wr.athena.read_sql_query( - sql=f"SELECT * FROM {table}", - database=database, + sql=f"SELECT * FROM {glue_table}", + database=glue_database, ctas_approach=True, chunksize=1, keep_files=False, - ctas_temp_table_name=table2, + ctas_temp_table_name=glue_table2, s3_output=path3, ) - assert wr.catalog.does_table_exist(database=database, table=table2) is False + assert wr.catalog.does_table_exist(database=glue_database, table=glue_table2) is False assert len(wr.s3.list_objects(path=path3)) > 2 assert len(wr.s3.list_objects(path=final_destination)) > 0 for df in dfs: @@ -661,15 +513,15 @@ def test_athena_ctas(path, path2, path3, table, table2, database, kms_key): # keep_files=True wr.s3.delete_objects(path=path3) dfs = wr.athena.read_sql_query( - sql=f"SELECT * FROM {table}", - database=database, + sql=f"SELECT * FROM {glue_table}", + database=glue_database, ctas_approach=True, chunksize=2, keep_files=True, - ctas_temp_table_name=table2, + ctas_temp_table_name=glue_table2, s3_output=path3, ) - assert wr.catalog.does_table_exist(database=database, table=table2) is False + assert wr.catalog.does_table_exist(database=glue_database, table=glue_table2) is False assert len(wr.s3.list_objects(path=path3)) > 2 assert len(wr.s3.list_objects(path=final_destination)) > 0 for df in dfs: @@ -677,8 +529,8 @@ def test_athena_ctas(path, path2, path3, table, table2, database, kms_key): assert len(wr.s3.list_objects(path=path3)) > 2 -def test_athena(path, database, kms_key, workgroup0, workgroup1): - wr.catalog.delete_table_if_exists(database=database, table="__test_athena") +def test_athena(path, glue_database, kms_key, workgroup0, workgroup1): + wr.catalog.delete_table_if_exists(database=glue_database, table="__test_athena") paths = wr.s3.to_parquet( df=get_df(), path=path, @@ -686,14 +538,14 @@ def test_athena(path, database, kms_key, workgroup0, workgroup1): use_threads=True, dataset=True, mode="overwrite", - database=database, + database=glue_database, table="__test_athena", partition_cols=["par0", "par1"], )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) dfs = wr.athena.read_sql_query( sql="SELECT * FROM __test_athena", - database=database, + database=glue_database, ctas_approach=False, chunksize=1, encryption="SSE_KMS", @@ -705,24 +557,24 @@ def test_athena(path, database, kms_key, workgroup0, workgroup1): ensure_data_types(df=df2) df = wr.athena.read_sql_query( sql="SELECT * FROM __test_athena", - database=database, + database=glue_database, ctas_approach=False, workgroup=workgroup1, keep_files=False, ) assert len(df.index) == 3 ensure_data_types(df=df) - wr.athena.repair_table(table="__test_athena", database=database) - wr.catalog.delete_table_if_exists(database=database, table="__test_athena") + wr.athena.repair_table(table="__test_athena", database=glue_database) + wr.catalog.delete_table_if_exists(database=glue_database, table="__test_athena") -def test_parquet_catalog(bucket, database): +def test_parquet_catalog(bucket, glue_database): with pytest.raises(wr.exceptions.UndetectedType): wr.s3.to_parquet( df=pd.DataFrame({"A": [None]}), path=f"s3://{bucket}/test_parquet_catalog", dataset=True, - database=database, + database=glue_database, table="test_parquet_catalog", ) df = get_df_list() @@ -733,7 +585,7 @@ def test_parquet_catalog(bucket, database): use_threads=True, dataset=False, mode="overwrite", - database=database, + database=glue_database, table="test_parquet_catalog", ) with pytest.raises(wr.exceptions.InvalidArgumentCombination): @@ -751,7 +603,7 @@ def test_parquet_catalog(bucket, database): use_threads=True, dataset=True, mode="overwrite", - database=database, + database=glue_database, ) wr.s3.to_parquet( df=df, @@ -759,7 +611,7 @@ def test_parquet_catalog(bucket, database): use_threads=True, dataset=True, mode="overwrite", - database=database, + database=glue_database, table="test_parquet_catalog", ) wr.s3.to_parquet( @@ -769,7 +621,7 @@ def test_parquet_catalog(bucket, database): use_threads=True, dataset=True, mode="overwrite", - database=database, + database=glue_database, table="test_parquet_catalog2", partition_cols=["iint8", "iint16"], ) @@ -779,32 +631,34 @@ def test_parquet_catalog(bucket, database): assert len(columns_types) == 18 assert len(partitions_types) == 2 columns_types, partitions_types, partitions_values = wr.s3.store_parquet_metadata( - path=f"s3://{bucket}/test_parquet_catalog2", database=database, table="test_parquet_catalog2", dataset=True + path=f"s3://{bucket}/test_parquet_catalog2", database=glue_database, table="test_parquet_catalog2", dataset=True ) assert len(columns_types) == 18 assert len(partitions_types) == 2 assert len(partitions_values) == 2 wr.s3.delete_objects(path=f"s3://{bucket}/test_parquet_catalog/") wr.s3.delete_objects(path=f"s3://{bucket}/test_parquet_catalog2/") - assert wr.catalog.delete_table_if_exists(database=database, table="test_parquet_catalog") is True - assert wr.catalog.delete_table_if_exists(database=database, table="test_parquet_catalog2") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_parquet_catalog") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_parquet_catalog2") is True -def test_parquet_catalog_duplicated(path, table, database): +def test_parquet_catalog_duplicated(path, glue_table, glue_database): df = pd.DataFrame({"A": [1], "a": [1]}) - wr.s3.to_parquet(df=df, path=path, index=False, dataset=True, mode="overwrite", database=database, table=table) + wr.s3.to_parquet( + df=df, path=path, index=False, dataset=True, mode="overwrite", database=glue_database, table=glue_table + ) df = wr.s3.read_parquet(path=path) assert df.shape == (1, 1) -def test_parquet_catalog_casting(path, database): +def test_parquet_catalog_casting(path, glue_database): paths = wr.s3.to_parquet( df=get_df_cast(), path=path, index=False, dataset=True, mode="overwrite", - database=database, + database=glue_database, table="__test_parquet_catalog_casting", dtype={ "iint8": "tinyint", @@ -828,22 +682,22 @@ def test_parquet_catalog_casting(path, database): df = wr.s3.read_parquet(path=path) assert df.shape == (3, 16) ensure_data_types(df=df, has_list=False) - df = wr.athena.read_sql_table(table="__test_parquet_catalog_casting", database=database, ctas_approach=True) + df = wr.athena.read_sql_table(table="__test_parquet_catalog_casting", database=glue_database, ctas_approach=True) assert df.shape == (3, 16) ensure_data_types(df=df, has_list=False) - df = wr.athena.read_sql_table(table="__test_parquet_catalog_casting", database=database, ctas_approach=False) + df = wr.athena.read_sql_table(table="__test_parquet_catalog_casting", database=glue_database, ctas_approach=False) assert df.shape == (3, 16) ensure_data_types(df=df, has_list=False) wr.s3.delete_objects(path=path) - assert wr.catalog.delete_table_if_exists(database=database, table="__test_parquet_catalog_casting") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="__test_parquet_catalog_casting") is True -def test_catalog(path, database, table): +def test_catalog(path, glue_database, glue_table): account_id = boto3.client("sts").get_caller_identity().get("Account") - assert wr.catalog.does_table_exist(database=database, table=table) is False + assert wr.catalog.does_table_exist(database=glue_database, table=glue_table) is False wr.catalog.create_parquet_table( - database=database, - table=table, + database=glue_database, + table=glue_table, path=path, columns_types={"col0": "int", "col1": "double"}, partitions_types={"y": "int", "m": "int"}, @@ -851,14 +705,14 @@ def test_catalog(path, database, table): ) with pytest.raises(wr.exceptions.InvalidArgumentValue): wr.catalog.create_parquet_table( - database=database, table=table, path=path, columns_types={"col0": "string"}, mode="append" + database=glue_database, table=glue_table, path=path, columns_types={"col0": "string"}, mode="append" ) - assert wr.catalog.does_table_exist(database=database, table=table) is True - assert wr.catalog.delete_table_if_exists(database=database, table=table) is True - assert wr.catalog.delete_table_if_exists(database=database, table=table) is False + assert wr.catalog.does_table_exist(database=glue_database, table=glue_table) is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is False wr.catalog.create_parquet_table( - database=database, - table=table, + database=glue_database, + table=glue_table, path=path, columns_types={"col0": "int", "col1": "double"}, partitions_types={"y": "int", "m": "int"}, @@ -869,72 +723,75 @@ def test_catalog(path, database, table): mode="overwrite", ) wr.catalog.add_parquet_partitions( - database=database, - table=table, + database=glue_database, + table=glue_table, partitions_values={f"{path}y=2020/m=1/": ["2020", "1"], f"{path}y=2021/m=2/": ["2021", "2"]}, compression="snappy", ) - assert wr.catalog.get_table_location(database=database, table=table) == path - partitions_values = wr.catalog.get_parquet_partitions(database=database, table=table) + assert wr.catalog.get_table_location(database=glue_database, table=glue_table) == path + partitions_values = wr.catalog.get_parquet_partitions(database=glue_database, table=glue_table) assert len(partitions_values) == 2 partitions_values = wr.catalog.get_parquet_partitions( - database=database, table=table, catalog_id=account_id, expression="y = 2021 AND m = 2" + database=glue_database, table=glue_table, catalog_id=account_id, expression="y = 2021 AND m = 2" ) assert len(partitions_values) == 1 assert len(set(partitions_values[f"{path}y=2021/m=2/"]) & {"2021", "2"}) == 2 - dtypes = wr.catalog.get_table_types(database=database, table=table) + dtypes = wr.catalog.get_table_types(database=glue_database, table=glue_table) assert dtypes["col0"] == "int" assert dtypes["col1"] == "double" assert dtypes["y"] == "int" assert dtypes["m"] == "int" df_dbs = wr.catalog.databases() assert len(wr.catalog.databases(catalog_id=account_id)) == len(df_dbs) - assert database in df_dbs["Database"].to_list() + assert glue_database in df_dbs["Database"].to_list() tables = list(wr.catalog.get_tables()) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" - tables = list(wr.catalog.get_tables(database=database)) + tables = list(wr.catalog.get_tables(database=glue_database)) assert len(tables) > 0 for tbl in tables: - assert tbl["DatabaseName"] == database + assert tbl["DatabaseName"] == glue_database # search tables = list(wr.catalog.search_tables(text="parquet", catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" # prefix - tables = list(wr.catalog.get_tables(name_prefix=table[:4], catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_prefix=glue_table[:4], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" # suffix - tables = list(wr.catalog.get_tables(name_suffix=table[-4:], catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_suffix=glue_table[-4:], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" # name_contains - tables = list(wr.catalog.get_tables(name_contains=table[4:-4], catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_contains=glue_table[4:-4], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" # prefix & suffix & name_contains with pytest.raises(wr.exceptions.InvalidArgumentCombination): list( wr.catalog.get_tables( - name_prefix=table[0], name_contains=table[3], name_suffix=table[-1], catalog_id=account_id + name_prefix=glue_table[0], + name_contains=glue_table[3], + name_suffix=glue_table[-1], + catalog_id=account_id, ) ) # prefix & suffix - tables = list(wr.catalog.get_tables(name_prefix=table[0], name_suffix=table[-1], catalog_id=account_id)) + tables = list(wr.catalog.get_tables(name_prefix=glue_table[0], name_suffix=glue_table[-1], catalog_id=account_id)) assert len(tables) > 0 for tbl in tables: - if tbl["Name"] == table: + if tbl["Name"] == glue_table: assert tbl["TableType"] == "EXTERNAL_TABLE" # DataFrames assert len(wr.catalog.databases().index) > 0 @@ -942,47 +799,49 @@ def test_catalog(path, database, table): assert ( len( wr.catalog.tables( - database=database, + database=glue_database, search_text="parquet", - name_prefix=table[0], - name_contains=table[3], - name_suffix=table[-1], + name_prefix=glue_table[0], + name_contains=glue_table[3], + name_suffix=glue_table[-1], catalog_id=account_id, ).index ) > 0 ) - assert len(wr.catalog.table(database=database, table=table).index) > 0 - assert len(wr.catalog.table(database=database, table=table, catalog_id=account_id).index) > 0 + assert len(wr.catalog.table(database=glue_database, table=glue_table).index) > 0 + assert len(wr.catalog.table(database=glue_database, table=glue_table, catalog_id=account_id).index) > 0 with pytest.raises(wr.exceptions.InvalidTable): - wr.catalog.overwrite_table_parameters({"foo": "boo"}, database, "fake_table") + wr.catalog.overwrite_table_parameters({"foo": "boo"}, glue_database, "fake_table") -def test_catalog_get_databases(database): +def test_catalog_get_databases(glue_database): dbs = list(wr.catalog.get_databases()) assert len(dbs) > 0 for db in dbs: - if db["Name"] == database: + if db["Name"] == glue_database: assert db["Description"] == "AWS Data Wrangler Test Arena - Glue Database" -def test_athena_query_cancelled(database): +def test_athena_query_cancelled(glue_database): session = boto3.DEFAULT_SESSION - query_execution_id = wr.athena.start_query_execution(sql=get_query_long(), database=database, boto3_session=session) + query_execution_id = wr.athena.start_query_execution( + sql=get_query_long(), database=glue_database, boto3_session=session + ) wr.athena.stop_query_execution(query_execution_id=query_execution_id, boto3_session=session) with pytest.raises(wr.exceptions.QueryCancelled): assert wr.athena.wait_query(query_execution_id=query_execution_id) -def test_athena_query_failed(database): - query_execution_id = wr.athena.start_query_execution(sql="SELECT random(-1)", database=database) +def test_athena_query_failed(glue_database): + query_execution_id = wr.athena.start_query_execution(sql="SELECT random(-1)", database=glue_database) with pytest.raises(wr.exceptions.QueryFailed): assert wr.athena.wait_query(query_execution_id=query_execution_id) -def test_athena_read_list(database): +def test_athena_read_list(glue_database): with pytest.raises(wr.exceptions.UnsupportedType): - wr.athena.read_sql_query(sql="SELECT ARRAY[1, 2, 3]", database=database, ctas_approach=False) + wr.athena.read_sql_query(sql="SELECT ARRAY[1, 2, 3]", database=glue_database, ctas_approach=False) def test_sanitize_names(): @@ -1006,7 +865,7 @@ def test_sanitize_names(): assert wr.catalog.sanitize_table_name("xyz_Cd") == "xyz_cd" -def test_athena_ctas_empty(database): +def test_athena_ctas_empty(glue_database): sql = """ WITH dataset AS ( SELECT 0 AS id @@ -1015,21 +874,21 @@ def test_athena_ctas_empty(database): FROM dataset WHERE id != 0 """ - assert wr.athena.read_sql_query(sql=sql, database=database).empty is True - assert len(list(wr.athena.read_sql_query(sql=sql, database=database, chunksize=1))) == 0 + assert wr.athena.read_sql_query(sql=sql, database=glue_database).empty is True + assert len(list(wr.athena.read_sql_query(sql=sql, database=glue_database, chunksize=1))) == 0 -def test_athena_struct(database): +def test_athena_struct(glue_database): sql = "SELECT CAST(ROW(1, 'foo') AS ROW(id BIGINT, value VARCHAR)) AS col0" with pytest.raises(wr.exceptions.UnsupportedType): - wr.athena.read_sql_query(sql=sql, database=database, ctas_approach=False) - df = wr.athena.read_sql_query(sql=sql, database=database, ctas_approach=True) + wr.athena.read_sql_query(sql=sql, database=glue_database, ctas_approach=False) + df = wr.athena.read_sql_query(sql=sql, database=glue_database, ctas_approach=True) assert len(df.index) == 1 assert len(df.columns) == 1 assert df["col0"].iloc[0]["id"] == 1 assert df["col0"].iloc[0]["value"] == "foo" sql = "SELECT ROW(1, ROW(2, ROW(3, '4'))) AS col0" - df = wr.athena.read_sql_query(sql=sql, database=database, ctas_approach=True) + df = wr.athena.read_sql_query(sql=sql, database=glue_database, ctas_approach=True) assert len(df.index) == 1 assert len(df.columns) == 1 assert df["col0"].iloc[0]["field0"] == 1 @@ -1038,23 +897,23 @@ def test_athena_struct(database): assert df["col0"].iloc[0]["field1"]["field1"]["field1"] == "4" -def test_athena_time_zone(database): +def test_athena_time_zone(glue_database): sql = "SELECT current_timestamp AS value, typeof(current_timestamp) AS type" - df = wr.athena.read_sql_query(sql=sql, database=database, ctas_approach=False) + df = wr.athena.read_sql_query(sql=sql, database=glue_database, ctas_approach=False) assert len(df.index) == 1 assert len(df.columns) == 2 assert df["type"][0] == "timestamp with time zone" assert df["value"][0].year == datetime.datetime.utcnow().year -def test_category(bucket, database): +def test_category(bucket, glue_database): df = get_df_category() path = f"s3://{bucket}/test_category/" paths = wr.s3.to_parquet( df=df, path=path, dataset=True, - database=database, + database=glue_database, table="test_category", mode="overwrite", partition_cols=["par0", "par1"], @@ -1062,38 +921,46 @@ def test_category(bucket, database): wr.s3.wait_objects_exist(paths=paths, use_threads=False) df2 = wr.s3.read_parquet(path=path, dataset=True, categories=[c for c in df.columns if c not in ["par0", "par1"]]) ensure_data_types_category(df2) - df2 = wr.athena.read_sql_query("SELECT * FROM test_category", database=database, categories=list(df.columns)) + df2 = wr.athena.read_sql_query("SELECT * FROM test_category", database=glue_database, categories=list(df.columns)) ensure_data_types_category(df2) - df2 = wr.athena.read_sql_table(table="test_category", database=database, categories=list(df.columns)) + df2 = wr.athena.read_sql_table(table="test_category", database=glue_database, categories=list(df.columns)) ensure_data_types_category(df2) df2 = wr.athena.read_sql_query( - "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False + "SELECT * FROM test_category", database=glue_database, categories=list(df.columns), ctas_approach=False ) ensure_data_types_category(df2) dfs = wr.athena.read_sql_query( - "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=False, chunksize=1 + "SELECT * FROM test_category", + database=glue_database, + categories=list(df.columns), + ctas_approach=False, + chunksize=1, ) for df2 in dfs: ensure_data_types_category(df2) dfs = wr.athena.read_sql_query( - "SELECT * FROM test_category", database=database, categories=list(df.columns), ctas_approach=True, chunksize=1 + "SELECT * FROM test_category", + database=glue_database, + categories=list(df.columns), + ctas_approach=True, + chunksize=1, ) for df2 in dfs: ensure_data_types_category(df2) wr.s3.delete_objects(path=paths) - assert wr.catalog.delete_table_if_exists(database=database, table="test_category") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_category") is True -def test_csv_dataset(path, database): +def test_csv_dataset(path, glue_database): with pytest.raises(wr.exceptions.UndetectedType): - wr.s3.to_csv(pd.DataFrame({"A": [None]}), path, dataset=True, database=database, table="test_csv_dataset") + wr.s3.to_csv(pd.DataFrame({"A": [None]}), path, dataset=True, database=glue_database, table="test_csv_dataset") df = get_df_csv() with pytest.raises(wr.exceptions.InvalidArgumentCombination): - wr.s3.to_csv(df, path, dataset=False, mode="overwrite", database=database, table="test_csv_dataset") + wr.s3.to_csv(df, path, dataset=False, mode="overwrite", database=glue_database, table="test_csv_dataset") with pytest.raises(wr.exceptions.InvalidArgumentCombination): wr.s3.to_csv(df, path, dataset=False, table="test_csv_dataset") with pytest.raises(wr.exceptions.InvalidArgumentCombination): - wr.s3.to_csv(df, path, dataset=True, mode="overwrite", database=database) + wr.s3.to_csv(df, path, dataset=True, mode="overwrite", database=glue_database) with pytest.raises(wr.exceptions.InvalidArgumentCombination): wr.s3.to_csv(df=df, path=path, mode="append") with pytest.raises(wr.exceptions.InvalidArgumentCombination): @@ -1122,7 +989,7 @@ def test_csv_dataset(path, database): wr.s3.delete_objects(path=paths) -def test_csv_catalog(path, table, database): +def test_csv_catalog(path, glue_table, glue_database): df = get_df_csv() paths = wr.s3.to_csv( df=df, @@ -1135,20 +1002,20 @@ def test_csv_catalog(path, table, database): dataset=True, partition_cols=["par0", "par1"], mode="overwrite", - table=table, - database=database, + table=glue_table, + database=glue_database, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 11 assert df2["id"].sum() == 6 ensure_data_types_csv(df2) wr.s3.delete_objects(path=paths) - assert wr.catalog.delete_table_if_exists(database=database, table=table) is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True -def test_csv_catalog_columns(bucket, database): +def test_csv_catalog_columns(bucket, glue_database): path = f"s3://{bucket}/test_csv_catalog_columns /" paths = wr.s3.to_csv( df=get_df_csv(), @@ -1163,10 +1030,10 @@ def test_csv_catalog_columns(bucket, database): partition_cols=["par0", "par1"], mode="overwrite", table="test_csv_catalog_columns", - database=database, + database=glue_database, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table("test_csv_catalog_columns", database) + df2 = wr.athena.read_sql_table("test_csv_catalog_columns", glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 5 assert df2["id"].sum() == 6 @@ -1184,20 +1051,20 @@ def test_csv_catalog_columns(bucket, database): partition_cols=["par0", "par1"], mode="overwrite_partitions", table="test_csv_catalog_columns", - database=database, + database=glue_database, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table("test_csv_catalog_columns", database) + df2 = wr.athena.read_sql_table("test_csv_catalog_columns", glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 5 assert df2["id"].sum() == 9 ensure_data_types_csv(df2) wr.s3.delete_objects(path=path) - assert wr.catalog.delete_table_if_exists(database=database, table="test_csv_catalog_columns") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_csv_catalog_columns") is True -def test_athena_types(bucket, database): +def test_athena_types(bucket, glue_database): path = f"s3://{bucket}/test_athena_types/" df = get_df_csv() paths = wr.s3.to_csv( @@ -1218,26 +1085,26 @@ def test_athena_types(bucket, database): ) wr.catalog.create_csv_table( table="test_athena_types", - database=database, + database=glue_database, path=path, partitions_types=partitions_types, columns_types=columns_types, ) wr.catalog.create_csv_table( - database=database, table="test_athena_types", path=path, columns_types={"col0": "string"}, mode="append" + database=glue_database, table="test_athena_types", path=path, columns_types={"col0": "string"}, mode="append" ) - wr.athena.repair_table("test_athena_types", database) - assert len(wr.catalog.get_csv_partitions(database, "test_athena_types")) == 3 - df2 = wr.athena.read_sql_table("test_athena_types", database) + wr.athena.repair_table("test_athena_types", glue_database) + assert len(wr.catalog.get_csv_partitions(glue_database, "test_athena_types")) == 3 + df2 = wr.athena.read_sql_table("test_athena_types", glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 10 assert df2["id"].sum() == 6 ensure_data_types_csv(df2) wr.s3.delete_objects(path=paths) - assert wr.catalog.delete_table_if_exists(database=database, table="test_athena_types") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_athena_types") is True -def test_parquet_catalog_columns(bucket, database): +def test_parquet_catalog_columns(bucket, glue_database): path = f"s3://{bucket}/test_parquet_catalog_columns/" paths = wr.s3.to_parquet( df=get_df_csv()[["id", "date", "timestamp", "par0", "par1"]], @@ -1250,10 +1117,10 @@ def test_parquet_catalog_columns(bucket, database): partition_cols=["par0", "par1"], mode="overwrite", table="test_parquet_catalog_columns", - database=database, + database=glue_database, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", database) + df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 5 assert df2["id"].sum() == 6 @@ -1270,41 +1137,43 @@ def test_parquet_catalog_columns(bucket, database): partition_cols=["par0", "par1"], mode="overwrite_partitions", table="test_parquet_catalog_columns", - database=database, + database=glue_database, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", database) + df2 = wr.athena.read_sql_table("test_parquet_catalog_columns", glue_database) assert len(df2.index) == 3 assert len(df2.columns) == 5 assert df2["id"].sum() == 9 ensure_data_types_csv(df2) wr.s3.delete_objects(path=path) - assert wr.catalog.delete_table_if_exists(database=database, table="test_parquet_catalog_columns") is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_parquet_catalog_columns") is True @pytest.mark.parametrize("compression", [None, "gzip", "snappy"]) -def test_parquet_compress(bucket, database, compression): +def test_parquet_compress(bucket, glue_database, compression): path = f"s3://{bucket}/test_parquet_compress_{compression}/" paths = wr.s3.to_parquet( df=get_df(), path=path, compression=compression, dataset=True, - database=database, + database=glue_database, table=f"test_parquet_compress_{compression}", mode="overwrite", )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table(f"test_parquet_compress_{compression}", database) + df2 = wr.athena.read_sql_table(f"test_parquet_compress_{compression}", glue_database) ensure_data_types(df2) df2 = wr.s3.read_parquet(path=path) wr.s3.delete_objects(path=path) - assert wr.catalog.delete_table_if_exists(database=database, table=f"test_parquet_compress_{compression}") is True + assert ( + wr.catalog.delete_table_if_exists(database=glue_database, table=f"test_parquet_compress_{compression}") is True + ) ensure_data_types(df2) -def test_parquet_char_length(path, database, table): +def test_parquet_char_length(path, glue_database, glue_table): df = pd.DataFrame( {"id": [1, 2], "cchar": ["foo", "boo"], "date": [datetime.date(2020, 1, 1), datetime.date(2020, 1, 2)]} ) @@ -1312,8 +1181,8 @@ def test_parquet_char_length(path, database, table): df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, mode="overwrite", partition_cols=["date"], dtype={"cchar": "char(3)"}, @@ -1324,7 +1193,7 @@ def test_parquet_char_length(path, database, table): assert len(df2.columns) == 3 assert df2.id.sum() == 3 - df2 = wr.athena.read_sql_table(table=table, database=database) + df2 = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df2.index) == 2 assert len(df2.columns) == 3 assert df2.id.sum() == 3 @@ -1332,14 +1201,19 @@ def test_parquet_char_length(path, database, table): @pytest.mark.parametrize("col2", [[1, 1, 1, 1, 1], [1, 2, 3, 4, 5], [1, 1, 1, 1, 2], [1, 2, 2, 2, 2]]) @pytest.mark.parametrize("chunked", [True, 1, 2, 100]) -def test_parquet_chunked(bucket, database, col2, chunked): - table = f"test_parquet_chunked_{chunked}_{''.join([str(x) for x in col2])}" - path = f"s3://{bucket}/{table}/" +def test_parquet_chunked(path, glue_database, glue_table, col2, chunked): wr.s3.delete_objects(path=path) values = list(range(5)) df = pd.DataFrame({"col1": values, "col2": col2}) paths = wr.s3.to_parquet( - df, path, index=False, dataset=True, database=database, table=table, partition_cols=["col2"], mode="overwrite" + df, + path, + index=False, + dataset=True, + database=glue_database, + table=glue_table, + partition_cols=["col2"], + mode="overwrite", )["paths"] wr.s3.wait_objects_exist(paths=paths) @@ -1353,7 +1227,7 @@ def test_parquet_chunked(bucket, database, col2, chunked): else: assert len(dfs) == len(set(col2)) - dfs = list(wr.athena.read_sql_table(database=database, table=table, chunksize=chunked)) + dfs = list(wr.athena.read_sql_table(database=glue_database, table=glue_table, chunksize=chunked)) assert sum(values) == pd.concat(dfs, ignore_index=True).col1.sum() if chunked is not True: assert len(dfs) == int(math.ceil(len(df) / chunked)) @@ -1362,15 +1236,24 @@ def test_parquet_chunked(bucket, database, col2, chunked): assert chunked >= len(dfs[-1]) wr.s3.delete_objects(path=paths) - assert wr.catalog.delete_table_if_exists(database=database, table=table) is True + assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True @pytest.mark.parametrize("workgroup", [None, 0, 1, 2, 3]) @pytest.mark.parametrize("encryption", [None, "SSE_S3", "SSE_KMS"]) -# @pytest.mark.parametrize("workgroup", [3]) -# @pytest.mark.parametrize("encryption", [None]) def test_athena_encryption( - path, path2, database, table, table2, kms_key, encryption, workgroup, workgroup0, workgroup1, workgroup2, workgroup3 + path, + path2, + glue_database, + glue_table, + glue_table2, + kms_key, + encryption, + workgroup, + workgroup0, + workgroup1, + workgroup2, + workgroup3, ): kms_key = None if (encryption == "SSE_S3") or (encryption is None) else kms_key if workgroup == 0: @@ -1383,26 +1266,32 @@ def test_athena_encryption( workgroup = workgroup3 df = pd.DataFrame({"a": [1, 2], "b": ["foo", "boo"]}) paths = wr.s3.to_parquet( - df=df, path=path, dataset=True, mode="overwrite", database=database, table=table, s3_additional_kwargs=None + df=df, + path=path, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + s3_additional_kwargs=None, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) df2 = wr.athena.read_sql_table( - table=table, + table=glue_table, ctas_approach=True, - database=database, + database=glue_database, encryption=encryption, workgroup=workgroup, kms_key=kms_key, keep_files=True, - ctas_temp_table_name=table2, + ctas_temp_table_name=glue_table2, s3_output=path2, ) - assert wr.catalog.does_table_exist(database=database, table=table2) is False + assert wr.catalog.does_table_exist(database=glue_database, table=glue_table2) is False assert len(df2.index) == 2 assert len(df2.columns) == 2 -def test_athena_nested(path, database, table): +def test_athena_nested(path, glue_database, glue_table): df = pd.DataFrame( { "c0": [[1, 2, 3], [4, 5, 6]], @@ -1414,25 +1303,32 @@ def test_athena_nested(path, database, table): } ) paths = wr.s3.to_parquet( - df=df, path=path, index=False, use_threads=True, dataset=True, mode="overwrite", database=database, table=table + df=df, + path=path, + index=False, + use_threads=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database) + df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {glue_table}", database=glue_database) assert len(df2.index) == 2 assert len(df2.columns) == 4 -def test_catalog_versioning(bucket, database): - table = "test_catalog_versioning" - wr.catalog.delete_table_if_exists(database=database, table=table) - path = f"s3://{bucket}/{table}/" +def test_catalog_versioning(path, glue_database, glue_table): + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) wr.s3.delete_objects(path=path) # Version 0 df = pd.DataFrame({"c0": [1, 2]}) - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table, mode="overwrite")["paths"] + paths = wr.s3.to_parquet( + df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite" + )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c0.dtype).startswith("Int") @@ -1440,10 +1336,16 @@ def test_catalog_versioning(bucket, database): # Version 1 df = pd.DataFrame({"c1": ["foo", "boo"]}) paths1 = wr.s3.to_parquet( - df=df, path=path, dataset=True, database=database, table=table, mode="overwrite", catalog_versioning=True + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + mode="overwrite", + catalog_versioning=True, )["paths"] wr.s3.wait_objects_exist(paths=paths1, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c1.dtype) == "string" @@ -1454,15 +1356,15 @@ def test_catalog_versioning(bucket, database): df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, mode="overwrite", catalog_versioning=True, index=False, )["paths"] wr.s3.wait_objects_exist(paths=paths2, use_threads=False) wr.s3.wait_objects_not_exist(paths=paths1, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c1.dtype).startswith("float") @@ -1473,35 +1375,31 @@ def test_catalog_versioning(bucket, database): df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, mode="overwrite", catalog_versioning=False, index=False, )["paths"] wr.s3.wait_objects_exist(paths=paths3, use_threads=False) wr.s3.wait_objects_not_exist(paths=paths2, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c1.dtype).startswith("boolean") - # Cleaning Up - wr.catalog.delete_table_if_exists(database=database, table=table) - wr.s3.delete_objects(path=path) - -def test_unsigned_parquet(bucket, database): - table = "test_unsigned_parquet" - path = f"s3://{bucket}/{table}/" +def test_unsigned_parquet(path, glue_database, glue_table): wr.s3.delete_objects(path=path) df = pd.DataFrame({"c0": [0, 0, (2 ** 8) - 1], "c1": [0, 0, (2 ** 16) - 1], "c2": [0, 0, (2 ** 32) - 1]}) df["c0"] = df.c0.astype("uint8") df["c1"] = df.c1.astype("uint16") df["c2"] = df.c2.astype("uint32") - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table, mode="overwrite")["paths"] + paths = wr.s3.to_parquet( + df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite" + )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert df.c0.sum() == (2 ** 8) - 1 assert df.c1.sum() == (2 ** 16) - 1 assert df.c2.sum() == (2 ** 32) - 1 @@ -1517,20 +1415,26 @@ def test_unsigned_parquet(bucket, database): df = pd.DataFrame({"c0": [0, 0, (2 ** 64) - 1]}) df["c0"] = df.c0.astype("uint64") with pytest.raises(wr.exceptions.UnsupportedType): - wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table, mode="overwrite") + wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite") wr.s3.delete_objects(path=path) - wr.catalog.delete_table_if_exists(database=database, table=table) + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) -def test_parquet_overwrite_partition_cols(path, database, table): +def test_parquet_overwrite_partition_cols(path, glue_database, glue_table): df = pd.DataFrame({"c0": [1, 2, 1, 2], "c1": [1, 2, 1, 2], "c2": [2, 1, 2, 1]}) paths = wr.s3.to_parquet( - df=df, path=path, dataset=True, database=database, table=table, mode="overwrite", partition_cols=["c2"] + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + mode="overwrite", + partition_cols=["c2"], )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 4 assert len(df.columns) == 3 assert df.c0.sum() == 6 @@ -1538,10 +1442,16 @@ def test_parquet_overwrite_partition_cols(path, database, table): assert df.c2.sum() == 6 paths = wr.s3.to_parquet( - df=df, path=path, dataset=True, database=database, table=table, mode="overwrite", partition_cols=["c1", "c2"] + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table, + mode="overwrite", + partition_cols=["c1", "c2"], )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 4 assert len(df.columns) == 3 assert df.c0.sum() == 6 @@ -1549,38 +1459,33 @@ def test_parquet_overwrite_partition_cols(path, database, table): assert df.c2.sum() == 6 -def test_catalog_parameters(bucket, database): - table = "test_catalog_parameters" - path = f"s3://{bucket}/{table}/" - wr.s3.delete_objects(path=path) - wr.catalog.delete_table_if_exists(database=database, table=table) - +def test_catalog_parameters(path, glue_database, glue_table): wr.s3.to_parquet( df=pd.DataFrame({"c0": [1, 2]}), path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, mode="overwrite", parameters={"a": "1", "b": "2"}, ) - pars = wr.catalog.get_table_parameters(database=database, table=table) + pars = wr.catalog.get_table_parameters(database=glue_database, table=glue_table) assert pars["a"] == "1" assert pars["b"] == "2" pars["a"] = "0" pars["c"] = "3" - wr.catalog.upsert_table_parameters(parameters=pars, database=database, table=table) - pars = wr.catalog.get_table_parameters(database=database, table=table) + wr.catalog.upsert_table_parameters(parameters=pars, database=glue_database, table=glue_table) + pars = wr.catalog.get_table_parameters(database=glue_database, table=glue_table) assert pars["a"] == "0" assert pars["b"] == "2" assert pars["c"] == "3" - wr.catalog.overwrite_table_parameters(parameters={"d": "4"}, database=database, table=table) - pars = wr.catalog.get_table_parameters(database=database, table=table) + wr.catalog.overwrite_table_parameters(parameters={"d": "4"}, database=glue_database, table=glue_table) + pars = wr.catalog.get_table_parameters(database=glue_database, table=glue_table) assert pars.get("a") is None assert pars.get("b") is None assert pars.get("c") is None assert pars["d"] == "4" - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert df.c0.sum() == 3 @@ -1589,54 +1494,57 @@ def test_catalog_parameters(bucket, database): df=pd.DataFrame({"c0": [3, 4]}), path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, mode="append", parameters={"e": "5"}, ) - pars = wr.catalog.get_table_parameters(database=database, table=table) + pars = wr.catalog.get_table_parameters(database=glue_database, table=glue_table) assert pars.get("a") is None assert pars.get("b") is None assert pars.get("c") is None assert pars["d"] == "4" assert pars["e"] == "5" - df = wr.athena.read_sql_table(table=table, database=database) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == 4 assert len(df.columns) == 1 assert df.c0.sum() == 10 - wr.s3.delete_objects(path=path) - wr.catalog.delete_table_if_exists(database=database, table=table) - -def test_athena_cache(path, database, table, workgroup1): +def test_athena_cache(path, glue_database, glue_table, workgroup1): df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=database, table=table)["paths"] + paths = wr.s3.to_parquet( + df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table + )["paths"] wr.s3.wait_objects_exist(paths=paths) - df2 = wr.athena.read_sql_table(table, database, ctas_approach=False, max_cache_seconds=1, workgroup=workgroup1) + df2 = wr.athena.read_sql_table( + glue_table, glue_database, ctas_approach=False, max_cache_seconds=1, workgroup=workgroup1 + ) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() - df2 = wr.athena.read_sql_table(table, database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1) + df2 = wr.athena.read_sql_table( + glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1 + ) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() dfs = wr.athena.read_sql_table( - table, database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1, chunksize=1 + glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, workgroup=workgroup1, chunksize=1 ) assert len(list(dfs)) == 2 -def test_cache_query_ctas_approach_true(path, database, table): +def test_cache_query_ctas_approach_true(path, glue_database, glue_table): df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") paths = wr.s3.to_parquet( df=df, path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, @@ -1646,27 +1554,27 @@ def test_cache_query_ctas_approach_true(path, database, table): with patch( "awswrangler.athena._check_for_cached_results", return_value={"has_valid_cache": False} ) as mocked_cache_attempt: - df2 = wr.athena.read_sql_table(table, database, ctas_approach=True, max_cache_seconds=0) + df2 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=True, max_cache_seconds=0) mocked_cache_attempt.assert_called() assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() with patch("awswrangler.athena._resolve_query_without_cache") as resolve_no_cache: - df3 = wr.athena.read_sql_table(table, database, ctas_approach=True, max_cache_seconds=900) + df3 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=True, max_cache_seconds=900) resolve_no_cache.assert_not_called() assert df.shape == df3.shape assert df.c0.sum() == df3.c0.sum() -def test_cache_query_ctas_approach_false(path, database, table): +def test_cache_query_ctas_approach_false(path, glue_database, glue_table): df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") paths = wr.s3.to_parquet( df=df, path=path, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, description="c0", parameters={"num_cols": str(len(df.columns)), "num_rows": str(len(df.index))}, columns_comments={"c0": "0"}, @@ -1676,28 +1584,30 @@ def test_cache_query_ctas_approach_false(path, database, table): with patch( "awswrangler.athena._check_for_cached_results", return_value={"has_valid_cache": False} ) as mocked_cache_attempt: - df2 = wr.athena.read_sql_table(table, database, ctas_approach=False, max_cache_seconds=0) + df2 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=False, max_cache_seconds=0) mocked_cache_attempt.assert_called() assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() with patch("awswrangler.athena._resolve_query_without_cache") as resolve_no_cache: - df3 = wr.athena.read_sql_table(table, database, ctas_approach=False, max_cache_seconds=900) + df3 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=False, max_cache_seconds=900) resolve_no_cache.assert_not_called() assert df.shape == df3.shape assert df.c0.sum() == df3.c0.sum() -def test_cache_query_semicolon(path, database, table): +def test_cache_query_semicolon(path, glue_database, glue_table): df = pd.DataFrame({"c0": [0, None]}, dtype="Int64") - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", database=database, table=table)["paths"] + paths = wr.s3.to_parquet( + df=df, path=path, dataset=True, mode="overwrite", database=glue_database, table=glue_table + )["paths"] wr.s3.wait_objects_exist(paths=paths) with patch( "awswrangler.athena._check_for_cached_results", return_value={"has_valid_cache": False} ) as mocked_cache_attempt: df2 = wr.athena.read_sql_query( - f"SELECT * FROM {table}", database=database, ctas_approach=True, max_cache_seconds=0 + f"SELECT * FROM {glue_table}", database=glue_database, ctas_approach=True, max_cache_seconds=0 ) mocked_cache_attempt.assert_called() assert df.shape == df2.shape @@ -1705,7 +1615,7 @@ def test_cache_query_semicolon(path, database, table): with patch("awswrangler.athena._resolve_query_without_cache") as resolve_no_cache: df3 = wr.athena.read_sql_query( - f"SELECT * FROM {table};", database=database, ctas_approach=True, max_cache_seconds=900 + f"SELECT * FROM {glue_table};", database=glue_database, ctas_approach=True, max_cache_seconds=900 ) resolve_no_cache.assert_not_called() assert df.shape == df3.shape @@ -1713,12 +1623,12 @@ def test_cache_query_semicolon(path, database, table): @pytest.mark.parametrize("partition_cols", [None, ["c2"], ["c1", "c2"]]) -def test_store_metadata_partitions_dataset(database, table, path, partition_cols): +def test_store_metadata_partitions_dataset(glue_database, glue_table, path, partition_cols): df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5], "c2": [6, 7, 8]}) paths = wr.s3.to_parquet(df=df, path=path, dataset=True, partition_cols=partition_cols)["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - wr.s3.store_parquet_metadata(path=path, database=database, table=table, dataset=True) - df2 = wr.athena.read_sql_table(table=table, database=database) + wr.s3.store_parquet_metadata(path=path, database=glue_database, table=glue_table, dataset=True) + df2 = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) == len(df2.index) assert len(df.columns) == len(df2.columns) assert df.c0.sum() == df2.c0.sum() @@ -1727,16 +1637,21 @@ def test_store_metadata_partitions_dataset(database, table, path, partition_cols @pytest.mark.parametrize("partition_cols", [None, ["c2"], ["c1", "c2"]]) -def test_store_metadata_partitions_sample_dataset(database, table, path, partition_cols): +def test_store_metadata_partitions_sample_dataset(glue_database, glue_table, path, partition_cols): num_files = 10 df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5], "c2": [6, 7, 8]}) for _ in range(num_files): paths = wr.s3.to_parquet(df=df, path=path, dataset=True, partition_cols=partition_cols)["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) wr.s3.store_parquet_metadata( - path=path, database=database, table=table, dtype={"c1": "bigint", "c2": "smallint"}, sampling=0.25, dataset=True + path=path, + database=glue_database, + table=glue_table, + dtype={"c1": "bigint", "c2": "smallint"}, + sampling=0.25, + dataset=True, ) - df2 = wr.athena.read_sql_table(table=table, database=database) + df2 = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert len(df.index) * num_files == len(df2.index) assert len(df.columns) == len(df2.columns) assert df.c0.sum() * num_files == df2.c0.sum() @@ -1744,28 +1659,28 @@ def test_store_metadata_partitions_sample_dataset(database, table, path, partiti assert df.c2.sum() * num_files == df2.c2.sum() -def test_athena_undefined_column(database): +def test_athena_undefined_column(glue_database): with pytest.raises(wr.exceptions.InvalidArgumentValue): - wr.athena.read_sql_query("SELECT 1", database) + wr.athena.read_sql_query("SELECT 1", glue_database) with pytest.raises(wr.exceptions.InvalidArgumentValue): - wr.athena.read_sql_query("SELECT NULL AS my_null", database) + wr.athena.read_sql_query("SELECT NULL AS my_null", glue_database) @pytest.mark.parametrize("partition_cols", [None, ["c1"], ["c2"], ["c1", "c2"], ["c2", "c1"]]) -def test_to_parquet_reverse_partitions(database, table, path, partition_cols): +def test_to_parquet_reverse_partitions(glue_database, glue_table, path, partition_cols): df = pd.DataFrame({"c0": [0, 1, 2], "c1": [3, 4, 5], "c2": [6, 7, 8]}) paths = wr.s3.to_parquet( - df=df, path=path, dataset=True, database=database, table=table, partition_cols=partition_cols + df=df, path=path, dataset=True, database=glue_database, table=glue_table, partition_cols=partition_cols )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table=table, database=database) + df2 = wr.athena.read_sql_table(table=glue_table, database=glue_database) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() assert df.c1.sum() == df2.c1.sum() assert df.c2.sum() == df2.c2.sum() -def test_to_parquet_nested_append(database, table, path): +def test_to_parquet_nested_append(glue_database, glue_table, path): df = pd.DataFrame( { "c0": [[1, 2, 3], [4, 5, 6]], @@ -1776,45 +1691,45 @@ def test_to_parquet_nested_append(database, table, path): "c5": [{"a": {"b": {"c": [1, 2]}}}, {"a": {"b": {"c": [3, 4]}}}], } ) - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"] + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database) + df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {glue_table}", database=glue_database) assert len(df2.index) == 2 assert len(df2.columns) == 4 - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"] + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {table}", database=database) + df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c1, c2, c4 FROM {glue_table}", database=glue_database) assert len(df2.index) == 4 assert len(df2.columns) == 4 -def test_to_parquet_nested_cast(database, table, path): +def test_to_parquet_nested_cast(glue_database, glue_table, path): df = pd.DataFrame({"c0": [[1, 2, 3], [4, 5, 6]], "c1": [[], []], "c2": [{"a": 1, "b": 2}, {"a": 3, "b": 4}]}) paths = wr.s3.to_parquet( df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, dtype={"c0": "array", "c1": "array", "c2": "struct"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) df = pd.DataFrame({"c0": [[1, 2, 3], [4, 5, 6]], "c1": [["a"], ["b"]], "c2": [{"a": 1, "b": 2}, {"a": 3, "b": 4}]}) - paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table)["paths"] + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table)["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c2 FROM {table}", database=database) + df2 = wr.athena.read_sql_query(sql=f"SELECT c0, c2 FROM {glue_table}", database=glue_database) assert len(df2.index) == 4 assert len(df2.columns) == 2 -def test_to_parquet_projection_integer(database, table, path): +def test_to_parquet_projection_integer(glue_database, glue_table, path): df = pd.DataFrame({"c0": [0, 1, 2], "c1": [0, 1, 2], "c2": [0, 100, 200], "c3": [0, 1, 2]}) paths = wr.s3.to_parquet( df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1", "c2", "c3"], regular_partitions=False, projection_enabled=True, @@ -1824,7 +1739,7 @@ def test_to_parquet_projection_integer(database, table, path): projection_digits={"c3": "1"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() assert df.c1.sum() == df2.c1.sum() @@ -1832,14 +1747,14 @@ def test_to_parquet_projection_integer(database, table, path): assert df.c3.sum() == df2.c3.sum() -def test_to_parquet_projection_enum(database, table, path): +def test_to_parquet_projection_enum(glue_database, glue_table, path): df = pd.DataFrame({"c0": [0, 1, 2], "c1": [1, 2, 3], "c2": ["foo", "boo", "bar"]}) paths = wr.s3.to_parquet( df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1", "c2"], regular_partitions=False, projection_enabled=True, @@ -1847,13 +1762,13 @@ def test_to_parquet_projection_enum(database, table, path): projection_values={"c1": "1,2,3", "c2": "foo,boo,bar"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() assert df.c1.sum() == df2.c1.sum() -def test_to_parquet_projection_date(database, table, path): +def test_to_parquet_projection_date(glue_database, glue_table, path): df = pd.DataFrame( { "c0": [0, 1, 2], @@ -1865,8 +1780,8 @@ def test_to_parquet_projection_date(database, table, path): df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1", "c2"], regular_partitions=False, projection_enabled=True, @@ -1874,27 +1789,27 @@ def test_to_parquet_projection_date(database, table, path): projection_ranges={"c1": "2020-01-01,2020-01-03", "c2": "2020-01-01 01:01:00,2020-01-01 01:01:03"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_table(table, database) + df2 = wr.athena.read_sql_table(glue_table, glue_database) print(df2) assert df.shape == df2.shape assert df.c0.sum() == df2.c0.sum() -def test_to_parquet_projection_injected(database, table, path): +def test_to_parquet_projection_injected(glue_database, glue_table, path): df = pd.DataFrame({"c0": [0, 1, 2], "c1": ["foo", "boo", "bar"], "c2": ["0", "1", "2"]}) paths = wr.s3.to_parquet( df=df, path=path, dataset=True, - database=database, - table=table, + database=glue_database, + table=glue_table, partition_cols=["c1", "c2"], regular_partitions=False, projection_enabled=True, projection_types={"c1": "injected", "c2": "injected"}, )["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) - df2 = wr.athena.read_sql_query(f"SELECT * FROM {table} WHERE c1='foo' AND c2='0'", database) + df2 = wr.athena.read_sql_query(f"SELECT * FROM {glue_table} WHERE c1='foo' AND c2='0'", glue_database) assert df2.shape == (1, 3) assert df2.c0.iloc[0] == 0 @@ -1902,30 +1817,30 @@ def test_to_parquet_projection_injected(database, table, path): def test_glue_database(): # Round 1 - Create Database - database_name = f"database_{get_time_str_with_random_suffix()}" - print(f"Database Name: {database_name}") - wr.catalog.create_database(name=database_name, description="Database Description") + glue_database_name = f"database_{get_time_str_with_random_suffix()}" + print(f"Database Name: {glue_database_name}") + wr.catalog.create_database(name=glue_database_name, description="Database Description") databases = wr.catalog.get_databases() test_database_name = "" test_database_description = "" for database in databases: - if database["Name"] == database_name: + if database["Name"] == glue_database_name: test_database_name = database["Name"] test_database_description = database["Description"] - assert test_database_name == database_name + assert test_database_name == glue_database_name assert test_database_description == "Database Description" # Round 2 - Delete Database - print(f"Database Name: {database_name}") - wr.catalog.delete_database(name=database_name) + print(f"Glue Database Name: {glue_database_name}") + wr.catalog.delete_database(name=glue_database_name) databases = wr.catalog.get_databases() test_database_name = "" test_database_description = "" for database in databases: - if database["Name"] == database_name: + if database["Name"] == glue_database_name: test_database_name = database["Name"] test_database_description = database["Description"] @@ -1933,7 +1848,7 @@ def test_glue_database(): assert test_database_description == "" -def test_parquet_catalog_casting_to_string(path, table, database): +def test_parquet_catalog_casting_to_string(path, glue_table, glue_database): for df in [get_df(), get_df_cast()]: paths = wr.s3.to_parquet( df=df, @@ -1941,8 +1856,8 @@ def test_parquet_catalog_casting_to_string(path, table, database): index=False, dataset=True, mode="overwrite", - database=database, - table=table, + database=glue_database, + table=glue_table, dtype={ "iint8": "string", "iint16": "string", @@ -1967,11 +1882,11 @@ def test_parquet_catalog_casting_to_string(path, table, database): assert df.shape == (3, 16) for dtype in df.dtypes.values: assert str(dtype) == "string" - df = wr.athena.read_sql_table(table=table, database=database, ctas_approach=True) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database, ctas_approach=True) assert df.shape == (3, 16) for dtype in df.dtypes.values: assert str(dtype) == "string" - df = wr.athena.read_sql_table(table=table, database=database, ctas_approach=False) + df = wr.athena.read_sql_table(table=glue_table, database=glue_database, ctas_approach=False) assert df.shape == (3, 16) for dtype in df.dtypes.values: assert str(dtype) == "string" diff --git a/tests/test_cloudwatch.py b/tests/test_cloudwatch.py index b96460146..cad3f335a 100644 --- a/tests/test_cloudwatch.py +++ b/tests/test_cloudwatch.py @@ -1,5 +1,4 @@ import logging -from datetime import datetime import boto3 import pytest @@ -7,42 +6,11 @@ import awswrangler as wr from awswrangler import exceptions -from ._utils import extract_cloudformation_outputs - logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="module") -def loggroup(cloudformation_outputs): - loggroup_name = cloudformation_outputs["LogGroupName"] - logstream_name = cloudformation_outputs["LogStream"] - client = boto3.client("logs") - response = client.describe_log_streams(logGroupName=loggroup_name, logStreamNamePrefix=logstream_name) - token = response["logStreams"][0].get("uploadSequenceToken") - events = [] - for i in range(5): - events.append({"timestamp": int(1000 * datetime.now().timestamp()), "message": str(i)}) - args = {"logGroupName": loggroup_name, "logStreamName": logstream_name, "logEvents": events} - if token: - args["sequenceToken"] = token - try: - client.put_log_events(**args) - except client.exceptions.DataAlreadyAcceptedException: - pass # Concurrency - while True: - results = wr.cloudwatch.run_query(log_group_names=[loggroup_name], query="fields @timestamp | limit 5") - if len(results) >= 5: - break - yield loggroup_name - - def test_query_cancelled(loggroup): client_logs = boto3.client("logs") query_id = wr.cloudwatch.start_query( diff --git a/tests/test_db.py b/tests/test_db.py index 980df3b11..913e1868e 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -10,98 +10,15 @@ import awswrangler as wr -from ._utils import ( - ensure_data_types, - ensure_data_types_category, - extract_cloudformation_outputs, - get_df, - get_df_category, - get_time_str_with_random_suffix, - path_generator, -) +from ._utils import ensure_data_types, ensure_data_types_category, get_df, get_df_category logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="function") -def path(bucket): - yield from path_generator(bucket) - - -@pytest.fixture(scope="module") -def bucket(cloudformation_outputs): - if "BucketName" in cloudformation_outputs: - bucket = cloudformation_outputs["BucketName"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)") - yield bucket - - -@pytest.fixture(scope="module") -def parameters(cloudformation_outputs): - parameters = dict(postgresql={}, mysql={}, redshift={}) - parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] - parameters["postgresql"]["port"] = 3306 - parameters["postgresql"]["schema"] = "public" - parameters["postgresql"]["database"] = "postgres" - parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] - parameters["mysql"]["port"] = 3306 - parameters["mysql"]["schema"] = "test" - parameters["mysql"]["database"] = "test" - parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] - parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] - parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] - parameters["redshift"]["schema"] = "public" - parameters["redshift"]["database"] = "test" - parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] - parameters["password"] = cloudformation_outputs["DatabasesPassword"] - parameters["user"] = "test" - yield parameters - - -@pytest.fixture(scope="module") -def glue_database(cloudformation_outputs): - yield cloudformation_outputs["GlueDatabaseName"] - - -@pytest.fixture(scope="function") -def glue_table(glue_database): - name = f"tbl_{get_time_str_with_random_suffix()}" - print(f"Table name: {name}") - wr.catalog.delete_table_if_exists(database=glue_database, table=name) - yield name - wr.catalog.delete_table_if_exists(database=glue_database, table=name) - - -@pytest.fixture(scope="module") -def external_schema(cloudformation_outputs, parameters, glue_database): - region = cloudformation_outputs.get("Region") - sql = f""" - CREATE EXTERNAL SCHEMA IF NOT EXISTS aws_data_wrangler_external FROM data catalog - DATABASE '{glue_database}' - IAM_ROLE '{parameters["redshift"]["role"]}' - REGION '{region}'; - """ - engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") - with engine.connect() as con: - con.execute(sql) - yield "aws_data_wrangler_external" - - -@pytest.fixture(scope="module") -def kms_key_id(cloudformation_outputs): - yield cloudformation_outputs["KmsKeyArn"].split("/", 1)[1] - - @pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"]) -def test_sql(parameters, db_type): +def test_sql(databases_parameters, db_type): df = get_df() if db_type == "redshift": df.drop(["binary"], axis=1, inplace=True) @@ -111,7 +28,7 @@ def test_sql(parameters, db_type): df=df, con=engine, name="test_sql", - schema=parameters[db_type]["schema"], + schema=databases_parameters[db_type]["schema"], if_exists="replace", index=index, index_label=None, @@ -119,18 +36,18 @@ def test_sql(parameters, db_type): method=None, dtype={"iint32": sqlalchemy.types.Integer}, ) - df = wr.db.read_sql_query(sql=f"SELECT * FROM {parameters[db_type]['schema']}.test_sql", con=engine) + df = wr.db.read_sql_query(sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.test_sql", con=engine) ensure_data_types(df, has_list=False) engine = wr.db.get_engine( db_type=db_type, - host=parameters[db_type]["host"], - port=parameters[db_type]["port"], - database=parameters[db_type]["database"], - user=parameters["user"], - password=parameters["password"], + host=databases_parameters[db_type]["host"], + port=databases_parameters[db_type]["port"], + database=databases_parameters[db_type]["database"], + user=databases_parameters["user"], + password=databases_parameters["password"], ) dfs = wr.db.read_sql_query( - sql=f"SELECT * FROM {parameters[db_type]['schema']}.test_sql", + sql=f"SELECT * FROM {databases_parameters[db_type]['schema']}.test_sql", con=engine, chunksize=1, dtype={ @@ -158,29 +75,31 @@ def test_sql(parameters, db_type): df=pd.DataFrame({"col0": [1, 2, 3]}, dtype="Int32"), con=engine, name="test_sql", - schema=parameters[db_type]["schema"], + schema=databases_parameters[db_type]["schema"], if_exists="replace", index=True, index_label="index", ) schema = None if db_type == "postgresql": - schema = parameters[db_type]["schema"] + schema = databases_parameters[db_type]["schema"] df = wr.db.read_sql_table(con=engine, table="test_sql", schema=schema, index_col="index") assert len(df.index) == 3 assert len(df.columns) == 1 -def test_redshift_temp_engine(parameters): - engine = wr.db.get_redshift_temp_engine(cluster_identifier=parameters["redshift"]["identifier"], user="test") +def test_redshift_temp_engine(databases_parameters): + engine = wr.db.get_redshift_temp_engine( + cluster_identifier=databases_parameters["redshift"]["identifier"], user="test" + ) with engine.connect() as con: cursor = con.execute("SELECT 1") assert cursor.fetchall()[0][0] == 1 -def test_redshift_temp_engine2(parameters): +def test_redshift_temp_engine2(databases_parameters): engine = wr.db.get_redshift_temp_engine( - cluster_identifier=parameters["redshift"]["identifier"], user="john_doe", duration=900, db_groups=[] + cluster_identifier=databases_parameters["redshift"]["identifier"], user="john_doe", duration=900, db_groups=[] ) with engine.connect() as con: cursor = con.execute("SELECT 1") @@ -195,7 +114,7 @@ def test_postgresql_param(): assert df["col0"].iloc[0] == 1 -def test_redshift_copy_unload(bucket, parameters): +def test_redshift_copy_unload(bucket, databases_parameters): path = f"s3://{bucket}/test_redshift_copy/" df = get_df().drop(["iint8", "binary"], axis=1, inplace=False) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") @@ -206,12 +125,12 @@ def test_redshift_copy_unload(bucket, parameters): schema="public", table="__test_redshift_copy", mode="overwrite", - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) df2 = wr.db.unload_redshift( sql="SELECT * FROM public.__test_redshift_copy", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) @@ -224,12 +143,12 @@ def test_redshift_copy_unload(bucket, parameters): schema="public", table="__test_redshift_copy", mode="append", - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) df2 = wr.db.unload_redshift( sql="SELECT * FROM public.__test_redshift_copy", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) @@ -238,7 +157,7 @@ def test_redshift_copy_unload(bucket, parameters): dfs = wr.db.unload_redshift( sql="SELECT * FROM public.__test_redshift_copy", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, chunked=True, @@ -247,7 +166,7 @@ def test_redshift_copy_unload(bucket, parameters): ensure_data_types(df=chunk, has_list=False) -def test_redshift_copy_upsert(bucket, parameters): +def test_redshift_copy_upsert(bucket, databases_parameters): engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") df = pd.DataFrame({"id": list((range(1_000))), "val": list(["foo" if i % 2 == 0 else "boo" for i in range(1_000)])}) df3 = pd.DataFrame( @@ -265,13 +184,13 @@ def test_redshift_copy_upsert(bucket, parameters): mode="overwrite", index=False, primary_keys=["id"], - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) path = f"s3://{bucket}/upsert/test_redshift_copy_upsert2/" df2 = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_copy_upsert", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) @@ -289,13 +208,13 @@ def test_redshift_copy_upsert(bucket, parameters): mode="upsert", index=False, primary_keys=["id"], - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) path = f"s3://{bucket}/upsert/test_redshift_copy_upsert4/" df4 = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_copy_upsert", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) @@ -311,13 +230,13 @@ def test_redshift_copy_upsert(bucket, parameters): table="test_redshift_copy_upsert", mode="upsert", index=False, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) path = f"s3://{bucket}/upsert/test_redshift_copy_upsert4/" df4 = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_copy_upsert", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) @@ -339,7 +258,7 @@ def test_redshift_copy_upsert(bucket, parameters): (None, None, wr.exceptions.InvalidRedshiftSortstyle, "foo", ["id"]), ], ) -def test_redshift_exceptions(bucket, parameters, diststyle, distkey, sortstyle, sortkey, exc): +def test_redshift_exceptions(bucket, databases_parameters, diststyle, distkey, sortstyle, sortkey, exc): df = pd.DataFrame({"id": [1], "name": "joe"}) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") path = f"s3://{bucket}/test_redshift_exceptions_{random.randint(0, 1_000_000)}/" @@ -355,13 +274,13 @@ def test_redshift_exceptions(bucket, parameters, diststyle, distkey, sortstyle, distkey=distkey, sortstyle=sortstyle, sortkey=sortkey, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], index=False, ) wr.s3.delete_objects(path=path) -def test_redshift_spectrum(bucket, glue_database, external_schema): +def test_redshift_spectrum(bucket, glue_database, redshift_external_schema): df = pd.DataFrame({"id": [1, 2, 3, 4, 5], "col_str": ["foo", None, "bar", None, "xoo"], "par_int": [0, 1, 0, 1, 1]}) path = f"s3://{bucket}/test_redshift_spectrum/" paths = wr.s3.to_parquet( @@ -377,7 +296,7 @@ def test_redshift_spectrum(bucket, glue_database, external_schema): wr.s3.wait_objects_exist(paths=paths, use_threads=False) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") with engine.connect() as con: - cursor = con.execute(f"SELECT * FROM {external_schema}.test_redshift_spectrum") + cursor = con.execute(f"SELECT * FROM {redshift_external_schema}.test_redshift_spectrum") rows = cursor.fetchall() assert len(rows) == len(df.index) for row in rows: @@ -386,7 +305,7 @@ def test_redshift_spectrum(bucket, glue_database, external_schema): assert wr.catalog.delete_table_if_exists(database=glue_database, table="test_redshift_spectrum") is True -def test_redshift_category(bucket, parameters): +def test_redshift_category(bucket, databases_parameters): path = f"s3://{bucket}/test_redshift_category/" df = get_df_category().drop(["binary"], axis=1, inplace=False) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") @@ -397,12 +316,12 @@ def test_redshift_category(bucket, parameters): schema="public", table="test_redshift_category", mode="overwrite", - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) df2 = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_category", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, categories=df.columns, @@ -411,7 +330,7 @@ def test_redshift_category(bucket, parameters): dfs = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_category", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, categories=df.columns, @@ -422,9 +341,9 @@ def test_redshift_category(bucket, parameters): wr.s3.delete_objects(path=path) -def test_redshift_unload_extras(bucket, parameters, kms_key_id): +def test_redshift_unload_extras(bucket, databases_parameters, kms_key_id): table = "test_redshift_unload_extras" - schema = parameters["redshift"]["schema"] + schema = databases_parameters["redshift"]["schema"] path = f"s3://{bucket}/{table}/" wr.s3.delete_objects(path=path) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") @@ -434,7 +353,7 @@ def test_redshift_unload_extras(bucket, parameters, kms_key_id): sql=f"SELECT * FROM {schema}.{table}", path=path, con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], region=wr.s3.get_bucket_region(bucket), max_file_size=5.0, kms_key_id=kms_key_id, @@ -448,7 +367,7 @@ def test_redshift_unload_extras(bucket, parameters, kms_key_id): df = wr.db.unload_redshift( sql=f"SELECT * FROM {schema}.{table}", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, region=wr.s3.get_bucket_region(bucket), @@ -461,9 +380,9 @@ def test_redshift_unload_extras(bucket, parameters, kms_key_id): @pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"]) -def test_to_sql_cast(parameters, db_type): +def test_to_sql_cast(databases_parameters, db_type): table = "test_to_sql_cast" - schema = parameters[db_type]["schema"] + schema = databases_parameters[db_type]["schema"] df = pd.DataFrame( { "col": [ @@ -491,9 +410,9 @@ def test_to_sql_cast(parameters, db_type): assert df.equals(df2) -def test_uuid(parameters): +def test_uuid(databases_parameters): table = "test_uuid" - schema = parameters["postgresql"]["schema"] + schema = databases_parameters["postgresql"]["schema"] engine = wr.catalog.get_engine(connection="aws-data-wrangler-postgresql") df = pd.DataFrame( { @@ -524,9 +443,9 @@ def test_uuid(parameters): @pytest.mark.parametrize("db_type", ["mysql", "redshift", "postgresql"]) -def test_null(parameters, db_type): +def test_null(databases_parameters, db_type): table = "test_null" - schema = parameters[db_type]["schema"] + schema = databases_parameters[db_type]["schema"] engine = wr.catalog.get_engine(connection=f"aws-data-wrangler-{db_type}") df = pd.DataFrame({"id": [1, 2, 3], "nothing": [None, None, None]}) wr.db.to_sql( @@ -557,7 +476,7 @@ def test_null(parameters, db_type): assert pd.concat(objs=[df, df], ignore_index=True).equals(df2) -def test_redshift_spectrum_long_string(path, glue_table, glue_database, external_schema): +def test_redshift_spectrum_long_string(path, glue_table, glue_database, redshift_external_schema): df = pd.DataFrame( { "id": [1, 2], @@ -573,14 +492,14 @@ def test_redshift_spectrum_long_string(path, glue_table, glue_database, external wr.s3.wait_objects_exist(paths=paths, use_threads=False) engine = wr.catalog.get_engine(connection="aws-data-wrangler-redshift") with engine.connect() as con: - cursor = con.execute(f"SELECT * FROM {external_schema}.{glue_table}") + cursor = con.execute(f"SELECT * FROM {redshift_external_schema}.{glue_table}") rows = cursor.fetchall() assert len(rows) == len(df.index) for row in rows: assert len(row) == len(df.columns) -def test_redshift_copy_unload_long_string(path, parameters): +def test_redshift_copy_unload_long_string(path, databases_parameters): df = pd.DataFrame( { "id": [1, 2], @@ -599,12 +518,12 @@ def test_redshift_copy_unload_long_string(path, parameters): table="test_redshift_copy_unload_long_string", mode="overwrite", varchar_lengths={"col_str": 300}, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], ) df2 = wr.db.unload_redshift( sql="SELECT * FROM public.test_redshift_copy_unload_long_string", con=engine, - iam_role=parameters["redshift"]["role"], + iam_role=databases_parameters["redshift"]["role"], path=path, keep_files=False, ) diff --git a/tests/test_emr.py b/tests/test_emr.py index 23a106217..41a4c0abc 100644 --- a/tests/test_emr.py +++ b/tests/test_emr.py @@ -5,27 +5,11 @@ import awswrangler as wr -from ._utils import extract_cloudformation_outputs - logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="module") -def bucket(cloudformation_outputs): - if "BucketName" in cloudformation_outputs: - bucket = cloudformation_outputs["BucketName"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)") - yield bucket - - def test_cluster(bucket, cloudformation_outputs): steps = [] for cmd in ['echo "Hello"', "ls -la"]: diff --git a/tests/test_moto.py b/tests/test_moto.py index 7b8238a2f..a8770edee 100644 --- a/tests/test_moto.py +++ b/tests/test_moto.py @@ -14,28 +14,20 @@ from ._utils import ensure_data_types, get_df_csv, get_df_list -@pytest.fixture(scope="function") -def s3(): - with moto.mock_s3(): - s3 = boto3.resource("s3") - s3.create_bucket(Bucket="bucket") - yield s3 - - @pytest.fixture(scope="module") -def emr(): +def moto_emr(): with moto.mock_emr(): yield True @pytest.fixture(scope="module") -def sts(): +def moto_sts(): with moto.mock_sts(): yield True @pytest.fixture(scope="module") -def subnet(): +def moto_subnet(): with moto.mock_ec2(): ec2 = boto3.resource("ec2", region_name="us-west-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") @@ -43,32 +35,40 @@ def subnet(): yield subnet.id +@pytest.fixture(scope="function") +def moto_s3(): + with moto.mock_s3(): + s3 = boto3.resource("s3") + s3.create_bucket(Bucket="bucket") + yield s3 + + def get_content_md5(desc: dict): result = desc.get("ResponseMetadata").get("HTTPHeaders").get("content-md5") return result -def test_get_bucket_region_succeed(s3): +def test_get_bucket_region_succeed(moto_s3): region = wr.s3.get_bucket_region("bucket", boto3_session=boto3.Session()) assert region == "us-east-1" -def test_object_not_exist_succeed(s3): +def test_object_not_exist_succeed(moto_s3): result = wr.s3.does_object_exist("s3://bucket/test.csv") assert result is False -def test_object_exist_succeed(s3): +def test_object_exist_succeed(moto_s3): path = "s3://bucket/test.csv" wr.s3.to_csv(df=get_df_csv(), path=path, index=False) result = wr.s3.does_object_exist(path) assert result is True -def test_list_directories_succeed(s3): +def test_list_directories_succeed(moto_s3): path = "s3://bucket" - s3_object1 = s3.Object("bucket", "foo/foo.tmp") - s3_object2 = s3.Object("bucket", "bar/bar.tmp") + s3_object1 = moto_s3.Object("bucket", "foo/foo.tmp") + s3_object2 = moto_s3.Object("bucket", "bar/bar.tmp") s3_object1.put(Body=b"foo") s3_object2.put(Body=b"bar") @@ -79,7 +79,7 @@ def test_list_directories_succeed(s3): assert sorted(files) == sorted(["s3://bucket/foo/foo.tmp", "s3://bucket/bar/bar.tmp"]) -def test_describe_no_object_succeed(s3): +def test_describe_no_object_succeed(moto_s3): desc = wr.s3.describe_objects("s3://bucket") @@ -87,10 +87,10 @@ def test_describe_no_object_succeed(s3): assert desc == {} -def test_describe_one_object_succeed(s3): +def test_describe_one_object_succeed(moto_s3): bucket = "bucket" key = "foo/foo.tmp" - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"foo") desc = wr.s3.describe_objects("s3://{}/{}".format(bucket, key)) @@ -99,12 +99,12 @@ def test_describe_one_object_succeed(s3): assert list(desc.keys()) == ["s3://bucket/foo/foo.tmp"] -def test_describe_list_of_objects_succeed(s3): +def test_describe_list_of_objects_succeed(moto_s3): bucket = "bucket" keys = ["foo/foo.tmp", "bar/bar.tmp"] for key in keys: - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"test") desc = wr.s3.describe_objects(["s3://{}/{}".format(bucket, key) for key in keys]) @@ -113,12 +113,12 @@ def test_describe_list_of_objects_succeed(s3): assert sorted(list(desc.keys())) == sorted(["s3://bucket/foo/foo.tmp", "s3://bucket/bar/bar.tmp"]) -def test_describe_list_of_objects_under_same_prefix_succeed(s3): +def test_describe_list_of_objects_under_same_prefix_succeed(moto_s3): bucket = "bucket" keys = ["foo/foo.tmp", "bar/bar.tmp"] for key in keys: - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"test") desc = wr.s3.describe_objects("s3://{}".format(bucket)) @@ -127,17 +127,17 @@ def test_describe_list_of_objects_under_same_prefix_succeed(s3): assert sorted(list(desc.keys())) == sorted(["s3://bucket/foo/foo.tmp", "s3://bucket/bar/bar.tmp"]) -def test_size_objects_without_object_succeed(s3): +def test_size_objects_without_object_succeed(moto_s3): size = wr.s3.size_objects("s3://bucket") assert isinstance(size, dict) assert size == {} -def test_size_list_of_objects_succeed(s3): +def test_size_list_of_objects_succeed(moto_s3): bucket = "bucket" - s3_object1 = s3.Object(bucket, "foo/foo.tmp") - s3_object2 = s3.Object(bucket, "bar/bar.tmp") + s3_object1 = moto_s3.Object(bucket, "foo/foo.tmp") + s3_object2 = moto_s3.Object(bucket, "bar/bar.tmp") s3_object1.put(Body=b"foofoo") s3_object2.put(Body=b"bar") @@ -147,10 +147,10 @@ def test_size_list_of_objects_succeed(s3): assert size == {"s3://bucket/foo/foo.tmp": 6, "s3://bucket/bar/bar.tmp": 3} -def test_copy_one_object_without_replace_filename_succeed(s3): +def test_copy_one_object_without_replace_filename_succeed(moto_s3): bucket = "bucket" key = "foo/foo.tmp" - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"foo") wr.s3.copy_objects( @@ -167,10 +167,10 @@ def test_copy_one_object_without_replace_filename_succeed(s3): ) -def test_copy_one_object_with_replace_filename_succeed(s3): +def test_copy_one_object_with_replace_filename_succeed(moto_s3): bucket = "bucket" key = "foo/foo.tmp" - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"foo") wr.s3.copy_objects( @@ -188,12 +188,12 @@ def test_copy_one_object_with_replace_filename_succeed(s3): ) -def test_copy_objects_without_replace_filename_succeed(s3): +def test_copy_objects_without_replace_filename_succeed(moto_s3): bucket = "bucket" keys = ["foo/foo1.tmp", "foo/foo2.tmp", "foo/foo3.tmp"] for key in keys: - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"foo") wr.s3.copy_objects( @@ -213,7 +213,7 @@ def test_copy_objects_without_replace_filename_succeed(s3): ) -def test_csv(s3): +def test_csv(moto_s3): path = "s3://bucket/test.csv" wr.s3.to_csv(df=get_df_csv(), path=path, index=False) df = wr.s3.read_csv(path=path) @@ -221,7 +221,7 @@ def test_csv(s3): assert len(df.columns) == 10 -def test_read_csv_with_chucksize_and_pandas_arguments(s3): +def test_read_csv_with_chucksize_and_pandas_arguments(moto_s3): path = "s3://bucket/test.csv" wr.s3.to_csv(df=get_df_csv(), path=path, index=False) dfs = [dfs for dfs in wr.s3.read_csv(path=path, chunksize=1, usecols=["id", "string"])] @@ -232,11 +232,11 @@ def test_read_csv_with_chucksize_and_pandas_arguments(s3): @mock.patch("pandas.read_csv") @mock.patch("s3fs.S3FileSystem.open") -def test_read_csv_pass_pandas_arguments_and_encoding_succeed(mock_open, mock_read_csv, s3): +def test_read_csv_pass_pandas_arguments_and_encoding_succeed(mock_open, mock_read_csv, moto_s3): bucket = "bucket" key = "foo/foo.csv" path = "s3://{}/{}".format(bucket, key) - s3_object = s3.Object(bucket, key) + s3_object = moto_s3.Object(bucket, key) s3_object.put(Body=b"foo") with pytest.raises(TypeError): @@ -245,7 +245,7 @@ def test_read_csv_pass_pandas_arguments_and_encoding_succeed(mock_open, mock_rea mock_read_csv.assert_called_with(ANY, compression=None, encoding="ISO-8859-1", sep=",", lineterminator="\r\n") -def test_to_csv_invalid_argument_combination_raise_when_dataset_false_succeed(s3): +def test_to_csv_invalid_argument_combination_raise_when_dataset_false_succeed(moto_s3): path = "s3://bucket/test.csv" with pytest.raises(InvalidArgumentCombination): wr.s3.to_csv(df=get_df_csv(), path=path, index=False, database="foo") @@ -275,7 +275,7 @@ def test_to_csv_invalid_argument_combination_raise_when_dataset_false_succeed(s3 wr.s3.to_csv(df=get_df_csv(), path=path, index=False, dataset=False, columns_comments={"col0": "test"}) -def test_to_csv_valid_argument_combination_when_dataset_true_succeed(s3): +def test_to_csv_valid_argument_combination_when_dataset_true_succeed(moto_s3): path = "s3://bucket/test.csv" wr.s3.to_csv(df=get_df_csv(), path=path, index=False) wr.s3.to_csv(df=get_df_csv(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"]) @@ -289,13 +289,13 @@ def test_to_csv_valid_argument_combination_when_dataset_true_succeed(s3): wr.s3.to_csv(df=get_df_csv(), path=path, index=False, dataset=True, columns_comments={"col0": "test"}) -def test_to_csv_data_empty_raise_succeed(s3): +def test_to_csv_data_empty_raise_succeed(moto_s3): path = "s3://bucket/test.csv" with pytest.raises(EmptyDataFrame): wr.s3.to_csv(df=pd.DataFrame(), path=path, index=False) -def test_parquet(s3): +def test_parquet(moto_s3): path = "s3://bucket/test.parquet" wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"]) df = wr.s3.read_parquet(path=path, dataset=True) @@ -303,7 +303,7 @@ def test_parquet(s3): assert df.shape == (3, 19) -def test_s3_delete_object_success(s3): +def test_s3_delete_object_success(moto_s3): path = "s3://bucket/test.parquet" wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"]) df = wr.s3.read_parquet(path=path, dataset=True) @@ -314,7 +314,7 @@ def test_s3_delete_object_success(s3): wr.s3.read_parquet(path=path, dataset=True) -def test_s3_raise_delete_object_exception_success(s3): +def test_s3_raise_delete_object_exception_success(moto_s3): path = "s3://bucket/test.parquet" wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"]) df = wr.s3.read_parquet(path=path, dataset=True) @@ -333,13 +333,13 @@ def mock_make_api_call(self, operation_name, kwarg): wr.s3.delete_objects(path=path) -def test_emr(s3, emr, sts, subnet): +def test_emr(moto_s3, moto_emr, moto_sts, moto_subnet): session = boto3.Session(region_name="us-west-1") cluster_id = wr.emr.create_cluster( cluster_name="wrangler_cluster", logging_s3_path="s3://bucket/emr-logs/", emr_release="emr-5.29.0", - subnet_id=subnet, + subnet_id=moto_subnet, emr_ec2_role="EMR_EC2_DefaultRole", emr_role="EMR_DefaultRole", instance_type_master="m5.xlarge", diff --git a/tests/test_quicksight.py b/tests/test_quicksight.py index 348da4933..e0bde4b8f 100644 --- a/tests/test_quicksight.py +++ b/tests/test_quicksight.py @@ -1,53 +1,18 @@ import logging -import pytest - import awswrangler as wr -from ._utils import extract_cloudformation_outputs, get_df_quicksight, get_time_str_with_random_suffix, path_generator +from ._utils import get_df_quicksight logging.basicConfig(level=logging.INFO, format="[%(asctime)s][%(levelname)s][%(name)s][%(funcName)s] %(message)s") logging.getLogger("awswrangler").setLevel(logging.DEBUG) logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="module") -def bucket(cloudformation_outputs): - if "BucketName" in cloudformation_outputs: - bucket = cloudformation_outputs["BucketName"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)") - yield bucket - - -@pytest.fixture(scope="module") -def database(cloudformation_outputs): - yield cloudformation_outputs["GlueDatabaseName"] - - -@pytest.fixture(scope="function") -def table(database): - name = f"tbl_{get_time_str_with_random_suffix()}" - print(f"Table name: {name}") - wr.catalog.delete_table_if_exists(database=database, table=name) - yield name - wr.catalog.delete_table_if_exists(database=database, table=name) - - -@pytest.fixture(scope="function") -def path(bucket): - yield from path_generator(bucket) - - -def test_quicksight(path, database, table): +def test_quicksight(path, glue_database, glue_table): df = get_df_quicksight() paths = wr.s3.to_parquet( - df=df, path=path, dataset=True, database=database, table=table, partition_cols=["par0", "par1"] + df=df, path=path, dataset=True, database=glue_database, table=glue_table, partition_cols=["par0", "par1"] )["paths"] wr.s3.wait_objects_exist(paths, use_threads=False) @@ -69,8 +34,8 @@ def test_quicksight(path, database, table): wr.quicksight.create_athena_dataset( name="test-table", - database=database, - table=table, + database=glue_database, + table=glue_table, data_source_name="test", allowed_to_manage=[wr.sts.get_current_identity_name()], rename_columns={"iint16": "new_col"}, @@ -80,7 +45,7 @@ def test_quicksight(path, database, table): wr.quicksight.create_athena_dataset( name="test-sql", - sql=f"SELECT * FROM {database}.{table}", + sql=f"SELECT * FROM {glue_database}.{glue_table}", data_source_name="test", import_mode="SPICE", allowed_to_use=[wr.sts.get_current_identity_name()], diff --git a/tests/test_s3.py b/tests/test_s3.py index c3b47cf25..308f8196c 100644 --- a/tests/test_s3.py +++ b/tests/test_s3.py @@ -16,7 +16,7 @@ import awswrangler as wr -from ._utils import extract_cloudformation_outputs, get_df_csv, path_generator +from ._utils import get_df_csv API_CALL = botocore.client.BaseClient._make_api_call @@ -25,30 +25,6 @@ logging.getLogger("botocore.credentials").setLevel(logging.CRITICAL) -@pytest.fixture(scope="module") -def cloudformation_outputs(): - yield extract_cloudformation_outputs() - - -@pytest.fixture(scope="module") -def region(cloudformation_outputs): - yield cloudformation_outputs["Region"] - - -@pytest.fixture(scope="module") -def bucket(cloudformation_outputs): - if "BucketName" in cloudformation_outputs: - bucket = cloudformation_outputs["BucketName"] - else: - raise Exception("You must deploy/update the test infrastructure (CloudFormation)") - yield bucket - - -@pytest.fixture(scope="function") -def path(bucket): - yield from path_generator(bucket) - - def test_delete_internal_error(bucket): response = { "Errors": [ @@ -223,10 +199,10 @@ def test_s3_empty_dfs(): wr.s3.to_csv(df=df, path="") -def test_absent_object(bucket): - path = f"s3://{bucket}/test_absent_object" - assert wr.s3.does_object_exist(path=path) is False - assert len(wr.s3.size_objects(path=path)) == 0 +def test_absent_object(path): + path_file = f"{path}test_absent_object" + assert wr.s3.does_object_exist(path=path_file) is False + assert len(wr.s3.size_objects(path=path_file)) == 0 assert wr.s3.wait_objects_exist(paths=[]) is None @@ -376,8 +352,7 @@ def test_copy_replacing_filename(bucket): wr.s3.delete_objects(path=path2) -def test_parquet_uint64(bucket): - path = f"s3://{bucket}/test_parquet_uint64/" +def test_parquet_uint64(path): wr.s3.delete_objects(path=path) df = pd.DataFrame( { @@ -396,8 +371,6 @@ def test_parquet_uint64(bucket): paths = wr.s3.to_parquet(df=df, path=path, dataset=True, mode="overwrite", partition_cols=["c4"])["paths"] wr.s3.wait_objects_exist(paths=paths, use_threads=False) df = wr.s3.read_parquet(path=path, dataset=True) - print(df) - print(df.dtypes) assert len(df.index) == 3 assert len(df.columns) == 5 assert df.c0.max() == (2 ** 8) - 1