From 5d981ea10c7c9170f4f0515ef9cfb0dce323a305 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Wed, 24 Jan 2024 18:54:50 -0600 Subject: [PATCH] parent aaf40215d94da3a8152a3f907f2cae5b50458c15 author Eduardo Blancas 1706144090 -0600 committer neelasha23 1706161587 +0530 parent aaf40215d94da3a8152a3f907f2cae5b50458c15 author Eduardo Blancas 1706144090 -0600 committer neelasha23 1706161586 +0530 parent aaf40215d94da3a8152a3f907f2cae5b50458c15 author Eduardo Blancas 1706144090 -0600 committer neelasha23 1706161584 +0530 parent aaf40215d94da3a8152a3f907f2cae5b50458c15 author Eduardo Blancas 1706144090 -0600 committer neelasha23 1706161581 +0530 Update README.md arg expansion guide docstring modified doc fix minor fix supported args table orient Parse line new util func rendering string fix test Removed table with doc fix magic_sql removed comment removed comment docstring Lint revert removals fix add no_var explorer guide profile testing docs tables doc tables doc tests schema snippets docs snippets --- CHANGELOG.md | 1 + doc/_toc.yml | 1 + doc/api/magic-plot.md | 49 +- doc/api/magic-profile.md | 26 +- doc/api/magic-snippets.md | 25 + doc/api/magic-sql.md | 20 +- doc/api/magic-tables-columns.md | 27 + doc/howto/testing-columns.md | 13 + doc/user-guide/argument-expansion.md | 108 ++ doc/user-guide/data-profiling.md | 20 + doc/user-guide/table_explorer.ipynb | 1621 +++++++++++++++++++++++++- doc/user-guide/tables-columns.md | 12 + src/sql/cmd/columns.py | 13 +- src/sql/cmd/explore.py | 12 +- src/sql/cmd/profile.py | 9 +- src/sql/cmd/snippets.py | 12 +- src/sql/cmd/tables.py | 9 +- src/sql/cmd/test.py | 10 +- src/sql/command.py | 3 + src/sql/magic.py | 6 +- src/sql/magic_cmd.py | 7 +- src/sql/magic_plot.py | 6 +- src/sql/util.py | 58 + src/tests/test_magic.py | 89 ++ src/tests/test_magic_cmd.py | 278 +++++ src/tests/test_magic_plot.py | 237 ++++ 26 files changed, 2648 insertions(+), 24 deletions(-) create mode 100644 doc/user-guide/argument-expansion.md diff --git a/CHANGELOG.md b/CHANGELOG.md index f2b91b554..0ad62bd6a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## 0.10.8dev * [Fix] Fix edge case where `select` and other SQL keywords were not properly used to find where the user's query started, causing argument parsing issues (#973) +* [Feature] Add support for parametrizing string type arguments of `%%sql`, `%sqlplot`, `%sqlcmd`' (#699) ## 0.10.7 (2023-12-23) diff --git a/doc/_toc.yml b/doc/_toc.yml index f667e807b..aae6c7aaf 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -13,6 +13,7 @@ parts: - file: user-guide/tables-columns - file: user-guide/ggplot - file: user-guide/template + - file: user-guide/argument-expansion - file: user-guide/connection-file - file: user-guide/table_explorer - file: user-guide/data-profiling diff --git a/doc/api/magic-plot.md b/doc/api/magic-plot.md index 33f627c7c..bc37fff64 100644 --- a/doc/api/magic-plot.md +++ b/doc/api/magic-plot.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.1 + jupytext_version: 1.16.0 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -293,3 +293,50 @@ You can also show the percentage on top of the pie using the `S`/`show-numbers` ```{code-cell} ipython3 %sqlplot pie --table penguins.csv --column species -S ``` + +## Parametrizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +```{code-cell} ipython3 +%%sql +DROP TABLE IF EXISTS penguins; +CREATE SCHEMA IF NOT EXISTS s1; +CREATE TABLE s1.penguins ( + species VARCHAR(255), + island VARCHAR(255), + bill_length_mm DECIMAL(5, 2), + bill_depth_mm DECIMAL(5, 2), + flipper_length_mm DECIMAL(5, 2), + body_mass_g INTEGER, + sex VARCHAR(255) +); +COPY s1.penguins FROM 'penguins.csv' WITH (FORMAT CSV, HEADER TRUE); +``` + +```{code-cell} ipython3 +table = "penguins" +schema = "s1" +orient = "h" +column = "bill_length_mm" +``` + +```{code-cell} ipython3 +%sqlplot boxplot --table {{table}} --schema {{schema}} --column {{column}} --orient {{orient}} +``` + +Now let's see another example using `--with`: + +```{code-cell} ipython3 +snippet = "gentoo" +``` + +```{code-cell} ipython3 +%%sql --save {{snippet}} +SELECT * FROM {{schema}}.{{table}} +WHERE species == 'Gentoo' +``` + +```{code-cell} ipython3 +%sqlplot boxplot --table {{snippet}} --with {{snippet}} --column {{column}} +``` diff --git a/doc/api/magic-profile.md b/doc/api/magic-profile.md index efb074deb..054617cc5 100644 --- a/doc/api/magic-profile.md +++ b/doc/api/magic-profile.md @@ -201,4 +201,28 @@ Let’s profile `my_numbers` of `b_schema` ```{code-cell} ipython3 %sqlcmd profile --table my_numbers --schema b_schema -``` \ No newline at end of file +``` + +# Parametrizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +Let's look at an example that uses variable expansion for `table`, `schema` and `output` arguments: + +```{code-cell} ipython3 +table = "my_numbers" +schema = "b_schema" +output = "numbers-report.html" +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%sqlcmd profile --table {{table}} --schema {{schema}} --output {{output}} +``` + +```{code-cell} ipython3 +from IPython.display import HTML + +HTML(output) +``` diff --git a/doc/api/magic-snippets.md b/doc/api/magic-snippets.md index b7926b0d2..1b5be7f4e 100644 --- a/doc/api/magic-snippets.md +++ b/doc/api/magic-snippets.md @@ -150,3 +150,28 @@ Now, force delete `chinstrap` and its dependent `chinstrap_sub`: ```{code-cell} ipython3 %sqlcmd snippets -A chinstrap ``` + + +## Parameterizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +Let's see some examples: + +```{code-cell} ipython3 +snippet_name = "gentoo" +``` + +```{code-cell} ipython3 +%%sql --save {{snippet_name}} +SELECT * FROM penguins.csv where species == 'Gentoo' +``` + +```{code-cell} ipython3 +gentoo_snippet = %sqlcmd snippets {{snippet_name}} +print(gentoo_snippet) +``` + +```{code-cell} ipython3 +%sqlcmd snippets -d {{snippet_name}} +``` \ No newline at end of file diff --git a/doc/api/magic-sql.md b/doc/api/magic-sql.md index a6c059305..c5f4d3b75 100644 --- a/doc/api/magic-sql.md +++ b/doc/api/magic-sql.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.15.1 + jupytext_version: 1.16.0 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -359,3 +359,21 @@ LIMIT 3 ```{code-cell} ipython3 %sql --file my-query.sql ``` + +## Parameterizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +Let's see an example of creating a connection using an alias and closing the same through variable substitution. + +```{code-cell} ipython3 +alias = "db-four" +``` + +```{code-cell} ipython3 +%sql sqlite:///db_four.db --alias {{alias}} +``` + +```{code-cell} ipython3 +%sql --close {{alias}} +``` \ No newline at end of file diff --git a/doc/api/magic-tables-columns.md b/doc/api/magic-tables-columns.md index f3d66265d..cef3d1798 100644 --- a/doc/api/magic-tables-columns.md +++ b/doc/api/magic-tables-columns.md @@ -101,6 +101,20 @@ CREATE TABLE s2.t2(id INTEGER PRIMARY KEY, j VARCHAR); %sqlcmd tables -s s1 ``` +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +Let's see an example: + +```{code-cell} ipython3 +schema = "s1" +``` + +```{code-cell} ipython3 +:tags: [hide-output] + +%sqlcmd tables -s {{schema}} +``` + As expected, the argument returns the table names under schema s1, which is t1. +++ @@ -123,3 +137,16 @@ Arguments: %sqlcmd columns -s s1 -t t1 ``` + +JupySQL also supports variable expansion of arguments of `columns`. Let's see an example: + +```{code-cell} ipython3 + +table = "t1" +schema = "s1" +``` + +```{code-cell} ipython3 + +%sqlcmd columns -s {{schema}} -t {{table}} +``` diff --git a/doc/howto/testing-columns.md b/doc/howto/testing-columns.md index 07d48f4cc..acda93ae8 100644 --- a/doc/howto/testing-columns.md +++ b/doc/howto/testing-columns.md @@ -80,3 +80,16 @@ The test fails, returning both Shakespeare and Brecht. Currently, 5 different comparator arguments are supported: `greater`, `greater-or-equal`, `less-than`, `less-than-or-equal`, and `no-nulls`. +## Parametrizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. Let's see an example of running tests using parametrization: + +```{code-cell} ipython3 +table = "writer" +column = "year_of_death" +limit = "2000" +``` + +```{code-cell} ipython3 +%sqlcmd test --table {{table}} --column {{column}} --less-than {{limit}} +``` \ No newline at end of file diff --git a/doc/user-guide/argument-expansion.md b/doc/user-guide/argument-expansion.md new file mode 100644 index 000000000..4bf70b453 --- /dev/null +++ b/doc/user-guide/argument-expansion.md @@ -0,0 +1,108 @@ +--- +jupytext: + notebook_metadata_filter: myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.14.7 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +myst: + html_meta: + description lang=en: Variable substitution of arguments in Jupyter via JupySQL + keywords: jupyter, sql, jupysql, jinja + property=og:locale: en_US +--- + +# Parameterizing arguments + +```{versionadded} 0.10.8 +JupySQL uses Jinja templates for enabling parametrization of arguments. Arguments are parametrized with `{{variable}}`. +``` + + +## Parametrization via `{{variable}}` + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically. + +The benefits of using parametrized arguments is that they can be reused for different purposes. + +Let's load some data and connect to the in-memory DuckDB instance: + +```{code-cell} ipython3 +%load_ext sql +%sql duckdb:// +%config SqlMagic.displaylimit = 3 +``` + +```{code-cell} ipython3 +filename = "penguins.csv" +``` + + +```{code-cell} ipython3 +from pathlib import Path +from urllib.request import urlretrieve + +if not Path("penguins.csv").is_file(): + urlretrieve( + "https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv", + filename, + ) +``` + +Now let's create a snippet from the data by declaring a `table` variable and use it in the `--save` argument. + ++++ + +### Create a snippet + +```{code-cell} ipython3 +table = "penguins_data" +``` + +```{code-cell} ipython3 +%%sql --save {{table}} +SELECT * +FROM penguins.csv +``` + +```{code-cell} ipython3 +snippet = %sqlcmd snippets {{table}} +print(snippet) +``` + + +### Plot a histogram + +Now, let's declare a variable `column` and plot a histogram on the data. + +```{code-cell} ipython3 +column = "body_mass_g" +``` + +```{code-cell} ipython3 +%sqlplot boxplot --table {{table}} --column {{column}} +``` + +### Profile and Explore + +We can use the `filename` variable to profile and explore the data as well: + +```{code-cell} ipython3 +%sqlcmd profile --table {{filename}} +``` + +```{code-cell} ipython3 +%sqlcmd explore --table {{filename}} +``` + +### Run some tests + +```{code-cell} ipython3 +%sqlcmd test --table {{table}} --column {{column}} --greater 3500 +``` + diff --git a/doc/user-guide/data-profiling.md b/doc/user-guide/data-profiling.md index 83ebd4d4c..8e9bd282f 100644 --- a/doc/user-guide/data-profiling.md +++ b/doc/user-guide/data-profiling.md @@ -119,3 +119,23 @@ Let's profile `my_numbers` of `b_schema` ```{code-cell} ipython3 %sqlcmd profile --table trips --schema some_schema ``` + +### Parametrizing arguments + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. Let's see an example using `table`, `schema` and `output`. + +```{code-cell} ipython3 +table = "trips" +schema = "some_schema" +output = "my-report.html" +``` + +```{code-cell} ipython3 +%sqlcmd profile --table {{table}} --schema {{schema}} --output {{output}} +``` + +```{code-cell} ipython3 +from IPython.display import HTML + +HTML(output) +``` diff --git a/doc/user-guide/table_explorer.ipynb b/doc/user-guide/table_explorer.ipynb index 0b40eb73f..26840e320 100644 --- a/doc/user-guide/table_explorer.ipynb +++ b/doc/user-guide/table_explorer.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "67e9f89e", "metadata": {}, "outputs": [], @@ -50,10 +50,18 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "2708d4a7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], "source": [ "%pip install jupysql --upgrade --quiet" ] @@ -73,10 +81,119 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "dbe40317", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "Tip: You may define configurations in /Users/neelashasen/Dev/jupysql_master/jupysql/pyproject.toml or /Users/neelashasen/.jupysql/config. " + ], + "text/plain": [ + "Tip: You may define configurations in /Users/neelashasen/Dev/jupysql_master/jupysql/pyproject.toml or /Users/neelashasen/.jupysql/config. " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Please review our configuration guideline." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Loading configurations from /Users/neelashasen/.jupysql/config." + ], + "text/plain": [ + "Loading configurations from /Users/neelashasen/.jupysql/config." + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Settings changed:" + ], + "text/plain": [ + "Settings changed:" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
Configvalue
feedbackTrue
autopandasTrue
" + ], + "text/plain": [ + "\n", + "+------------+-------+\n", + "| Config | value |\n", + "+------------+-------+\n", + "| feedback | True |\n", + "| autopandas | True |\n", + "+------------+-------+" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Connecting to 'default'" + ], + "text/plain": [ + "Connecting to 'default'" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Connecting and switching to connection 'duckdb://'" + ], + "text/plain": [ + "Connecting and switching to connection 'duckdb://'" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "%load_ext sql\n", "%sql duckdb://" @@ -95,13 +212,1491 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "7e6c6c7d", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "%sqlcmd explore --table \"yellow_tripdata_2021.parquet\"" ] + }, + { + "cell_type": "markdown", + "id": "0c008e2e-3a38-47ef-9073-3b0379a5b13e", + "metadata": {}, + "source": [ + "## Parametrizing arguments\n", + "\n", + "JupySQL supports variable expansion of arguments in the form of `{{variable}}`. This allows the user to specify arguments with placeholders that can be replaced by variables dynamically." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8c603c30-261f-4beb-863b-493d9d441625", + "metadata": {}, + "outputs": [], + "source": [ + "table_name = \"yellow_tripdata_2021.parquet\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1d2768ba-d2ad-4e09-842e-21a06609e94d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + "
\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%sqlcmd explore --table {{table_name}}" + ] } ], "metadata": { @@ -113,6 +1708,18 @@ "language": "python", "name": "python3" }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, "myst": { "html_meta": { "description lang=en": "Templatize SQL queries in Jupyter via JupySQL", diff --git a/doc/user-guide/tables-columns.md b/doc/user-guide/tables-columns.md index 42e60b060..83f001bd5 100644 --- a/doc/user-guide/tables-columns.md +++ b/doc/user-guide/tables-columns.md @@ -68,6 +68,7 @@ Pass `--schema/-s` to get tables in a different schema: +++ + ## List columns Use `%sqlcmd columns --table/-t` to get the columns for the given table. @@ -104,3 +105,14 @@ Get the columns for the table in the newly created schema: ```{code-cell} ipython3 %sqlcmd columns --table numbers --schema some_schema ``` + +JupySQL supports variable expansion of arguments in the form of `{{variable}}`. Let's see an example of parametrizing `table` and `schema`: + +```{code-cell} ipython3 +table = "numbers" +schema = "some_schema" +``` + +```{code-cell} ipython3 +%sqlcmd columns --table {{table}} --schema {{schema}} +``` \ No newline at end of file diff --git a/src/sql/cmd/columns.py b/src/sql/cmd/columns.py index 7d1ac3e59..d15539ab2 100644 --- a/src/sql/cmd/columns.py +++ b/src/sql/cmd/columns.py @@ -1,20 +1,25 @@ from sql import inspect from sql.util import sanitize_identifier from sql.cmd.cmd_utils import CmdParser +from sql.util import expand_args, is_rendering_required -def columns(others): +def columns(others, user_ns): """ Implementation of `%sqlcmd columns` This function takes in a string containing command line arguments, parses them to extract the name of the table and the schema, and returns - a list of columns for the specified table. + a list of columns for the specified table. It also uses the kernel + namespace for expanding arguments declared as variables. Parameters ---------- others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel + Returns ------- columns: list @@ -26,4 +31,8 @@ def columns(others): parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) args = parser.parse_args(others) + + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) + return inspect.get_columns(name=sanitize_identifier(args.table), schema=args.schema) diff --git a/src/sql/cmd/explore.py b/src/sql/cmd/explore.py index 3a089d1df..2d144d517 100644 --- a/src/sql/cmd/explore.py +++ b/src/sql/cmd/explore.py @@ -1,23 +1,31 @@ from sql.widgets import TableWidget from sql.cmd.cmd_utils import CmdParser +from sql.util import expand_args, is_rendering_required -def explore(others): +def explore(others, user_ns): """ Implementation of `%sqlcmd explore` This function takes in a string containing command line arguments, parses them to extract the name of the table, and displays an interactive - widget for exploring the contents of the specified table. + widget for exploring the contents of the specified table. It also uses the + kernel namespace for expanding arguments declared as variables. Parameters ---------- others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel + """ parser = CmdParser() parser.add_argument("-t", "--table", type=str, help="Table name", required=True) parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) args = parser.parse_args(others) + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) + table_widget = TableWidget(args.table, args.schema) return table_widget diff --git a/src/sql/cmd/profile.py b/src/sql/cmd/profile.py index a369145dd..600de1072 100644 --- a/src/sql/cmd/profile.py +++ b/src/sql/cmd/profile.py @@ -1,14 +1,16 @@ from sql import inspect from sql.cmd.cmd_utils import CmdParser +from sql.util import expand_args, is_rendering_required -def profile(others): +def profile(others, user_ns): """ Implementation of `%sqlcmd profile` This function takes in a string containing command line arguments, parses them to extract the name of the table, the schema, and the output location. It then retrieves statistical information about the specified table and either returns the report or writes it to the specified location. + It also uses the kernel namespace for expanding arguments declared as variables. Parameters @@ -16,6 +18,9 @@ def profile(others): others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel + Returns ------- report: PrettyTable @@ -31,6 +36,8 @@ def profile(others): ) args = parser.parse_args(others) + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) report = inspect.get_table_statistics(schema=args.schema, name=args.table) diff --git a/src/sql/cmd/snippets.py b/src/sql/cmd/snippets.py index 88d5bf8d0..496aa6df9 100644 --- a/src/sql/cmd/snippets.py +++ b/src/sql/cmd/snippets.py @@ -3,6 +3,7 @@ from sql.exceptions import UsageError from sql.cmd.cmd_utils import CmdParser from sql.display import Table, Message +from sql.util import expand_args, is_rendering_required, render_string_using_namespace def _modify_display_msg(key, remaining_keys, dependent_keys=None): @@ -32,12 +33,13 @@ def _modify_display_msg(key, remaining_keys, dependent_keys=None): return msg -def snippets(others): +def snippets(others, user_ns): """ Implementation of `%sqlcmd snippets` This function handles all the arguments related to %sqlcmd snippets, namely listing stored snippets, and delete/ force delete/ force delete a snippet and - all its dependent snippets. + all its dependent snippets. It also uses the kernel namespace for expanding + arguments declared as variables. Parameters @@ -45,6 +47,8 @@ def snippets(others): others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel """ parser = CmdParser() parser.add_argument( @@ -66,6 +70,7 @@ def snippets(others): ) all_snippets = store.get_all_keys() if len(others) == 1: + others[0] = render_string_using_namespace(others[0], user_ns) if others[0] in all_snippets: return str(store.store[others[0]]) @@ -79,6 +84,9 @@ def snippets(others): raise UsageError(err_msg) args = parser.parse_args(others) + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) + SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all] if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS): if len(all_snippets) == 0: diff --git a/src/sql/cmd/tables.py b/src/sql/cmd/tables.py index 7956b02a5..d0b940a40 100644 --- a/src/sql/cmd/tables.py +++ b/src/sql/cmd/tables.py @@ -1,20 +1,25 @@ from sql import inspect from sql.cmd.cmd_utils import CmdParser +from sql.util import expand_args, is_rendering_required -def tables(others): +def tables(others, user_ns): """ Implementation of `%sqlcmd tables` This function takes in a string containing command line arguments, parses them to extract the schema name, and returns a list of table names present in the specified schema or in the default schema if none is specified. + It also uses the kernel namespace for expanding arguments declared as variables. Parameters ---------- others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel + Returns ------- table_names: list @@ -26,5 +31,7 @@ def tables(others): parser.add_argument("-s", "--schema", type=str, help="Schema name", required=False) args = parser.parse_args(others) + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) return inspect.get_table_names(schema=args.schema) diff --git a/src/sql/cmd/test.py b/src/sql/cmd/test.py index 0390ed952..bde8495b3 100644 --- a/src/sql/cmd/test.py +++ b/src/sql/cmd/test.py @@ -3,6 +3,7 @@ from sqlglot import select, condition from prettytable import PrettyTable from sql.cmd.cmd_utils import CmdParser +from sql.util import expand_args, is_rendering_required def return_test_results(args, conn, query): @@ -77,19 +78,24 @@ def run_each_individually(args, conn): return storage -def test(others): +def test(others, user_ns): """ Implementation of `%sqlcmd test` This function takes in a string containing command line arguments, parses them to extract the table name, column name, and conditions to return if those conditions are satisfied in that table + It also uses the kernel namespace for expanding arguments declared as + variables. Parameters ---------- others : str, A string containing the command line arguments. + user_ns : dict, + User namespace of IPython kernel + Returns ------- result: bool @@ -141,6 +147,8 @@ def test(others): ) args = parser.parse_args(others) + if is_rendering_required(" ".join(others)): + expand_args(args, user_ns) COMPARATOR_ARGS = [ args.greater, diff --git a/src/sql/command.py b/src/sql/command.py index 675420d74..25fb554d7 100644 --- a/src/sql/command.py +++ b/src/sql/command.py @@ -90,6 +90,9 @@ def __init__(self, magic, user_ns, line, cell) -> None: self.parsed["connection"] = self.args.line[0] if self.args.with_: + self.args.with_ = [ + Template(item).render(user_ns) for item in self.args.with_ + ] final = store.render(self.parsed["sql"], with_=self.args.with_) self.parsed["sql"] = str(final) diff --git a/src/sql/magic.py b/src/sql/magic.py index 17a2a2a49..464bf298e 100644 --- a/src/sql/magic.py +++ b/src/sql/magic.py @@ -40,7 +40,6 @@ from sql.magic_cmd import SqlCmdMagic from sql._patch import patch_ipython_usage_error from sql import util -from sql.util import pretty_print from sql.error_handler import handle_exception from sql._current import _set_sql_magic @@ -409,6 +408,9 @@ def interactive_execute_wrapper(**kwargs): args = command.args + if util.is_rendering_required(line): + util.expand_args(args, user_ns) + if args.section and args.alias: raise exceptions.UsageError( "Cannot use --section with --alias since the section name " @@ -435,7 +437,7 @@ def interactive_execute_wrapper(**kwargs): command.set_sql_with(with_) display.message( f"Generating CTE with stored snippets: \ -{pretty_print(with_)}" +{util.pretty_print(with_)}" ) else: with_ = None diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index 970067e50..da6d370e2 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -2,7 +2,7 @@ import argparse import shlex -from IPython.core.magic import Magics, line_magic, magics_class +from IPython.core.magic import Magics, line_magic, magics_class, no_var_expand from IPython.core.magic_arguments import argument, magic_arguments from sql.inspect import support_only_sql_alchemy_connection from sql.cmd.tables import tables @@ -35,6 +35,7 @@ def error(self, message): class SqlCmdMagic(Magics, Configurable): """%sqlcmd magic""" + @no_var_expand @line_magic("sqlcmd") @magic_arguments() @argument("line", type=str, help="Command name") @@ -128,5 +129,7 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): } cmd = router.get(cmd_name) - if cmd: + if cmd_name == "connect": return cmd(others) + else: + return cmd(others, self.shell.user_ns.copy()) diff --git a/src/sql/magic_plot.py b/src/sql/magic_plot.py index c4dbd534b..92c1aecad 100644 --- a/src/sql/magic_plot.py +++ b/src/sql/magic_plot.py @@ -1,4 +1,4 @@ -from IPython.core.magic import Magics, line_magic, magics_class +from IPython.core.magic import Magics, line_magic, magics_class, no_var_expand from IPython.core.magic_arguments import argument, magic_arguments from ploomber_core.exceptions import modify_exceptions @@ -21,6 +21,7 @@ class SqlPlotMagic(Magics, Configurable): """%sqlplot magic""" + @no_var_expand @line_magic("sqlplot") @magic_arguments() @argument( @@ -83,6 +84,9 @@ def execute(self, line="", cell="", local_ns=None): cmd = SQLPlotCommand(self, line) + if util.is_rendering_required(line): + util.expand_args(cmd.args, self.shell.user_ns.copy()) + if len(cmd.args.column) == 1: column = cmd.args.column[0] else: diff --git a/src/sql/util.py b/src/sql/util.py index 4f7fd6d4f..b58419ec6 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -11,6 +11,8 @@ from os.path import isfile import re +from jinja2 import Template + try: import toml @@ -574,3 +576,59 @@ def enclose_table_with_double_quotations(table, conn): _table = _table.replace('"', "`") return _table + + +def is_rendering_required(line): + """Function to check possibility of line + text containing expandable arguments""" + + return "{{" in line and "}}" in line + + +def render_string_using_namespace(value, user_ns): + """ + Function to substitute command line arguments + with variables defined by user in the IPython + kernel. + + Parameters + ---------- + value : str, + text to be rendered + + user_ns : dict, + User namespace of IPython kernel + """ + + if isinstance(value, str) and value.startswith("{{") and value.endswith("}}"): + return Template(value).render(user_ns) + return value + + +def expand_args(args, user_ns): + """ + Function to substitute command line arguments + with variables defined by user in the IPython + kernel. + + Parameters + ---------- + args : argparse.Namespace, + object to hold the command line arguments. + + user_ns : dict, + User namespace of IPython kernel + """ + + for attribute in vars(args): + value = getattr(args, attribute) + if value: + if isinstance(value, list): + substituted_value = [] + for item in value: + rendered_value = render_string_using_namespace(item, user_ns) + substituted_value.append(rendered_value) + setattr(args, attribute, substituted_value) + else: + rendered_value = render_string_using_namespace(value, user_ns) + setattr(args, attribute, rendered_value) diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 175b450a7..e241b9767 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -2635,3 +2635,92 @@ def test_table_does_not_exist_with_snippet_error( def test_negative_operations_query(ip, query, expected): result = ip.run_cell(query).result assert list(result.dict().values())[-1] == expected + + +def test_bracket_var_substitution_save(ip): + ip.user_global_ns["col"] = "first_name" + ip.user_global_ns["snippet"] = "mysnippet" + ip.run_cell( + "%sql --save {{snippet}} SELECT * FROM author WHERE {{col}} = 'William' " + ) + out = ip.run_cell("%sql SELECT * FROM {{snippet}}").result + assert out[0] == ( + "William", + "Shakespeare", + 1616, + ) + + +def test_var_substitution_save_with(ip): + ip.user_global_ns["col"] = "first_name" + ip.user_global_ns["snippet_one"] = "william" + ip.user_global_ns["snippet_two"] = "bertold" + ip.run_cell( + "%sql --save {{snippet_one}} SELECT * FROM author WHERE {{col}} = 'William' " + ) + ip.run_cell( + "%sql --save {{snippet_two}} SELECT * FROM author WHERE {{col}} = 'Bertold' " + ) + out = ip.run_cell( + """%%sql --with {{snippet_one}} --with {{snippet_two}} +SELECT * FROM {{snippet_one}} +UNION +SELECT * FROM {{snippet_two}} +""" + ).result + + assert out[1] == ( + "William", + "Shakespeare", + 1616, + ) + assert out[0] == ( + "Bertold", + "Brecht", + 1956, + ) + + +def test_var_substitution_alias(clean_conns, ip_empty, tmp_empty): + ip_empty.user_global_ns["alias"] = "one" + ip_empty.run_cell("%sql sqlite:///one.db --alias {{alias}}") + assert {"one"} == set(ConnectionManager.connections) + + +@pytest.mark.parametrize( + "close_cell", + [ + "%sql -x {{alias}}", + "%sql --close {{alias}}", + ], +) +def test_var_substitution_close_connection_with_alias(ip, tmp_empty, close_cell): + ip.user_global_ns["alias"] = "one" + process = psutil.Process() + + ip.run_cell("%sql sqlite:///one.db --alias {{alias}}") + + assert {Path(f.path).name for f in process.open_files()} >= {"one.db"} + + ip.run_cell(close_cell) + + assert "sqlite:///one.db" not in ConnectionManager.connections + assert "first" not in ConnectionManager.connections + assert "one.db" not in {Path(f.path).name for f in process.open_files()} + + +def test_var_substitution_section(ip_empty, tmp_empty): + Path("connections.ini").write_text( + """ +[duck] +drivername = duckdb +""" + ) + ip_empty.user_global_ns["section"] = "duck" + + ip_empty.run_cell("%config SqlMagic.dsn_filename = 'connections.ini'") + + ip_empty.run_cell("%sql --section {{section}}") + + conns = ConnectionManager.connections + assert conns == {"duck": ANY} diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index ed4a244ee..c2550b33b 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -95,6 +95,24 @@ def sample_schema_with_table(ip_empty): ) +@pytest.mark.parametrize( + "cmd, cols, table_name", + [ + [ + "%sqlcmd columns -t {{table}}", + ["first_name", "last_name", "year_of_death"], + "author", + ], + ["%sqlcmd columns -t {{table}}", ["first", "second"], "table with spaces"], + ["%sqlcmd columns -t {{table}}", ["first", "second"], "table with spaces"], + ], +) +def test_columns_with_variable_substitution(ip, cmd, cols, table_name): + ip.user_global_ns["table"] = table_name + out = ip.run_cell(cmd).result._repr_html_() + assert all(col in out for col in cols) + + @pytest.mark.parametrize( "cell, error_message", [ @@ -205,6 +223,23 @@ def test_tables_with_schema(ip, tmp_empty): assert "numbers" in out +def test_tables_with_schema_variable_substitution(ip, tmp_empty): + conn = SQLAlchemyConnection(engine=create_engine("sqlite:///my.db")) + conn.execute("CREATE TABLE numbers (some_number FLOAT)") + + ip.user_global_ns["schema"] = "some_schema" + + ip.run_cell( + """%%sql +ATTACH DATABASE 'my.db' AS {{schema}} +""" + ) + + out = ip.run_cell("%sqlcmd tables --schema {{schema}}").result._repr_html_() + + assert "numbers" in out + + @pytest.mark.parametrize( "cmd, cols", [ @@ -242,6 +277,28 @@ def test_columns_with_schema(ip, tmp_empty, arguments): assert "some_number" in out +@pytest.mark.parametrize( + "arguments", + ["--table {{table}} --schema {{schema}}", "--table {{schema}}.{{table}}"], +) +def test_columns_with_schema_variable_substitution(ip, tmp_empty, arguments): + conn = SQLAlchemyConnection(engine=create_engine("sqlite:///my.db")) + conn.execute("CREATE TABLE numbers (some_number FLOAT)") + + ip.user_global_ns["table"] = "numbers" + ip.user_global_ns["schema"] = "some_schema" + + ip.run_cell( + """%%sql +ATTACH DATABASE 'my.db' AS {{schema}} +""" + ) + + out = ip.run_cell(f"%sqlcmd columns {arguments}").result._repr_html_() + + assert "some_number" in out + + @pytest.mark.parametrize( "conn", [ @@ -299,6 +356,65 @@ def test_table_profile(ip_with_connections, tmp_empty, conn): assert "position: sticky;" in out._table_html +@pytest.mark.parametrize( + "conn", + [ + ("sqlite_sqlalchemy"), + ("sqlite_dbapi"), + ], +) +def test_table_profile_with_substitution(ip_with_connections, tmp_empty, conn): + ip_with_connections.run_cell(f"%sql {conn}") + ip_with_connections.run_cell( + """ + %%sql + CREATE TABLE numbers (rating float, price float, number int, word varchar(50)); + INSERT INTO numbers VALUES (14.44, 2.48, 82, 'a'); + INSERT INTO numbers VALUES (13.13, 1.50, 93, 'b'); + INSERT INTO numbers VALUES (12.59, 0.20, 98, 'a'); + INSERT INTO numbers VALUES (11.54, 0.41, 89, 'a'); + INSERT INTO numbers VALUES (10.532, 0.1, 88, 'c'); + INSERT INTO numbers VALUES (11.5, 0.2, 84, ' '); + INSERT INTO numbers VALUES (11.1, 0.3, 90, 'a'); + INSERT INTO numbers VALUES (12.9, 0.31, 86, ''); + """ + ) + + expected = { + "count": [8, 8, 8, 8], + "mean": ["12.2165", "0.6875", "88.7500", math.nan], + "min": [10.532, 0.1, 82, math.nan], + "max": [14.44, 2.48, 98, math.nan], + "unique": [8, 7, 8, 5], + "freq": [math.nan, math.nan, math.nan, 4], + "top": [math.nan, math.nan, math.nan, "a"], + } + + ip_with_connections.user_global_ns["table"] = "numbers" + + out = ip_with_connections.run_cell("%sqlcmd profile -t {{table}}").result + + stats_table = out._table + + assert len(stats_table.rows) == len(expected) + + for row in stats_table: + profile_metric = _get_row_string(row, " ") + rating = _get_row_string(row, "rating") + price = _get_row_string(row, "price") + number = _get_row_string(row, "number") + word = _get_row_string(row, "word") + + assert profile_metric in expected + assert rating == str(expected[profile_metric][0]) + assert price == str(expected[profile_metric][1]) + assert number == str(expected[profile_metric][2]) + assert word == str(expected[profile_metric][3]) + + # Test sticky column style was injected + assert "position: sticky;" in out._table_html + + @pytest.mark.parametrize( "conn", [ @@ -410,6 +526,59 @@ def test_table_schema_profile(ip, tmp_empty, arguments): assert cell == str(expected[profile_metric][0]) +@pytest.mark.parametrize( + "arguments", + ["--table {{table}} --schema {{schema}}", "--table {{schema}}.{{table}}"], +) +def test_table_schema_profile_with_substitution(ip, tmp_empty, arguments): + ip.run_cell("%sql sqlite:///a.db") + ip.run_cell("%sql CREATE TABLE t (n FLOAT)") + ip.run_cell("%sql INSERT INTO t VALUES (1)") + ip.run_cell("%sql INSERT INTO t VALUES (2)") + ip.run_cell("%sql INSERT INTO t VALUES (3)") + ip.run_cell("%sql --close sqlite:///a.db") + + ip.run_cell("%sql sqlite:///b.db") + ip.run_cell("%sql CREATE TABLE t (n FLOAT)") + ip.run_cell("%sql INSERT INTO t VALUES (11)") + ip.run_cell("%sql INSERT INTO t VALUES (22)") + ip.run_cell("%sql INSERT INTO t VALUES (33)") + ip.run_cell("%sql --close sqlite:///b.db") + + ip.run_cell( + """ + %%sql sqlite:// + ATTACH DATABASE 'a.db' AS a_schema; + ATTACH DATABASE 'b.db' AS b_schema; + """ + ) + + expected = { + "count": ["3"], + "mean": ["22.0000"], + "min": ["11.0"], + "max": ["33.0"], + "std": ["11.0000"], + "unique": ["3"], + "freq": [math.nan], + "top": [math.nan], + } + + ip.user_global_ns["table"] = "t" + ip.user_global_ns["schema"] = "b_schema" + out = ip.run_cell(f"%sqlcmd profile {arguments}").result + + stats_table = out._table + + for row in stats_table: + profile_metric = _get_row_string(row, " ") + + cell = row.get_string(fields=["n"], border=False, header=False).strip() + + if profile_metric in expected: + assert cell == str(expected[profile_metric][0]) + + @pytest.mark.parametrize( "arguments", ["--table sample_table --schema test_schema", "--table test_schema.sample_table"], @@ -548,6 +717,35 @@ def test_table_profile_store(ip_with_connections, tmp_empty, conn, report_fname) assert report.is_file() +@pytest.mark.parametrize( + "conn, report_fname", + [ + ("sqlite_sqlalchemy", "test_report.html"), + ("sqlite_dbapi", "test_report_dbapi.html"), + ], +) +def test_table_profile_store_with_substitution( + ip_with_connections, tmp_empty, conn, report_fname +): + ip_with_connections.run_cell( + f""" + %%sql {conn} + CREATE TABLE test_store (rating, price, number, symbol); + INSERT INTO test_store VALUES (14.44, 2.48, 82, 'a'); + INSERT INTO test_store VALUES (13.13, 1.50, 93, 'b'); + INSERT INTO test_store VALUES (12.59, 0.20, 98, 'a'); + INSERT INTO test_store VALUES (11.54, 0.41, 89, 'a'); + """ + ) + ip_with_connections.user_global_ns["table"] = "test_store" + ip_with_connections.user_global_ns["output"] = report_fname + + ip_with_connections.run_cell("%sqlcmd profile -t {{table}} --output {{output}}") + + report = Path(report_fname) + assert report.is_file() + + @pytest.mark.parametrize( "cell, error_message", [ @@ -609,6 +807,27 @@ def test_passing_test_with_schema(ip_empty, sample_schema_with_table, arguments) assert out is True +@pytest.mark.parametrize( + "arguments", + ["--table {{schema}}.{{table}}", "--table {{table}} --schema {{schema}}"], +) +def test_test_with_schema_variable_substitution( + ip_empty, sample_schema_with_table, arguments +): + ip_empty.user_global_ns["table"] = "table1" + ip_empty.user_global_ns["schema"] = "schema1" + out = ip_empty.run_cell(f"%sqlcmd test {arguments} --column x --less-than 3").result + assert out is True + + +def test_test_column_variable_substitution(ip_empty, sample_schema_with_table): + ip_empty.user_global_ns["column"] = "x" + out = ip_empty.run_cell( + "%sqlcmd test --table schema1.table1 --column {{column}} --less-than 3" + ).result + assert out is True + + @pytest.mark.parametrize( "cmds, result", [ @@ -712,6 +931,15 @@ def test_delete_saved_key(ip_snippets, arg): assert "high_price_a" not in stored_snippets +def test_delete_saved_key_with_variable_substitution(ip_snippets): + ip_snippets.user_global_ns["snippet_name"] = "high_price_a" + out = ip_snippets.run_cell("%sqlcmd snippets --delete {{snippet_name}}").result + assert "high_price_a has been deleted.\n" in out + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price, high_price_b" in stored_snippets + assert "high_price_a" not in stored_snippets + + @pytest.mark.parametrize("arg", ["--delete-force", "-D"]) def test_force_delete(ip_snippets, arg): out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result @@ -724,6 +952,20 @@ def test_force_delete(ip_snippets, arg): assert "high_price," not in stored_snippets +def test_force_delete_with_variable_substitution(ip_snippets): + ip_snippets.user_global_ns["snippet_name"] = "high_price" + out = ip_snippets.run_cell( + "%sqlcmd snippets --delete-force {{snippet_name}}" + ).result + assert ( + "high_price has been deleted.\nhigh_price_a, " + "high_price_b depend on high_price\n" in out + ) + stored_snippets = out[out.find("Stored snippets") + len("Stored snippets: ") :] + assert "high_price_a, high_price_b" in stored_snippets + assert "high_price," not in stored_snippets + + @pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) def test_force_delete_all(ip_snippets, arg): out = ip_snippets.run_cell(f"%sqlcmd snippets {arg} high_price").result @@ -731,6 +973,15 @@ def test_force_delete_all(ip_snippets, arg): assert "There are no stored snippets" in out +def test_force_delete_all_with_variable_substitution(ip_snippets): + ip_snippets.user_global_ns["snippet_name"] = "high_price" + out = ip_snippets.run_cell( + "%sqlcmd snippets --delete-force-all {{snippet_name}}" + ).result + assert "high_price_a, high_price_b, high_price has been deleted" in out + assert "There are no stored snippets" in out + + @pytest.mark.parametrize("arg", ["--delete-force-all", "-A"]) def test_force_delete_all_child_query(ip_snippets, arg): ip_snippets.run_cell( @@ -796,6 +1047,18 @@ def test_delete_snippet_when_dependency_force_deleted(ip_snippets, arg): assert "high_price_a has been deleted.\nStored snippets: high_price_b" in out +def test_view_snippet_variable_substitution(ip_snippets): + ip_snippets.user_global_ns["snippet_name"] = "test_snippet" + ip_snippets.run_cell( + """%%sql --save {{snippet_name}} --no-execute +SELECT * FROM "test_store" WHERE price >= 1.50 +""" + ) + + out = ip_snippets.run_cell("%sqlcmd snippets {{snippet_name}}").result + assert 'SELECT * FROM "test_store" WHERE price >= 1.50' in out + + @pytest.mark.parametrize( "arguments", ["--table schema1.table1", "--table table1 --schema schema1"] ) @@ -807,6 +1070,21 @@ def test_explore_with_schema(ip_empty, sample_schema_with_table, arguments): assert [row in out._repr_html_() for row in expected_rows] +@pytest.mark.parametrize( + "arguments", + ["--table {{schema}}.{{table}}", "--table {{table}} --schema {{schema}}"], +) +def test_explore_with_schema_variable_substitution( + ip_empty, sample_schema_with_table, arguments +): + expected_rows = ['"x": 1', '"y": "one"', '"x": 2', '"y": "two"'] + ip_empty.user_global_ns["table"] = "table1" + ip_empty.user_global_ns["schema"] = "schema1" + out = ip_empty.run_cell(f"%sqlcmd explore {arguments}").result + assert isinstance(out, TableWidget) + assert [row in out._repr_html_() for row in expected_rows] + + @pytest.mark.parametrize( "file_content, stored_conns", [ diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index cd8aed92f..af3f1745d 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -798,3 +798,240 @@ def test_sqlplot_missing_table(ip_snippets, capsys): ip_snippets.run_cell("%sqlplot boxplot --table missing --column x") assert MISSING_TABLE_ERROR_MSG.strip() in str(excinfo.value).strip() + + +@_cleanup_cm() +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot boxplot --table {{table}} --column body_mass_g", + "%sqlplot boxplot --table penguins.csv --column {{column}}", + "%sqlplot boxplot --table {{table}} --column {{column}}", + ], + ids=["table", "column", "table-column"], +) +@image_comparison(baseline_images=["boxplot"], extensions=["png"], remove_text=True) +def test_boxplot_with_variable_substitution(load_penguin, ip, cell): + ip.user_global_ns["table"] = "penguins.csv" + ip.user_global_ns["column"] = "body_mass_g" + ip.run_cell(cell) + + +@_cleanup_cm() +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot histogram --table {{table}} --column body_mass_g", + "%sqlplot histogram --table penguins.csv --column {{column}}", + "%sqlplot histogram --table {{table}} --column {{column}}", + ], + ids=["table", "column", "table-column"], +) +@image_comparison(baseline_images=["hist"], extensions=["png"], remove_text=True) +def test_hist_with_variable_substitution(load_penguin, ip, cell): + ip.user_global_ns["table"] = "penguins.csv" + ip.user_global_ns["column"] = "body_mass_g" + ip.run_cell(cell) + + +@_cleanup_cm() +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot bar --table {{table}} --column x", + "%sqlplot bar --table data_one.csv --column {{column}}", + "%sqlplot bar --table {{table}} --column {{column}}", + ], + ids=["table", "column", "table-column"], +) +@image_comparison(baseline_images=["bar_one_col"], extensions=["png"], remove_text=True) +def test_bar_with_variable_substitution(load_data_one_col, ip, cell): + ip.user_global_ns["table"] = "data_one.csv" + ip.user_global_ns["column"] = "x" + ip.run_cell(cell) + + +@_cleanup_cm() +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot pie --table {{table}} --column x", + "%sqlplot pie --table data_one.csv --column {{column}}", + "%sqlplot pie --table {{table}} --column {{column}}", + ], + ids=["table", "column", "table-column"], +) +@image_comparison(baseline_images=["pie_one_col"], extensions=["png"], remove_text=True) +def test_pie_with_variable_substitution(load_data_one_col, ip, cell): + ip.user_global_ns["table"] = "data_one.csv" + ip.user_global_ns["column"] = "x" + ip.run_cell(cell) + + +@_cleanup_cm() +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot histogram --table {{table}} --column {{column}}", + "%sqlplot hist --table {{table}} --column {{column}}", + "%sqlplot boxplot --table {{table}} --column {{column}}", + "%sqlplot box --table {{table}} --column {{column}}", + "%sqlplot boxplot --table {{table}} --column {{column}} --orient {{orient}}", + "%sqlplot boxplot --table {{subset_table}} --column {{column}}", + "%sqlplot boxplot --table {{subset_table}} --column " + "{{column}} --with {{subset_table}}", + "%sqlplot boxplot -t {{subset_table}} -c {{column}} -w {{subset_table}} -o h", + "%sqlplot boxplot --table {{nas_table}} --column {{column}}", + "%sqlplot bar -t {{table}} -c {{column}}", + "%sqlplot bar --table {{subset_table}} --column {{column}}", + "%sqlplot bar --table {{subset_table}} --column {{column}} " + "--with {{subset_table}}", + "%sqlplot bar -t {{table}} -c {{column}} -S", + "%sqlplot bar -t {{table}} -c {{column}} -o h", + "%sqlplot bar -t {{table}} -c {{column}} y", + "%sqlplot pie -t {{table}} -c {{column}}", + "%sqlplot pie --table {{subset_table}} --column {{column}}", + "%sqlplot pie --table {{subset_table}} --column {{column}} " + "--with {{subset_table}}", + "%sqlplot pie -t {{table}} -c {{column}} -S", + "%sqlplot pie -t {{table}} -c {{column}} y", + '%sqlplot boxplot --table {{spaces_table}} --column "some column"', + '%sqlplot histogram --table {{spaces_table}} --column "some column"', + '%sqlplot bar --table {{spaces_table}} --column "some column"', + '%sqlplot pie --table {{spaces_table}} --column "some column"', + ], + ids=[ + "histogram", + "hist", + "boxplot", + "boxplot-with", + "box", + "boxplot-horizontal", + "boxplot-with", + "boxplot-shortcuts", + "boxplot-nas", + "bar-1-col", + "bar-subset", + "bar-subset-with", + "bar-1-col-show_num", + "bar-1-col-horizontal", + "bar-2-col", + "pie-1-col", + "pie-subset", + "pie-subset-with", + "pie-1-col-show_num", + "pie-2-col", + "boxplot-column-name-with-spaces", + "histogram-column-name-with-spaces", + "bar-column-name-with-spaces", + "pie-column-name-with-spaces", + ], +) +def test_sqlplot_with_variable_substitution(tmp_empty, ip, cell): + # clean current Axes + ip.user_global_ns["table"] = "data.csv" + ip.user_global_ns["column"] = "x" + ip.user_global_ns["subset_table"] = "subset" + ip.user_global_ns["nas_table"] = "nas.csv" + ip.user_global_ns["spaces_table"] = "spaces.csv" + ip.user_global_ns["file_spaces"] = "file with spaces.csv" + ip.user_global_ns["orient"] = "h" + plt.cla() + + Path("spaces.csv").write_text( + """\ +"some column", y +0, 0 +1, 1 +2, 2 +""" + ) + + Path("data.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +""" + ) + + Path("nas.csv").write_text( + """\ +x, y +, 0 +1, 1 +2, 2 +""" + ) + + Path("file with spaces.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +""" + ) + ip.run_cell("%sql duckdb://") + + ip.run_cell( + """%%sql --save subset --no-execute +SELECT * +FROM data.csv +WHERE x > -1 +""" + ) + + out = ip.run_cell(cell) + + # maptlotlib >= 3.7 has Axes but earlier Python + # versions are not compatible + assert type(out.result).__name__ in {"Axes", "AxesSubplot"} + + +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot histogram --table {{schema}}.{{table}} " "--column {{column}}", + "%sqlplot histogram --table {{table}} --schema {{schema}} " + "--column {{column}}", + ], +) +@_cleanup_cm() +@image_comparison( + baseline_images=["histogram_with_table_in_schema"], + extensions=["png"], + remove_text=True, +) +def test_histogram_with_table_in_schema_variable_substitution( + ip_with_schema_and_table, cell +): + ip_with_schema_and_table.user_global_ns["table"] = "penguins1" + ip_with_schema_and_table.user_global_ns["column"] = "body_mass_g" + ip_with_schema_and_table.user_global_ns["schema"] = "sqlalchemy_schema" + ip_with_schema_and_table.run_cell("%sql duckdb://") + ip_with_schema_and_table.run_cell(cell) + + +@pytest.mark.parametrize( + "cell", + [ + "%sqlplot boxplot --table {{schema}}.{{table}} --column {{column}}", + "%sqlplot boxplot --table {{table}} --schema {{schema}} " "--column {{column}}", + ], +) +@_cleanup_cm() +@image_comparison( + baseline_images=["boxplot_with_table_in_schema"], + extensions=["png"], + remove_text=True, +) +def test_boxplot_with_table_in_schema_variable_substitution( + ip_with_schema_and_table, cell +): + ip_with_schema_and_table.user_global_ns["table"] = "penguins1" + ip_with_schema_and_table.user_global_ns["column"] = "body_mass_g" + ip_with_schema_and_table.user_global_ns["schema"] = "sqlalchemy_schema" + ip_with_schema_and_table.run_cell("%sql duckdb://") + ip_with_schema_and_table.run_cell(cell)