forked from catherinedevlin/ipython-sql
-
Notifications
You must be signed in to change notification settings - Fork 76
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/master' into 643-document-post…
…gres-test merging with master to update ci.yml fix
- Loading branch information
Showing
14 changed files
with
764 additions
and
308 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import argparse | ||
import sys | ||
from sql import exceptions | ||
|
||
|
||
class CmdParser(argparse.ArgumentParser): | ||
""" | ||
Subclassing ArgumentParser as it throws a SystemExit | ||
error when it encounters argument validation errors. | ||
Now we raise a UsageError in case of argument validation | ||
issues. | ||
""" | ||
|
||
def exit(self, status=0, message=None): | ||
if message: | ||
self._print_message(message, sys.stderr) | ||
|
||
def error(self, message): | ||
raise exceptions.UsageError(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from sql import inspect | ||
from sql.util import sanitize_identifier | ||
from sql.cmd.cmd_utils import CmdParser | ||
|
||
|
||
def columns(others): | ||
""" | ||
Implementation of `%sqlcmd columns` | ||
This function takes in a string containing command line arguments, | ||
parses them to extract the name of the table and the schema, and returns | ||
a list of columns for the specified table. | ||
Parameters | ||
---------- | ||
others : str, | ||
A string containing the command line arguments. | ||
Returns | ||
------- | ||
columns: list | ||
information of the columns in the specified table | ||
""" | ||
parser = CmdParser() | ||
|
||
parser.add_argument("-t", "--table", type=str, help="Table name", required=True) | ||
parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) | ||
|
||
args = parser.parse_args(others) | ||
return inspect.get_columns(name=sanitize_identifier(args.table), schema=args.schema) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from sql.widgets import TableWidget | ||
from IPython.display import display | ||
from sql.cmd.cmd_utils import CmdParser | ||
|
||
|
||
def explore(others): | ||
""" | ||
Implementation of `%sqlcmd explore` | ||
This function takes in a string containing command line arguments, | ||
parses them to extract the name of the table, and displays an interactive | ||
widget for exploring the contents of the specified table. | ||
Parameters | ||
---------- | ||
others : str, | ||
A string containing the command line arguments. | ||
""" | ||
parser = CmdParser() | ||
parser.add_argument("-t", "--table", type=str, help="Table name", required=True) | ||
args = parser.parse_args(others) | ||
|
||
table_widget = TableWidget(args.table) | ||
display(table_widget) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from sql import inspect | ||
from sql.cmd.cmd_utils import CmdParser | ||
|
||
|
||
def profile(others): | ||
""" | ||
Implementation of `%sqlcmd profile` | ||
This function takes in a string containing command line arguments, | ||
parses them to extract the name of the table, the schema, and the output location. | ||
It then retrieves statistical information about the specified table and either | ||
returns the report or writes it to the specified location. | ||
Parameters | ||
---------- | ||
others : str, | ||
A string containing the command line arguments. | ||
Returns | ||
------- | ||
report: PrettyTable | ||
statistics of the table | ||
""" | ||
parser = CmdParser() | ||
parser.add_argument("-t", "--table", type=str, help="Table name", required=True) | ||
|
||
parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) | ||
|
||
parser.add_argument( | ||
"-o", "--output", type=str, help="Store report location", required=False | ||
) | ||
|
||
args = parser.parse_args(others) | ||
|
||
report = inspect.get_table_statistics(schema=args.schema, name=args.table) | ||
|
||
if args.output: | ||
with open(args.output, "w") as f: | ||
f.write(report._repr_html_()) | ||
|
||
return report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from sql import inspect | ||
from sql.cmd.cmd_utils import CmdParser | ||
|
||
|
||
def tables(others): | ||
""" | ||
Implementation of `%sqlcmd tables` | ||
This function takes in a string containing command line arguments, | ||
parses them to extract the schema name, and returns a list of table names | ||
present in the specified schema or in the default schema if none is specified. | ||
Parameters | ||
---------- | ||
others : str, | ||
A string containing the command line arguments. | ||
Returns | ||
------- | ||
table_names: list | ||
list of tables in the schema | ||
""" | ||
parser = CmdParser() | ||
|
||
parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) | ||
|
||
args = parser.parse_args(others) | ||
|
||
return inspect.get_table_names(schema=args.schema) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from sql import exceptions | ||
import sql.connection | ||
from sqlalchemy import text | ||
from sqlglot import select, condition | ||
from prettytable import PrettyTable | ||
from sql.cmd.cmd_utils import CmdParser | ||
|
||
|
||
def return_test_results(args, conn, query): | ||
try: | ||
columns = [] | ||
column_data = conn.execute(text(query)).cursor.description | ||
res = conn.execute(text(query)).fetchall() | ||
for column in column_data: | ||
columns.append(column[0]) | ||
res = [columns, *res] | ||
return res | ||
except Exception as e: | ||
if "column" in str(e): | ||
raise exceptions.UsageError( | ||
f"Referenced column '{args.column}' not found!" | ||
) from e | ||
|
||
|
||
def run_each_individually(args, conn): | ||
base_query = select("*").from_(args.table) | ||
|
||
storage = {} | ||
|
||
if args.greater: | ||
where = condition(args.column + "<=" + args.greater) | ||
current_query = base_query.where(where).sql() | ||
|
||
res = return_test_results(args, conn, query=current_query) | ||
|
||
if res is not None: | ||
storage["greater"] = res | ||
if args.greater_or_equal: | ||
where = condition(args.column + "<" + args.greater_or_equal) | ||
|
||
current_query = base_query.where(where).sql() | ||
|
||
res = return_test_results(args, conn, query=current_query) | ||
|
||
if res is not None: | ||
storage["greater_or_equal"] = res | ||
|
||
if args.less_than_or_equal: | ||
where = condition(args.column + ">" + args.less_than_or_equal) | ||
current_query = base_query.where(where).sql() | ||
|
||
res = return_test_results(args, conn, query=current_query) | ||
|
||
if res is not None: | ||
storage["less_than_or_equal"] = res | ||
if args.less_than: | ||
where = condition(args.column + ">=" + args.less_than) | ||
current_query = base_query.where(where).sql() | ||
|
||
res = return_test_results(args, conn, query=current_query) | ||
|
||
if res is not None: | ||
storage["less_than"] = res | ||
if args.no_nulls: | ||
where = condition("{} is NULL".format(args.column)) | ||
current_query = base_query.where(where).sql() | ||
|
||
res = return_test_results(args, conn, query=current_query) | ||
|
||
if res is not None: | ||
storage["null"] = res | ||
|
||
return storage | ||
|
||
|
||
def test(others): | ||
""" | ||
Implementation of `%sqlcmd test` | ||
This function takes in a string containing command line arguments, | ||
parses them to extract the table name, column name, and conditions | ||
to return if those conditions are satisfied in that table | ||
Parameters | ||
---------- | ||
others : str, | ||
A string containing the command line arguments. | ||
Returns | ||
------- | ||
result: bool | ||
Result of the test | ||
table: PrettyTable | ||
table with rows because of which the test fails | ||
""" | ||
parser = CmdParser() | ||
|
||
parser.add_argument("-t", "--table", type=str, help="Table name", required=True) | ||
parser.add_argument("-c", "--column", type=str, help="Column name", required=False) | ||
parser.add_argument( | ||
"-g", | ||
"--greater", | ||
type=str, | ||
help="Greater than a certain number.", | ||
required=False, | ||
) | ||
parser.add_argument( | ||
"-goe", | ||
"--greater-or-equal", | ||
type=str, | ||
help="Greater or equal than a certain number.", | ||
required=False, | ||
) | ||
parser.add_argument( | ||
"-l", | ||
"--less-than", | ||
type=str, | ||
help="Less than a certain number.", | ||
required=False, | ||
) | ||
parser.add_argument( | ||
"-loe", | ||
"--less-than-or-equal", | ||
type=str, | ||
help="Less than or equal to a certain number.", | ||
required=False, | ||
) | ||
parser.add_argument( | ||
"-nn", | ||
"--no-nulls", | ||
help="Returns rows in specified column that are not null.", | ||
action="store_true", | ||
) | ||
|
||
args = parser.parse_args(others) | ||
|
||
COMPARATOR_ARGS = [ | ||
args.greater, | ||
args.greater_or_equal, | ||
args.less_than, | ||
args.less_than_or_equal, | ||
] | ||
|
||
if args.table and not any(COMPARATOR_ARGS): | ||
raise exceptions.UsageError("Please use a valid comparator.") | ||
|
||
if args.table and any(COMPARATOR_ARGS) and not args.column: | ||
raise exceptions.UsageError("Please pass a column to test.") | ||
|
||
if args.greater and args.greater_or_equal: | ||
return exceptions.UsageError( | ||
"You cannot use both greater and greater " | ||
"than or equal to arguments at the same time." | ||
) | ||
elif args.less_than and args.less_than_or_equal: | ||
return exceptions.UsageError( | ||
"You cannot use both less and less than " | ||
"or equal to arguments at the same time." | ||
) | ||
|
||
conn = sql.connection.Connection.current.session | ||
result_dict = run_each_individually(args, conn) | ||
|
||
if any(len(rows) > 1 for rows in list(result_dict.values())): | ||
for comparator, rows in result_dict.items(): | ||
if len(rows) > 1: | ||
print(f"\n{comparator}:\n") | ||
_pretty = PrettyTable() | ||
_pretty.field_names = rows[0] | ||
for row in rows[1:]: | ||
_pretty.add_row(row) | ||
print(_pretty) | ||
raise exceptions.UsageError( | ||
"The above values do not not match your test requirements." | ||
) | ||
else: | ||
return True |
Oops, something went wrong.