From 2552066987cacd1b4962cf83959ecc2ac0340700 Mon Sep 17 00:00:00 2001 From: neelasha23 Date: Thu, 11 May 2023 13:29:36 +0530 Subject: [PATCH 1/3] --with deprecation in CTE CTE depr Warning msg, sqlglot Lint tests warn test conditions for extract Empty commit removed with test exception handling Fixed tests Lint typo error test condition for typo suggestions plot changes merge conflict fix test Fix test fix test tests list of with sqlcmd delete sqlcmd tests rebase Fixes fixed test fixes removed test non-existent tbl Test fix fix typo test fix typo test space removed delete snippet test test fix Docs docs toc changelog modified args modified versionchanged added message changed tests sqlcmd added deprecation warning changed tests depr warning removed if exists debug prints added clean_conns rebase test fix rebase issues revert error msg test for snippet utils telemetry removed Minor changes Moved string refactor plot tests plot_typo Renamed test file Removed line test fix test fix Review comments added issue numbers rebase lint --- CHANGELOG.md | 2 + doc/_toc.yml | 1 + doc/api/magic-snippets.md | 135 ++++++++++++++++++++++++ doc/compose.md | 21 ++-- doc/plot.md | 10 +- src/sql/command.py | 16 ++- src/sql/error_message.py | 41 +++++--- src/sql/magic.py | 28 ++++- src/sql/magic_cmd.py | 26 ++++- src/sql/magic_plot.py | 21 +++- src/sql/query_util.py | 29 ++++++ src/sql/sqlcmd.py | 90 ++++++++++++++++ src/sql/store.py | 22 ++++ src/sql/util.py | 131 +++++++++++++++++------- src/tests/conftest.py | 8 +- src/tests/test_command.py | 31 ------ src/tests/test_compose.py | 18 ---- src/tests/test_extract_tables.py | 85 ++++++++++++++++ src/tests/test_magic.py | 12 ++- src/tests/test_magic_cmd.py | 130 ++++++++++++++++++++++-- src/tests/test_magic_cte.py | 169 +++++++++++++++++++++++++++++++ src/tests/test_magic_plot.py | 71 ++++++++++++- src/tests/test_util.py | 83 +++++++++------ 23 files changed, 1008 insertions(+), 172 deletions(-) create mode 100644 doc/api/magic-snippets.md create mode 100644 src/sql/query_util.py create mode 100644 src/sql/sqlcmd.py delete mode 100644 src/tests/test_compose.py create mode 100644 src/tests/test_extract_tables.py diff --git a/CHANGELOG.md b/CHANGELOG.md index d3f308def..77edcc255 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## 0.7.8dev * [Feature] Add `%sqlplot bar` and `%sqlplot pie` +* [Feature] Automated dependency inference when creating CTEs. `--with` is now deprecated and will display a warning. (#166) + ## 0.7.7 (2023-05-31) diff --git a/doc/_toc.yml b/doc/_toc.yml index f20d771ab..06c920e34 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -42,6 +42,7 @@ parts: - file: api/magic-sql - file: api/magic-plot - file: api/magic-render + - file: api/magic-snippets - file: api/configuration - file: api/python - file: api/magic-tables-columns diff --git a/doc/api/magic-snippets.md b/doc/api/magic-snippets.md new file mode 100644 index 000000000..d84d6c9ef --- /dev/null +++ b/doc/api/magic-snippets.md @@ -0,0 +1,135 @@ +--- +jupytext: + notebook_metadata_filter: myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.5 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +myst: + html_meta: + description lang=en: Documentation for %sqlcmd snippets + from JupySQL + keywords: jupyter, sql, jupysql, snippets + property=og:locale: en_US +--- + +# `%sqlcmd snippets` + +`%sqlcmd snippets` returns the query snippets saved using `--save` + +## Load Data + +```{code-cell} ipython3 +%load_ext sql +%sql duckdb:// +``` + +```{code-cell} ipython3 +from pathlib import Path +from urllib.request import urlretrieve + +if not Path("penguins.csv").is_file(): + urlretrieve( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv", + "penguins.csv", + ) +``` + +```{code-cell} ipython3 +%%sql +SELECT * FROM penguins.csv LIMIT 3 +``` + +Let's save a couple of snippets. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save gentoo +SELECT * FROM penguins.csv where species == 'Gentoo' +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap +SELECT * FROM penguins.csv where species == 'Chinstrap' +``` + +## `%sqlcmd snippets` + ++++ + +Returns all the snippets saved in the environment + +```{code-cell} ipython3 +%sqlcmd snippets +``` + +Arguments: + +`-d`/`--delete` Delete a snippet. + +`-D`/`--delete-force` Force delete a snippet. This may be useful if there are other dependent snippets, and you still need to delete this snippet. + +`-A`/`--delete-force-all` Force delete a snippet and all dependent snippets. + +```{code-cell} ipython3 + +%sqlcmd snippets -d gentoo +``` + +This deletes the stored snippet `gentoo`. + +To demonstrate `force-delete` let's create a snippet dependent on `chinstrap` snippet. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap_sub +SELECT * FROM chinstrap where island == 'Dream' +``` ++++ + +Trying to delete the `chinstrap` snippet will display an error message: + +```{code-cell} ipython3 +:tags: [raises-exception] + +%sqlcmd snippets -d chinstrap +``` + +If you still wish to delete this snippet, you can run the below command: + +```{code-cell} ipython3 + +%sqlcmd snippets -D chinstrap +``` + +Now, let's see how to delete a snippet and all other dependent snippets. We'll create a few snippets again. + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap +SELECT * FROM penguins.csv where species == 'Chinstrap' +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%%sql --save chinstrap_sub +SELECT * FROM chinstrap where island == 'Dream' +``` + +Now, force delete `chinstrap` and its dependent `chinstrap_sub`: + +```{code-cell} ipython3 + +%sqlcmd snippets -A chinstrap +``` diff --git a/doc/compose.md b/doc/compose.md index 199bab303..f8016f009 100644 --- a/doc/compose.md +++ b/doc/compose.md @@ -27,7 +27,8 @@ pip install jupysql matplotlib ``` -*New in version 0.4.3* +```{versionchanged} 0.7.8 +``` ```{note} This is a beta feature, please [join our community](https://ploomber.io/community) and @@ -105,31 +106,39 @@ OR Name LIKE '%metal%' Join the filtered genres and tracks, so we only get Rock and Metal tracks, and save the query as `track_fav` -Note that we are using `--with`; this will retrieve previously saved queries, and prepend them (using CTEs), then, we save the query in `track_fav` . + +We automatically extract the tables from the query and infer the dependencies from all the saved snippets. + ```{code-cell} ipython3 -%%sql --with genres_fav --with tracks_with_info --save track_fav +%%sql --save track_fav SELECT t.* FROM tracks_with_info t JOIN genres_fav ON t.GenreId = genres_fav.GenreId ``` -Use the `track_fav` query to find artists with the most Rock and Metal tracks, and save the query as `top_artist` +Now let's find artists with the most Rock and Metal tracks, and save the query as `top_artist` ```{code-cell} ipython3 -%%sql --with track_fav --save top_artist +%%sql --save top_artist SELECT artist, COUNT(*) FROM track_fav GROUP BY artist ORDER BY COUNT(*) DESC ``` + +```{note} +A saved snippet will override an existing table with the same name during query formation. If you wish to delete a snippet please refer to [sqlcmd snippets API](api/magic-snippets.md). + +``` + #### Data visualization Once we have the desired results from the query `top_artist`, we can generate a visualization using the bar method ```{code-cell} ipython3 -top_artist = %sql --with top_artist SELECT * FROM top_artist +top_artist = %sql SELECT * FROM top_artist top_artist.bar() ``` diff --git a/doc/plot.md b/doc/plot.md index 98c51c499..222754e81 100644 --- a/doc/plot.md +++ b/doc/plot.md @@ -121,10 +121,10 @@ FROM "yellow_tripdata_2021-01.parquet" WHERE trip_distance < 6.3 ``` -Now, let's plot again, but this time let's pass `--table short_trips`. Note that this table *doesn't exist*; however, since we're passing the `--with` argument, JupySQL will use the query we defined above: +Now, let's plot again, but this time let's pass `--table short_trips`. Note that this table *doesn't exist*; JupySQL will automatically infer and use the saved snippet defined above. ```{code-cell} ipython3 -%sqlplot boxplot --table short_trips --column trip_distance --with short_trips +%sqlplot boxplot --table short_trips --column trip_distance ``` We can see the highest value is a bit over 6, that's expected since we set a 6.3 cutoff value. @@ -133,10 +133,10 @@ We can see the highest value is a bit over 6, that's expected since we set a 6.3 ## Histogram -To create a histogram, call `%sqlplot histogram`, and pass the name of the table, the column you want to plot, and the number of bins. Similarly to what we did in the [Boxplot](#boxplot) example, we're using `--with short_trips` so JupySQL uses the query we defined and only plots such data subset. +To create a histogram, call `%sqlplot histogram`, and pass the name of the table, the column you want to plot, and the number of bins. Similarly to what we did in the [Boxplot](#boxplot) example, JupySQL detects a saved snippet and only plots such data subset. ```{code-cell} ipython3 -%sqlplot histogram --table short_trips --column trip_distance --bins 10 --with short_trips +%sqlplot histogram --table short_trips --column trip_distance --bins 10 ``` ## Customize plot @@ -144,7 +144,7 @@ To create a histogram, call `%sqlplot histogram`, and pass the name of the table `%sqlplot` returns a `matplotlib.Axes` object that you can further customize: ```{code-cell} ipython3 -ax = %sqlplot histogram --table short_trips --column trip_distance --bins 50 --with short_trips +ax = %sqlplot histogram --table short_trips --column trip_distance --bins 50 ax.grid() ax.set_title("Trip distance from trips < 6.3") _ = ax.set_xlabel("Trip distance") diff --git a/src/sql/command.py b/src/sql/command.py index 2f4ad7093..b73355519 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -71,10 +71,6 @@ def __init__(self, magic, user_ns, line, cell) -> None: if add_alias: self.parsed["connection"] = self.args.line[0] - if self.args.with_: - final = store.render(self.parsed["sql"], with_=self.args.with_) - self.parsed["sql"] = str(final) - @property def sql(self): """ @@ -112,3 +108,15 @@ def __repr__(self) -> str: f"{type(self).__name__}(line={self._line!r}, cell={self._cell!r}) -> " f"({self.sql!r}, {self.sql_original!r})" ) + + def set_sql_with(self, with_): + """ + Sets the final rendered SQL query using the WITH clause + + Parameters + ---------- + with_ : list + list of all subqueries needed to render the query + """ + final = store.render(self.parsed["sql"], with_) + self.parsed["sql"] = str(final) diff --git a/src/sql/error_message.py b/src/sql/error_message.py index b07c2cb79..640dc94bf 100644 --- a/src/sql/error_message.py +++ b/src/sql/error_message.py @@ -5,6 +5,33 @@ ORIGINAL_ERROR = "\nOriginal error message from DB driver:\n" +def parse_sqlglot_error(e, q): + """ + Function to parse the error message from sqlglot + + Parameters + ---------- + e: sqlglot.errors.ParseError, exception + while parsing through sqlglot + q : str, user query + + Returns + ------- + str + Formatted error message containing description + and positions + """ + err = e.errors + position = "" + for item in err: + position += ( + f"Syntax Error in {q}: {item['description']} at " + f"Line {item['line']}, Column {item['col']}\n" + ) + msg = "Possible reason: \n" + position if position else "" + return msg + + def detail(original_error, query=None): original_error = str(original_error) return_msg = SYNTAX_ERROR @@ -25,18 +52,8 @@ def detail(original_error, query=None): ) except sqlglot.errors.ParseError as e: - err = e.errors - position = "" - for item in err: - position += ( - f"Syntax Error in {q}: {item['description']} at " - f"Line {item['line']}, Column {item['col']}\n" - ) - return_msg = ( - return_msg + "Possible reason: \n" + position - if position - else return_msg - ) + parse_msg = parse_sqlglot_error(e, q) + return_msg = return_msg + parse_msg if parse_msg else return_msg return return_msg + "\n" + ORIGINAL_ERROR + original_error + "\n" diff --git a/src/sql/magic.py b/src/sql/magic.py index 24e8277fa..692b1f686 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -18,6 +18,7 @@ from sqlalchemy.exc import OperationalError, ProgrammingError, DatabaseError import warnings +from difflib import get_close_matches import sql.connection import sql.parse import sql.run @@ -27,12 +28,15 @@ from sql.magic_plot import SqlPlotMagic from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error +from sql import query_util +from sql.util import get_suggestions_message, show_deprecation_warning from ploomber_core.dependencies import check_installed from sql.error_message import detail from traitlets.config.configurable import Configurable from traitlets import Bool, Int, TraitError, Unicode, Dict, observe, validate + try: from pandas.core.frame import DataFrame, Series except ModuleNotFoundError: @@ -323,6 +327,13 @@ def interactive_execute_wrapper(**kwargs): args = command.args + with_ = self._store.infer_dependencies(command.sql_original, args.save) + if with_: + command.set_sql_with(with_) + print(f"Generating CTE with stored snippets : {', '.join(with_)}") + else: + with_ = None + # Create the interactive slider if args.interact and not is_interactive_mode: check_installed(["ipywidgets"], "--interactive argument") @@ -410,6 +421,8 @@ def interactive_execute_wrapper(**kwargs): if not command.sql: return + if args.with_: + show_deprecation_warning() # store the query if needed if args.save: if "-" in args.save: @@ -420,7 +433,7 @@ def interactive_execute_wrapper(**kwargs): + " instead for the save argument.", FutureWarning, ) - self._store.store(args.save, command.sql_original, with_=args.with_) + self._store.store(args.save, command.sql_original, with_=with_) if args.no_execute: display.message("Skipping execution...") @@ -469,6 +482,19 @@ def interactive_execute_wrapper(**kwargs): if detailed_msg is not None: err = exceptions.UsageError(detailed_msg) raise err + # TODO: move to error_messages.py + # Added here due to circular dependency issue (#545) + elif "no such table" in str(e): + tables = query_util.extract_tables_from_query(command.sql) + for table in tables: + suggestions = get_close_matches(table, list(self._store)) + if len(suggestions) > 0: + err_message = f"There is no table with name {table!r}." + suggestions_message = get_suggestions_message(suggestions) + raise exceptions.TableNotFoundError( + f"{err_message}{suggestions_message}" + ) + print(e) else: print(e) else: diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index f477ea675..56340d840 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -34,6 +34,10 @@ def error(self, message): raise exceptions.UsageError(message) +# Added here due to circular dependencies (#545) +from sql.sqlcmd import sqlcmd_snippets # noqa + + @magics_class class SqlCmdMagic(Magics, Configurable): """%sqlcmd magic""" @@ -50,13 +54,22 @@ def _validate_execute_inputs(self, line): # We rely on SQLAlchemy when inspecting tables util.support_only_sql_alchemy_connection("%sqlcmd") - AVAILABLE_SQLCMD_COMMANDS = ["tables", "columns", "test", "profile", "explore"] + AVAILABLE_SQLCMD_COMMANDS = [ + "tables", + "columns", + "test", + "profile", + "explore", + "snippets", + ] + + VALID_COMMANDS_MSG = ( + f"Missing argument for %sqlcmd. " + f"Valid commands are: {', '.join(AVAILABLE_SQLCMD_COMMANDS)}" + ) if line == "": - raise exceptions.UsageError( - "Missing argument for %sqlcmd. " - "Valid commands are: {}".format(", ".join(AVAILABLE_SQLCMD_COMMANDS)) - ) + raise exceptions.UsageError(VALID_COMMANDS_MSG) else: split = arg_split(line) command, others = split[0].strip(), split[1:] @@ -223,6 +236,9 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): table_widget = TableWidget(args.table) display(table_widget) + elif cmd_name == "snippets": + return sqlcmd_snippets(others) + def return_test_results(args, conn, query): try: diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index c2c43e774..c365fb4f4 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -83,23 +83,25 @@ def execute(self, line="", cell="", local_ns=None): column = util.sanitize_identifier(column) table = util.sanitize_identifier(cmd.args.table) + if cmd.args.with_: + util.show_deprecation_warning() if cmd.args.line[0] in {"box", "boxplot"}: - util.is_table_exists(table, with_=cmd.args.with_) + with_ = self._check_table_exists(table) + return plot.boxplot( table=table, column=column, - with_=cmd.args.with_, + with_=with_, orient=cmd.args.orient, conn=None, ) elif cmd.args.line[0] in {"hist", "histogram"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.histogram( table=table, column=column, bins=cmd.args.bins, - with_=cmd.args.with_, + with_=with_, conn=None, ) elif cmd.args.line[0] in {"bar"}: @@ -128,3 +130,12 @@ def execute(self, line="", cell="", local_ns=None): raise exceptions.UsageError( f"Unknown plot {cmd.args.line[0]!r}. Must be any of: " f"{plot_str}" ) + + @staticmethod + def _check_table_exists(table): + with_ = None + if util.is_saved_snippet(table): + with_ = [table] + else: + util.is_table_exists(table) + return with_ diff --git a/src/sql/query_util.py b/src/sql/query_util.py new file mode 100644 index 000000000..fb342b812 --- /dev/null +++ b/src/sql/query_util.py @@ -0,0 +1,29 @@ +from sqlglot import parse_one, exp +from sqlglot.errors import ParseError + + +def extract_tables_from_query(query): + """ + Function to extract names of tables from + a syntactically correct query + + Parameters + ---------- + query : str, user query + + Returns + ------- + list + List of tables in the query + [] if error in parsing the query + """ + try: + tables = [table.name for table in parse_one(query).find_all(exp.Table)] + return tables + except ParseError: + # TODO : Instead of returning [] replace with call to + # error_messages.py::parse_sqlglot_error. Currently this + # is not possible because of an exception raised in test + # fixtures. (#546). This function can also be moved to util.py + # after #545 is resolved. + return [] diff --git a/src/sql/sqlcmd.py b/src/sql/sqlcmd.py new file mode 100644 index 000000000..2d6bd1b56 --- /dev/null +++ b/src/sql/sqlcmd.py @@ -0,0 +1,90 @@ +from sql.magic_cmd import CmdParser +from sql import util +from sql.exceptions import UsageError + + +def _modify_display_msg(key, remaining_keys, dependent_keys=None): + """ + + Parameters + ---------- + key : str, + deleted stored snippet + remaining_keys: list + snippets remaining after key is deleted + dependent_keys: list + snippets dependent on key + + Returns + ------- + msg: str + Formatted message + """ + msg = f"{key} has been deleted.\n" + if dependent_keys: + msg = f"{msg}{', '.join(dependent_keys)} depend on {key}\n" + if remaining_keys: + msg = f"{msg}Stored snippets : {', '.join(remaining_keys)}" + else: + msg = f"{msg}There are no stored snippets" + return msg + + +def sqlcmd_snippets(others): + """ + + Parameters + ---------- + This function handles all the arguments related to %sqlcmd snippets, namely + listing stored snippets, and delete/ force delete/ force delete a snippet and + all its dependent snippets. + + """ + parser = CmdParser() + parser.add_argument( + "-d", "--delete", type=str, help="Delete stored snippet", required=False + ) + parser.add_argument( + "-D", + "--delete-force", + type=str, + help="Force delete stored snippet", + required=False, + ) + parser.add_argument( + "-A", + "--delete-force-all", + type=str, + help="Force delete all stored snippets", + required=False, + ) + args = parser.parse_args(others) + SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all] + if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS): + return ", ".join(util.get_all_keys()) + if args.delete: + deps = util.get_key_dependents(args.delete) + if deps: + deps = ", ".join(deps) + raise UsageError( + f"The following tables are dependent on {args.delete}: {deps}.\n" + f"Pass --delete-force to only delete {args.delete}.\n" + f"Pass --delete-force-all to delete {deps} and {args.delete}" + ) + else: + key = args.delete + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(key, remaining_keys) + + elif args.delete_force: + key = args.delete_force + deps = util.get_key_dependents(key) + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(key, remaining_keys, deps) + + elif args.delete_force_all: + deps = util.get_key_dependents(args.delete_force_all) + deps.append(args.delete_force_all) + for key in deps: + remaining_keys = util.del_saved_key(key) + return _modify_display_msg(", ".join(deps), remaining_keys) diff --git a/src/sql/store.py b/src/sql/store.py index f0aa7303f..ffad19e6b 100644 --- a/src/sql/store.py +++ b/src/sql/store.py @@ -6,6 +6,7 @@ import difflib from sql import exceptions +from sql import query_util class SQLStore(MutableMapping): @@ -68,6 +69,18 @@ def render(self, query, with_=None): # TODO: if with is false, WITH should not appear return SQLQuery(self, query, with_) + def infer_dependencies(self, query, key): + dependencies = [] + saved_keys = [ + saved_key for saved_key in list(self._data.keys()) if saved_key != key + ] + if saved_keys and query: + tables = query_util.extract_tables_from_query(query) + for table in tables: + if table in saved_keys: + dependencies.append(table) + return dependencies + @modify_exceptions def store(self, key, query, with_=None): if "-" in key: @@ -138,6 +151,15 @@ def _get_dependencies(store, keys): return list(dict.fromkeys(deps + keys)) +def _get_dependents_for_key(store, key): + key_dependents = [] + for k in list(store): + deps = _get_dependencies_for_key(store, k) + if key in deps: + key_dependents.append(k) + return key_dependents + + def _get_dependencies_for_key(store, key): """Retrieve dependencies for a single key""" deps = store[key]._with_ diff --git a/src/sql/util.py b/src/sql/util.py index 65eac2f8a..19473ace6 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -1,8 +1,9 @@ +import warnings import sql from sql import inspect import difflib from sql.connection import Connection -from sql.store import store +from sql.store import store, _get_dependents_for_key from sql import exceptions import json @@ -52,11 +53,18 @@ def _is_long_number(num) -> bool: return False +def get_suggestions_message(suggestions): + suggestions_message = "" + if len(suggestions) > 0: + _suggestions_string = pretty_print(suggestions, last_delimiter="or") + suggestions_message = f"\nDid you mean : {_suggestions_string}" + return suggestions_message + + def is_table_exists( table: str, schema: str = None, ignore_error: bool = False, - with_: str = None, conn=None, ) -> bool: """ @@ -70,9 +78,6 @@ def is_table_exists( schema: str, default None Schema name - with_: list, default None - Temporary table - ignore_error: bool, default False Avoid raising a ValueError """ @@ -93,7 +98,7 @@ def is_table_exists( else: table_ = table - _is_exist = _is_table_exists(table_, with_, conn) + _is_exist = _is_table_exists(table_, conn) if not _is_exist: if not ignore_error: @@ -124,22 +129,12 @@ def is_table_exists( f"There is no table with name {table!r} in the default schema" ) - if table in list(store): - # Suggest user use --with when given table - # is in the store - suggestion_message = ( - ", but there is a stored query." - f"\nDid you miss passing --with {table}?" - ) - err_message = f"{err_message}{suggestion_message}" - else: - suggestions = difflib.get_close_matches(invalid_input, expected) - - if len(suggestions) > 0: - _suggestions_string = pretty_print(suggestions, last_delimiter="or") - suggestions_message = f"\nDid you mean : {_suggestions_string}" - err_message = f"{err_message}{suggestions_message}" - + suggestions = difflib.get_close_matches(invalid_input, expected) + suggestions_store = difflib.get_close_matches(invalid_input, list(store)) + suggestions.extend(suggestions_store) + suggestions_message = get_suggestions_message(suggestions) + if suggestions_message: + err_message = f"{err_message}{suggestions_message}" raise exceptions.TableNotFoundError(err_message) return _is_exist @@ -182,7 +177,14 @@ def strip_multiple_chars(string: str, chars: str) -> str: return string.translate(str.maketrans("", "", chars)) -def _is_table_exists(table: str, with_: str, conn) -> bool: +def is_saved_snippet(table: str) -> bool: + if table in list(store): + print(f"Plotting using saved snippet : {table}") + return True + return False + + +def _is_table_exists(table: str, conn) -> bool: """ Runs a SQL query to check if table exists """ @@ -191,21 +193,16 @@ def _is_table_exists(table: str, with_: str, conn) -> bool: identifiers = conn.get_curr_identifiers() - if with_: - return table in list(store) - else: - for iden in identifiers: - if isinstance(iden, tuple): - query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format( - iden[0], table, iden[1] - ) - else: - query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) - try: - conn.execute(query) - return True - except Exception: - pass + for iden in identifiers: + if isinstance(iden, tuple): + query = "SELECT * FROM {0}{1}{2} WHERE 1=0".format(iden[0], table, iden[1]) + else: + query = "SELECT * FROM {0}{1}{0} WHERE 1=0".format(iden, table) + try: + conn.execute(query) + return True + except Exception: + pass return False @@ -305,3 +302,61 @@ def parse_sql_results_to_json(rows, columns) -> str: ) return rows_json + + +def get_all_keys(): + """ + + Returns + ------- + All stored snippets in the current session + """ + return list(store) + + +def get_key_dependents(key: str) -> list: + """ + Function to find the stored snippets dependent on key + Parameters + ---------- + key : str, name of the table + + Returns + ------- + list + List of snippets dependent on key + + """ + deps = _get_dependents_for_key(store, key) + return deps + + +def del_saved_key(key: str) -> str: + """ + Deletes a stored snippet + Parameters + ---------- + key : str, name of the snippet to be deleted + + Returns + ------- + list + Remaining stored snippets + """ + all_keys = get_all_keys() + if key not in all_keys: + raise exceptions.UsageError(f"No such saved snippet found : {key}") + del store[key] + return get_all_keys() + + +def show_deprecation_warning(): + """ + Raises CTE deprecation warning + """ + warnings.warn( + "CTE dependencies are now automatically inferred, " + "you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please remove it.", + FutureWarning, + ) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 63ce88830..1802bea99 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -95,10 +95,10 @@ def ip(ip_empty): ], ) yield ip_empty - runsql(ip_empty, "DROP TABLE test") - runsql(ip_empty, "DROP TABLE author") - runsql(ip_empty, "DROP TABLE website") - runsql(ip_empty, "DROP TABLE number_table") + runsql(ip_empty, "DROP TABLE IF EXISTS test") + runsql(ip_empty, "DROP TABLE IF EXISTS author") + runsql(ip_empty, "DROP TABLE IF EXISTS website") + runsql(ip_empty, "DROP TABLE IF EXISTS number_table") @pytest.fixture diff --git a/src/tests/test_command.py b/src/tests/test_command.py index 7b73b14d5..89a085c26 100644 --- a/src/tests/test_command.py +++ b/src/tests/test_command.py @@ -95,37 +95,6 @@ def test_parsed( assert cmd.sql_original == parsed_sql -def test_parsed_sql_when_using_with(ip, sql_magic): - ip.run_cell_magic( - "sql", - "--save author_one", - """ - SELECT * FROM author LIMIT 1 - """, - ) - - cmd = SQLCommand( - sql_magic, ip.user_ns, line="--with author_one", cell="SELECT * FROM author_one" - ) - - sql = "WITH `author_one` AS (\n\n SELECT * FROM author LIMIT 1\n )\n\ -SELECT * FROM author_one" - - sql_original = "\nSELECT * FROM author_one" - - assert cmd.parsed == { - "connection": "", - "result_var": None, - "return_result_var": False, - "sql": sql, - "sql_original": sql_original, - } - - assert cmd.connection == "" - assert cmd.sql == sql - assert cmd.sql_original == sql_original - - def test_parsed_sql_when_using_file(ip, sql_magic, tmp_empty): Path("query.sql").write_text("SELECT * FROM author") cmd = SQLCommand(sql_magic, ip.user_ns, "--file query.sql", "") diff --git a/src/tests/test_compose.py b/src/tests/test_compose.py deleted file mode 100644 index 3fd16f0c2..000000000 --- a/src/tests/test_compose.py +++ /dev/null @@ -1,18 +0,0 @@ -def test_compose(ip): - ip.run_cell_magic( - "sql", - "--save author_sub", - "SELECT last_name FROM author WHERE year_of_death > 1900", - ) - - ip.run_cell_magic( - "sql", - "--with author_sub --save final", - "SELECT last_name FROM author_sub;", - ) - - result = ip.run_cell("%sqlrender final").result - expected = "WITH `author_sub` AS (\nSELECT last_name \ -FROM author WHERE year_of_death > 1900)\nSELECT last_name FROM author_sub;" - - assert result == expected diff --git a/src/tests/test_extract_tables.py b/src/tests/test_extract_tables.py new file mode 100644 index 000000000..da39fb608 --- /dev/null +++ b/src/tests/test_extract_tables.py @@ -0,0 +1,85 @@ +import pytest +from sql.query_util import extract_tables_from_query + + +@pytest.mark.parametrize( + "query, expected", + [ + ( + """ + SELECT t.* + FROM tracks_with_info t + JOIN genres_fav + ON t.GenreId = genres_fav.GenreId + """, + ["tracks_with_info", "genres_fav"], + ), + ( + """ + SELECT city FROM Customers + UNION + SELECT city FROM Suppliers""", + ["Customers", "Suppliers"], + ), + ( + """ + SELECT OrderID, Quantity, +CASE + WHEN Quantity > 30 THEN 'The quantity is greater than 30' + WHEN Quantity = 30 THEN 'The quantity is 30' + ELSE 'The quantity is under 30' +END AS QuantityText +FROM OrderDetails;""", + ["OrderDetails"], + ), + ( + """ +SELECT COUNT(CustomerID), Country +FROM Customers +GROUP BY Country +HAVING COUNT(CustomerID) > 5;""", + ["Customers"], + ), + ( + """ +SELECT LEFT(sub.date, 2) AS cleaned_month, + sub.day_of_week, + AVG(sub.incidents) AS average_incidents + FROM ( + SELECT day_of_week, + date, + COUNT(incidnt_num) AS incidents + FROM tutorial.sf_crime_incidents_2014_01 + GROUP BY 1,2 + ) sub + GROUP BY 1,2 + ORDER BY 1,2""", + ["sf_crime_incidents_2014_01"], + ), + ( + """ + SELECT incidents.*, + sub.incidents AS incidents_that_day + FROM tutorial.sf_crime_incidents_2014_01 incidents + JOIN ( SELECT date, + COUNT(incidnt_num) AS incidents + FROM tutorial.sf_crime_incidents_2014_01 + GROUP BY 1 + ) sub + ON incidents.date = sub.date + ORDER BY sub.incidents DESC, time + """, + ["sf_crime_incidents_2014_01", "sf_crime_incidents_2014_01"], + ), + ], + ids=["join", "union", "case", "groupby", "subquery", "subquery_join"], +) +def test_extract(query, expected): + tables = extract_tables_from_query(query) + assert expected == tables + + +def test_invalid_query(): + query = "SELECT city frm Customers" + tables = extract_tables_from_query(query) + assert [] == tables diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index c811bbff6..7ac94b843 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1155,10 +1155,16 @@ def test_save_with_number_table( def test_save_with_non_existing_with(ip): - out = ip.run_cell( - "%sql --with non_existing_sub_query " "SELECT * FROM non_existing_sub_query" + with pytest.warns(FutureWarning) as record: + ip.run_cell( + "%sql --with non_existing_sub_query " "SELECT * FROM non_existing_sub_query" + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred, you can omit the " + "--with arguments. Using --with will raise an exception in the next " + "major release so please remove it." in record[0].message.args[0] ) - assert isinstance(out.error_in_exec, UsageError) def test_save_with_non_existing_table(ip, capsys): diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 7a530d274..1a56e9b4c 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -3,8 +3,47 @@ import pytest from IPython.core.error import UsageError from pathlib import Path + from sqlalchemy import create_engine from sql.connection import Connection +from sql.store import store + + +VALID_COMMANDS_MESSAGE = ( + "Valid commands are: tables, " "columns, test, profile, explore, snippets" +) + + +@pytest.fixture +def ip_snippets(ip): + for key in list(store): + del store[key] + ip.run_cell("%sql sqlite://") + ip.run_cell( + """ + %%sql --save high_price --no-execute +SELECT * +FROM "test_store" +WHERE price >= 1.50 +""" + ) + ip.run_cell( + """ + %%sql --save high_price_a --no-execute +SELECT * +FROM "high_price" +WHERE symbol == 'a' +""" + ) + ip.run_cell( + """ + %%sql --save high_price_b --no-execute +SELECT * +FROM "high_price" +WHERE symbol == 'b' +""" + ) + yield ip @pytest.mark.parametrize( @@ -13,32 +52,27 @@ [ "%sqlcmd", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd ", UsageError, - "Missing argument for %sqlcmd. " - "Valid commands are: tables, columns, test, profile, explore", + "Missing argument for %sqlcmd. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd stuff", UsageError, - "%sqlcmd has no command: 'stuff'. " - "Valid commands are: tables, columns, test, profile, explore", + "%sqlcmd has no command: 'stuff'. " f"{VALID_COMMANDS_MESSAGE}", ], [ "%sqlcmd columns", @@ -266,3 +300,79 @@ def test_test_error(ip, cell, error_type, error_message): assert isinstance(out.error_in_exec, error_type) assert str(out.error_in_exec) == error_message + + +def test_snippet(ip_snippets): + out = ip_snippets.run_cell("%sqlcmd snippets").result + assert "high_price, high_price_a, high_price_b" in out + + +@pytest.mark.parametrize("arg", ["--delete", "-d"]) +def test_delete_saved_key(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price_a").result + assert "high_price_a has been deleted.\n" in out + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price, high_price_b" in stored_snippets + assert "high_price_a" not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete-force", "-D"]) +def test_force_delete(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result + assert ( + "high_price has been deleted.\nhigh_price_a, " + "high_price_b depend on high_price\n" in out + ) + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price_a, high_price_b" in stored_snippets + assert "high_price," not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) +def test_force_delete_all(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result + assert "high_price_a, high_price_b, high_price has been deleted" in out + assert "There are no stored snippets" in out + + +@pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) +def test_force_delete_all_child_query(ip_snippets, arg): + ip_snippets.run_cell( + """ + %%sql --save high_price_b_child --no-execute +SELECT * +FROM "high_price_b" +WHERE symbol == 'b' +LIMIT 3 +""" + ) + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price_b").result + assert "high_price_b_child, high_price_b has been deleted" in out + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price, high_price_a" in stored_snippets + assert "high_price_b," not in stored_snippets + assert "high_price_b_child" not in stored_snippets + + +@pytest.mark.parametrize("arg", ["--delete", "-d"]) +def test_delete_snippet_error(ip_snippets, arg): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price") + assert isinstance(out.error_in_exec, UsageError) + assert ( + str(out.error_in_exec) == "The following tables are dependent on high_price: " + "high_price_a, high_price_b.\nPass --delete-force to only " + "delete high_price.\nPass --delete-force-all to delete " + "high_price_a, high_price_b and high_price" + ) + + +@pytest.mark.parametrize( + "arg", ["--delete", "-d", "--delete-force-all", "-A", "--delete-force", "-D"] +) +def test_delete_invalid_snippet(arg, ip_snippets): + out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} non_existent_snippet") + assert isinstance(out.error_in_exec, UsageError) + assert ( + str(out.error_in_exec) == "No such saved snippet found " + ": non_existent_snippet" + ) diff --git a/src/tests/test_magic_cte.py b/src/tests/test_magic_cte.py index 95e96c0fe..a76a1b957 100644 --- a/src/tests/test_magic_cte.py +++ b/src/tests/test_magic_cte.py @@ -1,3 +1,7 @@ +import pytest +from IPython.core.error import UsageError + + def test_trailing_semicolons_removed_from_cte(ip): ip.run_cell( """%%sql --save positive_x @@ -32,3 +36,168 @@ def test_trailing_semicolons_removed_from_cte(ip): "FROM number_table WHERE y > 0)\nSELECT * FROM positive_x\n" "UNION\nSELECT * FROM positive_y;" ) + + +def test_infer_dependencies(ip, capsys): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_sub;", + ) + out, _ = capsys.readouterr() + result = ip.run_cell("%sqlrender final").result + expected = ( + "WITH `author_sub` AS (\nSELECT last_name FROM author " + "WHERE year_of_death > 1900)\nSELECT last_name FROM author_sub;" + ) + + assert result == expected + assert "Generating CTE with stored snippets : author_sub" in out + + +def test_deprecation_warning(ip): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + with pytest.warns(FutureWarning) as record: + ip.run_cell_magic( + "sql", + "--with author_sub --save final", + "SELECT last_name FROM author_sub;", + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred," + " you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please " + "remove it." in record[0].message.args[0] + ) + + +TABLE_NAME_TYPO_ERR_MSG = """ +There is no table with name 'author_subb'. +Did you mean : 'author_sub' +If you need help solving this issue, send us a message: https://ploomber.io/community +""" + + +def test_table_name_typo(ip): + ip.run_cell_magic( + "sql", + "--save author_sub", + "SELECT last_name FROM author WHERE year_of_death > 1900", + ) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_subb;", + ) + + assert excinfo.value.error_type == "TableNotFoundError" + assert str(excinfo.value) == TABLE_NAME_TYPO_ERR_MSG.strip() + + +def test_snippets_delete(ip, capsys): + ip.run_cell( + """ + %%sql sqlite:// + CREATE TABLE orders (order_id int, customer_id int, order_value float); + INSERT INTO orders VALUES (123, 15, 150.67); + INSERT INTO orders VALUES (124, 25, 200.66); + INSERT INTO orders VALUES (211, 15, 251.43); + INSERT INTO orders VALUES (312, 5, 333.41); + CREATE TABLE another_orders (order_id int, customer_id int, order_value float); + INSERT INTO another_orders VALUES (511,15, 150.67); + INSERT INTO another_orders VALUES (512, 30, 200.66); + CREATE TABLE customers (customer_id int, name varchar(25)); + INSERT INTO customers VALUES (15, 'John'); + INSERT INTO customers VALUES (25, 'Sheryl'); + INSERT INTO customers VALUES (5, 'Mike'); + INSERT INTO customers VALUES (30, 'Daisy'); + """ + ) + ip.run_cell_magic( + "sql", + "--save orders_less", + "SELECT * FROM orders WHERE order_value < 250.0", + ) + + ip.run_cell_magic( + "sql", + "--save another_orders", + "SELECT * FROM orders WHERE order_value > 250.0", + ) + + ip.run_cell_magic( + "sql", + "--save final", + """ + SELECT o.order_id, customers.name, o.order_value + FROM another_orders o + INNER JOIN customers ON o.customer_id=customers.customer_id; + """, + ) + + out, _ = capsys.readouterr() + assert "Generating CTE with stored snippets : another_orders" in out + result_del = ip.run_cell( + "%sqlcmd snippets --delete-force-all another_orders" + ).result + assert "final, another_orders has been deleted.\n" in result_del + stored_snippets = result_del[ + result_del.find("Stored snippets") + len("Stored snippets: ") : + ] + assert "orders_less" in stored_snippets + ip.run_cell_magic( + "sql", + "--save final", + """ + SELECT o.order_id, customers.name, o.order_value + FROM another_orders o + INNER JOIN customers ON o.customer_id=customers.customer_id; + """, + ) + result = ip.run_cell("%sqlrender final").result + expected = ( + "WITH\n\n SELECT o.order_id, customers.name, " + "o.order_value\n " + "FROM another_orders o\n INNER JOIN customers " + "ON o.customer_id=customers.customer_id" + ) + assert expected in result + + +SYNTAX_ERROR_MESSAGE = """ +Syntax Error in WITH `author_sub` AS ( +SELECT last_name FRM author WHERE year_of_death > 1900) +SELECT last_name FROM author_sub: Expecting ( at Line 1, Column 16 +""" + + +def test_query_syntax_error(ip): + ip.run_cell_magic( + "sql", + "--save author_sub --no-execute", + "SELECT last_name FRM author WHERE year_of_death > 1900", + ) + + with pytest.raises(UsageError) as excinfo: + ip.run_cell_magic( + "sql", + "--save final", + "SELECT last_name FROM author_sub;", + ) + + assert excinfo.value.error_type == "UsageError" + assert SYNTAX_ERROR_MESSAGE.strip() in str(excinfo.value) diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index 5f78373c4..8e793cdca 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -11,6 +11,35 @@ plot_str = util.pretty_print(SUPPORTED_PLOTS, last_delimiter="or") +@pytest.fixture +def ip_snippets(ip, tmp_empty): + Path("data.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +""" + ) + ip.run_cell("%sql duckdb://") + + ip.run_cell( + """%%sql --save subset --no-execute +SELECT * +FROM data.csv +WHERE x > -1 +""" + ) + ip.run_cell( + """%%sql --save subset_another --no-execute +SELECT * +FROM subset +WHERE x > 2 +""" + ) + yield ip + + @pytest.mark.parametrize( "cell, error_type, error_message", [ @@ -69,7 +98,7 @@ def test_validate_arguments(tmp_empty, ip, cell, error_type, error_message): "%sqlplot boxplot --table data.csv --column x", "%sqlplot box --table data.csv --column x", "%sqlplot boxplot --table data.csv --column x --orient h", - "%sqlplot boxplot --table subset --column x --with subset", + "%sqlplot boxplot --table subset --column x", "%sqlplot boxplot -t subset -c x -w subset -o h", "%sqlplot boxplot --table nas.csv --column x", "%sqlplot bar -t data.csv -c x", @@ -318,3 +347,43 @@ def test_pie_one_col_num(load_data_one_col, ip): @image_comparison(baseline_images=["pie_two_col"], extensions=["png"], remove_text=True) def test_pie_two_col(load_data_two_col, ip): ip.run_cell("%sqlplot pie -t data_two.csv -c x y") + + +def test_sqlplot_deprecation_warning(ip_snippets, capsys): + with pytest.warns(FutureWarning) as record: + res = ip_snippets.run_cell( + "%sqlplot boxplot --table subset --column x --with subset" + ) + assert len(record) == 1 + assert ( + "CTE dependencies are now automatically inferred," + " you can omit the --with arguments. Using --with will " + "raise an exception in the next major release so please " + "remove it." in record[0].message.args[0] + ) + out, err = capsys.readouterr() + assert type(res.result).__name__ in {"Axes", "AxesSubplot"} + assert "Plotting using saved snippet : subset" in out + + +@pytest.mark.parametrize( + "arg", ["--delete", "-d", "--delete-force-all", "-A", "--delete-force", "-D"] +) +def test_sqlplot_snippet_deletion(ip_snippets, arg, capsys): + ip_snippets.run_cell(f"%sqlcmd snippets {arg} subset_another") + ip_snippets.run_cell("%sqlplot boxplot --table subset_another --column x") + out, err = capsys.readouterr() + assert "There is no table with name 'subset_another' in the default schema" in err + + +TABLE_NAME_TYPO_MSG = """ +UsageError: There is no table with name 'subst' in the default schema +Did you mean : 'subset' +If you need help solving this issue, send us a message: https://ploomber.io/community +""" + + +def test_sqlplot_snippet_typo(ip_snippets, capsys): + ip_snippets.run_cell("%sqlplot boxplot --table subst --column x") + out, err = capsys.readouterr() + assert TABLE_NAME_TYPO_MSG.strip() == err.strip() diff --git a/src/tests/test_util.py b/src/tests/test_util.py index f3f6c04ba..74a8c8cae 100644 --- a/src/tests/test_util.py +++ b/src/tests/test_util.py @@ -15,41 +15,41 @@ ) -@pytest.mark.parametrize( - "store_table, query", - [ - ("a", "%sqlcmd columns --table {}"), - ("bbb", "%sqlcmd profile --table {}"), - ("c_c", "%sqlplot histogram --table {} --column x"), - ("d_d_d", "%sqlplot boxplot --table {} --column x"), - ], -) -def test_missing_with(ip, store_table, query): +@pytest.fixture +def ip_snippets(ip): ip.run_cell( - f""" - %%sql --save {store_table} --no-execute - SELECT * - FROM number_table """ - ).result - - query = query.format(store_table) - out = ip.run_cell(query) - - expected_store_message = EXPECTED_STORE_SUGGESTIONS.format(store_table) - - error_message = str(out.error_in_exec) - assert isinstance(out.error_in_exec, UsageError) - assert str(expected_store_message).lower() in error_message.lower() +%%sql --save a --no-execute +SELECT * +FROM number_table +""" + ) + ip.run_cell( + """ + %%sql --save b --no-execute + SELECT * + FROM a + WHERE x > 5 + """ + ) + ip.run_cell( + """ + %%sql --save c --no-execute + SELECT * + FROM a + WHERE x < 5 + """ + ) + yield ip @pytest.mark.parametrize( "store_table, query", [ - ("a", "%sqlcmd columns --table {} --with {}"), - ("bbb", "%sqlcmd profile --table {} --with {}"), - ("c_c", "%sqlplot histogram --table {} --with {} --column x"), - ("d_d_d", "%sqlplot boxplot --table {} --with {} --column x"), + ("a", "%sqlcmd columns --table {}"), + ("bbb", "%sqlcmd profile --table {}"), + ("c_c", "%sqlplot histogram --table {} --column x"), + ("d_d_d", "%sqlplot boxplot --table {} --column x"), ], ) def test_no_errors_with_stored_query(ip, store_table, query): @@ -237,7 +237,7 @@ def test_is_table_exists_with(ip, table, expected_error, expected_suggestions): ) if expected_error: with pytest.raises(expected_error) as error: - util.is_table_exists(table, with_=with_) + util.is_table_exists(table) error_suggestions_arr = str(error.value).split(EXPECTED_SUGGESTIONS_MESSAGE) @@ -444,3 +444,28 @@ def test_parse_sql_results_to_json(ip, capsys, rows, columns, expected_json): j = json.loads(j) with capsys.disabled(): assert str(j) == str(expected_json) + + +def test_get_all_keys(ip_snippets): + keys = util.get_all_keys() + assert "a" in keys + assert "b" in keys + assert "c" in keys + + +def test_get_key_dependents(ip_snippets): + keys = util.get_key_dependents("a") + assert "b" in keys + assert "c" in keys + + +def test_del_saved_key(ip_snippets): + keys = util.del_saved_key("c") + assert "a" in keys + assert "b" in keys + + +def test_del_saved_key_error(ip_snippets): + with pytest.raises(UsageError) as excinfo: + util.del_saved_key("non_existent_key") + assert "No such saved snippet found : non_existent_key" in str(excinfo.value) From c685d3f795232db4d6912d10745498014757edb3 Mon Sep 17 00:00:00 2001 From: neelasha23 Date: Thu, 1 Jun 2023 11:43:39 +0530 Subject: [PATCH 2/3] added to bar, pie" --- src/sql/magic_plot.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index c365fb4f4..f9554099e 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -105,23 +105,21 @@ def execute(self, line="", cell="", local_ns=None): conn=None, ) elif cmd.args.line[0] in {"bar"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.bar( table=table, column=column, - with_=cmd.args.with_, + with_=with_, orient=cmd.args.orient, show_num=cmd.args.show_numbers, conn=None, ) elif cmd.args.line[0] in {"pie"}: - util.is_table_exists(table, with_=cmd.args.with_) - + with_ = self._check_table_exists(table) return plot.pie( table=table, column=column, - with_=cmd.args.with_, + with_=with_, show_num=cmd.args.show_numbers, conn=None, ) From f0120d33b037a5ebe9a8a97fabb215a9c0698c60 Mon Sep 17 00:00:00 2001 From: neelasha23 Date: Thu, 1 Jun 2023 12:04:43 +0530 Subject: [PATCH 3/3] Empty commit