From f186f68a4fa164ac90fc19b2768930b0fe805837 Mon Sep 17 00:00:00 2001 From: index-git Date: Thu, 26 Jan 2023 15:34:51 +0100 Subject: [PATCH] Quote identifiers and literals in SQL db/__init__.py --- src/layman/layer/db/__init__.py | 97 +++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/src/layman/layer/db/__init__.py b/src/layman/layer/db/__init__.py index df1b2410a..453244092 100644 --- a/src/layman/layer/db/__init__.py +++ b/src/layman/layer/db/__init__.py @@ -32,11 +32,14 @@ def get_workspaces(conn_cur=None): conn_cur = db_util.get_connection_cursor() _, cur = conn_cur - try: - cur.execute(f"""select schema_name + query = sql.SQL("""select schema_name from information_schema.schemata - where schema_name NOT IN ('{"', '".join(settings.PG_NON_USER_SCHEMAS)}\ -') AND schema_owner = '{settings.LAYMAN_PG_USER}'""") + where schema_name NOT IN ({schemas}) AND schema_owner = {layman_pg_user}""").format( + schemas=sql.SQL(', ').join([sql.Literal(schema) for schema in settings.PG_NON_USER_SCHEMAS]), + layman_pg_user=sql.Literal(settings.LAYMAN_PG_USER), + ) + try: + cur.execute(query) except BaseException as exc: logger.error(f'get_workspaces ERROR') raise LaymanError(7) from exc @@ -60,9 +63,12 @@ def ensure_workspace(workspace, conn_cur=None): conn_cur = db_util.get_connection_cursor() conn, cur = conn_cur + statement = sql.SQL("""CREATE SCHEMA IF NOT EXISTS {schema} AUTHORIZATION {user}""").format( + schema=sql.Identifier(workspace), + user=sql.Identifier(settings.LAYMAN_PG_USER), + ) try: - cur.execute( - f"""CREATE SCHEMA IF NOT EXISTS "{workspace}" AUTHORIZATION {settings.LAYMAN_PG_USER}""") + cur.execute(statement) conn.commit() except BaseException as exc: logger.error(f'ensure_workspace ERROR') @@ -74,9 +80,11 @@ def delete_workspace(workspace, conn_cur=None): conn_cur = db_util.get_connection_cursor() conn, cur = conn_cur + statement = sql.SQL("""DROP SCHEMA IF EXISTS {schema} RESTRICT""").format( + schema=sql.Identifier(workspace), + ) try: - cur.execute( - f"""DROP SCHEMA IF EXISTS "{workspace}" RESTRICT""") + cur.execute(statement, (workspace, )) conn.commit() except BaseException as exc: logger.error(f'delete_workspace ERROR') @@ -187,14 +195,15 @@ def import_layer_vector_file_async(schema, table_name, main_filepath, def get_text_column_names(schema, table_name, conn_cur=None): _, cur = conn_cur or db_util.get_connection_cursor() - try: - cur.execute(f""" -SELECT QUOTE_IDENT(column_name) AS column_name + statement = """ +SELECT column_name FROM information_schema.columns -WHERE table_schema = '{schema}' -AND table_name = '{table_name}' +WHERE table_schema = %s +AND table_name = %s AND data_type IN ('character varying', 'varchar', 'character', 'char', 'text') -""") +""" + try: + cur.execute(statement, (schema, table_name)) except BaseException as exc: logger.error(f'get_text_column_names ERROR') raise LaymanError(7) from exc @@ -234,11 +243,14 @@ def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_col def get_number_of_features(schema, table_name, conn_cur=None): _, cur = conn_cur or db_util.get_connection_cursor() - try: - cur.execute(f""" + statement = sql.SQL(""" select count(*) -from {schema}.{table_name} -""") +from {table} +""").format( + table=sql.Identifier(schema, table_name), + ) + try: + cur.execute(statement) except BaseException as exc: logger.error(f'get_number_of_features ERROR') raise LaymanError(7) from exc @@ -255,13 +267,18 @@ def get_text_data(schema, table_name, conn_cur=None): if num_features == 0: return [], 0 limit = max(100, num_features // 10) - try: - cur.execute(f""" -select {', '.join(col_names)} -from {schema}.{table_name} + statement = sql.SQL(""" +select {fields} +from {table} order by ogc_fid limit {limit} -""") +""").format( + fields=sql.SQL(',').join([sql.Identifier(col) for col in col_names]), + table=sql.Identifier(schema, table_name), + limit=sql.Literal(limit), + ) + try: + cur.execute(statement) except BaseException as exc: logger.error(f'get_text_data ERROR') raise LaymanError(7) from exc @@ -296,7 +313,7 @@ def get_text_languages(schema, table_name, *, conn_cur=None): def get_most_frequent_lower_distance_query(schema, table_name): - query = f""" + query = sql.SQL(""" with t1 as ( select row_number() over (partition by ogc_fid) AS dump_id, @@ -304,7 +321,7 @@ def get_most_frequent_lower_distance_query(schema, table_name): from ( SELECT ogc_fid, (st_dump(wkb_geometry)).geom as geometry - FROM {{schema}}.{{table_name}} + FROM {table} ) sub_view order by ST_NPoints(geometry), ogc_fid, dump_id limit 5000 @@ -383,11 +400,9 @@ def get_most_frequent_lower_distance_query(schema, table_name): from tfreq, tstat order by freq desc limit 1 - """ - - query = query.format(schema=schema, - table_name=table_name, - ) + """).format( + table=sql.Identifier(schema, table_name), + ) return query @@ -449,7 +464,14 @@ def guess_scale_denominator(schema, table_name, *, conn_cur=None): def create_string_attributes(attribute_tuples, conn_cur=None): _, cur = conn_cur or db_util.get_connection_cursor() - query = "\n".join([f"""ALTER TABLE {schema}.{table} ADD COLUMN {attrname} VARCHAR(1024);""" for schema, _, table, attrname in attribute_tuples]) + "\n COMMIT;" + query = sql.SQL('{alters} \n COMMIT;').format( + alters=sql.SQL('\n').join( + [sql.SQL("""ALTER TABLE {table} ADD COLUMN {fattrname} VARCHAR(1024);""").format( + table=sql.Identifier(schema, table_name), + fattrname=sql.Identifier(attrname), + ) for schema, _, table_name, attrname in attribute_tuples] + ) + ) try: cur.execute(query) except BaseException as exc: @@ -463,12 +485,19 @@ def get_missing_attributes(attribute_tuples, conn_cur=None): table_name = {(workspace, layer): get_table_name(workspace, layer) for workspace, layer, _ in attribute_tuples} # Find all foursomes which do not already exist - query = f"""select attribs.* -from (""" + "\n union all\n".join([f"select '{workspace}' workspace, '{layername}' layername, '{table_name[(workspace, layername)]}' table_name, '{attrname}' attrname" for workspace, layername, attrname in attribute_tuples]) + """) attribs left join + query = sql.SQL("""select attribs.* +from ({selects}) attribs left join information_schema.columns c on c.table_schema = attribs.workspace and c.table_name = attribs.table_name and c.column_name = attribs.attrname -where c.column_name is null""" +where c.column_name is null""").format( + selects=sql.SQL("\n union all\n").join([sql.SQL("select {fworkspace} workspace, {flayername} layername, {ftablename} table_name, {fattrname} attrname").format( + fworkspace=sql.Literal(workspace), + flayername=sql.Literal(layername), + ftablename=sql.Literal(table_name[(workspace, layername)]), + fattrname=sql.Literal(attrname), + ) for workspace, layername, attrname in attribute_tuples]) + ) try: if attribute_tuples: