From f8590a12a46d0cae8d4c08331bb1447ccc524a90 Mon Sep 17 00:00:00 2001 From: kukushking Date: Tue, 11 Jul 2023 17:09:12 +0100 Subject: [PATCH] feat: PostgreSql improvements - escape and validate table identifiers and literals (#2390) * feat: PostgreSql improvements - escape and validate table identifiers and literals * fix: Use literals when querying information_schema --- awswrangler/postgresql.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index fd7a89903..7b67d6cf2 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -15,6 +15,7 @@ from awswrangler._config import apply_configs pg8000 = _utils.import_optional_dependency("pg8000") +pg8000_native = _utils.import_optional_dependency("pg8000.native") _logger: logging.Logger = logging.getLogger(__name__) @@ -29,18 +30,18 @@ def _validate_connection(con: "pg8000.Connection") -> None: def _drop_table(cursor: "pg8000.Cursor", schema: Optional[str], table: str) -> None: - schema_str = f'"{schema}".' if schema else "" - sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"' + schema_str = f"{pg8000_native.identifier(schema)}." if schema else "" + sql = f"DROP TABLE IF EXISTS {schema_str}{pg8000_native.identifier(table)}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) def _does_table_exist(cursor: "pg8000.Cursor", schema: Optional[str], table: str) -> bool: - schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" + schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else "" cursor.execute( f"SELECT true WHERE EXISTS (" f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " - f"{schema_str} TABLE_NAME = '{table}'" + f"{schema_str} TABLE_NAME = {pg8000_native.literal(table)}" f");" ) return len(cursor.fetchall()) > 0 @@ -69,7 +70,7 @@ def _create_table( converter_func=_data_types.pyarrow2postgresql, ) cols_str: str = "".join([f'"{k}" {v},\n' for k, v in postgresql_types.items()])[:-2] - sql = f'CREATE TABLE IF NOT EXISTS "{schema}"."{table}" (\n{cols_str})' + sql = f"CREATE TABLE IF NOT EXISTS {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)} (\n{cols_str})" _logger.debug("Create table query:\n%s", sql) cursor.execute(sql) @@ -94,12 +95,16 @@ def _iterate_server_side_cursor( """ with con.cursor() as cursor: sscursor_name: str = f"c_{uuid.uuid4().hex}" - cursor_args = _db_utils._convert_params(f"DECLARE {sscursor_name} CURSOR FOR {sql}", params) + cursor_args = _db_utils._convert_params( + f"DECLARE {pg8000_native.identifier(sscursor_name)} CURSOR FOR {sql}", params + ) cursor.execute(*cursor_args) try: while True: - cursor.execute(f"FETCH FORWARD {chunksize} FROM {sscursor_name}") + cursor.execute( + f"FETCH FORWARD {pg8000_native.literal(chunksize)} FROM {pg8000_native.identifier(sscursor_name)}" + ) records = cursor.fetchall() if not records: @@ -115,7 +120,7 @@ def _iterate_server_side_cursor( dtype_backend=dtype_backend, ) finally: - cursor.execute(f"CLOSE {sscursor_name}") + cursor.execute(f"CLOSE {pg8000_native.identifier(sscursor_name)}") @_utils.check_optional_dependency(pg8000, "pg8000") @@ -458,7 +463,11 @@ def read_sql_table( >>> con.close() """ - sql: str = f'SELECT * FROM "{table}"' if schema is None else f'SELECT * FROM "{schema}"."{table}"' + sql: str = ( + f"SELECT * FROM {pg8000_native.identifier(table)}" + if schema is None + else f"SELECT * FROM {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)}" + ) return read_sql_query( sql=sql, con=con, @@ -591,7 +600,7 @@ def to_sql( df=df, column_placeholders=column_placeholders, chunksize=chunksize ) for placeholders, parameters in placeholder_parameter_pair_generator: - sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}{upsert_str}' + sql: str = f"INSERT INTO {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)} {insertion_columns} VALUES {placeholders}{upsert_str}" _logger.debug("sql: %s", sql) cursor.executemany(sql, (parameters,)) con.commit()