Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

518 refactor magic cmd #654

Merged
merged 16 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# CHANGELOG

## 0.7.10dev

* [Feature] Modified `TableDescription` to add styling, generate messages and format the calculated outputs (#459)
* [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525))
* [Doc] Modified integrations content to ensure they're all consistent (#523)
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)

* [Fix] Refactored `ResultSet` to lazy loading (#470)

## 0.7.9 (2023-06-19)
Expand Down
Empty file added src/sql/cmd/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions src/sql/cmd/cmd_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import argparse
import sys
from sql import exceptions


class CmdParser(argparse.ArgumentParser):
"""
Subclassing ArgumentParser as it throws a SystemExit
error when it encounters argument validation errors.


Now we raise a UsageError in case of argument validation
issues.
"""

def exit(self, status=0, message=None):
if message:
self._print_message(message, sys.stderr)

def error(self, message):
AnirudhVIyer marked this conversation as resolved.
Show resolved Hide resolved
raise exceptions.UsageError(message)
29 changes: 29 additions & 0 deletions src/sql/cmd/columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from sql import inspect
from sql.util import sanitize_identifier
from sql.cmd.cmd_utils import CmdParser


def columns(others):
"""
Implementation of `%sqlcmd columns`
This function takes in a string containing command line arguments,
parses them to extract the name of the table and the schema, and returns
a list of columns for the specified table.

Parameters
----------
others : str,
A string containing the command line arguments.

Returns
-------
columns: list
information of the columns in the specified table
"""
parser = CmdParser()

parser.add_argument("-t", "--table", type=str, help="Table name", required=True)
parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False)

args = parser.parse_args(others)
return inspect.get_columns(name=sanitize_identifier(args.table), schema=args.schema)
24 changes: 24 additions & 0 deletions src/sql/cmd/explore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from sql.widgets import TableWidget
from IPython.display import display
from sql.cmd.cmd_utils import CmdParser


def explore(others):
"""
Implementation of `%sqlcmd explore`
This function takes in a string containing command line arguments,
parses them to extract the name of the table, and displays an interactive
widget for exploring the contents of the specified table.

Parameters
----------
others : str,
A string containing the command line arguments.

"""
parser = CmdParser()
parser.add_argument("-t", "--table", type=str, help="Table name", required=True)
args = parser.parse_args(others)

table_widget = TableWidget(args.table)
display(table_widget)
41 changes: 41 additions & 0 deletions src/sql/cmd/profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from sql import inspect
from sql.cmd.cmd_utils import CmdParser


def profile(others):
"""
Implementation of `%sqlcmd profile`
This function takes in a string containing command line arguments,
parses them to extract the name of the table, the schema, and the output location.
It then retrieves statistical information about the specified table and either
returns the report or writes it to the specified location.


Parameters
----------
others : str,
A string containing the command line arguments.

Returns
-------
report: PrettyTable
statistics of the table
"""
parser = CmdParser()
parser.add_argument("-t", "--table", type=str, help="Table name", required=True)

parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False)

parser.add_argument(
"-o", "--output", type=str, help="Store report location", required=False
)

args = parser.parse_args(others)

report = inspect.get_table_statistics(schema=args.schema, name=args.table)

if args.output:
with open(args.output, "w") as f:
f.write(report._repr_html_())

return report
14 changes: 9 additions & 5 deletions src/sql/sqlcmd.py → src/sql/cmd/snippets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sql.magic_cmd import CmdParser
from sql import util
from sql.exceptions import UsageError
from sql.cmd.cmd_utils import CmdParser


def _modify_display_msg(key, remaining_keys, dependent_keys=None):
Expand Down Expand Up @@ -30,15 +30,19 @@ def _modify_display_msg(key, remaining_keys, dependent_keys=None):
return msg


def sqlcmd_snippets(others):
def snippets(others):
"""

Parameters
----------
Implementation of `%sqlcmd snippets`
This function handles all the arguments related to %sqlcmd snippets, namely
listing stored snippets, and delete/ force delete/ force delete a snippet and
all its dependent snippets.


Parameters
----------
others : str,
A string containing the command line arguments.

"""
parser = CmdParser()
parser.add_argument(
Expand Down
30 changes: 30 additions & 0 deletions src/sql/cmd/tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from sql import inspect
from sql.cmd.cmd_utils import CmdParser


def tables(others):
"""
Implementation of `%sqlcmd tables`

This function takes in a string containing command line arguments,
parses them to extract the schema name, and returns a list of table names
present in the specified schema or in the default schema if none is specified.

Parameters
----------
others : str,
A string containing the command line arguments.

Returns
-------
table_names: list
list of tables in the schema

"""
parser = CmdParser()

parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False)

