diff --git a/CHANGELOG.md b/CHANGELOG.md index 43f2f3e44..32cc1a727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ * [Fix] Fix Twice message printing when switching to the current connection ([#772](https://github.com/ploomber/jupysql/issues/772)) * [Fix] Error when using %sqlplot in snowflake ([#697](https://github.com/ploomber/jupysql/issues/697)) * [Doc] Fixes documentation inaccuracy that said `:variable` was deprecated (we brought it back in `0.9.0`) +* [Fix] Descriptive error messages when specific syntax error occurs when running query in DuckDB or Oracle. ## 0.9.1 (2023-08-10) diff --git a/src/sql/cmd/snippets.py b/src/sql/cmd/snippets.py index fe6e19eb9..40a1f60d9 100644 --- a/src/sql/cmd/snippets.py +++ b/src/sql/cmd/snippets.py @@ -2,6 +2,7 @@ from sql.exceptions import UsageError from sql.cmd.cmd_utils import CmdParser from sql.store import store +from sql import store_utils from sql.display import Table, Message @@ -64,7 +65,7 @@ def snippets(others): help="Force delete all stored snippets", required=False, ) - all_snippets = util.get_all_keys() + all_snippets = store_utils.get_all_keys() if len(others) == 1: if others[0] in all_snippets: return str(store[others[0]]) @@ -87,7 +88,7 @@ def snippets(others): return Table(["Stored snippets"], [[snippet] for snippet in all_snippets]) if args.delete: - deps = util.get_key_dependents(args.delete) + deps = store_utils.get_key_dependents(args.delete) if deps: deps = ", ".join(deps) raise UsageError( @@ -97,18 +98,18 @@ def snippets(others): ) else: key = args.delete - remaining_keys = util.del_saved_key(key) + remaining_keys = store_utils.del_saved_key(key) return _modify_display_msg(key, remaining_keys) elif args.delete_force: key = args.delete_force - deps = util.get_key_dependents(key) - remaining_keys = util.del_saved_key(key) + deps = store_utils.get_key_dependents(key) + remaining_keys = store_utils.del_saved_key(key) return _modify_display_msg(key, remaining_keys, deps) elif args.delete_force_all: - deps = util.get_key_dependents(args.delete_force_all) + deps = store_utils.get_key_dependents(args.delete_force_all) deps.append(args.delete_force_all) for key in deps: - remaining_keys = util.del_saved_key(key) + remaining_keys = store_utils.del_saved_key(key) return _modify_display_msg(", ".join(deps), remaining_keys) diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 3b0e95f96..9dfee4079 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -24,7 +24,7 @@ from sql.store import store from sql.telemetry import telemetry from sql import exceptions, display -from sql.error_message import detail +from sql.error_handler import handle_exception from sql.parse import ( escape_string_literals_with_colon_prefix, find_named_parameters, @@ -935,11 +935,7 @@ def _start_sqlalchemy_connection(cls, engine, connect_str): connection = engine.connect() return connection except OperationalError as e: - detailed_msg = detail(e) - if detailed_msg is not None: - raise exceptions.RuntimeError(detailed_msg) from e - else: - raise exceptions.RuntimeError(str(e)) from e + handle_exception(e) except Exception as e: raise _error_invalid_connection_info(e, connect_str) from e diff --git a/src/sql/error_handler.py b/src/sql/error_handler.py new file mode 100644 index 000000000..2b12b260c --- /dev/null +++ b/src/sql/error_handler.py @@ -0,0 +1,102 @@ +from sql import display +from sql import util +from sql.store_utils import get_all_keys +from sql.exceptions import RuntimeError, TableNotFoundError + + +ORIGINAL_ERROR = "\nOriginal error message from DB driver:\n" +CTE_MSG = ( + "If using snippets, you may pass the --with argument explicitly.\n" + "For more details please refer: " + "https://jupysql.ploomber.io/en/latest/compose.html#with-argument" +) +POSTGRES_MSG = """\nLooks like you have run into some issues. + Review our DB connection via URL strings guide: + https://jupysql.ploomber.io/en/latest/connecting.html . + Using Ubuntu? Check out this guide: " + https://help.ubuntu.com/community/PostgreSQL#fe_sendauth:_ + no_password_supplied\n""" + + +def _snippet_typo_error_message(query): + """Function to generate message for possible + snippets if snippet name in user query is a + typo + """ + if query: + tables = util.extract_tables_from_query(query) + for table in tables: + suggestions = util.find_close_match(table, get_all_keys()) + err_message = f"There is no table with name {table!r}." + if len(suggestions) > 0: + suggestions_message = util.get_suggestions_message(suggestions) + return f"{err_message}{suggestions_message}" + return "" + + +def _detailed_message_with_error_type(error, query): + """Function to generate descriptive error message. + Currently it handles syntax error messages, table not found messages + and password issue when connecting to postgres + """ + original_error = str(error) + syntax_error_substrings = [ + "syntax error", + "error in your sql syntax", + "incorrect syntax", + "invalid sql", + ] + not_found_substrings = [ + "does not exist", + "not found", + "could not find", + "no such table", + ] + if util.if_substring_exists(original_error.lower(), syntax_error_substrings): + return f"{CTE_MSG}\n\n{ORIGINAL_ERROR}{original_error}\n", RuntimeError + elif util.if_substring_exists(original_error.lower(), not_found_substrings): + typo_err_msg = _snippet_typo_error_message(query) + if typo_err_msg: + return ( + f"{CTE_MSG}\n\n{typo_err_msg}\n\n" + f"{ORIGINAL_ERROR}{original_error}\n", + TableNotFoundError, + ) + else: + return f"{CTE_MSG}\n\n{ORIGINAL_ERROR}{original_error}\n", RuntimeError + elif "fe_sendauth: no password supplied" in original_error: + return f"{POSTGRES_MSG}\n{ORIGINAL_ERROR}{original_error}\n", RuntimeError + return None, None + + +def _display_error_msg_with_trace(error, message): + """Displays the detailed error message and prints + original stack trace as well.""" + if message is not None: + display.message(message) + error.modify_exception = True + raise error + + +def _raise_error(error, message, error_type): + """Raise specific error from the detailed message. If detailed + message is None reraise original error""" + if message is not None: + raise error_type(message) from error + else: + raise RuntimeError(str(error)) from error + + +def handle_exception(error, query=None, short_error=True): + """ + This function is the entry point for detecting error type + and handling it accordingly. + """ + if util.is_sqlalchemy_error(error) or util.is_non_sqlalchemy_error(error): + detailed_message, error_type = _detailed_message_with_error_type(error, query) + if short_error: + _raise_error(error, detailed_message, error_type) + else: + _display_error_msg_with_trace(error, detailed_message) + else: + raise error diff --git a/src/sql/error_message.py b/src/sql/error_message.py deleted file mode 100644 index 7c404f266..000000000 --- a/src/sql/error_message.py +++ /dev/null @@ -1,39 +0,0 @@ -ORIGINAL_ERROR = "\nOriginal error message from DB driver:\n" -CTE_MSG = ( - "If using snippets, you may pass the --with argument explicitly.\n" - "For more details please refer: " - "https://jupysql.ploomber.io/en/latest/compose.html#with-argument" -) - - -def _is_syntax_error(error): - """ - Function to detect whether error message from DB driver - is related to syntax error in user query. - """ - error_lower = error.lower() - return ( - "syntax error" in error_lower - or ("catalog error" in error_lower and "does not exist" in error_lower) - or "error in your sql syntax" in error_lower - or "incorrect syntax" in error_lower - or "not found" in error_lower - ) - - -def detail(original_error): - original_error = str(original_error) - if _is_syntax_error(original_error): - return f"{CTE_MSG}\n\n{ORIGINAL_ERROR}{original_error}\n" - - if "fe_sendauth: no password supplied" in original_error: - specific_error = """\nLooks like you have run into some issues. - Review our DB connection via URL strings guide: - https://jupysql.ploomber.io/en/latest/connecting.html . - Using Ubuntu? Check out this guide: " - https://help.ubuntu.com/community/PostgreSQL#fe_sendauth:_ - no_password_supplied\n""" - - return f"{specific_error}\n{ORIGINAL_ERROR}{original_error}\n" - - return None diff --git a/src/sql/inspect.py b/src/sql/inspect.py index 176347b8a..b1c216b26 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -6,6 +6,7 @@ from sql import exceptions import math from sql import util +from sql.store_utils import get_all_keys from IPython.core.display import HTML import uuid @@ -172,7 +173,7 @@ class Columns(DatabaseInspection): """ def __init__(self, name, schema, conn=None) -> None: - util.is_table_exists(name, schema) + is_table_exists(name, schema) inspector = _get_inspector(conn) @@ -230,7 +231,7 @@ class TableDescription(DatabaseInspection): """ def __init__(self, table_name, schema=None) -> None: - util.is_table_exists(table_name, schema) + is_table_exists(table_name, schema) if schema: table_name = f"{schema}.{table_name}" @@ -501,3 +502,171 @@ def get_schema_names(conn=None): """Get list of schema names for a given connection""" inspector = _get_inspector(conn) return inspector.get_schema_names() + + +def support_only_sql_alchemy_connection(command): + """ + Throws a sql.exceptions.RuntimeError if connection is not SQLAlchemy + """ + if ConnectionManager.current.is_dbapi_connection: + raise exceptions.RuntimeError( + f"{command} is only supported with SQLAlchemy " + "connections, not with DBAPI connections" + ) + + +def _is_table_exists(table: str, conn) -> bool: + """ + Runs a SQL query to check if table exists + """ + if not conn: + conn = ConnectionManager.current + + identifiers = conn.get_curr_identifiers() + + for iden in identifiers: + if isinstance(iden, tuple): + query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format(iden[0], table, iden[1]) + else: + query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) + try: + conn.execute(query) + return True + except Exception: + pass + + return False + + +def _get_list_of_existing_tables() -> list: + """ + Returns a list of table names for a given connection + """ + tables = [] + tables_rows = get_table_names()._table + for row in tables_rows: + table_name = row.get_string(fields=["Name"], border=False, header=False).strip() + + tables.append(table_name) + return tables + + +def is_table_exists( + table: str, + schema: str = None, + ignore_error: bool = False, + conn=None, +) -> bool: + """ + Checks if a given table exists for a given connection + + Parameters + ---------- + table: str + Table name + + schema: str, default None + Schema name + + ignore_error: bool, default False + Avoid raising a ValueError + """ + if table is None: + if ignore_error: + return False + else: + raise exceptions.UsageError("Table cannot be None") + if not ConnectionManager.current: + raise exceptions.RuntimeError("No active connection") + if not conn: + conn = ConnectionManager.current + + table = util.strip_multiple_chars(table, "\"'") + + if schema: + table_ = f"{schema}.{table}" + else: + table_ = table + + _is_exist = _is_table_exists(table_, conn) + + if not _is_exist: + if not ignore_error: + try_find_suggestions = not conn.is_dbapi_connection + expected = [] + existing_schemas = [] + existing_tables = [] + + if try_find_suggestions: + existing_schemas = get_schema_names() + + if schema and schema not in existing_schemas: + expected = existing_schemas + invalid_input = schema + else: + if try_find_suggestions: + existing_tables = _get_list_of_existing_tables() + + expected = existing_tables + invalid_input = table + + if schema: + err_message = ( + f"There is no table with name {table!r} in schema {schema!r}" + ) + else: + err_message = ( + f"There is no table with name {table!r} in the default schema" + ) + + if table not in get_all_keys(): + suggestions = util.find_close_match(invalid_input, expected) + suggestions_store = util.find_close_match(invalid_input, get_all_keys()) + suggestions.extend(suggestions_store) + suggestions_message = util.get_suggestions_message(suggestions) + if suggestions_message: + err_message = f"{err_message}{suggestions_message}" + raise exceptions.TableNotFoundError(err_message) + + return _is_exist + + +def fetch_sql_with_pagination( + table, offset, n_rows, sort_column=None, sort_order=None +) -> tuple: + """ + Returns next n_rows and columns from table starting at the offset + + Parameters + ---------- + table : str + Table name + + offset : int + Specifies the number of rows to skip before + it starts to return rows from the query expression. + + n_rows : int + Number of rows to return. + + sort_column : str, default None + Sort by column + + sort_order : 'DESC' or 'ASC', default None + Order list + """ + is_table_exists(table) + + order_by = "" if not sort_column else f"ORDER BY {sort_column} {sort_order}" + + query = f""" + SELECT * FROM {table} {order_by} + OFFSET {offset} ROWS FETCH NEXT {n_rows} ROWS ONLY""" + + rows = ConnectionManager.current.execute(query).fetchall() + + columns = ConnectionManager.current.raw_execute( + f"SELECT * FROM {table} WHERE 1=0" + ).keys() + + return rows, columns diff --git a/src/sql/magic.py b/src/sql/magic.py index af6e1cf21..4734ed91b 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -27,7 +27,6 @@ import warnings import shlex -from difflib import get_close_matches import sql.connection import sql.parse from sql.run.run import run_statements @@ -38,10 +37,9 @@ from sql.magic_plot import SqlPlotMagic from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error -from sql import query_util, util -from sql.util import get_suggestions_message, pretty_print -from sql.exceptions import RuntimeError -from sql.error_message import detail +from sql import util +from sql.util import pretty_print +from sql.error_handler import handle_exception from sql._current import _set_sql_magic @@ -250,33 +248,6 @@ def check_random_arguments(self, line="", cell=""): "Unrecognized argument(s): {}".format(check_argument) ) - def _error_handling(self, e, query): - detailed_msg = detail(e) - if self.short_errors: - if detailed_msg is not None: - raise exceptions.RuntimeError(detailed_msg) from e - # TODO: move to error_messages.py - # Added here due to circular dependency issue (#545) - elif "no such table" in str(e): - tables = query_util.extract_tables_from_query(query) - for table in tables: - suggestions = get_close_matches(table, list(self._store)) - err_message = f"There is no table with name {table!r}." - # with_message = "Alternatively, please specify table - # name using --with argument" - if len(suggestions) > 0: - suggestions_message = get_suggestions_message(suggestions) - raise exceptions.TableNotFoundError( - f"{err_message}{suggestions_message}" - ) from e - - raise RuntimeError(str(e)) from e - else: - if detailed_msg is not None: - display.message(detailed_msg) - e.modify_exception = True - raise e - @no_var_expand @needs_local_scope @line_magic("sql") @@ -612,13 +583,10 @@ def interactive_execute_wrapper(**kwargs): StatementError, ) as e: # Sqlite apparently return all errors as OperationalError :/ - self._error_handling(e, command.sql) + handle_exception(e, command.sql, self.short_errors) except Exception as e: - # handle DuckDB exceptions - if "Catalog Error" in str(e): - self._error_handling(e, command.sql) - else: - raise e + # Handle non SQLAlchemy errors + handle_exception(e, command.sql, self.short_errors) legal_sql_identifier = re.compile(r"^[A-Za-z0-9#_$]+") diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index 66fb43084..b6f33d486 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -4,7 +4,7 @@ from IPython.core.magic import Magics, line_magic, magics_class from IPython.core.magic_arguments import argument, magic_arguments -from sql import util +from sql.inspect import support_only_sql_alchemy_connection from sql.cmd.tables import tables from sql.cmd.columns import columns from sql.cmd.test import test @@ -87,7 +87,7 @@ def _validate_execute_inputs(self, line): ) if command in COMMANDS_SQLALCHEMY_ONLY: - util.support_only_sql_alchemy_connection(f"%sqlcmd {command}") + support_only_sql_alchemy_connection(f"%sqlcmd {command}") return self.execute(command, others) else: diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index 5c9ba07b8..c590b4d23 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -16,6 +16,8 @@ from sql.command import SQLPlotCommand from sql import exceptions from sql import util +from sql.inspect import is_table_exists +from sql.store_utils import is_saved_snippet SUPPORTED_PLOTS = ["histogram", "boxplot", "bar", "pie"] @@ -154,8 +156,8 @@ def execute(self, line="", cell="", local_ns=None): @staticmethod def _check_table_exists(table): with_ = None - if util.is_saved_snippet(table): + if is_saved_snippet(table): with_ = [table] else: - util.is_table_exists(table) + is_table_exists(table) return with_ diff --git a/src/sql/query_util.py b/src/sql/query_util.py deleted file mode 100644 index fb342b812..000000000 --- a/src/sql/query_util.py +++ /dev/null @@ -1,29 +0,0 @@ -from sqlglot import parse_one, exp -from sqlglot.errors import ParseError - - -def extract_tables_from_query(query): - """ - Function to extract names of tables from - a syntactically correct query - - Parameters - ---------- - query : str, user query - - Returns - ------- - list - List of tables in the query - [] if error in parsing the query - """ - try: - tables = [table.name for table in parse_one(query).find_all(exp.Table)] - return tables - except ParseError: - # TODO : Instead of returning [] replace with call to - # error_messages.py::parse_sqlglot_error. Currently this - # is not possible because of an exception raised in test - # fixtures. (#546). This function can also be moved to util.py - # after #545 is resolved. - return [] diff --git a/src/sql/store.py b/src/sql/store.py index c671ca017..6953559c0 100644 --- a/src/sql/store.py +++ b/src/sql/store.py @@ -6,7 +6,7 @@ import difflib from sql import exceptions -from sql import query_util +from sql import util class SQLStore(MutableMapping): @@ -73,7 +73,7 @@ def infer_dependencies(self, query, key): saved_key for saved_key in list(self._data.keys()) if saved_key != key ] if saved_keys and query: - tables = query_util.extract_tables_from_query(query) + tables = util.extract_tables_from_query(query) for table in tables: if table in saved_keys: dependencies.append(table) @@ -154,15 +154,6 @@ def _get_dependencies(store, keys): return list(dict.fromkeys(deps + keys)) -def _get_dependents_for_key(store, key): - key_dependents = [] - for k in list(store): - deps = _get_dependencies_for_key(store, k) - if key in deps: - key_dependents.append(k) - return key_dependents - - def _get_dependencies_for_key(store, key): """Retrieve dependencies for a single key""" deps = store[key]._with_ @@ -175,5 +166,14 @@ def _flatten(elements): return [element for sub in elements for element in sub] +def get_dependents_for_key(store, key): + key_dependents = [] + for k in list(store): + deps = _get_dependencies_for_key(store, k) + if key in deps: + key_dependents.append(k) + return key_dependents + + # session-wide store store = SQLStore() diff --git a/src/sql/store_utils.py b/src/sql/store_utils.py new file mode 100644 index 000000000..4bcda2eab --- /dev/null +++ b/src/sql/store_utils.py @@ -0,0 +1,49 @@ +from sql.store import store, get_dependents_for_key +from sql import exceptions + + +def get_all_keys(): + """ + Function to get list of all stored snippets in the current session + """ + return list(store) + + +def get_key_dependents(key: str) -> list: + """ + Function to find the stored snippets dependent on key + Parameters + ---------- + key : str, name of the table + + Returns + ------- + list + List of snippets dependent on key + + """ + deps = get_dependents_for_key(store, key) + return deps + + +def del_saved_key(key: str) -> str: + """ + Deletes a stored snippet + Parameters + ---------- + key : str, name of the snippet to be deleted + + Returns + ------- + list + Remaining stored snippets + """ + all_keys = get_all_keys() + if key not in all_keys: + raise exceptions.UsageError(f"No such saved snippet found : {key}") + del store[key] + return get_all_keys() + + +def is_saved_snippet(table: str) -> bool: + return table in get_all_keys() diff --git a/src/sql/util.py b/src/sql/util.py index 31fc260a3..99ed7d8e5 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -1,14 +1,15 @@ import warnings -from sql import inspect import difflib -from sql.connection import ConnectionManager -from sql.store import store, _get_dependents_for_key from sql import exceptions, display import json from pathlib import Path +from sqlglot import parse_one, exp +from sqlglot.errors import ParseError +from sqlalchemy.exc import SQLAlchemyError from ploomber_core.dependencies import requires import ast + try: import toml except ModuleNotFoundError: @@ -64,105 +65,10 @@ def get_suggestions_message(suggestions): suggestions_message = "" if len(suggestions) > 0: _suggestions_string = pretty_print(suggestions, last_delimiter="or") - suggestions_message = f"\nDid you mean : {_suggestions_string}" + suggestions_message = f"\nDid you mean: {_suggestions_string}" return suggestions_message -def is_table_exists( - table: str, - schema: str = None, - ignore_error: bool = False, - conn=None, -) -> bool: - """ - Checks if a given table exists for a given connection - - Parameters - ---------- - table: str - Table name - - schema: str, default None - Schema name - - ignore_error: bool, default False - Avoid raising a ValueError - """ - if table is None: - if ignore_error: - return False - else: - raise exceptions.UsageError("Table cannot be None") - if not ConnectionManager.current: - raise exceptions.RuntimeError("No active connection") - if not conn: - conn = ConnectionManager.current - - table = strip_multiple_chars(table, "\"'") - - if schema: - table_ = f"{schema}.{table}" - else: - table_ = table - - _is_exist = _is_table_exists(table_, conn) - - if not _is_exist: - if not ignore_error: - try_find_suggestions = not conn.is_dbapi_connection - expected = [] - existing_schemas = [] - existing_tables = [] - - if try_find_suggestions: - existing_schemas = inspect.get_schema_names() - - if schema and schema not in existing_schemas: - expected = existing_schemas - invalid_input = schema - else: - if try_find_suggestions: - existing_tables = _get_list_of_existing_tables() - - expected = existing_tables - invalid_input = table - - if schema: - err_message = ( - f"There is no table with name {table!r} in schema {schema!r}" - ) - else: - err_message = ( - f"There is no table with name {table!r} in the default schema" - ) - - if table not in list(store): - suggestions = difflib.get_close_matches(invalid_input, expected) - suggestions_store = difflib.get_close_matches( - invalid_input, list(store) - ) - suggestions.extend(suggestions_store) - suggestions_message = get_suggestions_message(suggestions) - if suggestions_message: - err_message = f"{err_message}{suggestions_message}" - raise exceptions.TableNotFoundError(err_message) - - return _is_exist - - -def _get_list_of_existing_tables() -> list: - """ - Returns a list of table names for a given connection - """ - tables = [] - tables_rows = inspect.get_table_names()._table - for row in tables_rows: - table_name = row.get_string(fields=["Name"], border=False, header=False).strip() - - tables.append(table_name) - return tables - - def pretty_print( obj: list, delimiter: str = ",", last_delimiter: str = "and", repr_: bool = False ) -> str: @@ -187,36 +93,6 @@ def strip_multiple_chars(string: str, chars: str) -> str: return string.translate(str.maketrans("", "", chars)) -def is_saved_snippet(table: str) -> bool: - if table in list(store): - display.message(f"Plotting using saved snippet : {table}") - return True - return False - - -def _is_table_exists(table: str, conn) -> bool: - """ - Runs a SQL query to check if table exists - """ - if not conn: - conn = ConnectionManager.current - - identifiers = conn.get_curr_identifiers() - - for iden in identifiers: - if isinstance(iden, tuple): - query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format(iden[0], table, iden[1]) - else: - query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) - try: - conn.execute(query) - return True - except Exception: - pass - - return False - - def flatten(src, ltypes=(list, tuple)): """The flatten function creates a new tuple / list with all sub-tuple / sub-list elements concatenated into it recursively @@ -253,58 +129,6 @@ def flatten(src, ltypes=(list, tuple)): return process_list -def support_only_sql_alchemy_connection(command): - """ - Throws a sql.exceptions.RuntimeError if connection is not SQLAlchemy - """ - if ConnectionManager.current.is_dbapi_connection: - raise exceptions.RuntimeError( - f"{command} is only supported with SQLAlchemy " - "connections, not with DBAPI connections" - ) - - -def fetch_sql_with_pagination( - table, offset, n_rows, sort_column=None, sort_order=None -) -> tuple: - """ - Returns next n_rows and columns from table starting at the offset - - Parameters - ---------- - table : str - Table name - - offset : int - Specifies the number of rows to skip before - it starts to return rows from the query expression. - - n_rows : int - Number of rows to return. - - sort_column : str, default None - Sort by column - - sort_order : 'DESC' or 'ASC', default None - Order list - """ - is_table_exists(table) - - order_by = "" if not sort_column else f"ORDER BY {sort_column} {sort_order}" - - query = f""" - SELECT * FROM {table} {order_by} - OFFSET {offset} ROWS FETCH NEXT {n_rows} ROWS ONLY""" - - rows = ConnectionManager.current.execute(query).fetchall() - - columns = ConnectionManager.current.raw_execute( - f"SELECT * FROM {table} WHERE 1=0" - ).keys() - - return rows, columns - - def parse_sql_results_to_json(rows, columns) -> str: """ Serializes sql rows to a JSON formatted ``str`` @@ -317,52 +141,6 @@ def parse_sql_results_to_json(rows, columns) -> str: return rows_json -def get_all_keys(): - """ - - Returns - ------- - All stored snippets in the current session - """ - return list(store) - - -def get_key_dependents(key: str) -> list: - """ - Function to find the stored snippets dependent on key - Parameters - ---------- - key : str, name of the table - - Returns - ------- - list - List of snippets dependent on key - - """ - deps = _get_dependents_for_key(store, key) - return deps - - -def del_saved_key(key: str) -> str: - """ - Deletes a stored snippet - Parameters - ---------- - key : str, name of the snippet to be deleted - - Returns - ------- - list - Remaining stored snippets - """ - all_keys = get_all_keys() - if key not in all_keys: - raise exceptions.UsageError(f"No such saved snippet found : {key}") - del store[key] - return get_all_keys() - - def show_deprecation_warning(): """ Raises CTE deprecation warning @@ -391,6 +169,11 @@ def find_path_from_root(file_name): return str(Path(current, file_name)) +def find_close_match(word, possibilities): + """Find closest match between invalid input and possible options""" + return difflib.get_close_matches(word, possibilities) + + def find_close_match_config(word, possibilities, n=3): """Finds closest matching configurations and displays message""" closest_matches = difflib.get_close_matches(word, possibilities, n=n) @@ -554,3 +337,48 @@ def is_valid_python_code(code): return True except SyntaxError: return False + + +def extract_tables_from_query(query): + """ + Function to extract names of tables from + a syntactically correct query + + Parameters + ---------- + query : str, user query + + Returns + ------- + list + List of tables in the query + [] if error in parsing the query + """ + try: + tables = [table.name for table in parse_one(query).find_all(exp.Table)] + return tables + except ParseError: + # TODO : Instead of returning [] return the + # exact parse error + return [] + + +def is_sqlalchemy_error(error): + """Function to check if error is SQLAlchemy error""" + return isinstance(error, SQLAlchemyError) + + +def is_non_sqlalchemy_error(error): + """Function to check if error is a specific non-SQLAlchemy error""" + specific_db_errors = [ + "duckdb.CatalogException", + "Parser Error", + "pyodbc.ProgrammingError", + ] + return any(msg in str(error) for msg in specific_db_errors) + + +def if_substring_exists(string, substrings): + """Function to check if any of substring in + substrings exist in string""" + return any(msg in string for msg in substrings) diff --git a/src/sql/widgets/table_widget/table_widget.py b/src/sql/widgets/table_widget/table_widget.py index 4c8d955fe..04cc85dbc 100644 --- a/src/sql/widgets/table_widget/table_widget.py +++ b/src/sql/widgets/table_widget/table_widget.py @@ -2,12 +2,8 @@ from IPython import get_ipython import math import time -from sql.util import ( - fetch_sql_with_pagination, - parse_sql_results_to_json, - is_table_exists, -) - +from sql.util import parse_sql_results_to_json +from sql.inspect import fetch_sql_with_pagination, is_table_exists from sql.widgets import utils from sql.telemetry import telemetry diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index 9be1315e8..4c5c0fe70 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -4,7 +4,7 @@ import pytest import warnings from sql.telemetry import telemetry -from sql.error_message import CTE_MSG +from sql.error_handler import CTE_MSG from unittest.mock import ANY, Mock from IPython.core.error import UsageError @@ -851,12 +851,7 @@ def test_sql_query(ip_with_dynamic_db, cell, request, test_table_name_dict): "ip_with_SQLite", "ip_with_duckDB_native", "ip_with_duckDB", - pytest.param( - "ip_with_MSSQL", - marks=pytest.mark.xfail( - reason="We need to close any pending results for this to work" - ), - ), + "ip_with_MSSQL", "ip_with_Snowflake", "ip_with_oracle", pytest.param( @@ -884,19 +879,11 @@ def test_sql_query_cte(ip_with_dynamic_db, request, test_table_name_dict, cell): "ip_with_mySQL", "ip_with_mariaDB", "ip_with_SQLite", - pytest.param( - "ip_with_duckDB_native", - marks=pytest.mark.xfail(reason="Not yet implemented"), - ), + "ip_with_duckDB_native", "ip_with_duckDB", "ip_with_Snowflake", - pytest.param( - "ip_with_MSSQL", marks=pytest.mark.xfail(reason="Not yet implemented") - ), - pytest.param( - "ip_with_oracle", - marks=pytest.mark.xfail(reason="Not yet implemented"), - ), + "ip_with_MSSQL", + "ip_with_oracle", pytest.param( "ip_with_clickhouse", marks=pytest.mark.xfail(reason="Not yet implemented"), diff --git a/src/tests/test_extract_tables.py b/src/tests/test_extract_tables.py index da39fb608..ea3298db7 100644 --- a/src/tests/test_extract_tables.py +++ b/src/tests/test_extract_tables.py @@ -1,5 +1,5 @@ import pytest -from sql.query_util import extract_tables_from_query +from sql.util import extract_tables_from_query @pytest.mark.parametrize( diff --git a/src/tests/test_inspect.py b/src/tests/test_inspect.py index cd2cdee63..4e13c3557 100644 --- a/src/tests/test_inspect.py +++ b/src/tests/test_inspect.py @@ -11,6 +11,13 @@ from sql import inspect, connection +EXPECTED_SUGGESTIONS_MESSAGE = "Did you mean:" +EXPECTED_NO_TABLE_IN_SCHEMA = "There is no table with name {0!r} in schema {1!r}" +EXPECTED_NO_TABLE_IN_DEFAULT_SCHEMA = ( + "There is no table with name {0!r} in the default schema" +) + + @pytest.fixture def sample_db(ip_empty, tmp_empty): ip_empty.run_cell("%sql sqlite:///first.db --alias first") @@ -89,6 +96,48 @@ def test_get_column(sample_db, name, first, second, schema): assert second in columns._repr_html_() +@pytest.mark.parametrize( + "table, offset, n_rows, expected_rows, expected_columns", + [ + ("number_table", 0, 0, [], ["x", "y"]), + ("number_table", 5, 0, [], ["x", "y"]), + ("number_table", 50, 0, [], ["x", "y"]), + ("number_table", 50, 10, [], ["x", "y"]), + ( + "number_table", + 2, + 10, + [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3), (-4, 2), (2, -5), (4, 3)], + ["x", "y"], + ), + ( + "number_table", + 2, + 100, + [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3), (-4, 2), (2, -5), (4, 3)], + ["x", "y"], + ), + ("number_table", 0, 2, [(4, -2), (-5, 0)], ["x", "y"]), + ("number_table", 2, 2, [(2, 4), (0, 2)], ["x", "y"]), + ( + "number_table", + 2, + 5, + [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3)], + ["x", "y"], + ), + ("empty_table", 2, 5, [], ["column", "another"]), + ], +) +def test_fetch_sql_with_pagination_no_sort( + ip, table, offset, n_rows, expected_rows, expected_columns +): + rows, columns = inspect.fetch_sql_with_pagination(table, offset, n_rows) + + assert rows == expected_rows + assert columns == expected_columns + + @pytest.mark.parametrize( "name, schema, error", [ @@ -251,3 +300,262 @@ def test_columns_with_missing_values( pt.add_rows(rows) assert str(inspect.get_columns(name=name, schema=schema)) == str(pt) + + +@pytest.mark.parametrize( + "table", + ["no_such_table", ""], +) +def test_fetch_sql_with_pagination_no_table_error(ip, table): + with pytest.raises(UsageError) as excinfo: + inspect.fetch_sql_with_pagination(table, 0, 2) + + assert excinfo.value.error_type == "TableNotFoundError" + + +def test_fetch_sql_with_pagination_none_table(ip): + with pytest.raises(UsageError) as excinfo: + inspect.fetch_sql_with_pagination(None, 0, 2) + + assert excinfo.value.error_type == "UsageError" + + +@pytest.mark.parametrize( + "table, offset, n_rows, sort_by, order_by, expected_rows, expected_columns", + [ + ("number_table", 0, 0, "x", "DESC", [], ["x", "y"]), + ("number_table", 5, 0, "x", "DESC", [], ["x", "y"]), + ("number_table", 50, 0, "y", "ASC", [], ["x", "y"]), + ("number_table", 50, 10, "y", "ASC", [], ["x", "y"]), + ("number_table", 0, 2, "x", "DESC", [(4, -2), (4, 3)], ["x", "y"]), + ("number_table", 0, 2, "x", "ASC", [(-5, 0), (-5, -1)], ["x", "y"]), + ("empty_table", 2, 5, "column", "ASC", [], ["column", "another"]), + ("number_table", 2, 2, "x", "ASC", [(-4, 2), (-2, -3)], ["x", "y"]), + ("number_table", 2, 2, "x", "DESC", [(2, 4), (2, -5)], ["x", "y"]), + ( + "number_table", + 2, + 10, + "x", + "DESC", + [(2, 4), (2, -5), (0, 2), (-2, -3), (-2, -3), (-4, 2), (-5, 0), (-5, -1)], + ["x", "y"], + ), + ( + "number_table", + 2, + 100, + "x", + "DESC", + [(2, 4), (2, -5), (0, 2), (-2, -3), (-2, -3), (-4, 2), (-5, 0), (-5, -1)], + ["x", "y"], + ), + ( + "number_table", + 2, + 5, + "y", + "ASC", + [(-2, -3), (4, -2), (-5, -1), (-5, 0), (0, 2)], + ["x", "y"], + ), + ], +) +def test_fetch_sql_with_pagination_with_sort( + ip, table, offset, n_rows, sort_by, order_by, expected_rows, expected_columns +): + rows, columns = inspect.fetch_sql_with_pagination( + table, offset, n_rows, sort_by, order_by + ) + + assert rows == expected_rows + assert columns == expected_columns + + +@pytest.mark.parametrize( + "table, expected_result", + [ + ("number_table", True), + ("test", True), + ("author", True), + ("empty_table", True), + ("numbers1", False), + ("test1", False), + ("author1", False), + ("empty_table1", False), + (None, False), + ], +) +def test_is_table_exists_ignore_error(ip, table, expected_result): + assert expected_result is inspect.is_table_exists(table, ignore_error=True) + + +@pytest.mark.parametrize( + "table, expected_error, error_type", + [ + ("number_table", False, "TableNotFoundError"), + ("test", False, "TableNotFoundError"), + ("author", False, "TableNotFoundError"), + ("empty_table", False, "TableNotFoundError"), + ("numbers1", True, "TableNotFoundError"), + ("test1", True, "TableNotFoundError"), + ("author1", True, "TableNotFoundError"), + ("empty_table1", True, "TableNotFoundError"), + (None, True, "UsageError"), + ], +) +def test_is_table_exists(ip, table, expected_error, error_type): + if expected_error: + with pytest.raises(UsageError) as excinfo: + inspect.is_table_exists(table) + + assert excinfo.value.error_type == error_type + else: + inspect.is_table_exists(table) + + +@pytest.mark.parametrize( + "table, expected_error, expected_suggestions", + [ + ("number_table", None, []), + ("number_tale", UsageError, ["number_table"]), + ("_table", UsageError, ["number_table", "empty_table"]), + (None, UsageError, []), + ], +) +def test_is_table_exists_with(ip, table, expected_error, expected_suggestions): + with_ = ["temp"] + + ip.run_cell( + f""" + %%sql --save {with_[0]} --no-execute + SELECT * + FROM {table} + WHERE x > 2 + """ + ) + if expected_error: + with pytest.raises(expected_error) as error: + inspect.is_table_exists(table) + + error_suggestions_arr = str(error.value).split(EXPECTED_SUGGESTIONS_MESSAGE) + + if len(expected_suggestions) > 0: + assert len(error_suggestions_arr) > 1 + for suggestion in expected_suggestions: + assert suggestion in error_suggestions_arr[1] + else: + assert len(error_suggestions_arr) == 1 + else: + inspect.is_table_exists(table) + + +def test_get_list_of_existing_tables(ip): + expected = ["author", "empty_table", "number_table", "test", "website"] + list_of_tables = inspect._get_list_of_existing_tables() + for table in expected: + assert table in list_of_tables + + +@pytest.mark.parametrize( + "table, query, suggestions", + [ + ("tes", "%sqlcmd columns --table {}", ["test"]), + ("_table", "%sqlcmd columns --table {}", ["empty_table", "number_table"]), + ("no_similar_tables", "%sqlcmd columns --table {}", []), + ("tes", "%sqlcmd profile --table {}", ["test"]), + ("_table", "%sqlcmd profile --table {}", ["empty_table", "number_table"]), + ("no_similar_tables", "%sqlcmd profile --table {}", []), + ("tes", "%sqlplot histogram --table {} --column x", ["test"]), + ("tes", "%sqlplot boxplot --table {} --column x", ["test"]), + ], +) +def test_bad_table_error_message(ip, table, query, suggestions): + query = query.format(table) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell(query) + + expected_error_message = EXPECTED_NO_TABLE_IN_DEFAULT_SCHEMA.format(table) + + error_message = str(excinfo.value) + assert str(expected_error_message).lower() in error_message.lower() + + error_suggestions_arr = error_message.split(EXPECTED_SUGGESTIONS_MESSAGE) + + if len(suggestions) > 0: + assert len(error_suggestions_arr) > 1 + for suggestion in suggestions: + assert suggestion in error_suggestions_arr[1] + + +@pytest.mark.parametrize( + "table, schema, query, suggestions", + [ + ( + "test_table", + "invalid_name_no_match", + "%sqlcmd columns --table {} --schema {}", + [], + ), + ( + "test_table", + "te_schema", + "%sqlcmd columns --table {} --schema {}", + ["test_schema"], + ), + ( + "invalid_name_no_match", + "test_schema", + "%sqlcmd columns --table {} --schema {}", + [], + ), + ( + "test_tabl", + "test_schema", + "%sqlcmd columns --table {} --schema {}", + ["test_table", "test"], + ), + ( + "invalid_name_no_match", + "invalid_name_no_match", + "%sqlcmd columns --table {} --schema {}", + [], + ), + ( + "_table", + "_schema", + "%sqlcmd columns --table {} --schema {}", + ["test_schema"], + ), + ], +) +def test_bad_table_error_message_with_schema(ip, query, suggestions, table, schema): + query = query.format(table, schema) + + expected_error_message = EXPECTED_NO_TABLE_IN_SCHEMA.format(table, schema) + + ip.run_cell( + """%%sql sqlite:///my.db +CREATE TABLE IF NOT EXISTS test_table (id INT) +""" + ) + + ip.run_cell( + """%%sql +ATTACH DATABASE 'my.db' AS test_schema +""" + ) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell(query) + + error_message = str(excinfo.value) + assert str(expected_error_message).lower() in error_message.lower() + + error_suggestions_arr = error_message.split(EXPECTED_SUGGESTIONS_MESSAGE) + + if len(suggestions) > 0: + assert len(error_suggestions_arr) > 1 + for suggestion in suggestions: + assert suggestion in error_suggestions_arr[1] diff --git a/src/tests/test_magic_cte.py b/src/tests/test_magic_cte.py index 0f8acc62d..6988b27b9 100644 --- a/src/tests/test_magic_cte.py +++ b/src/tests/test_magic_cte.py @@ -1,6 +1,6 @@ import pytest from IPython.core.error import UsageError -from sql.error_message import CTE_MSG +from sql.error_handler import CTE_MSG def test_trailing_semicolons_removed_from_cte(ip): @@ -62,8 +62,12 @@ def test_infer_dependencies(ip, capsys): TABLE_NAME_TYPO_ERR_MSG = """ There is no table with name 'author_subb'. -Did you mean : 'author_sub' -If you need help solving this issue, send us a message: https://ploomber.io/community +Did you mean: 'author_sub' + + +Original error message from DB driver: +(sqlite3.OperationalError) no such table: author_subb +[SQL: SELECT last_name FROM author_subb;] """ @@ -82,7 +86,7 @@ def test_table_name_typo(ip): ) assert excinfo.value.error_type == "TableNotFoundError" - assert str(excinfo.value) == TABLE_NAME_TYPO_ERR_MSG.strip() + assert TABLE_NAME_TYPO_ERR_MSG.strip() in str(excinfo.value) def test_snippets_delete(ip, capsys): diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index a61a38f2d..e6f9fecdb 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -577,7 +577,7 @@ def test_sqlplot_snippet_deletion(ip_snippets, arg): TABLE_NAME_TYPO_MSG = """ There is no table with name 'subst' in the default schema -Did you mean : 'subset' +Did you mean: 'subset' If you need help solving this issue, send us a message: https://ploomber.io/community """ diff --git a/src/tests/test_store.py b/src/tests/test_store.py index 5912f8778..fdd5f7e32 100644 --- a/src/tests/test_store.py +++ b/src/tests/test_store.py @@ -1,7 +1,8 @@ import pytest from sql.connection import SQLAlchemyConnection, ConnectionManager from IPython.core.error import UsageError -from sql.store import SQLStore, SQLQuery +from sql import store +from sql import store_utils from sqlalchemy import create_engine @@ -10,16 +11,44 @@ def setup_no_current_connect(monkeypatch): monkeypatch.setattr(ConnectionManager, "current", None) +@pytest.fixture +def ip_snippets(ip): + ip.run_cell( + """ +%%sql --save a --no-execute +SELECT * +FROM number_table +""" + ) + ip.run_cell( + """ + %%sql --save b --no-execute + SELECT * + FROM a + WHERE x > 5 + """ + ) + ip.run_cell( + """ + %%sql --save c --no-execute + SELECT * + FROM a + WHERE x < 5 + """ + ) + yield ip + + def test_sqlstore_setitem(): - store = SQLStore() - store["a"] = "SELECT * FROM a" - assert store["a"] == "SELECT * FROM a" + sql_store = store.SQLStore() + sql_store["a"] = "SELECT * FROM a" + assert sql_store["a"] == "SELECT * FROM a" def test_sqlstore_getitem_success(): - store = SQLStore() - store["first"] = "SELECT * FROM a" - assert store["first"] == "SELECT * FROM a" + sql_store = store.SQLStore() + sql_store["first"] = "SELECT * FROM a" + assert sql_store["first"] == "SELECT * FROM a" @pytest.mark.parametrize( @@ -43,23 +72,23 @@ def test_sqlstore_getitem_success(): ], ) def test_sqlstore_getitem(key, expected_error): - store = SQLStore() - store["first"] = "SELECT * FROM a" + sql_store = store.SQLStore() + sql_store["first"] = "SELECT * FROM a" with pytest.raises(UsageError) as excinfo: - store[key] + sql_store[key] assert excinfo.value.error_type == "UsageError" assert str(excinfo.value) == expected_error def test_sqlstore_getitem_with_multiple_existing_snippets(): - store = SQLStore() - store["first"] = "SELECT * FROM a" - store["first2"] = "SELECT * FROM a" + sql_store = store.SQLStore() + sql_store["first"] = "SELECT * FROM a" + sql_store["first2"] = "SELECT * FROM a" with pytest.raises(UsageError) as excinfo: - store["second"] + sql_store["second"] assert excinfo.value.error_type == "UsageError" assert ( @@ -70,19 +99,19 @@ def test_sqlstore_getitem_with_multiple_existing_snippets(): def test_hyphen(): - store = SQLStore() + sql_store = store.SQLStore() with pytest.raises(UsageError) as excinfo: - SQLQuery(store, "SELECT * FROM a", with_=["first-"]) + store.SQLQuery(sql_store, "SELECT * FROM a", with_=["first-"]) assert "Using hyphens is not allowed." in str(excinfo.value) def test_key(): - store = SQLStore() + sql_store = store.SQLStore() with pytest.raises(UsageError) as excinfo: - store.store("first", "SELECT * FROM first WHERE x > 20", with_=["first"]) + sql_store.store("first", "SELECT * FROM first WHERE x > 20", with_=["first"]) assert "cannot appear in with_ argument" in str(excinfo.value) @@ -128,13 +157,15 @@ def test_serial(with_, is_dialect_support_backtick, monkeypatch): ) identifier = "`" if is_dialect_support_backtick else "" - store = SQLStore() - store.store("first", "SELECT * FROM a WHERE x > 10") - store.store("second", "SELECT * FROM first WHERE x > 20", with_=["first"]) + sql_store = store.SQLStore() + sql_store.store("first", "SELECT * FROM a WHERE x > 10") + sql_store.store("second", "SELECT * FROM first WHERE x > 20", with_=["first"]) - store.store("third", "SELECT * FROM second WHERE x > 30", with_=["second", "first"]) + sql_store.store( + "third", "SELECT * FROM second WHERE x > 30", with_=["second", "first"] + ) - result = store.render("SELECT * FROM third", with_=with_) + result = sql_store.render("SELECT * FROM third", with_=with_) assert ( str(result) @@ -172,14 +203,16 @@ def test_branch_root(is_dialect_support_backtick, monkeypatch): ) identifier = "`" if is_dialect_support_backtick else "" - store = SQLStore() - store.store("first_a", "SELECT * FROM a WHERE x > 10") - store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) - store.store("third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"]) + sql_store = store.SQLStore() + sql_store.store("first_a", "SELECT * FROM a WHERE x > 10") + sql_store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) + sql_store.store( + "third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"] + ) - store.store("first_b", "SELECT * FROM b WHERE y > 10") + sql_store.store("first_b", "SELECT * FROM b WHERE y > 10") - result = store.render("SELECT * FROM third", with_=["third_a", "first_b"]) + result = sql_store.render("SELECT * FROM third", with_=["third_a", "first_b"]) assert ( str(result) == "WITH {0}first_a{0} AS (SELECT * FROM a WHERE x > 10), \ @@ -218,15 +251,17 @@ def test_branch_root_reverse_final_with(is_dialect_support_backtick, monkeypatch ) identifier = "`" if is_dialect_support_backtick else "" - store = SQLStore() + sql_store = store.SQLStore() - store.store("first_a", "SELECT * FROM a WHERE x > 10") - store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) - store.store("third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"]) + sql_store.store("first_a", "SELECT * FROM a WHERE x > 10") + sql_store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) + sql_store.store( + "third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"] + ) - store.store("first_b", "SELECT * FROM b WHERE y > 10") + sql_store.store("first_b", "SELECT * FROM b WHERE y > 10") - result = store.render("SELECT * FROM third", with_=["first_b", "third_a"]) + result = sql_store.render("SELECT * FROM third", with_=["first_b", "third_a"]) assert ( str(result) == "WITH {0}first_a{0} AS (SELECT * FROM a WHERE x > 10), \ @@ -263,15 +298,19 @@ def test_branch(is_dialect_support_backtick, monkeypatch): ) identifier = "`" if is_dialect_support_backtick else "" - store = SQLStore() + sql_store = store.SQLStore() - store.store("first_a", "SELECT * FROM a WHERE x > 10") - store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) - store.store("third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"]) + sql_store.store("first_a", "SELECT * FROM a WHERE x > 10") + sql_store.store("second_a", "SELECT * FROM first_a WHERE x > 20", with_=["first_a"]) + sql_store.store( + "third_a", "SELECT * FROM second_a WHERE x > 30", with_=["second_a"] + ) - store.store("first_b", "SELECT * FROM second_a WHERE y > 10", with_=["second_a"]) + sql_store.store( + "first_b", "SELECT * FROM second_a WHERE y > 10", with_=["second_a"] + ) - result = store.render("SELECT * FROM third", with_=["first_b", "third_a"]) + result = sql_store.render("SELECT * FROM third", with_=["first_b", "third_a"]) assert ( str(result) == "WITH {0}first_a{0} AS (SELECT * FROM a WHERE x > 10), \ @@ -281,3 +320,28 @@ def test_branch(is_dialect_support_backtick, monkeypatch): identifier ) ) + + +def test_get_all_keys(ip_snippets): + keys = store_utils.get_all_keys() + assert "a" in keys + assert "b" in keys + assert "c" in keys + + +def test_get_key_dependents(ip_snippets): + keys = store_utils.get_key_dependents("a") + assert "b" in keys + assert "c" in keys + + +def test_del_saved_key(ip_snippets): + keys = store_utils.del_saved_key("c") + assert "a" in keys + assert "b" in keys + + +def test_del_saved_key_error(ip_snippets): + with pytest.raises(UsageError) as excinfo: + store_utils.del_saved_key("non_existent_key") + assert "No such saved snippet found : non_existent_key" in str(excinfo.value) diff --git a/src/tests/test_syntax_errors.py b/src/tests/test_syntax_errors.py index 2dc4439c2..d4021543d 100644 --- a/src/tests/test_syntax_errors.py +++ b/src/tests/test_syntax_errors.py @@ -4,7 +4,7 @@ from sqlalchemy.exc import OperationalError from IPython.core.error import UsageError -from sql.error_message import ORIGINAL_ERROR, CTE_MSG +from sql.error_handler import ORIGINAL_ERROR, CTE_MSG from ploomber_core.exceptions import COMMUNITY diff --git a/src/tests/test_util.py b/src/tests/test_util.py index 5da1d1650..2b61a7cd1 100644 --- a/src/tests/test_util.py +++ b/src/tests/test_util.py @@ -2,47 +2,13 @@ import pytest from sql import util import json -from IPython.core.error import UsageError ERROR_MESSAGE = "Table cannot be None" -EXPECTED_SUGGESTIONS_MESSAGE = "Did you mean :" -EXPECTED_NO_TABLE_IN_DEFAULT_SCHEMA = ( - "There is no table with name {0!r} in the default schema" -) -EXPECTED_NO_TABLE_IN_SCHEMA = "There is no table with name {0!r} in schema {1!r}" EXPECTED_STORE_SUGGESTIONS = ( "but there is a stored query.\nDid you miss passing --with {0}?" ) -@pytest.fixture -def ip_snippets(ip): - ip.run_cell( - """ -%%sql --save a --no-execute -SELECT * -FROM number_table -""" - ) - ip.run_cell( - """ - %%sql --save b --no-execute - SELECT * - FROM a - WHERE x > 5 - """ - ) - ip.run_cell( - """ - %%sql --save c --no-execute - SELECT * - FROM a - WHERE x < 5 - """ - ) - yield ip - - @pytest.mark.parametrize( "store_table, query", [ @@ -91,195 +57,6 @@ def test_no_errors_with_stored_query(ip_empty, store_table, query): assert out.success -@pytest.mark.parametrize( - "table, query, suggestions", - [ - ("tes", "%sqlcmd columns --table {}", ["test"]), - ("_table", "%sqlcmd columns --table {}", ["empty_table", "number_table"]), - ("no_similar_tables", "%sqlcmd columns --table {}", []), - ("tes", "%sqlcmd profile --table {}", ["test"]), - ("_table", "%sqlcmd profile --table {}", ["empty_table", "number_table"]), - ("no_similar_tables", "%sqlcmd profile --table {}", []), - ("tes", "%sqlplot histogram --table {} --column x", ["test"]), - ("tes", "%sqlplot boxplot --table {} --column x", ["test"]), - ], -) -def test_bad_table_error_message(ip, table, query, suggestions): - query = query.format(table) - - with pytest.raises(UsageError) as excinfo: - ip.run_cell(query) - - expected_error_message = EXPECTED_NO_TABLE_IN_DEFAULT_SCHEMA.format(table) - - error_message = str(excinfo.value) - assert str(expected_error_message).lower() in error_message.lower() - - error_suggestions_arr = error_message.split(EXPECTED_SUGGESTIONS_MESSAGE) - - if len(suggestions) > 0: - assert len(error_suggestions_arr) > 1 - for suggestion in suggestions: - assert suggestion in error_suggestions_arr[1] - - -@pytest.mark.parametrize( - "table, schema, query, suggestions", - [ - ( - "test_table", - "invalid_name_no_match", - "%sqlcmd columns --table {} --schema {}", - [], - ), - ( - "test_table", - "te_schema", - "%sqlcmd columns --table {} --schema {}", - ["test_schema"], - ), - ( - "invalid_name_no_match", - "test_schema", - "%sqlcmd columns --table {} --schema {}", - [], - ), - ( - "test_tabl", - "test_schema", - "%sqlcmd columns --table {} --schema {}", - ["test_table", "test"], - ), - ( - "invalid_name_no_match", - "invalid_name_no_match", - "%sqlcmd columns --table {} --schema {}", - [], - ), - ( - "_table", - "_schema", - "%sqlcmd columns --table {} --schema {}", - ["test_schema"], - ), - ], -) -def test_bad_table_error_message_with_schema(ip, query, suggestions, table, schema): - query = query.format(table, schema) - - expected_error_message = EXPECTED_NO_TABLE_IN_SCHEMA.format(table, schema) - - ip.run_cell( - """%%sql sqlite:///my.db -CREATE TABLE IF NOT EXISTS test_table (id INT) -""" - ) - - ip.run_cell( - """%%sql -ATTACH DATABASE 'my.db' AS test_schema -""" - ) - - with pytest.raises(UsageError) as excinfo: - ip.run_cell(query) - - error_message = str(excinfo.value) - assert str(expected_error_message).lower() in error_message.lower() - - error_suggestions_arr = error_message.split(EXPECTED_SUGGESTIONS_MESSAGE) - - if len(suggestions) > 0: - assert len(error_suggestions_arr) > 1 - for suggestion in suggestions: - assert suggestion in error_suggestions_arr[1] - - -@pytest.mark.parametrize( - "table, expected_result", - [ - ("number_table", True), - ("test", True), - ("author", True), - ("empty_table", True), - ("numbers1", False), - ("test1", False), - ("author1", False), - ("empty_table1", False), - (None, False), - ], -) -def test_is_table_exists_ignore_error(ip, table, expected_result): - assert expected_result is util.is_table_exists(table, ignore_error=True) - - -@pytest.mark.parametrize( - "table, expected_error, error_type", - [ - ("number_table", False, "TableNotFoundError"), - ("test", False, "TableNotFoundError"), - ("author", False, "TableNotFoundError"), - ("empty_table", False, "TableNotFoundError"), - ("numbers1", True, "TableNotFoundError"), - ("test1", True, "TableNotFoundError"), - ("author1", True, "TableNotFoundError"), - ("empty_table1", True, "TableNotFoundError"), - (None, True, "UsageError"), - ], -) -def test_is_table_exists(ip, table, expected_error, error_type): - if expected_error: - with pytest.raises(UsageError) as excinfo: - util.is_table_exists(table) - - assert excinfo.value.error_type == error_type - else: - util.is_table_exists(table) - - -@pytest.mark.parametrize( - "table, expected_error, expected_suggestions", - [ - ("number_table", None, []), - ("number_tale", UsageError, ["number_table"]), - ("_table", UsageError, ["number_table", "empty_table"]), - (None, UsageError, []), - ], -) -def test_is_table_exists_with(ip, table, expected_error, expected_suggestions): - with_ = ["temp"] - - ip.run_cell( - f""" - %%sql --save {with_[0]} --no-execute - SELECT * - FROM {table} - WHERE x > 2 - """ - ) - if expected_error: - with pytest.raises(expected_error) as error: - util.is_table_exists(table) - - error_suggestions_arr = str(error.value).split(EXPECTED_SUGGESTIONS_MESSAGE) - - if len(expected_suggestions) > 0: - assert len(error_suggestions_arr) > 1 - for suggestion in expected_suggestions: - assert suggestion in error_suggestions_arr[1] - else: - assert len(error_suggestions_arr) == 1 - else: - util.is_table_exists(table) - - -def test_get_list_of_existing_tables(ip): - expected = ["author", "empty_table", "number_table", "test", "website"] - list_of_tables = util._get_list_of_existing_tables() - for table in expected: - assert table in list_of_tables - - @pytest.mark.parametrize( "src, ltypes, expected", [ @@ -306,118 +83,6 @@ def test_flatten(src, ltypes, expected): assert util.flatten(src) == expected -@pytest.mark.parametrize( - "table, offset, n_rows, expected_rows, expected_columns", - [ - ("number_table", 0, 0, [], ["x", "y"]), - ("number_table", 5, 0, [], ["x", "y"]), - ("number_table", 50, 0, [], ["x", "y"]), - ("number_table", 50, 10, [], ["x", "y"]), - ( - "number_table", - 2, - 10, - [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3), (-4, 2), (2, -5), (4, 3)], - ["x", "y"], - ), - ( - "number_table", - 2, - 100, - [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3), (-4, 2), (2, -5), (4, 3)], - ["x", "y"], - ), - ("number_table", 0, 2, [(4, -2), (-5, 0)], ["x", "y"]), - ("number_table", 2, 2, [(2, 4), (0, 2)], ["x", "y"]), - ( - "number_table", - 2, - 5, - [(2, 4), (0, 2), (-5, -1), (-2, -3), (-2, -3)], - ["x", "y"], - ), - ("empty_table", 2, 5, [], ["column", "another"]), - ], -) -def test_fetch_sql_with_pagination_no_sort( - ip, table, offset, n_rows, expected_rows, expected_columns -): - rows, columns = util.fetch_sql_with_pagination(table, offset, n_rows) - - assert rows == expected_rows - assert columns == expected_columns - - -@pytest.mark.parametrize( - "table, offset, n_rows, sort_by, order_by, expected_rows, expected_columns", - [ - ("number_table", 0, 0, "x", "DESC", [], ["x", "y"]), - ("number_table", 5, 0, "x", "DESC", [], ["x", "y"]), - ("number_table", 50, 0, "y", "ASC", [], ["x", "y"]), - ("number_table", 50, 10, "y", "ASC", [], ["x", "y"]), - ("number_table", 0, 2, "x", "DESC", [(4, -2), (4, 3)], ["x", "y"]), - ("number_table", 0, 2, "x", "ASC", [(-5, 0), (-5, -1)], ["x", "y"]), - ("empty_table", 2, 5, "column", "ASC", [], ["column", "another"]), - ("number_table", 2, 2, "x", "ASC", [(-4, 2), (-2, -3)], ["x", "y"]), - ("number_table", 2, 2, "x", "DESC", [(2, 4), (2, -5)], ["x", "y"]), - ( - "number_table", - 2, - 10, - "x", - "DESC", - [(2, 4), (2, -5), (0, 2), (-2, -3), (-2, -3), (-4, 2), (-5, 0), (-5, -1)], - ["x", "y"], - ), - ( - "number_table", - 2, - 100, - "x", - "DESC", - [(2, 4), (2, -5), (0, 2), (-2, -3), (-2, -3), (-4, 2), (-5, 0), (-5, -1)], - ["x", "y"], - ), - ( - "number_table", - 2, - 5, - "y", - "ASC", - [(-2, -3), (4, -2), (-5, -1), (-5, 0), (0, 2)], - ["x", "y"], - ), - ], -) -def test_fetch_sql_with_pagination_with_sort( - ip, table, offset, n_rows, sort_by, order_by, expected_rows, expected_columns -): - rows, columns = util.fetch_sql_with_pagination( - table, offset, n_rows, sort_by, order_by - ) - - assert rows == expected_rows - assert columns == expected_columns - - -@pytest.mark.parametrize( - "table", - ["no_such_table", ""], -) -def test_fetch_sql_with_pagination_no_table_error(ip, table): - with pytest.raises(UsageError) as excinfo: - util.fetch_sql_with_pagination(table, 0, 2) - - assert excinfo.value.error_type == "TableNotFoundError" - - -def test_fetch_sql_with_pagination_none_table(ip): - with pytest.raises(UsageError) as excinfo: - util.fetch_sql_with_pagination(None, 0, 2) - - assert excinfo.value.error_type == "UsageError" - - date_format = "%Y-%m-%d %H:%M:%S" @@ -468,26 +133,13 @@ def test_parse_sql_results_to_json(ip, capsys, rows, columns, expected_json): assert str(j) == str(expected_json) -def test_get_all_keys(ip_snippets): - keys = util.get_all_keys() - assert "a" in keys - assert "b" in keys - assert "c" in keys - - -def test_get_key_dependents(ip_snippets): - keys = util.get_key_dependents("a") - assert "b" in keys - assert "c" in keys - - -def test_del_saved_key(ip_snippets): - keys = util.del_saved_key("c") - assert "a" in keys - assert "b" in keys - - -def test_del_saved_key_error(ip_snippets): - with pytest.raises(UsageError) as excinfo: - util.del_saved_key("non_existent_key") - assert "No such saved snippet found : non_existent_key" in str(excinfo.value) +@pytest.mark.parametrize( + "string, substrings, expected", + [ + ["some-string", ["some", "another"], True], + ["some-string", ["another", "word"], False], + ], +) +def test_is_sqlalchemy_error(string, substrings, expected): + result = util.if_substring_exists(string, substrings) + assert result == expected