diff --git a/README.md b/README.md index 218b08e24..225604f86 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![AWS Data Wrangler](docs/source/_static/logo.png?raw=true "AWS Data Wrangler") -> Utility belt to handle data on AWS. +> DataFrames on AWS. -[![Release](https://img.shields.io/badge/release-0.2.2-brightgreen.svg)](https://pypi.org/project/awswrangler/) +[![Release](https://img.shields.io/badge/release-0.2.5-brightgreen.svg)](https://pypi.org/project/awswrangler/) [![Downloads](https://img.shields.io/pypi/dm/awswrangler.svg)](https://pypi.org/project/awswrangler/) [![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7-brightgreen.svg)](https://pypi.org/project/awswrangler/) [![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/en/latest/?badge=latest) diff --git a/awswrangler/__version__.py b/awswrangler/__version__.py index b90c8613d..65b246978 100644 --- a/awswrangler/__version__.py +++ b/awswrangler/__version__.py @@ -1,4 +1,4 @@ __title__ = "awswrangler" -__description__ = "Utility belt to handle data on AWS." -__version__ = "0.2.2" +__description__ = "DataFrames on AWS." +__version__ = "0.2.5" __license__ = "Apache License 2.0" diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 1e6db38e8..ac4c7c8ac 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -1247,7 +1247,6 @@ def to_redshift( generated_conn = True try: - if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE: num_partitions: int = 1 else: @@ -1558,7 +1557,7 @@ def read_sql_redshift(self, :param sql: SQL Query :param iam_role: AWS IAM role with the related permissions - :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) :param temp_s3_path: AWS S3 path to write temporary data (e.g. s3://...) (Default uses the Athena's results bucket) :param procs_cpu_bound: Number of cores used for CPU bound tasks """ @@ -1574,6 +1573,13 @@ def read_sql_redshift(self, logger.debug(f"temp_s3_path: {temp_s3_path}") self._session.s3.delete_objects(path=temp_s3_path) paths: Optional[List[str]] = None + + generated_conn: bool = False + if type(connection) == str: + logger.debug("Glue connection (str) provided.") + connection = self._session.glue.get_connection(name=connection) + generated_conn = True + try: paths = self._session.redshift.to_parquet(sql=sql, path=temp_s3_path, @@ -1581,14 +1587,20 @@ def read_sql_redshift(self, connection=connection) logger.debug(f"paths: {paths}") df: pd.DataFrame = self.read_parquet(path=paths, procs_cpu_bound=procs_cpu_bound) # type: ignore - self._session.s3.delete_listed_objects(objects_paths=paths + [temp_s3_path + "/manifest"]) # type: ignore - return df - except Exception as e: + except Exception as ex: + connection.rollback() if paths is not None: self._session.s3.delete_listed_objects(objects_paths=paths + [temp_s3_path + "/manifest"]) else: self._session.s3.delete_objects(path=temp_s3_path) - raise e + if generated_conn is True: + connection.close() + raise ex + + if generated_conn is True: + connection.close() + self._session.s3.delete_listed_objects(objects_paths=paths + [temp_s3_path + "/manifest"]) # type: ignore + return df def to_aurora(self, dataframe: pd.DataFrame, diff --git a/awswrangler/spark.py b/awswrangler/spark.py index 113b64626..431f438d0 100644 --- a/awswrangler/spark.py +++ b/awswrangler/spark.py @@ -71,7 +71,7 @@ def to_redshift( :param dataframe: Pandas Dataframe :param path: S3 path to write temporary files (E.g. s3://BUCKET_NAME/ANY_NAME/) - :param connection: A PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) + :param connection: Glue connection name (str) OR a PEP 249 compatible connection (Can be generated with Redshift.generate_connection()) :param schema: The Redshift Schema for the table :param table: The name of the desired Redshift table :param iam_role: AWS IAM role with the related permissions @@ -93,68 +93,83 @@ def to_redshift( dataframe.cache() num_rows: int = dataframe.count() logger.info(f"Number of rows: {num_rows}") - num_partitions: int - if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE: - num_partitions = 1 - else: - num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection) - logger.debug(f"Number of slices on Redshift: {num_slices}") - num_partitions = num_slices - while num_partitions < min_num_partitions: - num_partitions += num_slices - logger.debug(f"Number of partitions calculated: {num_partitions}") - spark.conf.set("spark.sql.execution.arrow.enabled", "true") - session_primitives = self._session.primitives - par_col_name: str = "aws_data_wrangler_internal_partition_id" - @pandas_udf(returnType="objects_paths string", functionType=PandasUDFType.GROUPED_MAP) - def write(pandas_dataframe: pd.DataFrame) -> pd.DataFrame: - # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for - # a temporary workaround while waiting for Apache Arrow updates - # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0 - os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" + generated_conn: bool = False + if type(connection) == str: + logger.debug("Glue connection (str) provided.") + connection = self._session.glue.get_connection(name=connection) + generated_conn = True - del pandas_dataframe[par_col_name] - paths: List[str] = session_primitives.session.pandas.to_parquet(dataframe=pandas_dataframe, - path=path, - preserve_index=False, - mode="append", - procs_cpu_bound=1, - procs_io_bound=1, - cast_columns=casts) - return pd.DataFrame.from_dict({"objects_paths": paths}) + try: + num_partitions: int + if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE: + num_partitions = 1 + else: + num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection) + logger.debug(f"Number of slices on Redshift: {num_slices}") + num_partitions = num_slices + while num_partitions < min_num_partitions: + num_partitions += num_slices + logger.debug(f"Number of partitions calculated: {num_partitions}") + spark.conf.set("spark.sql.execution.arrow.enabled", "true") + session_primitives = self._session.primitives + par_col_name: str = "aws_data_wrangler_internal_partition_id" - df_objects_paths: DataFrame = dataframe.repartition(numPartitions=num_partitions) # type: ignore - df_objects_paths: DataFrame = df_objects_paths.withColumn(par_col_name, spark_partition_id()) # type: ignore - df_objects_paths: DataFrame = df_objects_paths.groupby(par_col_name).apply(write) # type: ignore + @pandas_udf(returnType="objects_paths string", functionType=PandasUDFType.GROUPED_MAP) + def write(pandas_dataframe: pd.DataFrame) -> pd.DataFrame: + # Exporting ARROW_PRE_0_15_IPC_FORMAT environment variable for + # a temporary workaround while waiting for Apache Arrow updates + # https://stackoverflow.com/questions/58273063/pandasudf-and-pyarrow-0-15-0 + os.environ["ARROW_PRE_0_15_IPC_FORMAT"] = "1" - objects_paths: List[str] = list(df_objects_paths.toPandas()["objects_paths"]) - dataframe.unpersist() - num_files_returned: int = len(objects_paths) - if num_files_returned != num_partitions: - raise MissingBatchDetected(f"{num_files_returned} files returned. {num_partitions} expected.") - logger.debug(f"List of objects returned: {objects_paths}") - logger.debug(f"Number of objects returned from UDF: {num_files_returned}") - manifest_path: str = f"{path}manifest.json" - self._session.redshift.write_load_manifest(manifest_path=manifest_path, - objects_paths=objects_paths, - procs_io_bound=self._procs_io_bound) - self._session.redshift.load_table(dataframe=dataframe, - dataframe_type="spark", - manifest_path=manifest_path, - schema_name=schema, - table_name=table, - redshift_conn=connection, - preserve_index=False, - num_files=num_partitions, - iam_role=iam_role, - diststyle=diststyle, - distkey=distkey, - sortstyle=sortstyle, - sortkey=sortkey, - mode=mode, - cast_columns=casts) - self._session.s3.delete_objects(path=path, procs_io_bound=self._procs_io_bound) + del pandas_dataframe[par_col_name] + paths: List[str] = session_primitives.session.pandas.to_parquet(dataframe=pandas_dataframe, + path=path, + preserve_index=False, + mode="append", + procs_cpu_bound=1, + procs_io_bound=1, + cast_columns=casts) + return pd.DataFrame.from_dict({"objects_paths": paths}) + + df_objects_paths: DataFrame = dataframe.repartition(numPartitions=num_partitions) # type: ignore + df_objects_paths = df_objects_paths.withColumn(par_col_name, spark_partition_id()) # type: ignore + df_objects_paths = df_objects_paths.groupby(par_col_name).apply(write) # type: ignore + + objects_paths: List[str] = list(df_objects_paths.toPandas()["objects_paths"]) + dataframe.unpersist() + num_files_returned: int = len(objects_paths) + if num_files_returned != num_partitions: + raise MissingBatchDetected(f"{num_files_returned} files returned. {num_partitions} expected.") + logger.debug(f"List of objects returned: {objects_paths}") + logger.debug(f"Number of objects returned from UDF: {num_files_returned}") + manifest_path: str = f"{path}manifest.json" + self._session.redshift.write_load_manifest(manifest_path=manifest_path, + objects_paths=objects_paths, + procs_io_bound=self._procs_io_bound) + self._session.redshift.load_table(dataframe=dataframe, + dataframe_type="spark", + manifest_path=manifest_path, + schema_name=schema, + table_name=table, + redshift_conn=connection, + preserve_index=False, + num_files=num_partitions, + iam_role=iam_role, + diststyle=diststyle, + distkey=distkey, + sortstyle=sortstyle, + sortkey=sortkey, + mode=mode, + cast_columns=casts) + self._session.s3.delete_objects(path=path, procs_io_bound=self._procs_io_bound) + except Exception as ex: + connection.rollback() + if generated_conn is True: + connection.close() + raise ex + if generated_conn is True: + connection.close() def create_glue_table(self, database, diff --git a/docs/source/index.rst b/docs/source/index.rst index dae71da2f..fc76c3327 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -8,7 +8,7 @@ :alt: alternate text :figclass: align-center -*Utility belt to handle data on AWS.* +*DataFrames on AWS.* `Read the Tutorials `_: `Catalog & Metadata `_ | `Athena Nested `_ | `S3 Write Modes `_ diff --git a/requirements.txt b/requirements.txt index 9c3015f9f..2bb8b8bb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ numpy~=1.18.1 pandas~=0.25.3 pyarrow~=0.15.1 -botocore~=1.14.1 -boto3~=1.11.1 +botocore~=1.14.2 +boto3~=1.11.2 s3fs~=0.4.0 tenacity~=6.0.0 pg8000~=1.13.2 diff --git a/testing/test_awswrangler/test_redshift.py b/testing/test_awswrangler/test_redshift.py index f4b74a308..8670e8bdc 100644 --- a/testing/test_awswrangler/test_redshift.py +++ b/testing/test_awswrangler/test_redshift.py @@ -347,7 +347,7 @@ def test_to_redshift_spark_bool(session, bucket, redshift_parameters): session.spark.to_redshift( dataframe=dataframe, path=f"s3://{bucket}/redshift-load-bool/", - connection=con, + connection="aws-data-wrangler-redshift", schema="public", table="test", iam_role=redshift_parameters.get("RedshiftRole"), @@ -722,3 +722,33 @@ def test_to_redshift_pandas_upsert(session, bucket, redshift_parameters): wr.s3.delete_objects(path=f"s3://{bucket}/") con.close() + + +@pytest.mark.parametrize("sample_name", ["micro", "small", "nano"]) +def test_read_sql_redshift_pandas_glue_conn(session, bucket, redshift_parameters, sample_name): + if sample_name == "micro": + dates = ["date"] + elif sample_name == "small": + dates = ["date"] + else: + dates = ["date", "time"] + df = pd.read_csv(f"data_samples/{sample_name}.csv", parse_dates=dates, infer_datetime_format=True) + df["date"] = df["date"].dt.date + path = f"s3://{bucket}/test_read_sql_redshift_pandas_glue_conn/" + session.pandas.to_redshift( + dataframe=df, + path=path, + schema="public", + table="test", + connection="aws-data-wrangler-redshift", + iam_role=redshift_parameters.get("RedshiftRole"), + mode="overwrite", + preserve_index=True, + ) + path2 = f"s3://{bucket}/test_read_sql_redshift_pandas_glue_conn2/" + df2 = session.pandas.read_sql_redshift(sql="select * from public.test", + iam_role=redshift_parameters.get("RedshiftRole"), + connection="aws-data-wrangler-redshift", + temp_s3_path=path2) + assert len(df.index) == len(df2.index) + assert len(df.columns) + 1 == len(df2.columns)