Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example script to extract column names. #681

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/extract_column_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python
#
# Copyright (C) 2009-2020 the sqlparse authors and contributors
# <see AUTHORS file>
#
# This example is part of python-sqlparse and is released under
# the BSD License: https://opensource.org/licenses/BSD-3-Clause
#
# This example illustrates how to extract table names from nested
# SELECT statements.
#
# See:
# https://groups.google.com/forum/#!forum/sqlparse/browse_thread/thread/b0bd9a022e9d4895

import sqlparse
from sqlparse.sql import IdentifierList, Identifier
from sqlparse.tokens import Keyword, DML


def extract_select_part(parsed):
select_seen = False
for item in parsed.tokens:
if item.ttype is Keyword and item.value.upper() == 'FROM':
return
if select_seen:
yield item
if item.ttype is DML and item.value.upper() == 'SELECT':
select_seen = True


def extract_column_identifiers(token_stream):
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
yield identifier.get_name()
elif isinstance(item, Identifier):
yield item.get_name()


def extract_columns(sql):
stream = extract_select_part(sqlparse.parse(sql)[0])
return list(extract_column_identifiers(stream))


if __name__ == '__main__':
sql = """
WITH schema AS (
SELECT a, b, c, d
FROM schema
)
SELECT ALL t0_as_b, `t1_as_c` AS "t1 as c", COUNT(*) AS "count"
FROM (
SELECT ALL `t0_as_b`, max(`t1_as_c`) AS `t1_as_c`, max(`t2 d as d`) AS `t2 d as d`
FROM (
SELECT a, b AS `t0_as_b`
FROM schema
) t0
INNER JOIN (
SELECT a, c AS "t1_as_c"
FROM schema
) t1
ON t0.a = t1.a
INNER JOIN (
SELECT a, d AS 't2 d as d'
FROM schema
) t2
ON t0.a = t2.a
WHERE 1 = 1
GROUP BY a, `t0_as_b`
) "virtual_table"
GROUP BY `t0_as_b`, `t1_as_c`
ORDER BY `t1_as_c` DESC
LIMIT 1000;
"""

columns = ', '.join(extract_columns(sql))
print('Columns: {}'.format(columns))
# >>> [Output]:
# >>> Columns: t0_as_b, t1 as c, count