diff --git a/CHANGELOG.md b/CHANGELOG.md index d3f308def..77edcc255 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## 0.7.8dev * [Feature] Add `%sqlplot bar` and `%sqlplot pie` +* [Feature] Automated dependency inference when creating CTEs. `--with` is now deprecated and will display a warning. (#166) + ## 0.7.7 (2023-05-31) diff --git a/doc/_toc.yml b/doc/_toc.yml index f20d771ab..06c920e34 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -42,6 +42,7 @@ parts: - file: api/magic-sql - file: api/magic-plot - file: api/magic-render + - file: api/magic-snippets - file: api/configuration - file: api/python - file: api/magic-tables-columns diff --git a/doc/api/magic-snippets.md b/doc/api/magic-snippets.md new file mode 100644 index 000000000..d84d6c9ef --- /dev/null +++ b/doc/api/magic-snippets.md @@ -0,0 +1,135 @@ +--- +jupytext: + notebook_metadata_filter: myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.5 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +myst: + html_meta: + description lang=en: Documentation for %sqlcmd snippets + from JupySQL + keywords: jupyter, sql, jupysql, snippets + property=og:locale: en_US +--- + +# `%sqlcmd snippets` + +`%sqlcmd snippets` returns the query snippets saved using `--save` + +## Load Data + +```{code-cell} ipython3 +%load_ext sql +%sql duckdb:// +``` + +```{code-cell} ipython3 +from pathlib import Path +from urllib.request import urlretrieve + +if not Path("penguins.csv").is_file(): + urlretrieve( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv", + "penguins.csv", + ) +``` + +```{code-cell} ipython3 +%%sql +SELECT * FROM penguins.csv LIMIT 3 +``` + +Let's save a couple of snippets. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save gentoo +SELECT * FROM penguins.csv where species == 'Gentoo' +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap +SELECT * FROM penguins.csv where species == 'Chinstrap' +``` + +## `%sqlcmd snippets` + ++++ + +Returns all the snippets saved in the environment + +```{code-cell} ipython3 +%sqlcmd snippets +``` + +Arguments: + +`-d`/`--delete` Delete a snippet. + +`-D`/`--delete-force` Force delete a snippet. This may be useful if there are other dependent snippets, and you still need to delete this snippet. + +`-A`/`--delete-force-all` Force delete a snippet and all dependent snippets. + +```{code-cell} ipython3 + +%sqlcmd snippets -d gentoo +``` + +This deletes the stored snippet `gentoo`. + +To demonstrate `force-delete` let's create a snippet dependent on `chinstrap` snippet. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap_sub +SELECT * FROM chinstrap where island == 'Dream' +``` ++++ + +Trying to delete the `chinstrap` snippet will display an error message: + +```{code-cell} ipython3 +:tags: [raises-exception] + +%sqlcmd snippets -d chinstrap +``` + +If you still wish to delete this snippet, you can run the below command: + +```{code-cell} ipython3 + +%sqlcmd snippets -D chinstrap +``` + +Now, let's see how to delete a snippet and all other dependent snippets. We'll create a few snippets again. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap +SELECT * FROM penguins.csv where species == 'Chinstrap' +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap_sub +SELECT * FROM chinstrap where island == 'Dream' +``` + +Now, force delete `chinstrap` and its dependent `chinstrap_sub`: + +```{code-cell} ipython3 + +%sqlcmd snippets -A chinstrap +``` diff --git a/doc/compose.md b/doc/compose.md index 199bab303..f8016f009 100644 --- a/doc/compose.md +++ b/doc/compose.md @@ -27,7 +27,8 @@ pip install jupysql matplotlib ``` -*New in version 0.4.3* +```{versionchanged} 0.7.8 +``` ```{note} This is a beta feature, please [join our community](https://ploomber.io/community) and @@ -105,31 +106,39 @@ OR Name LIKE '%metal%' Join the filtered genres and tracks, so we only get Rock and Metal tracks, and save the query as `track_fav` -Note that we are using `--with`; this will retrieve previously saved queries, and prepend them (using CTEs), then, we save the query in `track_fav` . + +We automatically extract the tables from the query and infer the dependencies from all the saved snippets. + ```{code-cell} ipython3 -%%sql --with genres_fav --with tracks_with_info --save track_fav +%%sql --save track_fav SELECT t.* FROM tracks_with_info t JOIN genres_fav ON t.GenreId = genres_fav.GenreId ``` -Use the `track_fav` query to find artists with the most Rock and Metal tracks, and save the query as `top_artist` +Now let's find artists with the most Rock and Metal tracks, and save the query as `top_artist` ```{code-cell} ipython3 -%%sql --with track_fav --save top_artist +%%sql --save top_artist SELECT artist, COUNT(*) FROM track_fav GROUP BY artist ORDER BY COUNT(*) DESC ``` + +```{note} +A saved snippet will override an existing table with the same name during query formation. If you wish to delete a snippet please refer to [sqlcmd snippets API](api/magic-snippets.md). + +``` + #### Data visualization Once we have the desired results from the query `top_artist`, we can generate a visualization using the bar method ```{code-cell} ipython3 -top_artist = %sql --with top_artist SELECT * FROM top_artist +top_artist = %sql SELECT * FROM top_artist top_artist.bar() ``` diff --git a/doc/plot.md b/doc/plot.md index 98c51c499..222754e81 100644 --- a/doc/plot.md +++ b/doc/plot.md @@ -121,10 +121,10 @@ FROM "yellow_tripdata_2021-01.parquet" WHERE trip_distance < 6.3 ``` -Now, let's plot again, but this time let's pass `--table short_trips`. Note that this table *doesn't exist*; however, since we're passing the `--with` argument, JupySQL will use the query we defined above: +Now, let's plot again, but this time let's pass `--table short_trips`. Note that this table *doesn't exist*; JupySQL will automatically infer and use the saved snippet defined above. ```{code-cell} ipython3 -%sqlplot boxplot --table short_trips --column trip_distance --with short_trips +%sqlplot boxplot --table short_trips --column trip_distance ``` We can see the highest value is a bit over 6, that's expected since we set a 6.3 cutoff value. @@ -133,10 +133,10 @@ We can see the highest value is a bit over 6, that's expected since we set a 6.3 ## Histogram -To create a histogram, call `%sqlplot histogram`, and pass the name of the table, the column you want to plot, and the number of bins. Similarly to what we did in the [Boxplot](#boxplot) example, we're using `--with short_trips` so JupySQL uses the query we defined and only plots such data subset. +To create a histogram, call `%sqlplot histogram`, and pass the name of the table, the column you want to plot, and the number of bins. Similarly to what we did in the [Boxplot](#boxplot) example, JupySQL detects a saved snippet and only plots such data subset. ```{code-cell} ipython3 -%sqlplot histogram --table short_trips --column trip_distance --bins 10 --with short_trips +%sqlplot histogram --table short_trips --column trip_distance --bins 10 ``` ## Customize plot @@ -144,7 +144,7 @@ To create a histogram, call `%sqlplot histogram`, and pass the name of the table `%sqlplot` returns a `matplotlib.Axes` object that you can further customize: ```{code-cell} ipython3 -ax = %sqlplot histogram --table short_trips --column trip_distance --bins 50 --with short_trips +ax = %sqlplot histogram --table short_trips --column trip_distance --bins 50 ax.grid() ax.set_title("Trip distance from trips < 6.3") _ = ax.set_xlabel("Trip distance") diff --git a/src/sql/command.py b/src/sql/command.py index 2f4ad7093..b73355519 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -71,10 +71,6 @@ def __init__(self, magic, user_ns, line, cell) -> None: if add_alias: self.parsed["connection"] = self.args.line[0] - if self.args.with_: - final = store.render(self.parsed["sql"], with_=self.args.with_) - self.parsed["sql"] = str(final) - @property def sql(self): """ @@ -112,3 +108,15 @@ def __repr__(self) -> str: f"{type(self).__name__}(line={self._line!r}, cell={self._cell!r}) -> " f"({self.sql!r}, {self.sql_original!r})" ) + + def set_sql_with(self, with_): + """ + Sets the final rendered SQL query using the WITH clause + + Parameters + ---------- + with_ : list + list of all subqueries needed to render the query + """ + final = store.render(self.parsed["sql"], with_) + self.parsed["sql"] = str(final) diff --git a/src/sql/error_message.py b/src/sql/error_message.py index b07c2cb79..640dc94bf 100644 --- a/src/sql/error_message.py +++ b/src/sql/error_message.py @@ -5,6 +5,33 @@ ORIGINAL_ERROR = "\nOriginal error message from DB driver:\n" +def parse_sqlglot_error(e, q): + """ + Function to parse the error message from sqlglot + + Parameters + ---------- + e: sqlglot.errors.ParseError, exception + while parsing through sqlglot + q : str, user query + + Returns + ------- + str + Formatted error message containing description + and positions + """ + err = e.errors + position = "" + for item in err: + position += ( + f"Syntax Error in {q}: {item['description']} at " + f"Line {item['line']}, Column {item['col']}\n" + ) + msg = "Possible reason: \n" + position if position else "" + return msg + + def detail(original_error, query=None): original_error = str(original_error) return_msg = SYNTAX_ERROR @@ -25,18 +52,8 @@ def detail(original_error, query=None): ) except sqlglot.errors.ParseError as e: - err = e.errors - position = "" - for item in err: - position += ( - f"Syntax Error in {q}: {item['description']} at " - f"Line {item['line']}, Column {item['col']}\n" - ) - return_msg = ( - return_msg + "Possible reason: \n" + position - if position - else return_msg - ) + parse_msg = parse_sqlglot_error(e, q) + return_msg = return_msg + parse_msg if parse_msg else return_msg return return_msg + "\n" + ORIGINAL_ERROR + original_error + "\n" diff --git a/src/sql/magic.py b/src/sql/magic.py index 24e8277fa..692b1f686 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -18,6 +18,7 @@ from sqlalchemy.exc import OperationalError, ProgrammingError, DatabaseError import warnings +from difflib import get_close_matches import sql.connection import sql.parse import sql.run @@ -27,12 +28,15 @@ from sql.magic_plot import SqlPlotMagic from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error +from sql import query_util +from sql.util import get_suggestions_message, show_deprecation_warning from ploomber_core.dependencies import check_installed from sql.error_message import detail from traitlets.config.configurable import Configurable from traitlets import Bool, Int, TraitError, Unicode, Dict, observe, validate + try: from pandas.core.frame import DataFrame, Series except ModuleNotFoundError: @@ -323,6 +327,13 @@ def interactive_execute_wrapper(**kwargs): args = command.args + with_ = self._store.infer_dependencies(command.sql_original, args.save) + if with_: + command.set_sql_with(with_) + print(f"Generating CTE with stored snippets : {', '.join(with_)}") + else: + with_ = None + # Create the interactive slider if args.interact and not is_interactive_mode: check_installed(["ipywidgets"], "--interactive argument") @@ -410,6 +421,8 @@ def interactive_execute_wrapper(**kwargs): if not command.sql: return + if args.with_: + show_deprecation_warning() # store the query if needed if args.save: if "-" in args.save: @@ -420,7 +433,7 @@ def interactive_execute_wrapper(**kwargs): + " instead for the save argument.", FutureWarning, ) - self._store.store(args.save, command.sql_original, with_=args.with_) + self._store.store(args.save, command.sql_original, with_=with_) if args.no_execute: display.message("Skipping execution...") @@ -469,6 +482,19 @@ def interactive_execute_wrapper(**kwargs): if detailed_msg is not None: err = exceptions.UsageError(detailed_msg) raise err + # TODO: move to error_messages.py + # Added here due to circular dependency issue (#545) + elif "no such table" in str(e): + tables = query_util.extract_tables_from_query(command.sql) + for table in tables: + suggestions = get_close_matches(table, list(self._store)) + if len(suggestions) > 0: + err_message = f"There is no table with name {table!r}." + suggestions_message = get_suggestions_message(suggestions) + raise exceptions.TableNotFoundError( + f"{err_message}{suggestions_message}" + ) + print(e) else: print(e) else: diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index f477ea675..56340d840 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -34,6 +34,10 @@ def error(self, message): raise exceptions.UsageError(message) +# Added here due to circular dependencies (#545) +from sql.sqlcmd import sqlcmd_snippets # noqa + + @magics_class class SqlCmdMagic(Magics, Configurable): """%sqlcmd magic""" @@ -50,13 +54,22 @@ def _validate_execute_inputs(self, line): # We rely on SQLAlchemy when inspecting tables util.support_only_sql_alchemy_connection("%sqlcmd") - AVAILABLE_SQLCMD_COMMANDS = ["tables", "columns", "test", "profile", "explore"] + AVAILABLE_SQLCMD_COMMANDS = [ + "tables", + "columns", + "test", + "profile", + "explore", + "snippets", + ] + + VALID_COMMANDS_MSG = ( + f"Missing argument for %sqlcmd. " + f"Valid commands are: {', '.join(AVAILABLE_SQLCMD_COMMANDS)}" + ) if line == "": - raise exceptions.UsageError( - "Missing argument for %sqlcmd. " - "Valid commands are: {}".format(", ".join(AVAILABLE_SQLCMD_COMMANDS)) - ) + raise exceptions.UsageError(VALID_COMMANDS_MSG) else: split = arg_split(line) command, others = split[0].strip(), split[1:] @@ -223,6 +236,9 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): table_widget = TableWidget(args.table) display(table_widget) + elif cmd_name == "snippets": + return sqlcmd_snippets(others) + def return_test_results(args, conn, query): try: diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index c2c43e774..f9554099e 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -83,43 +83,43 @@ def execute(self, line="", cell="", local_ns=None): column = util.sanitize_identifier(column) table = util.sanitize_identifier(cmd.args.table) + if cmd.args.with_: + util.show_deprecation_warning() if cmd.args.line[0] in {"box", "boxplot"}: - util.is_table_exists(table, with_=cmd.args.with_) + with_ = self._check_table_exists(table) + return plot.boxplot( table=table, column=column, - with_=cmd.args.with_, + with_=with_, orient=cmd.args.orient, conn=None, ) elif cmd.args.line[0] in {"hist", "histogram"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.histogram( table=table, column=column, bins=cmd.args.bins, - with_=cmd.args.with_, + with_=with_, conn=None, ) elif cmd.args.line[0] in {"bar"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.bar( table=table, column=column, - with_=cmd.args.with_, + with_=with_, orient=cmd.args.orient, show_num=cmd.args.show_numbers, conn=None, ) elif cmd.args.line[0] in {"pie"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.pie( table=table, column=column, - with_=cmd.args.with_, + with_=with_, show_num=cmd.args.show_numbers, conn=None, ) @@ -128,3 +128,12 @@ def execute(self, line="", cell="", local_ns=None): raise exceptions.UsageError( f"Unknown plot {cmd.args.line[0]!r}. Must be any of: " f"{plot_str}" ) + + @staticmethod + def _check_table_exists(table): + with_ = None + if util.is_saved_snippet(table): + with_ = [table] + else: + util.is_table_exists(table) + return with_ diff --git a/src/sql/query_util.py b/src/sql/query_util.py new file mode 100644 index 000000000..fb342b812 --- /dev/null +++ b/src/sql/query_util.py @@ -0,0 +1,29 @@ +from sqlglot import parse_one, exp +from sqlglot.errors import ParseError + + +def extract_tables_from_query(query): + """ + Function to extract names of tables from + a syntactically correct query + + Parameters + ---------- + query : str, user query + + Returns + ------- + list + List of tables in the query + [] if error in parsing the query + """ + try: + tables = [table.name for table in parse_one(query).find_all(exp.Table)] + return tables + except ParseError: + # TODO : Instead of returning [] replace with call to + # error_messages.py::parse_sqlglot_error. Currently this + # is not possible because of an exception raised in test + # fixtures. (#546). This function can also be moved to util.py + # after #545 is resolved. + return [] diff --git a/src/sql/sqlcmd.py b/src/sql/sqlcmd.py new file mode 100644 index 000000000..2d6bd1b56 --- /dev/null +++ b/src/sql/sqlcmd.py @@ -0,0 +1,90 @@ +from sql.magic_cmd import CmdParser +from sql import util +from sql.exceptions import UsageError + + +def _modify_display_msg(key, remaining_keys, dependent_keys=None): + """ + + Parameters + ---------- + key : str, + deleted stored snippet + remaining_keys: list + snippets remaining after key is deleted + dependent_keys: list + snippets dependent on key + + Returns + ------- + msg: str + Formatted message + """ + msg = f"{key} has been deleted.\n" + if dependent_keys: + msg = f"{msg}{', '.join(dependent_keys)} depend on {key}\n" + if remaining_keys: + msg = f"{msg}Stored snippets : {', '.join(remaining_keys)}" + else: + msg = f"{msg}There are no stored snippets" + return msg + + +def sqlcmd_snippets(others): + """ + + Parameters + ---------- + This function handles all the arguments related to %sqlcmd snippets, namely + listing stored snippets, and delete/ force delete/ force delete a snippet and + all its dependent snippets. + + """ + parser = CmdParser() + parser.add_argument( + "-d", "--delete", type=str, help="Delete stored snippet", required=False + ) + parser.add_argument( + "-D", + "--delete-force", + type=str, + help="Force delete stored snippet", + required=False, + ) + parser.add_argument( + "-A", + "--delete-force-all", + type=str, + help="Force delete all stored snippets", + required=False, + ) + args = parser.parse_args(others) + SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all] + if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS): + return ", ".join(util.get_all_keys()) + if args.delete: + deps = util.get_key_dependents(args.delete) + if deps: + deps = ", ".join(deps) + raise UsageError( + f"The following tables are dependent on {args.delete}: {deps}.\n" + f"Pass --delete-force to only delete {args.delete}.\n" + f"Pass --delete-force-all to delete {deps} and {args.delete}" + ) + else: + key = args.delete + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(key, remaining_keys) + + elif args.delete_force: + key = args.delete_force + deps = util.get_key_dependents(key) + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(key, remaining_keys, deps) + + elif args.delete_force_all: + deps = util.get_key_dependents(args.delete_force_all) + deps.append(args.delete_force_all) + for key in deps: + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(", ".join(deps), remaining_keys) diff --git a/src/sql/store.py b/src/sql/store.py index f0aa7303f..ffad19e6b 100644 --- a/src/sql/store.py +++ b/src/sql/store.py @@ -6,6 +6,7 @@ import difflib from sql import exceptions +from sql import query_util class SQLStore(MutableMapping): @@ -68,6 +69,18 @@ def render(self, query, with_=None): # TODO: if with is false, WITH should not appear return SQLQuery(self, query, with_) + def infer_dependencies(self, query, key): + dependencies = [] + saved_keys = [ + saved_key for saved_key in list(self._data.keys()) if saved_key != key + ] + if saved_keys and query: + tables = query_util.extract_tables_from_query(query) + for table in tables: + if table in saved_keys: + dependencies.append(table) + return dependencies + @modify_exceptions def store(self, key, query, with_=None): if "-" in key: @@ -138,6 +151,15 @@ def _get_dependencies(store, keys): return list(dict.fromkeys(deps + keys)) +def _get_dependents_for_key(store, key): + key_dependents = [] + for k in list(store): + deps = _get_dependencies_for_key(store, k) + if key in deps: + key_dependents.append(k) + return key_dependents + + def _get_dependencies_for_key(store, key): """Retrieve dependencies for a single key""" deps = store[key]._with_ diff --git a/src/sql/util.py b/src/sql/util.py index 65eac2f8a..19473ace6 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -1,8 +1,9 @@ +import warnings import sql from sql import inspect import difflib from sql.connection import Connection -from sql.store import store +from sql.store import store, _get_dependents_for_key from sql import exceptions import json @@ -52,11 +53,18 @@ def _is_long_number(num) -> bool: return False +def get_suggestions_message(suggestions): + suggestions_message = "" + if len(suggestions) > 0: + _suggestions_string = pretty_print(suggestions, last_delimiter="or") + suggestions_message = f"\nDid you mean : {_suggestions_string}" + return suggestions_message + + def is_table_exists( table: str, schema: str = None, ignore_error: bool = False, - with_: str = None, conn=None, ) -> bool: """ @@ -70,9 +78,6 @@ def is_table_exists( schema: str, default None Schema name - with_: list, default None - Temporary table - ignore_error: bool, default False Avoid raising a ValueError """ @@ -93,7 +98,7 @@ def is_table_exists( else: table_ = table - _is_exist = _is_table_exists(table_, with_, conn) + _is_exist = _is_table_exists(table_, conn) if not _is_exist: if not ignore_error: @@ -124,22 +129,12 @@ def is_table_exists( f"There is no table with name {table!r} in the default schema" ) - if table in list(store): - # Suggest user use --with when given table - # is in the store - suggestion_message = ( - ", but there is a stored query." - f"\nDid you miss passing --with {table}?" - ) - err_message = f"{err_message}{suggestion_message}" - else: - suggestions = difflib.get_close_matches(invalid_input, expected) - - if len(suggestions) > 0: - _suggestions_string = pretty_print(suggestions, last_delimiter="or") - suggestions_message = f"\nDid you mean : {_suggestions_string}" - err_message = f"{err_message}{suggestions_message}" - + suggestions = difflib.get_close_matches(invalid_input, expected) + suggestions_store = difflib.get_close_matches(invalid_input, list(store)) + suggestions.extend(suggestions_store) + suggestions_message = get_suggestions_message(suggestions) + if suggestions_message: + err_message = f"{err_message}{suggestions_message}" raise exceptions.TableNotFoundError(err_message) return _is_exist @@ -182,7 +177,14 @@ def strip_multiple_chars(string: str, chars: str) -> str: return string.translate(str.maketrans("", "", chars)) -def _is_table_exists(table: str, with_: str, conn) -> bool: +def is_saved_snippet(table: str) -> bool: + if table in list(store): + print(f"Plotting using saved snippet : {table}") + return True + return False + + +def _is_table_exists(table: str, conn) -> bool: """ Runs a SQL query to check if table exists """ @@ -191,21 +193,16 @@ def _is_table_exists(table: str, with_: str, conn) -> bool: identifiers = conn.get_curr_identifiers() - if with_: - return table in list(store) - else: - for iden in identifiers: - if isinstance(iden, tuple): - query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format( - iden[0], table, iden[1] - ) - else: - query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) - try: - conn.execute(query) - return True - except Exception: - pass + for iden in identifiers: + if isinstance(iden, tuple): + query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format(iden[0], table, iden[1]) + else: + query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) + try: + conn.execute(query) + return True + except Exception: + pass return False @@ -305,3 +302,61 @@ def parse_sql_results_to_json(rows, columns) -> str: ) return rows_json + + +def get_all_keys(): + """ + + Returns + ------- + All stored snippets in the current session + """ + return list(store) + + +def get_key_dependents(key: str) -> list: + """ + Function to find the stored snippets dependent on key + Parameters + ---------- + key : str, name of the table + + Returns + ------- + list + List of snippets dependent on key + + """ + deps = _get_dependents_for_key(store, key) + return deps + + +def del_saved_key(key: str) -> str: + """ + Deletes a stored snippet + Parameters + ---------- + key : str, name of the snippet to be deleted + + Returns + ------- + list + Remaining stored snippets + """ + all_keys = get_all_keys() + if key not in all_keys: + raise exceptions.UsageError(f"No such saved snippet found : {key}") + del store[key] + return get_all_keys() + + +def show_deprecation_warning(): + """ + Raises CTE deprecation warning + """ + warnings.warn( + "CTE dependencies are now automatically inferred, " + "you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please remove it.", + FutureWarning, + ) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 63ce88830..1802bea99 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -95,10 +95,10 @@ def ip(ip_empty): ], ) yield ip_empty - runsql(ip_empty, "DROP TABLE test") - runsql(ip_empty, "DROP TABLE author") - runsql(ip_empty, "DROP TABLE website") - runsql(ip_empty, "DROP TABLE number_table") + runsql(ip_empty, "DROP TABLE IF EXISTS test") + runsql(ip_empty, "DROP TABLE IF EXISTS author") + runsql(ip_empty, "DROP TABLE IF EXISTS website") + runsql(ip_empty, "DROP TABLE IF EXISTS number_table") @pytest.fixture diff --git a/src/tests/test_command.py b/src/tests/test_command.py index 7b73b14d5..89a085c26 100644 --- a/src/tests/test_command.py +++ b/src/tests/test_command.py @@ -95,37 +95,6 @@ def test_parsed( assert cmd.sql_original == parsed_sql -def test_parsed_sql_when_using_with(ip, sql_magic): - ip.run_cell_magic( - "sql", - "--save author_one", - """ - SELECT * FROM author LIMIT 1 - """, - ) - - cmd = SQLCommand( - sql_magic, ip.user_ns, line="--with author_one", cell="SELECT * FROM author_one" - ) - - sql = "WITH `author_one` AS (\n\n SELECT * FROM author LIMIT 1\n )\n\ -SELECT * FROM author_one" - - sql_original = "\nSELECT * FROM author_one" - - assert cmd.parsed == { - "connection": "", - "result_var": None, - "return_result_var": False, - "sql": sql, - "sql_original": sql_original, - } - - assert cmd.connection == "" - assert cmd.sql == sql - assert cmd.sql_original == sql_original - - def test_parsed_sql_when_using_file(ip, sql_magic, tmp_empty): Path("query.sql").write_text("SELECT * FROM author") cmd = SQLCommand(sql_magic, ip.user_ns, "--file query.sql", "") diff --git a/src/tests/test_compose.py b/src/tests/test_compose.py deleted file mode 100644 index 3fd16f0c2..000000000 --- a/src/tests/test_compose.py +++ /dev/null @@ -1,18 +0,0 @@ -def test_compose(ip): - ip.run_cell_magic( - "sql", - "--save author_sub", - "SELECT last_name FROM author WHERE year_of_death > 1900", - ) - - ip.run_cell_magic( - "sql", - "--with author_sub --save final", - "SELECT last_name FROM author_sub;", - ) - - result = ip.run_cell("%sqlrender final").result - expected = "WITH `author_sub` AS (\nSELECT last_name \ -FROM author WHERE year_of_death > 1900)\nSELECT last_name FROM author_sub;" - - assert result == expected diff --git a/src/tests/test_extract_tables.py b/src/tests/test_extract_tables.py new file mode 100644 index 000000000..da39fb608 --- /dev/null +++ b/src/tests/test_extract_tables.py @@ -0,0 +1,85 @@ +import pytest +from sql.query_util import extract_tables_from_query + + +@pytest.mark.parametrize( + "query, expected", + [ + ( + """ + SELECT t.* + FROM tracks_with_info t + JOIN genres_fav + ON t.GenreId = genres_fav.GenreId + """, + ["tracks_with_info", "genres_fav"], + ), + ( + """ + SELECT city FROM Customers + UNION + SELECT city FROM Suppliers""", + ["Customers", "Suppliers"], + ), + ( + """ + SELECT OrderID, Quantity, +CASE + WHEN Quantity > 30 THEN 'The quantity is greater than 30' + WHEN Quantity = 30 THEN 'The quantity is 30' + ELSE 'The quantity is under 30' +END AS QuantityText +FROM OrderDetails;""", + ["OrderDetails"], + ), + ( + """ +SELECT COUNT(CustomerID), Country +FROM Customers +GROUP BY Country +HAVING COUNT(CustomerID) > 5;""", + ["Customers"], + ), + ( + """ +SELECT LEFT(sub.date, 2) AS cleaned_month, + sub.day_of_week, + AVG(sub.incidents) AS average_incidents + FROM ( + SELECT day_of_week, + date, + COUNT(incidnt_num) AS incidents + FROM tutorial.sf_crime_incidents_2014_01 + GROUP BY 1,2 + ) sub + GROUP BY 1,2 + ORDER BY 1,2""", + ["sf_crime_incidents_2014_01"], + ), + ( + """ + SELECT incidents.*, + sub.incidents AS incidents_that_day + FROM tutorial.sf_crime_incidents_2014_01 incidents + JOIN ( SELECT date, + COUNT(incidnt_num) AS incidents + FROM tutorial.sf_crime_incidents_2014_01 + GROUP BY 1 + ) sub + ON incidents.date = sub.date + ORDER BY sub.incidents DESC, time + """, + ["sf_crime_incidents_2014_01", "sf_crime_incidents_2014_01"], + ), + ], + ids=["join", "union", "case", "groupby", "subquery", "subquery_join"], +) +def test_extract(query, expected): + tables = extract_tables_from_query(query) + assert expected == tables + + +def test_invalid_query(): + query = "SELECT city frm Customers" + tables = extract_tables_from_query(query) + assert [] == tables diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index c811bbff6..7ac94b843 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1155,10 +1155,16 @@ def test_save_with_number_table( def test_save_with_non_existing_with(ip): - out = ip.run_cell( - "%sql --with non_existing_sub_query " "SELECT * FROM non_existing_sub_query" + with pytest.warns(FutureWarning) as record: + ip.run_cell( + "%sql --with non_existing_sub_query " "SELECT * FROM non_existing_sub_query" + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred, you can omit the " + "--with arguments. Using --with will raise an exception in the next " + "major release so please remove it." in record[0].message.args[0] ) - assert isinstance(out.error_in_exec, UsageError) def test_save_with_non_existing_table(ip, capsys): diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 7a530d274..1a56e9b4c 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -3,8 +3,47 @@ import pytest from IPython.core.error import UsageError from pathlib import Path + from sqlalchemy import create_engine from sql.connection import Connection +from sql.store import store + + +VALID_COMMANDS_MESSAGE = ( + "Valid commands are: tables, " "columns, test, profile, explore, snippets" +) + + +@pytest.fixture +def ip_snippets(ip): + for key in list(store): + del store[key] + ip.run_cell("%sql sqlite://") + ip.run_cell( + """ + %%sql --save high_price --no-execute +SELECT * +FROM "test_store" +WHERE price >= 1.50 +""" + ) + ip.run_cell( + """ + %%sql --save high_price_a --no-execute +SELECT * +FROM "high_price" +WHERE symbol == 'a' +""" + ) + ip.run_cell( + """ + %%sql --save high_price_b --no-execute +SELECT * +FROM "high_price" +WHERE symbol == 'b' +""" + ) + yield ip @pytest.mark.parametrize( @@ -13,32 +52,27 @@ [ "%sqlcmd", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd stuff", UsageError, - "%sqlcmd has no command: 'stuff'. " - "Valid commands are: tables, columns, test, profile, explore", + "%sqlcmd has no command: 'stuff'. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd columns", @@ -266,3 +300,79 @@ def test_test_error(ip, cell, error_type, error_message): assert isinstance(out.error_in_exec, error_type) assert str(out.error_in_exec) == error_message + + +def test_snippet(ip_snippets): + out = ip_snippets.run_cell("%sqlcmd snippets").result + assert "high_price, high_price_a, high_price_b" in out + + +@pytest.mark.parametrize("arg", ["--delete", "-d"]) +def test_delete_saved_key(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price_a").result + assert "high_price_a has been deleted.\n" in out + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price, high_price_b" in stored_snippets + assert "high_price_a" not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete-force", "-D"]) +def test_force_delete(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result + assert ( + "high_price has been deleted.\nhigh_price_a, " + "high_price_b depend on high_price\n" in out + ) + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price_a, high_price_b" in stored_snippets + assert "high_price," not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) +def test_force_delete_all(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result + assert "high_price_a, high_price_b, high_price has been deleted" in out + assert "There are no stored snippets" in out + + +@pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) +def test_force_delete_all_child_query(ip_snippets, arg): + ip_snippets.run_cell( + """ + %%sql --save high_price_b_child --no-execute +SELECT * +FROM "high_price_b" +WHERE symbol == 'b' +LIMIT 3 +""" + ) + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price_b").result + assert "high_price_b_child, high_price_b has been deleted" in out + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price, high_price_a" in stored_snippets + assert "high_price_b," not in stored_snippets + assert "high_price_b_child" not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete", "-d"]) +def test_delete_snippet_error(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price") + assert isinstance(out.error_in_exec, UsageError) + assert ( + str(out.error_in_exec) == "The following tables are dependent on high_price: " + "high_price_a, high_price_b.\nPass --delete-force to only " + "delete high_price.\nPass --delete-force-all to delete " + "high_price_a, high_price_b and high_price" + ) + + +@pytest.mark.parametrize( + "arg", ["--delete", "-d", "--delete-force-all", "-A", "--delete-force", "-D"] +) +def test_delete_invalid_snippet(arg, ip_snippets): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} non_existent_snippet") + assert isinstance(out.error_in_exec, UsageError) + assert ( + str(out.error_in_exec) == "No such saved snippet found " + ": non_existent_snippet" + ) diff --git a/src/tests/test_magic_cte.py b/src/tests/test_magic_cte.py index 95e96c0fe..a76a1b957 100644 --- a/src/tests/test_magic_cte.py +++ b/src/tests/test_magic_cte.py @@ -1,3 +1,7 @@ +import pytest +from IPython.core.error import UsageError + + def test_trailing_semicolons_removed_from_cte(ip): ip.run_cell( """%%sql --save positive_x @@ -32,3 +36,168 @@ def test_trailing_semicolons_removed_from_cte(ip): "FROM number_table WHERE y > 0)\nSELECT * FROM positive_x\n" "UNION\nSELECT * FROM positive_y;" ) + + +def test_infer_dependencies(ip, capsys): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_sub;", + ) + out, _ = capsys.readouterr() + result = ip.run_cell("%sqlrender final").result + expected = ( + "WITH `author_sub` AS (\nSELECT last_name FROM author " + "WHERE year_of_death > 1900)\nSELECT last_name FROM author_sub;" + ) + + assert result == expected + assert "Generating CTE with stored snippets : author_sub" in out + + +def test_deprecation_warning(ip): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + with pytest.warns(FutureWarning) as record: + ip.run_cell_magic( + "sql", + "--with author_sub --save final", + "SELECT last_name FROM author_sub;", + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred," + " you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please " + "remove it." in record[0].message.args[0] + ) + + +TABLE_NAME_TYPO_ERR_MSG = """ +There is no table with name 'author_subb'. +Did you mean : 'author_sub' +If you need help solving this issue, send us a message: https://ploomber.io/community +""" + + +def test_table_name_typo(ip): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_subb;", + ) + + assert excinfo.value.error_type == "TableNotFoundError" + assert str(excinfo.value) == TABLE_NAME_TYPO_ERR_MSG.strip() + + +def test_snippets_delete(ip, capsys): + ip.run_cell( + """ + %%sql sqlite:// + CREATE TABLE orders (order_id int, customer_id int, order_value float); + INSERT INTO orders VALUES (123, 15, 150.67); + INSERT INTO orders VALUES (124, 25, 200.66); + INSERT INTO orders VALUES (211, 15, 251.43); + INSERT INTO orders VALUES (312, 5, 333.41); + CREATE TABLE another_orders (order_id int, customer_id int, order_value float); + INSERT INTO another_orders VALUES (511,15, 150.67); + INSERT INTO another_orders VALUES (512, 30, 200.66); + CREATE TABLE customers (customer_id int, name varchar(25)); + INSERT INTO customers VALUES (15, 'John'); + INSERT INTO customers VALUES (25, 'Sheryl'); + INSERT INTO customers VALUES (5, 'Mike'); + INSERT INTO customers VALUES (30, 'Daisy'); + """ + ) + ip.run_cell_magic( + "sql", + "--save orders_less", + "SELECT * FROM orders WHERE order_value < 250.0", + ) + + ip.run_cell_magic( + "sql", + "--save another_orders", + "SELECT * FROM orders WHERE order_value > 250.0", + ) + + ip.run_cell_magic( + "sql", + "--save final", + """ + SELECT o.order_id, customers.name, o.order_value + FROM another_orders o + INNER JOIN customers ON o.customer_id=customers.customer_id; + """, + ) + + out, _ = capsys.readouterr() + assert "Generating CTE with stored snippets : another_orders" in out + result_del = ip.run_cell( + "%sqlcmd snippets --delete-force-all another_orders" + ).result + assert "final, another_orders has been deleted.\n" in result_del + stored_snippets = result_del[ + result_del.find("Stored snippets") + len("Stored snippets: ") : + ] + assert "orders_less" in stored_snippets + ip.run_cell_magic( + "sql", + "--save final", + """ + SELECT o.order_id, customers.name, o.order_value + FROM another_orders o + INNER JOIN customers ON o.customer_id=customers.customer_id; + """, + ) + result = ip.run_cell("%sqlrender final").result + expected = ( + "WITH\n\n SELECT o.order_id, customers.name, " + "o.order_value\n " + "FROM another_orders o\n INNER JOIN customers " + "ON o.customer_id=customers.customer_id" + ) + assert expected in result + + +SYNTAX_ERROR_MESSAGE = """ +Syntax Error in WITH `author_sub` AS ( +SELECT last_name FRM author WHERE year_of_death > 1900) +SELECT last_name FROM author_sub: Expecting ( at Line 1, Column 16 +""" + + +def test_query_syntax_error(ip): + ip.run_cell_magic( + "sql", + "--save author_sub --no-execute", + "SELECT last_name FRM author WHERE year_of_death > 1900", + ) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_sub;", + ) + + assert excinfo.value.error_type == "UsageError" + assert SYNTAX_ERROR_MESSAGE.strip() in str(excinfo.value) diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index 5f78373c4..8e793cdca 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -11,6 +11,35 @@ plot_str = util.pretty_print(SUPPORTED_PLOTS, last_delimiter="or") +@pytest.fixture +def ip_snippets(ip, tmp_empty): + Path("data.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +""" + ) + ip.run_cell("%sql duckdb://") + + ip.run_cell( + """%%sql --save subset --no-execute +SELECT * +FROM data.csv +WHERE x > -1 +""" + ) + ip.run_cell( + """%%sql --save subset_another --no-execute +SELECT * +FROM subset +WHERE x > 2 +""" + ) + yield ip + + @pytest.mark.parametrize( "cell, error_type, error_message", [ @@ -69,7 +98,7 @@ def test_validate_arguments(tmp_empty, ip, cell, error_type, error_message): "%sqlplot boxplot --table data.csv --column x", "%sqlplot box --table data.csv --column x", "%sqlplot boxplot --table data.csv --column x --orient h", - "%sqlplot boxplot --table subset --column x --with subset", + "%sqlplot boxplot --table subset --column x", "%sqlplot boxplot -t subset -c x -w subset -o h", "%sqlplot boxplot --table nas.csv --column x", "%sqlplot bar -t data.csv -c x", @@ -318,3 +347,43 @@ def test_pie_one_col_num(load_data_one_col, ip): @image_comparison(baseline_images=["pie_two_col"], extensions=["png"], remove_text=True) def test_pie_two_col(load_data_two_col, ip): ip.run_cell("%sqlplot pie -t data_two.csv -c x y") + + +def test_sqlplot_deprecation_warning(ip_snippets, capsys): + with pytest.warns(FutureWarning) as record: + res = ip_snippets.run_cell( + "%sqlplot boxplot --table subset --column x --with subset" + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred," + " you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please " + "remove it." in record[0].message.args[0] + ) + out, err = capsys.readouterr() + assert type(res.result).__name__ in {"Axes", "AxesSubplot"} + assert "Plotting using saved snippet : subset" in out + + +@pytest.mark.parametrize( + "arg", ["--delete", "-d", "--delete-force-all", "-A", "--delete-force", "-D"] +) +def test_sqlplot_snippet_deletion(ip_snippets, arg, capsys): + ip_snippets.run_cell(f"%sqlcmd snippets {arg} subset_another") + ip_snippets.run_cell("%sqlplot boxplot --table subset_another --column x") + out, err = capsys.readouterr() + assert "There is no table with name 'subset_another' in the default schema" in err + + +TABLE_NAME_TYPO_MSG = """ +UsageError: There is no table with name 'subst' in the default schema +Did you mean : 'subset' +If you need help solving this issue, send us a message: https://ploomber.io/community +""" + + +def test_sqlplot_snippet_typo(ip_snippets, capsys): + ip_snippets.run_cell("%sqlplot boxplot --table subst --column x") + out, err = capsys.readouterr() + assert TABLE_NAME_TYPO_MSG.strip() == err.strip() diff --git a/src/tests/test_util.py b/src/tests/test_util.py index f3f6c04ba..74a8c8cae 100644 --- a/src/tests/test_util.py +++ b/src/tests/test_util.py @@ -15,41 +15,41 @@ ) -@pytest.mark.parametrize( - "store_table, query", - [ - ("a", "%sqlcmd columns --table {}"), - ("bbb", "%sqlcmd profile --table {}"), - ("c_c", "%sqlplot histogram --table {} --column x"), - ("d_d_d", "%sqlplot boxplot --table {} --column x"), - ], -) -def test_missing_with(ip, store_table, query): +@pytest.fixture +def ip_snippets(ip): ip.run_cell( - f""" - %%sql --save {store_table} --no-execute - SELECT * - FROM number_table """ - ).result - - query = query.format(store_table) - out = ip.run_cell(query) - - expected_store_message = EXPECTED_STORE_SUGGESTIONS.format(store_table) - - error_message = str(out.error_in_exec) - assert isinstance(out.error_in_exec, UsageError) - assert str(expected_store_message).lower() in error_message.lower() +%%sql --save a --no-execute +SELECT * +FROM number_table +""" + ) + ip.run_cell( + """ + %%sql --save b --no-execute + SELECT * + FROM a + WHERE x > 5 + """ + ) + ip.run_cell( + """ + %%sql --save c --no-execute + SELECT * + FROM a + WHERE x < 5 + """ + ) + yield ip @pytest.mark.parametrize( "store_table, query", [ - ("a", "%sqlcmd columns --table {} --with {}"), - ("bbb", "%sqlcmd profile --table {} --with {}"), - ("c_c", "%sqlplot histogram --table {} --with {} --column x"), - ("d_d_d", "%sqlplot boxplot --table {} --with {} --column x"), + ("a", "%sqlcmd columns --table {}"), + ("bbb", "%sqlcmd profile --table {}"), + ("c_c", "%sqlplot histogram --table {} --column x"), + ("d_d_d", "%sqlplot boxplot --table {} --column x"), ], ) def test_no_errors_with_stored_query(ip, store_table, query): @@ -237,7 +237,7 @@ def test_is_table_exists_with(ip, table, expected_error, expected_suggestions): ) if expected_error: with pytest.raises(expected_error) as error: - util.is_table_exists(table, with_=with_) + util.is_table_exists(table) error_suggestions_arr = str(error.value).split(EXPECTED_SUGGESTIONS_MESSAGE) @@ -444,3 +444,28 @@ def test_parse_sql_results_to_json(ip, capsys, rows, columns, expected_json): j = json.loads(j) with capsys.disabled(): assert str(j) == str(expected_json) + + +def test_get_all_keys(ip_snippets): + keys = util.get_all_keys() + assert "a" in keys + assert "b" in keys + assert "c" in keys + + +def test_get_key_dependents(ip_snippets): + keys = util.get_key_dependents("a") + assert "b" in keys + assert "c" in keys + + +def test_del_saved_key(ip_snippets): + keys = util.del_saved_key("c") + assert "a" in keys + assert "b" in keys + + +def test_del_saved_key_error(ip_snippets): + with pytest.raises(UsageError) as excinfo: + util.del_saved_key("non_existent_key") + assert "No such saved snippet found : non_existent_key" in str(excinfo.value)