Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

703 Quote identifiers and literals in SQL db/__init__.py #760

Merged
merged 2 commits into from
Jan 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 67 additions & 45 deletions src/layman/layer/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def get_workspaces(conn_cur=None):
conn_cur = db_util.get_connection_cursor()
_, cur = conn_cur

try:
cur.execute(f"""select schema_name
query = sql.SQL("""select schema_name
from information_schema.schemata
where schema_name NOT IN ('{"', '".join(settings.PG_NON_USER_SCHEMAS)}\
') AND schema_owner = '{settings.LAYMAN_PG_USER}'""")
where schema_name NOT IN ({schemas}) AND schema_owner = {layman_pg_user}""").format(
schemas=sql.SQL(', ').join([sql.Literal(schema) for schema in settings.PG_NON_USER_SCHEMAS]),
layman_pg_user=sql.Literal(settings.LAYMAN_PG_USER),
)
try:
cur.execute(query)
except BaseException as exc:
logger.error(f'get_workspaces ERROR')
raise LaymanError(7) from exc
Expand All @@ -60,9 +63,12 @@ def ensure_workspace(workspace, conn_cur=None):
conn_cur = db_util.get_connection_cursor()
conn, cur = conn_cur

statement = sql.SQL("""CREATE SCHEMA IF NOT EXISTS {schema} AUTHORIZATION {user}""").format(
schema=sql.Identifier(workspace),
user=sql.Identifier(settings.LAYMAN_PG_USER),
)
try:
cur.execute(
f"""CREATE SCHEMA IF NOT EXISTS "{workspace}" AUTHORIZATION {settings.LAYMAN_PG_USER}""")
cur.execute(statement)
conn.commit()
except BaseException as exc:
logger.error(f'ensure_workspace ERROR')
Expand All @@ -74,9 +80,11 @@ def delete_workspace(workspace, conn_cur=None):
conn_cur = db_util.get_connection_cursor()
conn, cur = conn_cur

