diff --git a/src/sql/__init__.py b/src/sql/__init__.py index 0377bd8d8..e55222310 100644 --- a/src/sql/__init__.py +++ b/src/sql/__init__.py @@ -1,10 +1,7 @@ from .magic import RenderMagic, SqlMagic, load_ipython_extension +from .util import del_all_saved_keys __version__ = "0.7.5dev" -__all__ = [ - "RenderMagic", - "SqlMagic", - "load_ipython_extension", -] +__all__ = ["RenderMagic", "SqlMagic", "load_ipython_extension", "del_all_saved_keys"] diff --git a/src/sql/magic_cmd.py b/src/sql/magic_cmd.py index b3724051f..193638586 100644 --- a/src/sql/magic_cmd.py +++ b/src/sql/magic_cmd.py @@ -225,13 +225,15 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): ) parser.add_argument( "-da", - "--delete-force--all", + "--delete-force-all", type=str, help="Force Delete all stored snippets", required=False, ) args = parser.parse_args(others) - # util.get_all_snippets() + SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all] + if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS): + return ", ".join(util.get_all_keys()) if args.delete: deps = util.get_key_dependents(args.delete) if deps: @@ -246,7 +248,10 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): remaining_keys = util.del_saved_key(key) display_msg = f"Deleted snippet : {key}" if remaining_keys: - display_msg = f"{display_msg}. Current saved snippets : {', '.join(remaining_keys)}" + display_msg = ( + f"{display_msg}. Current saved snippets : " + f"{', '.join(remaining_keys)}" + ) else: display_msg = ( f"{display_msg}. There are no more saved snippets." @@ -266,9 +271,9 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None): elif args.delete_force_all: deps = util.get_key_dependents(args.delete_force_all) - keys = deps + args.delete_force_all - display_msg = f"Deleted snippets : {', '.join(keys)}" - for key in keys: + deps.append(args.delete_force_all) + display_msg = f"Deleted snippets : {', '.join(deps)}" + for key in deps: remaining_keys = util.del_saved_key(key) display_msg = _modify_display_msg(display_msg, remaining_keys) return display_msg diff --git a/src/sql/util.py b/src/sql/util.py index 3c3ec72d9..ee4c0d300 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -1,7 +1,7 @@ from sql import inspect import difflib from sql.connection import Connection -from sql.store import store, _get_dependencies, _get_dependents_for_key +from sql.store import store, _get_dependents_for_key from sql import exceptions SINGLE_QUOTE = "'" @@ -262,9 +262,7 @@ def del_saved_key(key: str) -> str: return get_all_keys() -def del_saved_key_with_message(key: str) -> str: +def del_all_saved_keys(): all_keys = get_all_keys() - if key not in all_keys: - return f"No such saved snippet found : {key}" - - remaining_keys = get_all_keys() + for key in all_keys: + del store[key] diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 2d6216852..6d17e103a 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -4,31 +4,7 @@ import pytest from IPython.core.error import UsageError from pathlib import Path - - -high_price_snippet = """ - %%sql --save high_price --no-execute -SELECT * -FROM "test_store" -WHERE price >= 1.50 -""" - - -def test_force_delete_all(ip, tmp_empty, clean_conns): - ip.run_cell(high_price_snippet) - ip.run_cell( - """ - %%sql --save high_price_a --no-execute -SELECT * -FROM "high_price" -WHERE symbol == 'a' -""" - ) - - out = ip.run_cell("%sqlcmd snippets --delete-force-all high_price").result - assert ( - "Deleted snippet : high_price. Current saved snippets : 'high_price_a'" in out - ) +from sql.util import del_all_saved_keys @pytest.mark.parametrize( @@ -290,23 +266,46 @@ def test_test_error(ip, cell, error_type, error_message): assert str(out.error_in_exec) == error_message -# @pytest.fixture() -# def data(): -# cell = """%%sql sqlite:// -# CREATE TABLE test_store (rating, price, number, symbol); -# INSERT INTO test_store VALUES (14.44, 2.48, 82, 'a'); -# INSERT INTO test_store VALUES (13.13, 1.50, 93, 'b'); -# INSERT INTO test_store VALUES (12.59, 0.20, 98, 'a'); -# INSERT INTO test_store VALUES (11.54, 0.41, 89, 'a'); -# INSERT INTO test_store VALUES (11.22, 3.01, 89, 'a'); -# INSERT INTO test_store VALUES (13.54, 4.51, 89, 'b'); -# """ -# return cell +high_price_snippet = """ + %%sql --save high_price --no-execute +SELECT * +FROM "test_store" +WHERE price >= 1.50 +""" + +def test_snippet(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + out = ip_empty.run_cell("%sqlcmd snippets").result + assert "high_price" == out + # Clean saved keys + del_all_saved_keys() -def test_force_delete(ip, tmp_empty, data, high_price_snippet, clean_conns): - ip.run_cell(high_price_snippet) - ip.run_cell( + +def test_invalid_args(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + out = ip_empty.run_cell("%sqlcmd snippets --delete").result + assert "high_price" == out + # Clean saved keys + del_all_saved_keys() + + +def test_delete_only_saved_key(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + + out = ip_empty.run_cell("%sqlcmd snippets --delete high_price").result + assert "Deleted snippet : high_price. There are no more saved snippets." in out + # Clean saved keys + del_all_saved_keys() + + +def test_delete_saved_key(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + ip_empty.run_cell( """ %%sql --save high_price_a --no-execute SELECT * @@ -314,30 +313,37 @@ def test_force_delete(ip, tmp_empty, data, high_price_snippet, clean_conns): WHERE symbol == 'a' """ ) - with pytest.warns(UserWarning) as record: - out = ip.run_cell("%sqlcmd snippets --delete-force high_price").result - assert len(record) == 1 - assert "Force deleting saved snippet" in record[0].message.args[0] - assert ( - "Deleted snippet : high_price. Current saved snippets : 'high_price_a'" in out - ) + out = ip_empty.run_cell("%sqlcmd snippets --delete high_price_a").result + assert "Deleted snippet : high_price_a. Current saved snippets : high_price" in out + # Clean saved keys + del_all_saved_keys() -def test_delete_only_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns): - ip.run_cell(data) - - ip.run_cell(high_price_snippet) - - out = ip.run_cell("%sqlcmd snippets --delete high_price").result - assert "Deleted snippet : high_price. There are no more saved snippets." in out +def test_force_delete(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + ip_empty.run_cell( + """ + %%sql --save high_price_a --no-execute +SELECT * +FROM "high_price" +WHERE symbol == 'a' +""" + ) + with pytest.warns(UserWarning) as record: + out = ip_empty.run_cell("%sqlcmd snippets --delete-force high_price").result + assert len(record) == 1 + assert "Force deleting saved snippet" in record[0].message.args[0] + assert "Deleted snippet : high_price. Current saved snippets : high_price_a" in out + # Clean saved keys + del_all_saved_keys() -def test_delete_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns): - ip.run_cell(data) - - ip.run_cell(high_price_snippet) - ip.run_cell( +def test_force_delete_all(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + ip_empty.run_cell( """ %%sql --save high_price_a --no-execute SELECT * @@ -346,16 +352,19 @@ def test_delete_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns): """ ) - out = ip.run_cell("%sqlcmd snippets --delete high_price_a").result + out = ip_empty.run_cell("%sqlcmd snippets --delete-force-all high_price").result assert ( - "Deleted snippet : high_price_a. Current saved snippets : 'high_price'" in out + "Deleted snippets : high_price_a, high_price. There are no more saved snippets." + == out ) + # Clean saved keys + del_all_saved_keys() -def test_delete_snippet_error(ip, tmp_empty, data, high_price_snippet): - ip.run_cell(data) - ip.run_cell(high_price_snippet) - ip.run_cell( +def test_delete_snippet_error(ip_empty): + ip_empty.run_cell("%sql sqlite://") + ip_empty.run_cell(high_price_snippet) + ip_empty.run_cell( """ %%sql --save high_price_a --no-execute SELECT * @@ -363,7 +372,7 @@ def test_delete_snippet_error(ip, tmp_empty, data, high_price_snippet): WHERE symbol == 'a' """ ) - ip.run_cell( + ip_empty.run_cell( """ %%sql --save high_price_b --no-execute SELECT * @@ -372,11 +381,13 @@ def test_delete_snippet_error(ip, tmp_empty, data, high_price_snippet): """ ) - out = ip.run_cell("%sqlcmd snippets --delete high_price") + out = ip_empty.run_cell("%sqlcmd snippets --delete high_price") assert isinstance(out.error_in_exec, UsageError) assert ( str(out.error_in_exec) == "The following tables are dependent on high_price: " "high_price_a, high_price_b.Pass --delete-force to only " "delete high_price.Pass --delete-force-all to delete " - "high_price_a, high_price_b, high_price" + "high_price_a, high_price_b and high_price" ) + # Clean saved keys + del_all_saved_keys()