Skip to content

Commit

Permalink
Quote identifiers and literals in SQL db/__init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
index-git committed Jan 27, 2023
1 parent bd08081 commit 309a050
Showing 1 changed file with 62 additions and 33 deletions.
95 changes: 62 additions & 33 deletions src/layman/layer/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def get_workspaces(conn_cur=None):
conn_cur = db_util.get_connection_cursor()
_, cur = conn_cur

try:
cur.execute(f"""select schema_name
query = sql.SQL("""select schema_name
from information_schema.schemata
where schema_name NOT IN ('{"', '".join(settings.PG_NON_USER_SCHEMAS)}\
') AND schema_owner = '{settings.LAYMAN_PG_USER}'""")
where schema_name NOT IN ({schemas}) AND schema_owner = {layman_pg_user}""").format(
schemas=sql.SQL(', ').join([sql.Literal(schema) for schema in settings.PG_NON_USER_SCHEMAS]),
layman_pg_user=sql.Literal(settings.LAYMAN_PG_USER),
)
try:
cur.execute(query)
except BaseException as exc:
logger.error(f'get_workspaces ERROR')
raise LaymanError(7) from exc
Expand All @@ -60,9 +63,12 @@ def ensure_workspace(workspace, conn_cur=None):
conn_cur = db_util.get_connection_cursor()
conn, cur = conn_cur

statement = sql.SQL("""CREATE SCHEMA IF NOT EXISTS {schema} AUTHORIZATION {user}""").format(
schema=sql.Identifier(workspace),
user=sql.Identifier(settings.LAYMAN_PG_USER),
)
try:
cur.execute(
f"""CREATE SCHEMA IF NOT EXISTS "{workspace}" AUTHORIZATION {settings.LAYMAN_PG_USER}""")
cur.execute(statement)
conn.commit()
except BaseException as exc:
logger.error(f'ensure_workspace ERROR')
Expand All @@ -74,9 +80,11 @@ def delete_workspace(workspace, conn_cur=None):
conn_cur = db_util.get_connection_cursor()
conn, cur = conn_cur

statement = sql.SQL("""DROP SCHEMA IF EXISTS {schema} RESTRICT""").format(
schema=sql.Identifier(workspace),
)
try:
cur.execute(
f"""DROP SCHEMA IF EXISTS "{workspace}" RESTRICT""")
cur.execute(statement, (workspace, ))
conn.commit()
except BaseException as exc:
logger.error(f'delete_workspace ERROR')
Expand Down Expand Up @@ -187,14 +195,15 @@ def import_layer_vector_file_async(schema, table_name, main_filepath,
def get_text_column_names(schema, table_name, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()

try:
cur.execute(f"""
statement = """
SELECT QUOTE_IDENT(column_name) AS column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{table_name}'
WHERE table_schema = %s
AND table_name = %s
AND data_type IN ('character varying', 'varchar', 'character', 'char', 'text')
""")
"""
try:
cur.execute(statement, (schema, table_name))
except BaseException as exc:
logger.error(f'get_text_column_names ERROR')
raise LaymanError(7) from exc
Expand Down Expand Up @@ -234,11 +243,14 @@ def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_col
def get_number_of_features(schema, table_name, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()

try:
cur.execute(f"""
statement = sql.SQL("""
select count(*)
from {schema}.{table_name}
""")
from {table}
""").format(
table=sql.Identifier(schema, table_name),
)
try:
cur.execute(statement)
except BaseException as exc:
logger.error(f'get_number_of_features ERROR')
raise LaymanError(7) from exc
Expand All @@ -255,13 +267,18 @@ def get_text_data(schema, table_name, conn_cur=None):
if num_features == 0:
return [], 0
limit = max(100, num_features // 10)
try:
cur.execute(f"""
select {', '.join(col_names)}
from {schema}.{table_name}
statement = sql.SQL("""
select {fields}
from {table}
order by ogc_fid
limit {limit}
""")
""").format(
fields=sql.SQL(',').join([sql.Identifier(col) for col in col_names]),
table=sql.Identifier(schema, table_name),
limit=sql.Literal(limit),
)
try:
cur.execute(statement)
except BaseException as exc:
logger.error(f'get_text_data ERROR')
raise LaymanError(7) from exc
Expand Down Expand Up @@ -296,15 +313,15 @@ def get_text_languages(schema, table_name, *, conn_cur=None):


def get_most_frequent_lower_distance_query(schema, table_name):
query = f"""
query = sql.SQL("""
with t1 as (
select
row_number() over (partition by ogc_fid) AS dump_id,
sub_view.*
from (
SELECT
ogc_fid, (st_dump(wkb_geometry)).geom as geometry
FROM {{schema}}.{{table_name}}
FROM {table}
) sub_view
order by ST_NPoints(geometry), ogc_fid, dump_id
limit 5000
Expand Down Expand Up @@ -383,11 +400,9 @@ def get_most_frequent_lower_distance_query(schema, table_name):
from tfreq, tstat
order by freq desc
limit 1
"""

query = query.format(schema=schema,
table_name=table_name,
)
""").format(
table=sql.Identifier(schema, table_name),
)
return query


Expand Down Expand Up @@ -449,7 +464,14 @@ def guess_scale_denominator(schema, table_name, *, conn_cur=None):

def create_string_attributes(attribute_tuples, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()
query = "\n".join([f"""ALTER TABLE {schema}.{table} ADD COLUMN {attrname} VARCHAR(1024);""" for schema, _, table, attrname in attribute_tuples]) + "\n COMMIT;"
query = sql.SQL('{alters} \n COMMIT;').format(
alters=sql.SQL('\n').join(
[sql.SQL("""ALTER TABLE {table} ADD COLUMN {fattrname} VARCHAR(1024);""").format(
table=sql.Identifier(schema, table_name),
fattrname=sql.Identifier(attrname),
) for schema, _, table_name, attrname in attribute_tuples]
)
)
try:
cur.execute(query)
except BaseException as exc:
Expand All @@ -463,12 +485,19 @@ def get_missing_attributes(attribute_tuples, conn_cur=None):
table_name = {(workspace, layer): get_table_name(workspace, layer) for workspace, layer, _ in attribute_tuples}

# Find all foursomes which do not already exist
query = f"""select attribs.*
from (""" + "\n union all\n".join([f"select '{workspace}' workspace, '{layername}' layername, '{table_name[(workspace, layername)]}' table_name, '{attrname}' attrname" for workspace, layername, attrname in attribute_tuples]) + """) attribs left join
query = sql.SQL("""select attribs.*
from ({selects}) attribs left join
information_schema.columns c on c.table_schema = attribs.workspace
and c.table_name = attribs.table_name
and c.column_name = attribs.attrname
where c.column_name is null"""
where c.column_name is null""").format(
selects=sql.SQL("\n union all\n").join([sql.SQL("select '{fworkspace}' workspace, '{flayername}' layername, '{ftablename}' table_name, '{fattrname}' attrname").format(
fworkspace=sql.Identifier(workspace),
flayername=sql.Identifier(layername),
ftablename=sql.Identifier(table_name[(workspace, layername)]),
fattrname=sql.Identifier(attrname),
) for workspace, layername, attrname in attribute_tuples])
)

try:
if attribute_tuples:
Expand Down

0 comments on commit 309a050

Please sign in to comment.