Skip to content

Commit

Permalink
feat: PostgreSql improvements - escape and validate table identifiers…
Browse files Browse the repository at this point in the history
… and literals (#2390)

* feat: PostgreSql improvements - escape and validate table identifiers and literals

* fix: Use literals when querying information_schema
  • Loading branch information
kukushking authored Jul 11, 2023
1 parent 95b9715 commit f8590a1
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from awswrangler._config import apply_configs

pg8000 = _utils.import_optional_dependency("pg8000")
pg8000_native = _utils.import_optional_dependency("pg8000.native")

_logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -29,18 +30,18 @@ def _validate_connection(con: "pg8000.Connection") -> None:


def _drop_table(cursor: "pg8000.Cursor", schema: Optional[str], table: str) -> None:
schema_str = f'"{schema}".' if schema else ""
sql = f'DROP TABLE IF EXISTS {schema_str}"{table}"'
schema_str = f"{pg8000_native.identifier(schema)}." if schema else ""
sql = f"DROP TABLE IF EXISTS {schema_str}{pg8000_native.identifier(table)}"
_logger.debug("Drop table query:\n%s", sql)
cursor.execute(sql)


def _does_table_exist(cursor: "pg8000.Cursor", schema: Optional[str], table: str) -> bool:
schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else ""
schema_str = f"TABLE_SCHEMA = {pg8000_native.literal(schema)} AND" if schema else ""
cursor.execute(
f"SELECT true WHERE EXISTS ("
f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE "
f"{schema_str} TABLE_NAME = '{table}'"
f"{schema_str} TABLE_NAME = {pg8000_native.literal(table)}"
f");"
)
return len(cursor.fetchall()) > 0
Expand Down Expand Up @@ -69,7 +70,7 @@ def _create_table(
converter_func=_data_types.pyarrow2postgresql,
)
cols_str: str = "".join([f'"{k}" {v},\n' for k, v in postgresql_types.items()])[:-2]
sql = f'CREATE TABLE IF NOT EXISTS "{schema}"."{table}" (\n{cols_str})'
sql = f"CREATE TABLE IF NOT EXISTS {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)} (\n{cols_str})"
_logger.debug("Create table query:\n%s", sql)
cursor.execute(sql)

Expand All @@ -94,12 +95,16 @@ def _iterate_server_side_cursor(
"""
with con.cursor() as cursor:
sscursor_name: str = f"c_{uuid.uuid4().hex}"
cursor_args = _db_utils._convert_params(f"DECLARE {sscursor_name} CURSOR FOR {sql}", params)
cursor_args = _db_utils._convert_params(
f"DECLARE {pg8000_native.identifier(sscursor_name)} CURSOR FOR {sql}", params
)
cursor.execute(*cursor_args)

try:
while True:
cursor.execute(f"FETCH FORWARD {chunksize} FROM {sscursor_name}")
cursor.execute(
f"FETCH FORWARD {pg8000_native.literal(chunksize)} FROM {pg8000_native.identifier(sscursor_name)}"
)
records = cursor.fetchall()

if not records:
Expand All @@ -115,7 +120,7 @@ def _iterate_server_side_cursor(
dtype_backend=dtype_backend,
)
finally:
cursor.execute(f"CLOSE {sscursor_name}")
cursor.execute(f"CLOSE {pg8000_native.identifier(sscursor_name)}")


@_utils.check_optional_dependency(pg8000, "pg8000")
Expand Down Expand Up @@ -458,7 +463,11 @@ def read_sql_table(
>>> con.close()
"""
sql: str = f'SELECT * FROM "{table}"' if schema is None else f'SELECT * FROM "{schema}"."{table}"'
sql: str = (
f"SELECT * FROM {pg8000_native.identifier(table)}"
if schema is None
else f"SELECT * FROM {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)}"
)
return read_sql_query(
sql=sql,
con=con,
Expand Down Expand Up @@ -591,7 +600,7 @@ def to_sql(
df=df, column_placeholders=column_placeholders, chunksize=chunksize
)
for placeholders, parameters in placeholder_parameter_pair_generator:
sql: str = f'INSERT INTO "{schema}"."{table}" {insertion_columns} VALUES {placeholders}{upsert_str}'
sql: str = f"INSERT INTO {pg8000_native.identifier(schema)}.{pg8000_native.identifier(table)} {insertion_columns} VALUES {placeholders}{upsert_str}"
_logger.debug("sql: %s", sql)
cursor.executemany(sql, (parameters,))
con.commit()
Expand Down

0 comments on commit f8590a1

Please sign in to comment.