args = parser.parse_args(others)

return inspect.get_table_names(schema=args.schema)
180 changes: 180 additions & 0 deletions src/sql/cmd/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from sql import exceptions
import sql.connection
from sqlalchemy import text
from sqlglot import select, condition
from prettytable import PrettyTable
from sql.cmd.cmd_utils import CmdParser


def return_test_results(args, conn, query):
try:
columns = []
column_data = conn.execute(text(query)).cursor.description
res = conn.execute(text(query)).fetchall()
for column in column_data:
columns.append(column[0])
res = [columns, *res]
return res
except Exception as e:
if "column" in str(e):
raise exceptions.UsageError(
f"Referenced column '{args.column}' not found!"
) from e


def run_each_individually(args, conn):
base_query = select("*").from_(args.table)

storage = {}

if args.greater:
where = condition(args.column + "<=" + args.greater)
current_query = base_query.where(where).sql()

res = return_test_results(args, conn, query=current_query)

if res is not None:
storage["greater"] = res
if args.greater_or_equal:
where = condition(args.column + "<" + args.greater_or_equal)

current_query = base_query.where(where).sql()

res = return_test_results(args, conn, query=current_query)

if res is not None:
storage["greater_or_equal"] = res

if args.less_than_or_equal:
where = condition(args.column + ">" + args.less_than_or_equal)
current_query = base_query.where(where).sql()

res = return_test_results(args, conn, query=current_query)

if res is not None:
storage["less_than_or_equal"] = res
if args.less_than:
where = condition(args.column + ">=" + args.less_than)
current_query = base_query.where(where).sql()

res = return_test_results(args, conn, query=current_query)

if res is not None:
storage["less_than"] = res
if args.no_nulls:
where = condition("{} is NULL".format(args.column))
current_query = base_query.where(where).sql()

res = return_test_results(args, conn, query=current_query)

if res is not None:
storage["null"] = res

return storage


def test(others):
"""
Implementation of `%sqlcmd test`

This function takes in a string containing command line arguments,
parses them to extract the table name, column name, and conditions
to return if those conditions are satisfied in that table

Parameters
----------
others : str,
A string containing the command line arguments.

Returns
-------
result: bool
Result of the test

table: PrettyTable
table with rows because of which the test fails


"""
parser = CmdParser()

parser.add_argument("-t", "--table", type=str, help="Table name", required=True)
parser.add_argument("-c", "--column", type=str, help="Column name", required=False)
parser.add_argument(
"-g",
"--greater",
type=str,
help="Greater than a certain number.",
required=False,
)
parser.add_argument(
"-goe",
"--greater-or-equal",
type=str,
help="Greater or equal than a certain number.",
required=False,
)
parser.add_argument(
"-l",
"--less-than",
type=str,
help="Less than a certain number.",
required=False,
)
parser.add_argument(
"-loe",
"--less-than-or-equal",
type=str,
help="Less than or equal to a certain number.",
required=False,
)
parser.add_argument(
"-nn",
"--no-nulls",
help="Returns rows in specified column that are not null.",
action="store_true",
)

args = parser.parse_args(others)

COMPARATOR_ARGS = [
args.greater,
args.greater_or_equal,
args.less_than,
args.less_than_or_equal,
]

if args.table and not any(COMPARATOR_ARGS):
raise exceptions.UsageError("Please use a valid comparator.")

if args.table and any(COMPARATOR_ARGS) and not args.column:
raise exceptions.UsageError("Please pass a column to test.")

if args.greater and args.greater_or_equal:
return exceptions.UsageError(
"You cannot use both greater and greater "
"than or equal to arguments at the same time."
)
elif args.less_than and args.less_than_or_equal:
return exceptions.UsageError(
"You cannot use both less and less than "
"or equal to arguments at the same time."
)

conn = sql.connection.Connection.current.session
result_dict = run_each_individually(args, conn)

if any(len(rows) > 1 for rows in list(result_dict.values())):
for comparator, rows in result_dict.items():
if len(rows) > 1:
print(f"\n{comparator}:\n")
_pretty = PrettyTable()
_pretty.field_names = rows[0]
for row in rows[1:]:
_pretty.add_row(row)
print(_pretty)
raise exceptions.UsageError(
"The above values do not not match your test requirements."
)
else:
return True
Loading
Loading