diff --git a/CHANGELOG.md b/CHANGELOG.md index 05c52daa1..9d6406860 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ * [Feature] Using native DuckDB `.df()` method when using `autopandas` * [Doc] documenting `%sqlcmd tables`/`%sqlcmd columns` +* [Feature] Better error messages when function used in plotting API unsupported by DB driver (#159) ## 0.7.4 (2023-04-28) No changes diff --git a/src/sql/plot.py b/src/sql/plot.py index e8c2be9cf..e6c55e842 100644 --- a/src/sql/plot.py +++ b/src/sql/plot.py @@ -6,6 +6,8 @@ from jinja2 import Template from sql.util import flatten +from sqlalchemy.exc import ProgrammingError +from sql import exceptions try: import matplotlib.pyplot as plt @@ -30,6 +32,7 @@ def _summary_stats(conn, table, column, with_=None): if not conn: conn = sql.connection.Connection.current + driver = conn._get_curr_sqlalchemy_connection_info()["driver"] template = Template( """ @@ -44,7 +47,16 @@ def _summary_stats(conn, table, column, with_=None): query = template.render(table=table, column=column) - values = conn.execute(query, with_).fetchone() + try: + values = conn.execute(query, with_).fetchone() + except ProgrammingError as e: + print(e) + raise exceptions.RuntimeError( + f"\nEnsure that percentile_disc function is available on {driver}." + ) + except Exception as e: + raise e + keys = ["q1", "med", "q3", "mean", "N"] return {k: float(v) for k, v in zip(keys, flatten(values))} diff --git a/src/tests/integration/test_mssql.py b/src/tests/integration/test_mssql.py index 63057ea14..4bd8a3f72 100644 --- a/src/tests/integration/test_mssql.py +++ b/src/tests/integration/test_mssql.py @@ -1,6 +1,7 @@ import pyodbc import pytest from matplotlib import pyplot as plt +from IPython.core.error import UsageError def test_query_count(ip_with_MSSQL, test_table_name_dict): @@ -112,3 +113,17 @@ def test_sqlplot_boxplot(ip_with_MSSQL, cell): out = ip_with_MSSQL.run_cell(cell) assert type(out.result).__name__ in {"Axes", "AxesSubplot"} + + +def test_unsupported_function(ip_with_MSSQL, test_table_name_dict): + # clean current Axes + plt.cla() + out = ip_with_MSSQL.run_cell( + f"%sqlplot boxplot --table " f"{test_table_name_dict['taxi']} --column x" + ) + assert isinstance(out.error_in_exec, UsageError) + assert "Ensure that percentile_disc function is available" in str(out.error_in_exec) + assert ( + "If you need help solving this issue, " + "send us a message: https://ploomber.io/community" in str(out.error_in_exec) + ) diff --git a/src/tests/test_magic_plot.py b/src/tests/test_magic_plot.py index 780512092..6e0834af1 100644 --- a/src/tests/test_magic_plot.py +++ b/src/tests/test_magic_plot.py @@ -4,6 +4,34 @@ from IPython.core.error import UsageError import matplotlib.pyplot as plt +from sqlalchemy.exc import OperationalError + +from sql.plot import _summary_stats + + +def test_summary_stats_success(tmp_empty, ip): + Path("data.csv").write_text( + """\ +x, y +0, 0 +1, 1 +2, 2 +5, 7 +9, 9 +""" + ) + ip.run_cell("%sql duckdb://") + result = _summary_stats(None, "data.csv", column="x") + expected = {"q1": 1.0, "med": 2.0, "q3": 5.0, "mean": 3.4, "N": 5.0} + assert result == expected + + +def test_summary_stats_failure(tmp_empty, ip): + ip.run_cell("%sql duckdb://") + with pytest.raises(OperationalError) as e: + _summary_stats(None, "data.csv", column="x") + assert 'No files found that match the pattern "data.csv"' in str(e) + @pytest.mark.parametrize( "cell, error_type, error_message", @@ -166,3 +194,4 @@ def test_sqlplot(tmp_empty, ip, cell): # maptlotlib >= 3.7 has Axes but earlier Python # versions are not compatible assert type(out.result).__name__ in {"Axes", "AxesSubplot"} +