Skip to content

Commit

Permalink
snippets display improvement (#716)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbeat2782 authored Jul 11, 2023
1 parent 60c8a32 commit f56dfab
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525))
* [Feature] Added a line under `ResultSet` to distinguish it from data frame and error message when invalid operations are performed (#468)
* [Feature] Moved `%sqlrender` feature to `%sqlcmd snippets` (#647)
* [Feature] Added tables listing stored snippets when `%sqlcmd snippets` is called (#648)

* [Doc] Modified integrations content to ensure they're all consistent (#523)
* [Doc] Document --persist-replace in API section (#539)
Expand Down
9 changes: 7 additions & 2 deletions src/sql/cmd/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sql.exceptions import UsageError
from sql.cmd.cmd_utils import CmdParser
from sql.store import store
from sql.display import Table, Message


def _modify_display_msg(key, remaining_keys, dependent_keys=None):
Expand Down Expand Up @@ -63,8 +64,8 @@ def snippets(others):
help="Force delete all stored snippets",
required=False,
)
all_snippets = util.get_all_keys()
if len(others) == 1:
all_snippets = util.get_all_keys()
if others[0] in all_snippets:
return str(store[others[0]])

Expand All @@ -80,7 +81,11 @@ def snippets(others):
args = parser.parse_args(others)
SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all]
if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS):
return ", ".join(util.get_all_keys())
if len(all_snippets) == 0:
return Message("No snippets stored")
else:
return Table(["Stored snippets"], [[snippet] for snippet in all_snippets])

if args.delete:
deps = util.get_key_dependents(args.delete)
if deps:
Expand Down
67 changes: 64 additions & 3 deletions src/tests/test_magic_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sql.connection import Connection
from sql.store import store
from sql.inspect import _is_numeric
from sql.display import Table, Message


VALID_COMMANDS_MESSAGE = (
Expand Down Expand Up @@ -62,6 +63,14 @@ def ip_snippets(ip):
yield ip


@pytest.fixture
def test_snippet_ip(ip):
for key in list(store):
del store[key]
ip.run_cell("%sql sqlite://")
yield ip


@pytest.mark.parametrize(
"cell, error_type, error_message",
[
Expand Down Expand Up @@ -416,9 +425,61 @@ def test_test_error(ip, cell, error_type, error_message):
assert str(out.error_in_exec) == error_message


def test_snippet(ip_snippets):
out = ip_snippets.run_cell("%sqlcmd snippets").result
assert "high_price, high_price_a, high_price_b" in out
@pytest.mark.parametrize(
"cmds, result",
[
(["%sqlcmd snippets"], Message("No snippets stored")),
(
[
"""%%sql --save test_snippet --no-execute
SELECT * FROM "test_store" WHERE price >= 1.50
""",
"%sqlcmd snippets",
],
Table(
["Stored snippets"],
[["test_snippet"]],
),
),
(
[
"""%%sql --save test_snippet --no-execute
SELECT * FROM "test_store" WHERE price >= 1.50
""",
"""%%sql --save test_snippet_a --no-execute
SELECT * FROM "test_snippet" WHERE symbol == 'a'
""",
"%sqlcmd snippets",
],
Table(
["Stored snippets"],
[["test_snippet"], ["test_snippet_a"]],
),
),
(
[
"""%%sql --save test_snippet --no-execute
SELECT * FROM "test_store" WHERE price >= 1.50
""",
"""%%sql --save test_snippet_a --no-execute
SELECT * FROM "test_snippet" WHERE symbol == 'a'
""",
"""%%sql --save test_snippet_b --no-execute
SELECT * FROM "test_snippet" WHERE symbol == 'b'
""",
"%sqlcmd snippets",
],
Table(
["Stored snippets"],
[["test_snippet"], ["test_snippet_a"], ["test_snippet_b"]],
),
),
],
)
def test_snippet(test_snippet_ip, cmds, result):
out = [test_snippet_ip.run_cell(cmd) for cmd in cmds][-1].result
assert str(out) == str(result)
assert isinstance(out, type(result))


@pytest.mark.parametrize(
Expand Down

0 comments on commit f56dfab

Please sign in to comment.