statement = sql.SQL("""DROP SCHEMA IF EXISTS {schema} RESTRICT""").format(
schema=sql.Identifier(workspace),
)
try:
cur.execute(
f"""DROP SCHEMA IF EXISTS "{workspace}" RESTRICT""")
cur.execute(statement, (workspace, ))
conn.commit()
except BaseException as exc:
logger.error(f'delete_workspace ERROR')
Expand Down Expand Up @@ -187,14 +195,15 @@ def import_layer_vector_file_async(schema, table_name, main_filepath,
def get_text_column_names(schema, table_name, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()

try:
cur.execute(f"""
SELECT QUOTE_IDENT(column_name) AS column_name
statement = """
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{schema}'
AND table_name = '{table_name}'
WHERE table_schema = %s
AND table_name = %s
AND data_type IN ('character varying', 'varchar', 'character', 'char', 'text')
""")
"""
try:
cur.execute(statement, (schema, table_name))
except BaseException as exc:
logger.error(f'get_text_column_names ERROR')
raise LaymanError(7) from exc
Expand Down Expand Up @@ -234,11 +243,14 @@ def get_all_column_infos(schema, table_name, *, conn_cur=None, omit_geometry_col
def get_number_of_features(schema, table_name, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()

try:
cur.execute(f"""
statement = sql.SQL("""
select count(*)
from {schema}.{table_name}
""")
from {table}
""").format(
table=sql.Identifier(schema, table_name),
)
try:
cur.execute(statement)
except BaseException as exc:
logger.error(f'get_number_of_features ERROR')
raise LaymanError(7) from exc
Expand All @@ -255,13 +267,18 @@ def get_text_data(schema, table_name, conn_cur=None):
if num_features == 0:
return [], 0
limit = max(100, num_features // 10)
try:
cur.execute(f"""
select {', '.join(col_names)}
from {schema}.{table_name}
statement = sql.SQL("""
select {fields}
from {table}
order by ogc_fid
limit {limit}
""")
""").format(
fields=sql.SQL(',').join([sql.Identifier(col) for col in col_names]),
table=sql.Identifier(schema, table_name),
limit=sql.Literal(limit),
)
try:
cur.execute(statement)
except BaseException as exc:
logger.error(f'get_text_data ERROR')
raise LaymanError(7) from exc
Expand Down Expand Up @@ -295,18 +312,18 @@ def get_text_languages(schema, table_name, *, conn_cur=None):
return sorted(list(all_langs))


def get_most_frequent_lower_distance_query(schema, table_name, order_by_methods):
query = f"""
def get_most_frequent_lower_distance_query(schema, table_name):
query = sql.SQL("""
with t1 as (
select
row_number() over (partition by ogc_fid) AS dump_id,
sub_view.*
from (
SELECT
ogc_fid, (st_dump(wkb_geometry)).geom as geometry
FROM {{schema}}.{{table_name}}
FROM {table}
) sub_view
order by {{order_by_prefix}}geometry{{order_by_suffix}}, ogc_fid, dump_id
order by ST_NPoints(geometry), ogc_fid, dump_id
limit 5000
)
, t2 as (
Expand All @@ -326,7 +343,7 @@ def get_most_frequent_lower_distance_query(schema, table_name, order_by_methods)
where st_geometrytype(geometry) = 'ST_LineString'
)
) sub_view
order by {{order_by_prefix}}geometry{{order_by_suffix}}, ogc_fid, dump_id, ring_id
order by ST_NPoints(geometry), ogc_fid, dump_id, ring_id
limit 5000
)
, t2cumsum as (
Expand Down Expand Up @@ -383,25 +400,16 @@ def get_most_frequent_lower_distance_query(schema, table_name, order_by_methods)
from tfreq, tstat
order by freq desc
limit 1
"""

order_by_prefix = ''.join([f"{method}(" for method in order_by_methods])
order_by_suffix = ')' * len(order_by_methods)

query = query.format(schema=schema,
table_name=table_name,
order_by_prefix=order_by_prefix,
order_by_suffix=order_by_suffix,
)
""").format(
table=sql.Identifier(schema, table_name),
)
return query


def get_most_frequent_lower_distance(schema, table_name, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()

query = get_most_frequent_lower_distance_query(schema, table_name, [
'ST_NPoints'
])
query = get_most_frequent_lower_distance_query(schema, table_name)

# print(f"\nget_most_frequent_lower_distance v1\nusername={username}, layername={layername}")
# print(query)
Expand Down Expand Up @@ -456,7 +464,14 @@ def guess_scale_denominator(schema, table_name, *, conn_cur=None):

def create_string_attributes(attribute_tuples, conn_cur=None):
_, cur = conn_cur or db_util.get_connection_cursor()
query = "\n".join([f"""ALTER TABLE {schema}.{table} ADD COLUMN {attrname} VARCHAR(1024);""" for schema, _, table, attrname in attribute_tuples]) + "\n COMMIT;"
query = sql.SQL('{alters} \n COMMIT;').format(
alters=sql.SQL('\n').join(
[sql.SQL("""ALTER TABLE {table} ADD COLUMN {fattrname} VARCHAR(1024);""").format(
table=sql.Identifier(schema, table_name),
fattrname=sql.Identifier(attrname),
) for schema, _, table_name, attrname in attribute_tuples]
)
)
try:
cur.execute(query)
except BaseException as exc:
Expand All @@ -470,12 +485,19 @@ def get_missing_attributes(attribute_tuples, conn_cur=None):
table_name = {(workspace, layer): get_table_name(workspace, layer) for workspace, layer, _ in attribute_tuples}

# Find all foursomes which do not already exist
query = f"""select attribs.*
from (""" + "\n union all\n".join([f"select '{workspace}' workspace, '{layername}' layername, '{table_name[(workspace, layername)]}' table_name, '{attrname}' attrname" for workspace, layername, attrname in attribute_tuples]) + """) attribs left join
query = sql.SQL("""select attribs.*
from ({selects}) attribs left join
information_schema.columns c on c.table_schema = attribs.workspace
and c.table_name = attribs.table_name
and c.column_name = attribs.attrname
where c.column_name is null"""
where c.column_name is null""").format(
selects=sql.SQL("\n union all\n").join([sql.SQL("select {fworkspace} workspace, {flayername} layername, {ftablename} table_name, {fattrname} attrname").format(
fworkspace=sql.Literal(workspace),
flayername=sql.Literal(layername),
ftablename=sql.Literal(table_name[(workspace, layername)]),
fattrname=sql.Literal(attrname),
) for workspace, layername, attrname in attribute_tuples])
)

try:
if attribute_tuples:
Expand Down