Skip to content

Commit

Permalink
sqlcmd tests
Browse files Browse the repository at this point in the history
  • Loading branch information
neelasha23 committed May 22, 2023
1 parent 548197d commit c2bfe3e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 85 deletions.
7 changes: 2 additions & 5 deletions src/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from .magic import RenderMagic, SqlMagic, load_ipython_extension
from .util import del_all_saved_keys

__version__ = "0.7.5dev"


__all__ = [
"RenderMagic",
"SqlMagic",
"load_ipython_extension",
]
__all__ = ["RenderMagic", "SqlMagic", "load_ipython_extension", "del_all_saved_keys"]
17 changes: 11 additions & 6 deletions src/sql/magic_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,15 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None):
)
parser.add_argument(
"-da",
"--delete-force--all",
"--delete-force-all",
type=str,
help="Force Delete all stored snippets",
required=False,
)
args = parser.parse_args(others)
# util.get_all_snippets()
SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all]
if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS):
return ", ".join(util.get_all_keys())
if args.delete:
deps = util.get_key_dependents(args.delete)
if deps:
Expand All @@ -246,7 +248,10 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None):
remaining_keys = util.del_saved_key(key)
display_msg = f"Deleted snippet : {key}"
if remaining_keys:
display_msg = f"{display_msg}. Current saved snippets : {', '.join(remaining_keys)}"
display_msg = (
f"{display_msg}. Current saved snippets : "
f"{', '.join(remaining_keys)}"
)
else:
display_msg = (
f"{display_msg}. There are no more saved snippets."
Expand All @@ -266,9 +271,9 @@ def execute(self, cmd_name="", others="", cell="", local_ns=None):

elif args.delete_force_all:
deps = util.get_key_dependents(args.delete_force_all)
keys = deps + args.delete_force_all
display_msg = f"Deleted snippets : {', '.join(keys)}"
for key in keys:
deps.append(args.delete_force_all)
display_msg = f"Deleted snippets : {', '.join(deps)}"
for key in deps:
remaining_keys = util.del_saved_key(key)
display_msg = _modify_display_msg(display_msg, remaining_keys)
return display_msg
Expand Down
10 changes: 4 additions & 6 deletions src/sql/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sql import inspect
import difflib
from sql.connection import Connection
from sql.store import store, _get_dependencies, _get_dependents_for_key
from sql.store import store, _get_dependents_for_key
from sql import exceptions

SINGLE_QUOTE = "'"
Expand Down Expand Up @@ -262,9 +262,7 @@ def del_saved_key(key: str) -> str:
return get_all_keys()


def del_saved_key_with_message(key: str) -> str:
def del_all_saved_keys():
all_keys = get_all_keys()
if key not in all_keys:
return f"No such saved snippet found : {key}"

remaining_keys = get_all_keys()
for key in all_keys:
del store[key]
147 changes: 79 additions & 68 deletions src/tests/test_magic_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,7 @@
import pytest
from IPython.core.error import UsageError
from pathlib import Path


high_price_snippet = """
%%sql --save high_price --no-execute
SELECT *
FROM "test_store"
WHERE price >= 1.50
"""


def test_force_delete_all(ip, tmp_empty, clean_conns):
ip.run_cell(high_price_snippet)
ip.run_cell(
"""
%%sql --save high_price_a --no-execute
SELECT *
FROM "high_price"
WHERE symbol == 'a'
"""
)

out = ip.run_cell("%sqlcmd snippets --delete-force-all high_price").result
assert (
"Deleted snippet : high_price. Current saved snippets : 'high_price_a'" in out
)
from sql.util import del_all_saved_keys


@pytest.mark.parametrize(
Expand Down Expand Up @@ -290,54 +266,84 @@ def test_test_error(ip, cell, error_type, error_message):
assert str(out.error_in_exec) == error_message


# @pytest.fixture()
# def data():
# cell = """%%sql sqlite://
# CREATE TABLE test_store (rating, price, number, symbol);
# INSERT INTO test_store VALUES (14.44, 2.48, 82, 'a');
# INSERT INTO test_store VALUES (13.13, 1.50, 93, 'b');
# INSERT INTO test_store VALUES (12.59, 0.20, 98, 'a');
# INSERT INTO test_store VALUES (11.54, 0.41, 89, 'a');
# INSERT INTO test_store VALUES (11.22, 3.01, 89, 'a');
# INSERT INTO test_store VALUES (13.54, 4.51, 89, 'b');
# """
# return cell
high_price_snippet = """
%%sql --save high_price --no-execute
SELECT *
FROM "test_store"
WHERE price >= 1.50
"""


def test_snippet(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
out = ip_empty.run_cell("%sqlcmd snippets").result
assert "high_price" == out
# Clean saved keys
del_all_saved_keys()

def test_force_delete(ip, tmp_empty, data, high_price_snippet, clean_conns):
ip.run_cell(high_price_snippet)
ip.run_cell(

def test_invalid_args(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
out = ip_empty.run_cell("%sqlcmd snippets --delete").result
assert "high_price" == out
# Clean saved keys
del_all_saved_keys()


def test_delete_only_saved_key(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)

out = ip_empty.run_cell("%sqlcmd snippets --delete high_price").result
assert "Deleted snippet : high_price. There are no more saved snippets." in out
# Clean saved keys
del_all_saved_keys()


def test_delete_saved_key(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
ip_empty.run_cell(
"""
%%sql --save high_price_a --no-execute
SELECT *
FROM "high_price"
WHERE symbol == 'a'
"""
)
with pytest.warns(UserWarning) as record:
out = ip.run_cell("%sqlcmd snippets --delete-force high_price").result
assert len(record) == 1
assert "Force deleting saved snippet" in record[0].message.args[0]
assert (
"Deleted snippet : high_price. Current saved snippets : 'high_price_a'" in out
)

out = ip_empty.run_cell("%sqlcmd snippets --delete high_price_a").result
assert "Deleted snippet : high_price_a. Current saved snippets : high_price" in out
# Clean saved keys
del_all_saved_keys()

def test_delete_only_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns):
ip.run_cell(data)

ip.run_cell(high_price_snippet)

out = ip.run_cell("%sqlcmd snippets --delete high_price").result
assert "Deleted snippet : high_price. There are no more saved snippets." in out

def test_force_delete(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
ip_empty.run_cell(
"""
%%sql --save high_price_a --no-execute
SELECT *
FROM "high_price"
WHERE symbol == 'a'
"""
)
with pytest.warns(UserWarning) as record:
out = ip_empty.run_cell("%sqlcmd snippets --delete-force high_price").result
assert len(record) == 1
assert "Force deleting saved snippet" in record[0].message.args[0]
assert "Deleted snippet : high_price. Current saved snippets : high_price_a" in out
# Clean saved keys
del_all_saved_keys()

def test_delete_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns):
ip.run_cell(data)

ip.run_cell(high_price_snippet)

ip.run_cell(
def test_force_delete_all(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
ip_empty.run_cell(
"""
%%sql --save high_price_a --no-execute
SELECT *
Expand All @@ -346,24 +352,27 @@ def test_delete_saved_key(ip, tmp_empty, data, high_price_snippet, clean_conns):
"""
)

out = ip.run_cell("%sqlcmd snippets --delete high_price_a").result
out = ip_empty.run_cell("%sqlcmd snippets --delete-force-all high_price").result
assert (
"Deleted snippet : high_price_a. Current saved snippets : 'high_price'" in out
"Deleted snippets : high_price_a, high_price. There are no more saved snippets."
== out
)
# Clean saved keys
del_all_saved_keys()


def test_delete_snippet_error(ip, tmp_empty, data, high_price_snippet):
ip.run_cell(data)
ip.run_cell(high_price_snippet)
ip.run_cell(
def test_delete_snippet_error(ip_empty):
ip_empty.run_cell("%sql sqlite://")
ip_empty.run_cell(high_price_snippet)
ip_empty.run_cell(
"""
%%sql --save high_price_a --no-execute
SELECT *
FROM "high_price"
WHERE symbol == 'a'
"""
)
ip.run_cell(
ip_empty.run_cell(
"""
%%sql --save high_price_b --no-execute
SELECT *
Expand All @@ -372,11 +381,13 @@ def test_delete_snippet_error(ip, tmp_empty, data, high_price_snippet):
"""
)

out = ip.run_cell("%sqlcmd snippets --delete high_price")
out = ip_empty.run_cell("%sqlcmd snippets --delete high_price")
assert isinstance(out.error_in_exec, UsageError)
assert (
str(out.error_in_exec) == "The following tables are dependent on high_price: "
"high_price_a, high_price_b.Pass --delete-force to only "
"delete high_price.Pass --delete-force-all to delete "
"high_price_a, high_price_b, high_price"
"high_price_a, high_price_b and high_price"
)
# Clean saved keys
del_all_saved_keys()

0 comments on commit c2bfe3e

Please sign in to comment.