diff --git a/CHANGELOG.md b/CHANGELOG.md index fb69fe39a..c803e6a37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) * [Fix] Refactored `ResultSet` to lazy loading (#470) +* [Fix] Error when executing multiple SQL statements when using DuckDB with `autopandas` on (#674) * [Fix] Removed `WITH` when a snippet does not have a dependency (#657) ## 0.7.9 (2023-06-19) diff --git a/src/sql/connection.py b/src/sql/connection.py index d5435055b..180f93d2a 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -203,6 +203,8 @@ def __init__(self, engine, alias=None): ) self.dialect = engine.url.get_dialect() + + # TODO: delete this! self.session = self._create_session(engine, self.url) self.connections[alias or self.url] = self @@ -651,7 +653,7 @@ def execute(self, query, with_=None): Executes SQL query on a given connection """ query = self._prepare_query(query, with_) - return self.session.execute(query) + return self.engine.execute(query) atexit.register(Connection.close_all, verbose=True) @@ -672,11 +674,15 @@ def __init__(self, connection, engine): } ) + # TODO: need to close the cursor def execute(self, query): cur = self.engine.cursor() cur.execute(query) return cur + def commit(self): + self.engine.commit() + class CustomConnection(Connection): """ @@ -698,6 +704,7 @@ def __init__(self, payload, engine=None, alias=None): self.name = connection_name_ self.dialect = connection_name_ self.session = CustomSession(self, engine) + self.engine = self.session self.connections[alias or connection_name_] = self diff --git a/src/sql/inspect.py b/src/sql/inspect.py index f4455370c..ada6b16b3 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -18,7 +18,7 @@ def _get_inspector(conn): if not Connection.current: raise exceptions.RuntimeError("No active connection") else: - return inspect(Connection.current.session) + return inspect(Connection.current.engine) class DatabaseInspection: @@ -239,11 +239,15 @@ def __init__(self, table_name, schema=None) -> None: columns_query_result = sql.run.raw_run( Connection.current, f"SELECT * FROM {table_name} WHERE 1=0" ) + if Connection.is_custom_connection(): columns = [i[0] for i in columns_query_result.description] else: columns = columns_query_result.keys() + # TODO: abstract it internally + columns_query_result.close() + table_stats = dict({}) columns_to_include_in_report = set() columns_with_styles = [] diff --git a/src/sql/run.py b/src/sql/run.py index fb8bcaff3..fa27ea24a 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -9,6 +9,7 @@ import prettytable import sqlalchemy +from sqlalchemy.exc import ResourceClosedError import sqlparse from sql.connection import Connection from sql import exceptions, display @@ -109,27 +110,88 @@ class ResultSet(ColumnGuesserMixin): Can access rows listwise, or by string value of leftmost column. """ - def __init__(self, sqlaproxy, config): + def __init__(self, sqlaproxy, config, sql=None, engine=None): self.config = config - self.keys = {} - self._results = [] self.truncated = False self.sqlaproxy = sqlaproxy + self.sql = sql + self.engine = engine + + self._keys = None + self._field_names = None + self._results = [] # https://peps.python.org/pep-0249/#description self.is_dbapi_results = hasattr(sqlaproxy, "description") - self.pretty = None + # NOTE: this will trigger key fetching + self.pretty_table = self._init_table() + + self._done_fetching = False + + if self.config.autolimit == 1: + self.fetchmany(size=1) + self.did_finish_fetching() + # EXPLAIN WHY WE NEED TWO + else: + self.fetchmany(size=2) + + def extend_results(self, elements): + self._results.extend(elements) + + # NOTE: we shouold use add_rows but there is a subclass that behaves weird + # so using add_row for now + for e in elements: + self.pretty_table.add_row(e) + + def done_fetching(self): + self._done_fetching = True + self.sqlaproxy.close() + + def did_finish_fetching(self): + return self._done_fetching + + # NOTE: this triggers key fetching + @property + def field_names(self): + if self._field_names is None: + self._field_names = unduplicate_field_names(self.keys) + + return self._field_names + + # NOTE: this triggers key fetching + @property + def keys(self): + if self._keys is not None: + return self._keys + + if not self.is_dbapi_results: + try: + self._keys = self.sqlaproxy.keys() + # sqlite raises this error when running a script that doesn't return rows + # e.g, 'CREATE TABLE' but others don't (e.g., duckdb) + except ResourceClosedError: + self._keys = [] + return self._keys + + elif isinstance(self.sqlaproxy.description, Iterable): + self._keys = [i[0] for i in self.sqlaproxy.description] + else: + self._keys = [] + + return self._keys def _repr_html_(self): + self.fetch_for_repr_if_needed() + _cell_with_spaces_pattern = re.compile(r"()( {2,})") - if self.pretty: - self.pretty.add_rows(self) - result = self.pretty.get_html_string() + if self.pretty_table: + self.pretty_table.add_rows(self) + result = self.pretty_table.get_html_string() # to create clickable links result = html.unescape(result) result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) - if self.truncated: + if self.config.displaylimit != 0: HTML = ( '%s\n' "Truncated to displaylimit of %d" @@ -139,7 +201,7 @@ def _repr_html_(self): 'displaylimit' # noqa: E501 " configuration" ) - result = HTML % (result, self.pretty.row_count) + result = HTML % (result, self.config.displaylimit) return result else: return None @@ -148,18 +210,18 @@ def __len__(self): return len(self._results) def __iter__(self): - results = self._fetch_query_results(fetch_all=True) + self.fetchall() - for result in results: + for result in self._results: yield result def __str__(self, *arg, **kwarg): - if self.pretty: - self.pretty.add_rows(self) - return str(self.pretty or "") + self.fetch_for_repr_if_needed() + return str(self.pretty_table) def __repr__(self) -> str: - return str(self) + self.fetch_for_repr_if_needed() + return str(self.pretty_table) def __eq__(self, another: object) -> bool: return self._results == another @@ -195,10 +257,28 @@ def DataFrame(self, payload): "Returns a Pandas DataFrame instance built from the result set." import pandas as pd + # TODO: re-add + # payload[ + # "connection_info" + # ] = Connection.current._get_curr_sqlalchemy_connection_info() + + if self.did_finish_fetching(): + return pd.DataFrame(self, columns=(self and self.keys) or []) + + # only run when using duckdb + if self.engine.dialect.name == "duckdb": + # why do we need this? + self.sqlaproxy.close() + + conn_duckdb_raw = self.engine.raw_connection() + cursor = conn_duckdb_raw.cursor() + cursor.execute(str(self.sql)) + df = cursor.df() + conn_duckdb_raw.close() + return df + frame = pd.DataFrame(self, columns=(self and self.keys) or []) - payload[ - "connection_info" - ] = Connection.current._get_curr_sqlalchemy_connection_info() + return frame @telemetry.log_call("polars-data-frame") @@ -320,9 +400,9 @@ def bar(self, key_word_sep=" ", title=None, **kwargs): def csv(self, filename=None, **format_params): """Generate results in comma-separated form. Write to ``filename`` if given. Any other parameters will be passed on to csv.writer.""" - if not self.pretty: + if not self.pretty_table: return None # no results - self.pretty.add_rows(self) + self.pretty_table.add_rows(self) if filename: encoding = format_params.get("encoding", "utf-8") outfile = open(filename, "w", newline="", encoding=encoding) @@ -339,104 +419,51 @@ def csv(self, filename=None, **format_params): else: return outfile.getvalue() - def fetch_results(self, fetch_all=False): - """ - Returns a limited representation of the query results. - - Parameters - ---------- - fetch_all : bool default False - Return all query rows - """ - is_dbapi_results = self.is_dbapi_results - sqlaproxy = self.sqlaproxy - config = self.config - - if is_dbapi_results: - should_try_fetch_results = True - else: - should_try_fetch_results = sqlaproxy.returns_rows - - if should_try_fetch_results: - # sql alchemy results - if not is_dbapi_results: - self.keys = sqlaproxy.keys() - elif isinstance(sqlaproxy.description, Iterable): - self.keys = [i[0] for i in sqlaproxy.description] - else: - self.keys = [] - - if len(self.keys) > 0: - self._results = self._fetch_query_results(fetch_all=fetch_all) + def fetchmany(self, size): + """Fetch n results and add it to the results""" + if not self.did_finish_fetching(): + try: + returned = self.sqlaproxy.fetchmany(size=size) + # sqlite raises this error when running a script that doesn't return rows + # e.g, 'CREATE TABLE' but others don't (e.g., duckdb) + except ResourceClosedError: + self.done_fetching() + return - self.field_names = unduplicate_field_names(self.keys) + self.extend_results(returned) - _style = None + if len(returned) < size: + self.done_fetching() - self.pretty = PrettyTable(self.field_names) + if ( + self.config.autolimit is not None + and self.config.autolimit != 0 + and len(self._results) >= self.config.autolimit + ): + self.done_fetching() - if isinstance(config.style, str): - _style = prettytable.__dict__[config.style.upper()] - self.pretty.set_style(_style) + def fetch_for_repr_if_needed(self): + if self.config.displaylimit == 0: + self.fetchall() - return self + missing = self.config.displaylimit - len(self._results) - def _fetch_query_results(self, fetch_all=False): - """ - Returns rows of a query result as a list of tuples. - - Parameters - ---------- - fetch_all : bool default False - Return all query rows - """ - sqlaproxy = self.sqlaproxy - config = self.config - _should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed") - - _should_fetch_all = ( - (config.displaylimit == 0 or not config.displaylimit) - or fetch_all - or not _should_try_lazy_fetch - ) + if missing > 0: + self.fetchmany(missing) - is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0 - is_connection_closed = ( - sqlaproxy._soft_closed if _should_try_lazy_fetch else False - ) - - should_return_results = is_connection_closed or ( - len(self._results) > 0 and is_autolimit - ) + def fetchall(self): + if not self.did_finish_fetching(): + self.extend_results(self.sqlaproxy.fetchall()) + self.done_fetching() - if should_return_results: - # this means we already loaded all - # the results to self._results or we use - # autolimit and shouldn't fetch more - results = self._results - else: - if is_autolimit: - results = sqlaproxy.fetchmany(size=config.autolimit) - else: - if _should_fetch_all: - all_results = sqlaproxy.fetchall() - results = self._results + all_results - self._results = results - else: - results = sqlaproxy.fetchmany(size=config.displaylimit) + def _init_table(self): + pretty = PrettyTable(self.field_names) - if _should_try_lazy_fetch: - # Try to fetch an extra row to find out - # if there are more results to fetch - row = sqlaproxy.fetchone() - if row is not None: - results += [row] + if isinstance(self.config.style, str): + _style = prettytable.__dict__[self.config.style.upper()] + pretty.set_style(_style) - # Check if we have more rows to show - if config.displaylimit > 0: - self.truncated = len(results) > config.displaylimit - - return results + return pretty def display_affected_rowcount(rowcount): @@ -507,6 +534,7 @@ def _commit(conn, config, manual_commit): with Session(conn.session) as session: session.commit() except sqlalchemy.exc.OperationalError: + # TODO: missing rollback here? print("The database does not support the COMMIT command") @@ -565,7 +593,6 @@ def select_df_type(resultset, config): return resultset.PolarsDataFrame(**config.polars_dataframe_kwargs) else: return resultset - # returning only last result, intentionally def run(conn, sql, config): @@ -582,17 +609,16 @@ def run(conn, sql, config): config Configuration object """ - info = conn._get_curr_sqlalchemy_connection_info() - - duckdb_autopandas = info and info.get("dialect") == "duckdb" and config.autopandas - if not sql.strip(): # returning only when sql is empty string return "Connected: %s" % conn.name - for statement in sqlparse.split(sql): - first_word = sql.strip().split()[0].lower() - manual_commit = False + statements = sqlparse.split(sql) + + for statement in statements: + first_word = _first_word(statement) + is_select = first_word == "select" + # manual_commit = False # attempting to run a transaction if first_word == "begin": @@ -604,41 +630,48 @@ def run(conn, sql, config): # regular query else: - manual_commit = set_autocommit(conn, config) + # TODO: add commit feature again + # _commit(conn=conn, config=config, manual_commit=manual_commit) + # manual_commit = set_autocommit(conn, config) is_custom_connection = Connection.is_custom_connection(conn) # if regular sqlalchemy, pass a text object if not is_custom_connection: statement = sqlalchemy.sql.text(statement) - if duckdb_autopandas: - conn = conn.engine.raw_connection() - cursor = conn.cursor() - cursor.execute(str(statement)) - + # NOTE: are there any edge cases we should cover? + if not is_select: + # TODO: we should abstract this + if is_custom_connection: + result = conn.engine.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) + conn.engine.commit() + else: + with Session(conn.engine) as session: + result = session.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) + session.commit() else: - result = conn.session.execute(statement) - _commit(conn=conn, config=config, manual_commit=manual_commit) - - if result and config.feedback: - if hasattr(result, "rowcount"): - display_affected_rowcount(result.rowcount) - - # bypass ResultSet and use duckdb's native method to return a pandas data frame - if duckdb_autopandas: - df = cursor.df() - conn.close() - return df - else: - resultset = ResultSet(result, config) + # we need the session so sqlalchemy 2.x works + # result = conn.engine.execute(statement) + with Session(conn.engine) as session: + result = session.execute(statement) + + resultset = ResultSet(result, config, statement, conn.engine) + + if result and config.feedback: + if hasattr(result, "rowcount"): + display_affected_rowcount(result.rowcount) + + return select_df_type(resultset, config) + - # lazy load - resultset.fetch_results() - return select_df_type(resultset, config) +def _first_word(sql): + return sql.strip().split()[0].lower() def raw_run(conn, sql): - return conn.session.execute(sqlalchemy.sql.text(sql)) + return conn.engine.execute(sqlalchemy.sql.text(sql)) class PrettyTable(prettytable.PrettyTable): diff --git a/src/sql/util.py b/src/sql/util.py index 19473ace6..b621f383b 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -81,6 +81,7 @@ def is_table_exists( ignore_error: bool, default False Avoid raising a ValueError """ + if table is None: if ignore_error: return False @@ -188,6 +189,7 @@ def _is_table_exists(table: str, conn) -> bool: """ Runs a SQL query to check if table exists """ + if not conn: conn = Connection.current @@ -201,6 +203,7 @@ def _is_table_exists(table: str, conn) -> bool: try: conn.execute(query) return True + # TODO: Add specific exception except Exception: pass diff --git a/src/tests/conftest.py b/src/tests/conftest.py index bd2422e4f..a96e5f832 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -48,6 +48,29 @@ def clean_conns(): yield +class TestingShell(InteractiveShell): + def run_cell(self, *args, **kwargs): + result = super().run_cell(*args, **kwargs) + + if result.error_in_exec is not None: + raise result.error_in_exec + + return result + + +# migrate to this one +@pytest.fixture +def ip_empty_testing_shell(): + ip_session = TestingShell() + ip_session.register_magics(SqlMagic) + ip_session.register_magics(RenderMagic) + ip_session.register_magics(SqlPlotMagic) + ip_session.register_magics(SqlCmdMagic) + + yield ip_session + Connection.close_all() + + @pytest.fixture def ip_empty(): ip_session = InteractiveShell() @@ -55,6 +78,7 @@ def ip_empty(): ip_session.register_magics(RenderMagic) ip_session.register_magics(SqlPlotMagic) ip_session.register_magics(SqlCmdMagic) + yield ip_session Connection.close_all() diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index 18b47a56f..116bc9053 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -1,5 +1,13 @@ +from unittest.mock import Mock import logging +from sqlalchemy import create_engine +import pandas as pd +import pytest + +from sql.connection import Connection +from sql.run import ResultSet + def test_auto_commit_mode_on(ip_with_duckDB, caplog): with caplog.at_level(logging.DEBUG): @@ -27,3 +35,122 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): # Check the tables is created tables_out = ip_with_duckDB.run_cell("%sql SHOW TABLES;").result assert any("weather" == table[0] for table in tables_out) + + +@pytest.mark.parametrize( + "config", + [ + "%config SqlMagic.autopandas = True", + "%config SqlMagic.autopandas = False", + ], + ids=[ + "autopandas_on", + "autopandas_off", + ], +) +@pytest.mark.parametrize( + "sql, tables", + [ + ["%sql SELECT * FROM weather; SELECT * FROM weather;", ["weather"]], + [ + "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;", + ["weather", "names"], + ], + [ + ( + "%sql CREATE TABLE names (city VARCHAR,);" + "CREATE TABLE more_names (city VARCHAR,);" + "INSERT INTO names VALUES ('NYC');" + "SELECT * FROM names UNION ALL SELECT * FROM more_names;" + ), + ["weather", "names", "more_names"], + ], + ], + ids=[ + "multiple_selects", + "multiple_statements", + "multiple_tables_created", + ], +) +def test_multiple_statements(ip_empty, config, sql, tables): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell(config) + + ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") + ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');") + ip_empty.run_cell("%sql SELECT * FROM weather;") + + out = ip_empty.run_cell(sql) + out_tables = ip_empty.run_cell("%sqlcmd tables") + + assert out.error_in_exec is None + + if config == "%config SqlMagic.autopandas = True": + assert out.result.to_dict() == {"city": {0: "NYC"}} + else: + assert out.result.dict() == {"city": ("NYC",)} + + assert set(tables) == set(r[0] for r in out_tables.result._table.rows) + + +def test_autopandas_when_last_result_is_not_a_select_statement(ip_empty): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell("%config SqlMagic.autopandas=True") + + out = ip_empty.run_cell( + "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);" + ) + + assert out.error_in_exec is None + assert out.result.to_dict() == dict() + + +@pytest.mark.parametrize( + "sql", + [ + ( + "%sql CREATE TABLE a (x INT,); CREATE TABLE b (x INT,); " + "INSERT INTO a VALUES (1,); INSERT INTO b VALUES(2,); " + "SELECT * FROM a UNION ALL SELECT * FROM b;" + ), + """\ +%%sql +CREATE TABLE a (x INT,); +CREATE TABLE b (x INT,); +INSERT INTO a VALUES (1,); +INSERT INTO b VALUES(2,); +SELECT * FROM a UNION ALL SELECT * FROM b; +""", + ], +) +def test_commits_all_statements(ip_empty, sql): + ip_empty.run_cell("%sql duckdb://") + out = ip_empty.run_cell(sql) + assert out.error_in_exec is None + assert out.result.dict() == {"x": (1, 2)} + + +def test_resultset_uses_native_duckdb_df(ip_empty): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);") + + sql = "SELECT * FROM a" + results = engine.execute(sql) + + Connection.set(engine, displaycon=False) + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchone = Mock(side_effect=ValueError("Should not be called")) + results.fetchall = Mock(side_effect=ValueError("Should not be called")) + + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 0 + + result_set = ResultSet(results, mock, sql=sql, engine=engine) + df = result_set.DataFrame() + + assert isinstance(df, pd.DataFrame) + assert df.to_dict() == {"x": {0: 10, 1: 20, 2: 30}} + + results.fetchmany.assert_called_once_with(size=2) diff --git a/src/tests/test_inspect.py b/src/tests/test_inspect.py index 4234d412d..5bf32fbab 100644 --- a/src/tests/test_inspect.py +++ b/src/tests/test_inspect.py @@ -224,20 +224,20 @@ def test_columns_with_missing_values( monkeypatch.setattr(inspect, "_get_inspector", lambda _: mock) ip.run_cell( - """%%sql sqlite:///another.db -CREATE TABLE IF NOT EXISTS another_table (id INT) + """%%sql duckdb:// +CREATE TABLE IF NOT EXISTS test_table (id INT) """ ) ip.run_cell( - """%%sql sqlite:///my.db -CREATE TABLE IF NOT EXISTS test_table (id INT) + """%%sql +CREATE SCHEMA another_schema; """ ) ip.run_cell( """%%sql -ATTACH DATABASE 'another.db' as 'another_schema'; +CREATE TABLE IF NOT EXISTS another_schema.another_table (id INT) """ ) diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 814099e90..f4f3fc81f 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1302,3 +1302,17 @@ def test_interact_and_missing_ipywidgets_installed(ip): "%sql --interact my_variable SELECT * FROM author LIMIT {{my_variable}}" ) assert isinstance(out.error_in_exec, ModuleNotFoundError) + + +def test_generic_driver(ip_empty_testing_shell): + # TODO: add duckdb support + # ip_empty_testing_shell.run_cell("import duckdb") + # ip_empty_testing_shell.run_cell("conn = duckdb.connect(':memory:')") + ip_empty_testing_shell.run_cell("import sqlite3") + ip_empty_testing_shell.run_cell("conn = sqlite3.connect(':memory:')") + + ip_empty_testing_shell.run_cell("%sql conn") + ip_empty_testing_shell.run_cell("%sql CREATE TABLE test (a INTEGER, b INTEGER)") + ip_empty_testing_shell.run_cell("%sql INSERT INTO test VALUES (1, 2)") + result = ip_empty_testing_shell.run_cell("%sql SELECT * FROM test") + assert result.result == [(1, 2)] diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index d9b7ef97b..ea524741b 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -1,11 +1,10 @@ +import sqlite3 import sys import math import pytest from IPython.core.error import UsageError from pathlib import Path -from sqlalchemy import create_engine -from sql.connection import Connection from sql.store import store from sql.inspect import _is_numeric @@ -112,8 +111,15 @@ def test_tables(ip): def test_tables_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + # TODO: why does this fail? + # ip.run_cell( + # """%%sql sqlite:///my.db + # CREATE TABLE numbers (some_number FLOAT) + # """ + # ) + + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql @@ -150,8 +156,8 @@ def test_columns(ip, cmd, cols): def test_columns_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 52c7f338e..96fa90952 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -1,7 +1,8 @@ +from unittest.mock import Mock, call + from sqlalchemy import create_engine -from sql.connection import Connection from pathlib import Path -from unittest.mock import Mock + import pytest import pandas as pd @@ -9,7 +10,6 @@ import sqlalchemy from sql.run import ResultSet -from sql import run as run_module import re @@ -35,7 +35,7 @@ def result(): @pytest.fixture def result_set(result, config): - return ResultSet(result, config).fetch_results() + return ResultSet(result, config) def test_resultset_getitem(result_set): @@ -51,12 +51,6 @@ def test_resultset_dicts(result_set): assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] -def test_resultset_dataframe(result_set, monkeypatch): - monkeypatch.setattr(run_module.Connection, "current", Mock()) - - assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)})) - - def test_resultset_polars_dataframe(result_set, monkeypatch): assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) @@ -71,189 +65,276 @@ def test_resultset_str(result_set): assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" -def test_resultset_repr_html(result_set): - assert result_set._repr_html_() == ( - "\n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n
x
0
1
2
" - ) - - def test_resultset_config_autolimit_dict(result, config): config.autolimit = 1 - resultset = ResultSet(result, config).fetch_results() + resultset = ResultSet(result, config) assert resultset.dict() == {"x": (0,)} -def test_resultset_with_non_sqlalchemy_results(config): - df = pd.DataFrame({"x": range(3)}) # noqa - conn = Connection(engine=create_engine("duckdb://")) - result = conn.execute("SELECT * FROM df") - assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] +def _get_number_of_rows_in_html_table(html): + """ + Returns the number of tags within the section + """ + pattern = r"(.*?)<\/tbody>" + tbody_content = re.findall(pattern, html, re.DOTALL)[0] + row_count = len(re.findall(r"", tbody_content)) + return row_count -def test_none_pretty(config): - conn = Connection(engine=create_engine("sqlite://")) - result = conn.execute("create table some_table (name, age)") - result_set = ResultSet(result, config) - assert result_set.pretty is None - assert "" == str(result_set) +# TODO: add dbapi tests -def test_lazy_loading(result, config): - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == 3 +@pytest.fixture +def results(ip_empty): + engine = create_engine("duckdb://") -@pytest.mark.parametrize( - "autolimit, expected", - [ - (None, 3), - (False, 3), - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_autolimit(result, config, autolimit, expected): - config.autolimit = autolimit - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == expected + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") -@pytest.mark.parametrize( - "displaylimit, expected", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit(result, config, displaylimit, expected): - config.displaylimit = displaylimit - result_set = ResultSet(result, config) + sql = "SELECT * FROM a" + results = engine.execute(sql) - assert len(result_set._results) == 0 - result_set.fetch_results() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - assert row_count == expected + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) + # results.fetchone = Mock(side_effect=ValueError("fetchone called")) -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit_fetch_all( - result, config, displaylimit, expected_display -): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) + yield results - # Initialize result_set without fetching results - assert len(result_set._results) == 0 - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe(ip_empty, uri): + engine = create_engine(uri) - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 - assert len(result_set._results) == expected_results - assert row_count == expected_display + results = engine.execute("CREATE TABLE a (x INT);") + # results = engine.execute("CREATE TABLE test (n INT, name TEXT)") - # Fetch the the rest results, but don't display them in the table - result_set.fetch_results(fetch_all=True) + rs = ResultSet(results, mock) + df = rs.DataFrame() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) + assert df.to_dict() == {} - assert len(result_set._results) == max_results_count - assert row_count == expected_display +def test_convert_to_dataframe_2(ip_empty): + engine = create_engine("duckdb://") + + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + + engine.execute("CREATE TABLE a (x INT,);") + results = engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + rs = ResultSet(results, mock) + df = rs.DataFrame() + + assert df.to_dict() == {"Count": {0: 5}} + + +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe_3(ip_empty, uri): + engine = create_engine(uri) + + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + + engine.execute("CREATE TABLE a (x INT);") + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + results = engine.execute("SELECT * FROM a") + # from ipdb import set_trace + + # set_trace() + rs = ResultSet(results, mock, sql="SELECT * FROM a", engine=engine) + df = rs.DataFrame() + + # TODO: check native duckdb was called if using duckb + assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}} -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_list(result, config, displaylimit, expected_display): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) - # Initialize result_set without fetching results - assert len(result_set._results) == 0 +def test_done_fetching_if_reached_autolimit(results): + mock = Mock() + mock.autolimit = 2 + mock.displaylimit = 100 - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() + rs = ResultSet(results, mock) - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) + assert rs.did_finish_fetching() is True - assert len(result_set._results) == expected_results - assert len(list(result_set)) == max_results_count + +def test_done_fetching_if_reached_autolimit_2(results): + mock = Mock() + mock.autolimit = 4 + mock.displaylimit = 100 + + rs = ResultSet(results, mock) + list(rs) + + assert rs.did_finish_fetching() is True + + +@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"]) +@pytest.mark.parametrize("autolimit", [1000_000, 0]) +def test_no_displaylimit(results, method, autolimit): + mock = Mock() + mock.displaylimit = 0 + mock.autolimit = autolimit + + rs = ResultSet(results, mock) + getattr(rs, method)() + + assert rs._results == [(1,), (2,), (3,), (4,), (5,)] + assert rs.did_finish_fetching() is True + + +def test_no_fetching_if_one_result(): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (1);") + + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 1000_000 + + results = engine.execute("SELECT * FROM a") + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) + + rs = ResultSet(results, mock) + + assert rs.did_finish_fetching() is True + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + str(rs) + list(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + +# TODO: try with other values, and also change the display limit and re-run the repr +# TODO: try with __repr__ and __str__ +def test_resultset_fetches_required_rows(results): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 + + ResultSet(results, mock) + # rs.fetch_results() + # rs._repr_html_() + # str(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + +def test_fetches_remaining_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 + + rs = ResultSet(results, mock) + + # this will trigger fetching two + str(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + # this will trigger fetching the rest + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] + + results.fetchall.assert_called_once_with() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + rs.sqlaproxy.fetchmany = Mock(side_effect=ValueError("fetchmany called")) + rs.sqlaproxy.fetchall = Mock(side_effect=ValueError("fetchall called")) + rs.sqlaproxy.fetchone = Mock(side_effect=ValueError("fetchone called")) + + # this should not trigger any more fetching + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] @pytest.mark.parametrize( - "autolimit, expected_results", + "method, repr_", [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), + [ + "__repr__", + "+---+\n| x |\n+---+\n| 1 |\n| 2 |\n| 3 |\n+---+", + ], + [ + "_repr_html_", + "\n \n \n \n " + "\n \n \n \n " + "\n \n \n " + "\n \n \n " + "\n \n \n
x
1
2
3
", + ], ], ) -def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): - config.autolimit = autolimit - result_set = ResultSet(result, config) - assert len(result_set._results) == 0 +def test_resultset_fetches_required_rows_repr_html(results, method, repr_): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 - result_set.fetch_results() + rs = ResultSet(results, mock) + rs_repr = getattr(rs, method)() - assert len(result_set._results) == expected_results - assert len(list(result_set)) == expected_results + assert repr_ in rs_repr + assert rs.did_finish_fetching() is False + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2), call(size=1)]) + results.fetchone.assert_not_called() -def _get_number_of_rows_in_html_table(html): - """ - Returns the number of tags within the section - """ - pattern = r"(.*?)<\/tbody>" - tbody_content = re.findall(pattern, html, re.DOTALL)[0] - row_count = len(re.findall(r"", tbody_content)) +def test_resultset_fetches_no_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 - return row_count + ResultSet(results, mock) + + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() + + +def test_resultset_autolimit_one(results): + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 1 + + rs = ResultSet(results, mock) + repr(rs) + str(rs) + rs._repr_html_() + list(rs) + + results.fetchmany.assert_has_calls([call(size=1)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() + + +# TODO: try with more values of displaylimit +# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows +def test_displaylimit_message(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 0 + + rs = ResultSet(results, mock) + + assert "Truncated to displaylimit of 1" in rs._repr_html_()