From 4df42c696695e063ac7abe0d28743eb4334c4faa Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Tue, 27 Jun 2023 10:50:36 -0600 Subject: [PATCH 01/16] Error when executing multiple SQL statements when using DuckDB with `autopandas` on --- CHANGELOG.md | 1 + src/sql/run.py | 6 +++--- src/tests/integration/test_duckDB.py | 9 +++++++++ 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 60ab95b2d..b9ec5a724 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) * [Fix] Refactored `ResultSet` to lazy loading (#470) +* [Fix] Erro when executing multiple SQL statements when using DuckDB with `autopandas` on ## 0.7.9 (2023-06-19) diff --git a/src/sql/run.py b/src/sql/run.py index fb8bcaff3..dc93df57d 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -612,8 +612,8 @@ def run(conn, sql, config): statement = sqlalchemy.sql.text(statement) if duckdb_autopandas: - conn = conn.engine.raw_connection() - cursor = conn.cursor() + conn_duckdb_raw = conn.engine.raw_connection() + cursor = conn_duckdb_raw.cursor() cursor.execute(str(statement)) else: @@ -627,7 +627,7 @@ def run(conn, sql, config): # bypass ResultSet and use duckdb's native method to return a pandas data frame if duckdb_autopandas: df = cursor.df() - conn.close() + conn_duckdb_raw.close() return df else: resultset = ResultSet(result, config) diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index 18b47a56f..90ee5620a 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -27,3 +27,12 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): # Check the tables is created tables_out = ip_with_duckDB.run_cell("%sql SHOW TABLES;").result assert any("weather" == table[0] for table in tables_out) + + +def test_autopandas_with_multiple_statements(ip_with_duckDB): + ip_with_duckDB.run_cell("%config SqlMagic.autopandas=True") + out = ip_with_duckDB.run_cell( + "%sql SELECT * FROM weather LIMIT 3; SELECT * FROM weather LIMIT 3;" + ) + + assert out.error_in_exec is None From 17799b219ca377c94b54e2a8954be2c2ea1e2c23 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Tue, 27 Jun 2023 10:54:22 -0600 Subject: [PATCH 02/16] lint --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9ec5a724..fa60a0fb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) * [Fix] Refactored `ResultSet` to lazy loading (#470) -* [Fix] Erro when executing multiple SQL statements when using DuckDB with `autopandas` on +* [Fix] Error when executing multiple SQL statements when using DuckDB with `autopandas` on ## 0.7.9 (2023-06-19) From 585f125101c2cf143f050a2c36fe6c0bb919d3ee Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Tue, 27 Jun 2023 11:06:27 -0600 Subject: [PATCH 03/16] fix --- src/tests/integration/test_duckDB.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index 90ee5620a..fb61b2053 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -31,8 +31,7 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): def test_autopandas_with_multiple_statements(ip_with_duckDB): ip_with_duckDB.run_cell("%config SqlMagic.autopandas=True") - out = ip_with_duckDB.run_cell( - "%sql SELECT * FROM weather LIMIT 3; SELECT * FROM weather LIMIT 3;" - ) + ip_with_duckDB.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") + out = ip_with_duckDB.run_cell("%sql SELECT * FROM weather; SELECT * FROM weather;") assert out.error_in_exec is None From 36e106b5915c6348c0d8e007e12b5eefcfedf56f Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Tue, 27 Jun 2023 11:30:44 -0600 Subject: [PATCH 04/16] fix --- src/tests/integration/test_duckDB.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index fb61b2053..d869873bc 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -29,9 +29,10 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): assert any("weather" == table[0] for table in tables_out) -def test_autopandas_with_multiple_statements(ip_with_duckDB): - ip_with_duckDB.run_cell("%config SqlMagic.autopandas=True") - ip_with_duckDB.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") - out = ip_with_duckDB.run_cell("%sql SELECT * FROM weather; SELECT * FROM weather;") +def test_autopandas_with_multiple_statements(ip_empty): + ip_empty.run_cell("%sql duckdb://") + # ip_empty.run_cell("%config SqlMagic.autopandas=True") + ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") + out = ip_empty.run_cell("%sql SELECT * FROM weather; SELECT * FROM weather;") assert out.error_in_exec is None From 57b6fa7bb4a931ccd1c464ae3e6011308a3ba894 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Tue, 27 Jun 2023 21:52:31 -0600 Subject: [PATCH 05/16] more testing, fixes test case --- src/sql/run.py | 16 ++++++++++---- src/tests/integration/test_duckDB.py | 32 +++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 7 deletions(-) diff --git a/src/sql/run.py b/src/sql/run.py index dc93df57d..5b7d5a595 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -590,9 +590,13 @@ def run(conn, sql, config): # returning only when sql is empty string return "Connected: %s" % conn.name - for statement in sqlparse.split(sql): - first_word = sql.strip().split()[0].lower() + statements = sqlparse.split(sql) + last_statement_is_select = _first_word(statements[-1]) + + for index, statement in enumerate(statements): + first_word = _first_word(statement) manual_commit = False + is_last_statement = index == len(statements) - 1 # attempting to run a transaction if first_word == "begin": @@ -611,7 +615,7 @@ def run(conn, sql, config): if not is_custom_connection: statement = sqlalchemy.sql.text(statement) - if duckdb_autopandas: + if duckdb_autopandas and last_statement_is_select and is_last_statement: conn_duckdb_raw = conn.engine.raw_connection() cursor = conn_duckdb_raw.cursor() cursor.execute(str(statement)) @@ -625,7 +629,7 @@ def run(conn, sql, config): display_affected_rowcount(result.rowcount) # bypass ResultSet and use duckdb's native method to return a pandas data frame - if duckdb_autopandas: + if duckdb_autopandas and last_statement_is_select: df = cursor.df() conn_duckdb_raw.close() return df @@ -637,6 +641,10 @@ def run(conn, sql, config): return select_df_type(resultset, config) +def _first_word(sql): + return sql.strip().split()[0].lower() + + def raw_run(conn, sql): return conn.session.execute(sqlalchemy.sql.text(sql)) diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index d869873bc..e5fd38b2e 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -1,5 +1,7 @@ import logging +import pytest + def test_auto_commit_mode_on(ip_with_duckDB, caplog): with caplog.at_level(logging.DEBUG): @@ -29,10 +31,34 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): assert any("weather" == table[0] for table in tables_out) -def test_autopandas_with_multiple_statements(ip_empty): +@pytest.mark.parametrize( + "config", + [ + "SqlMagic.autopandas=True", + "SqlMagic.autopandas=False", + ], + ids=[ + "autopandas_on", + "autopandas_off", + ], +) +@pytest.mark.parametrize( + "sql", + [ + "%sql SELECT * FROM weather; SELECT * FROM weather;", + "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;", + ], + ids=[ + "multiple_selects", + "multiple_statements", + ], +) +def test_multiple_statements(ip_empty, config, sql): ip_empty.run_cell("%sql duckdb://") - # ip_empty.run_cell("%config SqlMagic.autopandas=True") + ip_empty.run_cell(config) ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") - out = ip_empty.run_cell("%sql SELECT * FROM weather; SELECT * FROM weather;") + ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');") + out = ip_empty.run_cell(sql) assert out.error_in_exec is None + assert out.result.dict() == {"city": ("NYC",)} From 345c922fd5755d5e69ae84c23363c5e3df73598b Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Wed, 28 Jun 2023 07:31:27 -0600 Subject: [PATCH 06/16] adds missing issue number --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa60a0fb5..71af0aefd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) * [Fix] Refactored `ResultSet` to lazy loading (#470) -* [Fix] Error when executing multiple SQL statements when using DuckDB with `autopandas` on +* [Fix] Error when executing multiple SQL statements when using DuckDB with `autopandas` on (#674) ## 0.7.9 (2023-06-19) From 400bfd025f711b9d01dd608eec9b7c221031a76b Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Wed, 28 Jun 2023 23:43:32 -0600 Subject: [PATCH 07/16] wip --- src/sql/connection.py | 2 + src/sql/inspect.py | 2 +- src/sql/run.py | 281 ++++++++---- src/tests/conftest.py | 14 + src/tests/integration/test_duckDB.py | 116 ++++- src/tests/test_resultset.py | 619 +++++++++++++++++++-------- 6 files changed, 752 insertions(+), 282 deletions(-) diff --git a/src/sql/connection.py b/src/sql/connection.py index d5435055b..58352f797 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -203,6 +203,8 @@ def __init__(self, engine, alias=None): ) self.dialect = engine.url.get_dialect() + + # TODO: delete this! self.session = self._create_session(engine, self.url) self.connections[alias or self.url] = self diff --git a/src/sql/inspect.py b/src/sql/inspect.py index f4455370c..1d7cfc103 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -18,7 +18,7 @@ def _get_inspector(conn): if not Connection.current: raise exceptions.RuntimeError("No active connection") else: - return inspect(Connection.current.session) + return inspect(Connection.current.engine) class DatabaseInspection: diff --git a/src/sql/run.py b/src/sql/run.py index 5b7d5a595..acaab205b 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -9,6 +9,7 @@ import prettytable import sqlalchemy +from sqlalchemy.exc import ResourceClosedError import sqlparse from sql.connection import Connection from sql import exceptions, display @@ -109,27 +110,88 @@ class ResultSet(ColumnGuesserMixin): Can access rows listwise, or by string value of leftmost column. """ - def __init__(self, sqlaproxy, config): + def __init__(self, sqlaproxy, config, sql=None, engine=None): self.config = config - self.keys = {} - self._results = [] self.truncated = False self.sqlaproxy = sqlaproxy + self.sql = sql + self.engine = engine + + self._keys = None + self._field_names = None + self._results = [] # https://peps.python.org/pep-0249/#description self.is_dbapi_results = hasattr(sqlaproxy, "description") - self.pretty = None + # NOTE: this will trigger key fetching + self.pretty_table = self._init_table() + + self._done_fetching = False + + if self.config.autolimit == 1: + self.fetchmany(size=1) + self.did_finish_fetching() + # EXPLAIN WHY WE NEED TWO + else: + self.fetchmany(size=2) + + def extend_results(self, elements): + self._results.extend(elements) + + # NOTE: we shouold use add_rows but there is a subclass that behaves weird + # so using add_row for now + for e in elements: + self.pretty_table.add_row(e) + + def done_fetching(self): + self._done_fetching = True + self.sqlaproxy.close() + + def did_finish_fetching(self): + return self._done_fetching + + # NOTE: this triggers key fetching + @property + def field_names(self): + if self._field_names is None: + self._field_names = unduplicate_field_names(self.keys) + + return self._field_names + + # NOTE: this triggers key fetching + @property + def keys(self): + if self._keys is not None: + return self._keys + + if not self.is_dbapi_results: + try: + self._keys = self.sqlaproxy.keys() + # sqlite raises this error when running a script that doesn't return rows + # e.g, 'CREATE TABLE' but others don't (e.g., duckdb) + except ResourceClosedError: + self._keys = [] + return self._keys + + elif isinstance(self.sqlaproxy.description, Iterable): + self._keys = [i[0] for i in self.sqlaproxy.description] + else: + self._keys = [] + + return self._keys def _repr_html_(self): + self.fetch_for_repr_if_needed() + _cell_with_spaces_pattern = re.compile(r"()( {2,})") - if self.pretty: - self.pretty.add_rows(self) - result = self.pretty.get_html_string() + if self.pretty_table: + self.pretty_table.add_rows(self) + result = self.pretty_table.get_html_string() # to create clickable links result = html.unescape(result) result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) - if self.truncated: + if self.config.displaylimit != 0: HTML = ( '%s\n' "Truncated to displaylimit of %d" @@ -139,7 +201,7 @@ def _repr_html_(self): 'displaylimit' # noqa: E501 " configuration" ) - result = HTML % (result, self.pretty.row_count) + result = HTML % (result, self.config.displaylimit) return result else: return None @@ -148,18 +210,18 @@ def __len__(self): return len(self._results) def __iter__(self): - results = self._fetch_query_results(fetch_all=True) + self.fetchall() - for result in results: + for result in self._results: yield result def __str__(self, *arg, **kwarg): - if self.pretty: - self.pretty.add_rows(self) - return str(self.pretty or "") + self.fetch_for_repr_if_needed() + return str(self.pretty_table) def __repr__(self) -> str: - return str(self) + self.fetch_for_repr_if_needed() + return str(self.pretty_table) def __eq__(self, another: object) -> bool: return self._results == another @@ -195,10 +257,36 @@ def DataFrame(self, payload): "Returns a Pandas DataFrame instance built from the result set." import pandas as pd + # TODO: re-add + # payload[ + # "connection_info" + # ] = Connection.current._get_curr_sqlalchemy_connection_info() + + # TODO: test with generic db connection + self.sqlaproxy.dialect.name + + # list(self.engine.execute(self.sql)) + + # if not self.has_more_results: + # return pd.DataFrame() + + if self.did_finish_fetching(): + return pd.DataFrame(self, columns=(self and self.keys) or []) + + # only run when using duckdb + if self.engine.dialect.name == "duckdb": + # why do we need this? + self.sqlaproxy.close() + + conn_duckdb_raw = self.engine.raw_connection() + cursor = conn_duckdb_raw.cursor() + cursor.execute(str(self.sql)) + df = cursor.df() + conn_duckdb_raw.close() + return df + frame = pd.DataFrame(self, columns=(self and self.keys) or []) - payload[ - "connection_info" - ] = Connection.current._get_curr_sqlalchemy_connection_info() + return frame @telemetry.log_call("polars-data-frame") @@ -320,9 +408,9 @@ def bar(self, key_word_sep=" ", title=None, **kwargs): def csv(self, filename=None, **format_params): """Generate results in comma-separated form. Write to ``filename`` if given. Any other parameters will be passed on to csv.writer.""" - if not self.pretty: + if not self.pretty_table: return None # no results - self.pretty.add_rows(self) + self.pretty_table.add_rows(self) if filename: encoding = format_params.get("encoding", "utf-8") outfile = open(filename, "w", newline="", encoding=encoding) @@ -339,47 +427,56 @@ def csv(self, filename=None, **format_params): else: return outfile.getvalue() - def fetch_results(self, fetch_all=False): - """ - Returns a limited representation of the query results. + def fetchmany(self, size): + """Fetch n results and add it to the results""" + if not self.did_finish_fetching(): + try: + returned = self.sqlaproxy.fetchmany(size=size) + # sqlite raises this error when running a script that doesn't return rows + # e.g, 'CREATE TABLE' but others don't (e.g., duckdb) + except ResourceClosedError: + self.done_fetching() + return - Parameters - ---------- - fetch_all : bool default False - Return all query rows - """ - is_dbapi_results = self.is_dbapi_results - sqlaproxy = self.sqlaproxy - config = self.config + self.extend_results(returned) - if is_dbapi_results: - should_try_fetch_results = True - else: - should_try_fetch_results = sqlaproxy.returns_rows - - if should_try_fetch_results: - # sql alchemy results - if not is_dbapi_results: - self.keys = sqlaproxy.keys() - elif isinstance(sqlaproxy.description, Iterable): - self.keys = [i[0] for i in sqlaproxy.description] - else: - self.keys = [] + # if "SELECT * FROM number_table" in str(self.sql): + # from ipdb import set_trace + + # set_trace() - if len(self.keys) > 0: - self._results = self._fetch_query_results(fetch_all=fetch_all) + if len(returned) < size: + self.done_fetching() - self.field_names = unduplicate_field_names(self.keys) + if ( + self.config.autolimit is not None + and self.config.autolimit != 0 + and len(self._results) >= self.config.autolimit + ): + self.done_fetching() - _style = None + def fetch_for_repr_if_needed(self): + if self.config.displaylimit == 0: + self.fetchall() - self.pretty = PrettyTable(self.field_names) + missing = self.config.displaylimit - len(self._results) - if isinstance(config.style, str): - _style = prettytable.__dict__[config.style.upper()] - self.pretty.set_style(_style) + if missing > 0: + self.fetchmany(missing) - return self + def fetchall(self): + if not self.did_finish_fetching(): + self.extend_results(self.sqlaproxy.fetchall()) + self.done_fetching() + + def _init_table(self): + pretty = PrettyTable(self.field_names) + + if isinstance(self.config.style, str): + _style = prettytable.__dict__[self.config.style.upper()] + pretty.set_style(_style) + + return pretty def _fetch_query_results(self, fetch_all=False): """ @@ -390,19 +487,19 @@ def _fetch_query_results(self, fetch_all=False): fetch_all : bool default False Return all query rows """ - sqlaproxy = self.sqlaproxy - config = self.config - _should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed") + _should_try_lazy_fetch = hasattr(self.sqlaproxy, "_soft_closed") _should_fetch_all = ( - (config.displaylimit == 0 or not config.displaylimit) + (self.config.displaylimit == 0 or not self.config.displaylimit) or fetch_all or not _should_try_lazy_fetch ) - is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0 + is_autolimit = ( + isinstance(self.config.autolimit, int) and self.config.autolimit > 0 + ) is_connection_closed = ( - sqlaproxy._soft_closed if _should_try_lazy_fetch else False + self.sqlaproxy._soft_closed if _should_try_lazy_fetch else False ) should_return_results = is_connection_closed or ( @@ -416,25 +513,31 @@ def _fetch_query_results(self, fetch_all=False): results = self._results else: if is_autolimit: - results = sqlaproxy.fetchmany(size=config.autolimit) + results = self.sqlaproxy.fetchmany(size=self.config.autolimit) else: if _should_fetch_all: - all_results = sqlaproxy.fetchall() + all_results = self.sqlaproxy.fetchall() results = self._results + all_results self._results = results else: - results = sqlaproxy.fetchmany(size=config.displaylimit) + # operational errors are silenced! + try: + results = self.sqlaproxy.fetchmany( + size=self.config.displaylimit + ) + except Exception as e: + raise RuntimeError("Could not fetch from database") from e if _should_try_lazy_fetch: # Try to fetch an extra row to find out # if there are more results to fetch - row = sqlaproxy.fetchone() + row = self.sqlaproxy.fetchone() if row is not None: results += [row] # Check if we have more rows to show - if config.displaylimit > 0: - self.truncated = len(results) > config.displaylimit + if self.config.displaylimit > 0: + self.truncated = len(results) > self.config.displaylimit return results @@ -507,6 +610,7 @@ def _commit(conn, config, manual_commit): with Session(conn.session) as session: session.commit() except sqlalchemy.exc.OperationalError: + # TODO: missing rollback here? print("The database does not support the COMMIT command") @@ -582,21 +686,16 @@ def run(conn, sql, config): config Configuration object """ - info = conn._get_curr_sqlalchemy_connection_info() - - duckdb_autopandas = info and info.get("dialect") == "duckdb" and config.autopandas - if not sql.strip(): # returning only when sql is empty string return "Connected: %s" % conn.name statements = sqlparse.split(sql) - last_statement_is_select = _first_word(statements[-1]) for index, statement in enumerate(statements): first_word = _first_word(statement) - manual_commit = False - is_last_statement = index == len(statements) - 1 + is_select = first_word == "select" + # manual_commit = False # attempting to run a transaction if first_word == "begin": @@ -608,37 +707,31 @@ def run(conn, sql, config): # regular query else: - manual_commit = set_autocommit(conn, config) + # TODO: re add commmit feature + # _commit(conn=conn, config=config, manual_commit=manual_commit) + # manual_commit = set_autocommit(conn, config) is_custom_connection = Connection.is_custom_connection(conn) # if regular sqlalchemy, pass a text object if not is_custom_connection: statement = sqlalchemy.sql.text(statement) - if duckdb_autopandas and last_statement_is_select and is_last_statement: - conn_duckdb_raw = conn.engine.raw_connection() - cursor = conn_duckdb_raw.cursor() - cursor.execute(str(statement)) + if not is_select: + with Session(conn.engine, expire_on_commit=False) as session: + result = session.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) + session.commit() else: - result = conn.session.execute(statement) - _commit(conn=conn, config=config, manual_commit=manual_commit) - - if result and config.feedback: - if hasattr(result, "rowcount"): - display_affected_rowcount(result.rowcount) - - # bypass ResultSet and use duckdb's native method to return a pandas data frame - if duckdb_autopandas and last_statement_is_select: - df = cursor.df() - conn_duckdb_raw.close() - return df - else: - resultset = ResultSet(result, config) + result = conn.engine.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) + + if result and config.feedback: + if hasattr(result, "rowcount"): + display_affected_rowcount(result.rowcount) - # lazy load - resultset.fetch_results() - return select_df_type(resultset, config) + # resultset = ResultSet(result, config, statement, conn.engine) + return select_df_type(resultset, config) def _first_word(sql): diff --git a/src/tests/conftest.py b/src/tests/conftest.py index bd2422e4f..5bceaaad8 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -48,6 +48,19 @@ def clean_conns(): yield +# if we enable it, we'll have to update tests! +# because they expect the error not to be raised +class TestingShell(InteractiveShell): + def run_cell(self, *args, **kwargs): + result = super().run_cell(*args, **kwargs) + + if result.error_in_exec is not None: + # raise RuntimeError("a") from result.error_in_exec + raise result.error_in_exec + + return result + + @pytest.fixture def ip_empty(): ip_session = InteractiveShell() @@ -55,6 +68,7 @@ def ip_empty(): ip_session.register_magics(RenderMagic) ip_session.register_magics(SqlPlotMagic) ip_session.register_magics(SqlCmdMagic) + yield ip_session Connection.close_all() diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index e5fd38b2e..dff8bbd1a 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -1,7 +1,12 @@ +from unittest.mock import Mock import logging +import pandas as pd import pytest +from sql import connection +from sql.run import ResultSet + def test_auto_commit_mode_on(ip_with_duckDB, caplog): with caplog.at_level(logging.DEBUG): @@ -34,8 +39,8 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): @pytest.mark.parametrize( "config", [ - "SqlMagic.autopandas=True", - "SqlMagic.autopandas=False", + "%config SqlMagic.autopandas = True", + "%config SqlMagic.autopandas = False", ], ids=[ "autopandas_on", @@ -43,22 +48,119 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog): ], ) @pytest.mark.parametrize( - "sql", + "sql, tables", [ - "%sql SELECT * FROM weather; SELECT * FROM weather;", - "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;", + ["%sql SELECT * FROM weather; SELECT * FROM weather;", ["weather"]], + [ + "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;", + ["weather", "names"], + ], + [ + ( + "%sql CREATE TABLE names (city VARCHAR,);" + "CREATE TABLE more_names (city VARCHAR,);" + "INSERT INTO names VALUES ('NYC');" + "SELECT * FROM names UNION ALL SELECT * FROM more_names;" + ), + ["weather", "names", "more_names"], + ], ], ids=[ "multiple_selects", "multiple_statements", + "multiple_tables_created", ], ) -def test_multiple_statements(ip_empty, config, sql): +def test_multiple_statements(ip_empty, config, sql, tables): ip_empty.run_cell("%sql duckdb://") ip_empty.run_cell(config) + ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);") ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');") + ip_empty.run_cell("%sql SELECT * FROM weather;") + out = ip_empty.run_cell(sql) + out_tables = ip_empty.run_cell("%sqlcmd tables") + + assert out.error_in_exec is None + + if config == "%config SqlMagic.autopandas = True": + assert out.result.to_dict() == {"city": {0: "NYC"}} + else: + assert out.result.dict() == {"city": ("NYC",)} + + assert set(tables) == set(r[0] for r in out_tables.result._table.rows) + +def test_dataframe_returned_only_if_last_statement_is_select(ip_empty): + ip_empty.run_cell("%sql duckdb://") + ip_empty.run_cell("%config SqlMagic.autopandas=True") + connection.Connection.connections["duckdb://"].engine.raw_connection = Mock( + side_effect=ValueError("some error") + ) + + out = ip_empty.run_cell( + "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);" + ) + + assert out.error_in_exec is None + + +@pytest.mark.parametrize( + "sql", + [ + ( + "%sql CREATE TABLE a (x INT,); CREATE TABLE b (x INT,); " + "INSERT INTO a VALUES (1,); INSERT INTO b VALUES(2,); " + "SELECT * FROM a UNION ALL SELECT * FROM b;" + ), + """\ +%%sql +CREATE TABLE a (x INT,); +CREATE TABLE b (x INT,); +INSERT INTO a VALUES (1,); +INSERT INTO b VALUES(2,); +SELECT * FROM a UNION ALL SELECT * FROM b; +""", + ], +) +def test_commits_all_statements(ip_empty, sql): + ip_empty.run_cell("%sql duckdb://") + out = ip_empty.run_cell(sql) assert out.error_in_exec is None - assert out.result.dict() == {"city": ("NYC",)} + assert out.result.dict() == {"x": (1, 2)} + + +def test_resultset_uses_native_duckdb_df(ip_empty): + from sqlalchemy import create_engine + from sql.connection import Connection + + engine = create_engine("duckdb://") + + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);") + + sql = "SELECT * FROM a" + + # this breaks if there's an open results set + engine.execute(sql).fetchall() + + results = engine.execute(sql) + + Connection.set(engine, displaycon=False) + + results.fetchmany = Mock(wraps=results.fetchmany) + + mock = Mock() + mock.displaylimit = 1 + + result_set = ResultSet(results, mock, sql=sql, engine=engine) + + result_set.fetch_results() + + df = result_set.DataFrame() + + assert isinstance(df, pd.DataFrame) + assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3}} + + results.fetchmany.assert_called_once_with(size=1) diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 52c7f338e..922a4b2c1 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -1,3 +1,5 @@ +from unittest.mock import Mock, call + from sqlalchemy import create_engine from sql.connection import Connection from pathlib import Path @@ -14,246 +16,503 @@ import re -@pytest.fixture -def config(): - config = Mock() - config.displaylimit = 5 - config.autolimit = 100 - return config +# @pytest.fixture +# def config(): +# config = Mock() +# config.displaylimit = 5 +# config.autolimit = 100 +# return config -@pytest.fixture -def result(): - df = pd.DataFrame({"x": range(3)}) # noqa - engine = sqlalchemy.create_engine("duckdb://") +# @pytest.fixture +# def result(): +# df = pd.DataFrame({"x": range(3)}) # noqa +# engine = sqlalchemy.create_engine("duckdb://") + +# conn = engine.connect() +# result = conn.execute(sqlalchemy.text("select * from df")) +# yield result +# conn.close() + + +# @pytest.fixture +# def result_set(result, config): +# return ResultSet(result, config).fetch_results() + + +# def test_resultset_getitem(result_set): +# assert result_set[0] == (0,) +# assert result_set[0:2] == [(0,), (1,)] + + +# def test_resultset_dict(result_set): +# assert result_set.dict() == {"x": (0, 1, 2)} + + +# def test_resultset_dicts(result_set): +# assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] + + +# def test_resultset_dataframe(result_set, monkeypatch): +# monkeypatch.setattr(run_module.Connection, "current", Mock()) + +# assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)})) + + +# def test_resultset_polars_dataframe(result_set, monkeypatch): +# assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) + + +# def test_resultset_csv(result_set, tmp_empty): +# result_set.csv("file.csv") + +# assert Path("file.csv").read_text() == "x\n0\n1\n2\n" + + +# def test_resultset_str(result_set): +# assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" + + +# def test_resultset_repr_html(result_set): +# assert result_set._repr_html_() == ( +# "\n \n \n " +# "\n \n \n \n " +# "\n \n \n \n " +# "\n \n \n \n " +# "\n \n
x
0
1
2
" +# ) + + +# def test_resultset_config_autolimit_dict(result, config): +# config.autolimit = 1 +# resultset = ResultSet(result, config).fetch_results() +# assert resultset.dict() == {"x": (0,)} + + +# def test_resultset_with_non_sqlalchemy_results(config): +# df = pd.DataFrame({"x": range(3)}) # noqa +# conn = Connection(engine=create_engine("duckdb://")) +# result = conn.execute("SELECT * FROM df") +# assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] + + +# def test_none_pretty(config): +# conn = Connection(engine=create_engine("sqlite://")) +# result = conn.execute("create table some_table (name, age)") +# result_set = ResultSet(result, config) +# assert result_set.pretty is None +# assert "" == str(result_set) + + +# def test_lazy_loading(result, config): +# resultset = ResultSet(result, config) +# assert len(resultset._results) == 0 +# resultset.fetch_results() +# assert len(resultset._results) == 3 + + +# @pytest.mark.parametrize( +# "autolimit, expected", +# [ +# (None, 3), +# (False, 3), +# (0, 3), +# (1, 1), +# (2, 2), +# (3, 3), +# (4, 3), +# ], +# ) +# def test_lazy_loading_autolimit(result, config, autolimit, expected): +# config.autolimit = autolimit +# resultset = ResultSet(result, config) +# assert len(resultset._results) == 0 +# resultset.fetch_results() +# assert len(resultset._results) == expected + - conn = engine.connect() - result = conn.execute(sqlalchemy.text("select * from df")) - yield result - conn.close() +# @pytest.mark.parametrize( +# "displaylimit, expected", +# [ +# (0, 3), +# (1, 1), +# (2, 2), +# (3, 3), +# (4, 3), +# ], +# ) +# def test_lazy_loading_displaylimit(result, config, displaylimit, expected): +# config.displaylimit = displaylimit +# result_set = ResultSet(result, config) + +# assert len(result_set._results) == 0 +# result_set.fetch_results() +# html = result_set._repr_html_() +# row_count = _get_number_of_rows_in_html_table(html) +# assert row_count == expected + + +# @pytest.mark.parametrize( +# "displaylimit, expected_display", +# [ +# (0, 3), +# (1, 1), +# (2, 2), +# (3, 3), +# (4, 3), +# ], +# ) +# def test_lazy_loading_displaylimit_fetch_all( +# result, config, displaylimit, expected_display +# ): +# max_results_count = 3 +# config.autolimit = False +# config.displaylimit = displaylimit +# result_set = ResultSet(result, config) + +# # Initialize result_set without fetching results +# assert len(result_set._results) == 0 + +# # Fetch the min number of rows (based on configuration) +# result_set.fetch_results() + +# html = result_set._repr_html_() +# row_count = _get_number_of_rows_in_html_table(html) +# expected_results = ( +# max_results_count +# if expected_display + 1 >= max_results_count +# else expected_display + 1 +# ) + +# assert len(result_set._results) == expected_results +# assert row_count == expected_display + +# # Fetch the the rest results, but don't display them in the table +# result_set.fetch_results(fetch_all=True) + +# html = result_set._repr_html_() +# row_count = _get_number_of_rows_in_html_table(html) + +# assert len(result_set._results) == max_results_count +# assert row_count == expected_display + + +# @pytest.mark.parametrize( +# "displaylimit, expected_display", +# [ +# (0, 3), +# (1, 1), +# (2, 2), +# (3, 3), +# (4, 3), +# ], +# ) +# def test_lazy_loading_list(result, config, displaylimit, expected_display): +# max_results_count = 3 +# config.autolimit = False +# config.displaylimit = displaylimit +# result_set = ResultSet(result, config) + +# # Initialize result_set without fetching results +# assert len(result_set._results) == 0 + +# # Fetch the min number of rows (based on configuration) +# result_set.fetch_results() + +# expected_results = ( +# max_results_count +# if expected_display + 1 >= max_results_count +# else expected_display + 1 +# ) + +# assert len(result_set._results) == expected_results +# assert len(list(result_set)) == max_results_count + + +# @pytest.mark.parametrize( +# "autolimit, expected_results", +# [ +# (0, 3), +# (1, 1), +# (2, 2), +# (3, 3), +# (4, 3), +# ], +# ) +# def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): +# config.autolimit = autolimit +# result_set = ResultSet(result, config) +# assert len(result_set._results) == 0 + +# result_set.fetch_results() + +# assert len(result_set._results) == expected_results +# assert len(list(result_set)) == expected_results + + +# def _get_number_of_rows_in_html_table(html): +# """ +# Returns the number of tags within the section +# """ +# pattern = r"(.*?)<\/tbody>" +# tbody_content = re.findall(pattern, html, re.DOTALL)[0] +# row_count = len(re.findall(r"", tbody_content)) + +# return row_count + +# TODO: add dbapi tests @pytest.fixture -def result_set(result, config): - return ResultSet(result, config).fetch_results() +def results(ip_empty): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") -def test_resultset_getitem(result_set): - assert result_set[0] == (0,) - assert result_set[0:2] == [(0,), (1,)] + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + sql = "SELECT * FROM a" + results = engine.execute(sql) -def test_resultset_dict(result_set): - assert result_set.dict() == {"x": (0, 1, 2)} + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) + # results.fetchone = Mock(side_effect=ValueError("fetchone called")) -def test_resultset_dicts(result_set): - assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] + yield results -def test_resultset_dataframe(result_set, monkeypatch): - monkeypatch.setattr(run_module.Connection, "current", Mock()) +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe(ip_empty, uri): + engine = create_engine(uri) - assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)})) + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + results = engine.execute("CREATE TABLE a (x INT);") + # results = engine.execute("CREATE TABLE test (n INT, name TEXT)") -def test_resultset_polars_dataframe(result_set, monkeypatch): - assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) + rs = ResultSet(results, mock) + df = rs.DataFrame() + assert df.to_dict() == {} -def test_resultset_csv(result_set, tmp_empty): - result_set.csv("file.csv") - assert Path("file.csv").read_text() == "x\n0\n1\n2\n" +def test_convert_to_dataframe_2(ip_empty): + engine = create_engine("duckdb://") + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 -def test_resultset_str(result_set): - assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" + engine.execute("CREATE TABLE a (x INT,);") + results = engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + rs = ResultSet(results, mock) + df = rs.DataFrame() + assert df.to_dict() == {"Count": {0: 5}} -def test_resultset_repr_html(result_set): - assert result_set._repr_html_() == ( - "\n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n \n \n " - "\n \n
x
0
1
2
" - ) +@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"]) +def test_convert_to_dataframe_3(ip_empty, uri): + engine = create_engine(uri) -def test_resultset_config_autolimit_dict(result, config): - config.autolimit = 1 - resultset = ResultSet(result, config).fetch_results() - assert resultset.dict() == {"x": (0,)} + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 100000 + engine.execute("CREATE TABLE a (x INT);") + engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);") + results = engine.execute("SELECT * FROM a") + # from ipdb import set_trace -def test_resultset_with_non_sqlalchemy_results(config): - df = pd.DataFrame({"x": range(3)}) # noqa - conn = Connection(engine=create_engine("duckdb://")) - result = conn.execute("SELECT * FROM df") - assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] + # set_trace() + rs = ResultSet(results, mock, sql="SELECT * FROM a", engine=engine) + df = rs.DataFrame() + # TODO: check native duckdb was called if using duckb + assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}} -def test_none_pretty(config): - conn = Connection(engine=create_engine("sqlite://")) - result = conn.execute("create table some_table (name, age)") - result_set = ResultSet(result, config) - assert result_set.pretty is None - assert "" == str(result_set) +def test_done_fetching_if_reached_autolimit(results): + mock = Mock() + mock.autolimit = 2 + mock.displaylimit = 100 -def test_lazy_loading(result, config): - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == 3 + rs = ResultSet(results, mock) + assert rs.did_finish_fetching() is True -@pytest.mark.parametrize( - "autolimit, expected", - [ - (None, 3), - (False, 3), - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_autolimit(result, config, autolimit, expected): - config.autolimit = autolimit - resultset = ResultSet(result, config) - assert len(resultset._results) == 0 - resultset.fetch_results() - assert len(resultset._results) == expected +def test_done_fetching_if_reached_autolimit_2(results): + mock = Mock() + mock.autolimit = 4 + mock.displaylimit = 100 -@pytest.mark.parametrize( - "displaylimit, expected", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit(result, config, displaylimit, expected): - config.displaylimit = displaylimit - result_set = ResultSet(result, config) + rs = ResultSet(results, mock) + list(rs) - assert len(result_set._results) == 0 - result_set.fetch_results() - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - assert row_count == expected + assert rs.did_finish_fetching() is True -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_displaylimit_fetch_all( - result, config, displaylimit, expected_display -): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) +@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"]) +@pytest.mark.parametrize("autolimit", [1000_000, 0]) +def test_no_displaylimit(results, method, autolimit): + mock = Mock() + mock.displaylimit = 0 + mock.autolimit = autolimit - # Initialize result_set without fetching results - assert len(result_set._results) == 0 + rs = ResultSet(results, mock) + getattr(rs, method)() - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() + assert rs._results == [(1,), (2,), (3,), (4,), (5,)] + assert rs.did_finish_fetching() is True - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) - assert len(result_set._results) == expected_results - assert row_count == expected_display +def test_no_fetching_if_one_result(): + engine = create_engine("duckdb://") + engine.execute("CREATE TABLE a (x INT,);") + engine.execute("INSERT INTO a(x) VALUES (1);") - # Fetch the the rest results, but don't display them in the table - result_set.fetch_results(fetch_all=True) + mock = Mock() + mock.displaylimit = 100 + mock.autolimit = 1000_000 - html = result_set._repr_html_() - row_count = _get_number_of_rows_in_html_table(html) + results = engine.execute("SELECT * FROM a") + results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchall = Mock(wraps=results.fetchall) + results.fetchone = Mock(wraps=results.fetchone) - assert len(result_set._results) == max_results_count - assert row_count == expected_display + rs = ResultSet(results, mock) + assert rs.did_finish_fetching() is True + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() + + str(rs) + list(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() -@pytest.mark.parametrize( - "displaylimit, expected_display", - [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), - ], -) -def test_lazy_loading_list(result, config, displaylimit, expected_display): - max_results_count = 3 - config.autolimit = False - config.displaylimit = displaylimit - result_set = ResultSet(result, config) - # Initialize result_set without fetching results - assert len(result_set._results) == 0 +# TODO: try with other values, and also change the display limit and re-run the repr +# TODO: try with __repr__ and __str__ +def test_resultset_fetches_required_rows(results): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 - # Fetch the min number of rows (based on configuration) - result_set.fetch_results() + rs = ResultSet(results, mock) + # rs.fetch_results() + # rs._repr_html_() + # str(rs) - expected_results = ( - max_results_count - if expected_display + 1 >= max_results_count - else expected_display + 1 - ) + results.fetchall.assert_not_called() + results.fetchmany.assert_called_once_with(size=2) + results.fetchone.assert_not_called() - assert len(result_set._results) == expected_results - assert len(list(result_set)) == max_results_count + +def test_fetches_remaining_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 + + rs = ResultSet(results, mock) + + # this will trigger fetching two + str(rs) + + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + # this will trigger fetching the rest + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] + + results.fetchall.assert_called_once_with() + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + + rs.sqlaproxy.fetchmany = Mock(side_effect=ValueError("fetchmany called")) + rs.sqlaproxy.fetchall = Mock(side_effect=ValueError("fetchall called")) + rs.sqlaproxy.fetchone = Mock(side_effect=ValueError("fetchone called")) + + # this should not trigger any more fetching + assert list(rs) == [(1,), (2,), (3,), (4,), (5,)] @pytest.mark.parametrize( - "autolimit, expected_results", + "method, repr_", [ - (0, 3), - (1, 1), - (2, 2), - (3, 3), - (4, 3), + [ + "__repr__", + "+---+\n| x |\n+---+\n| 1 |\n| 2 |\n| 3 |\n+---+", + ], + [ + "_repr_html_", + "\n \n \n \n " + "\n \n \n \n " + "\n \n \n " + "\n \n \n " + "\n \n \n
x
1
2
3
", + ], ], ) -def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): - config.autolimit = autolimit - result_set = ResultSet(result, config) - assert len(result_set._results) == 0 +def test_resultset_fetches_required_rows_repr_html(results, method, repr_): + mock = Mock() + mock.displaylimit = 3 + mock.autolimit = 1000_000 + + rs = ResultSet(results, mock) + rs_repr = getattr(rs, method)() + + assert repr_ in rs_repr + assert rs.did_finish_fetching() is False + results.fetchall.assert_not_called() + results.fetchmany.assert_has_calls([call(size=2), call(size=1)]) + results.fetchone.assert_not_called() + + +def test_resultset_fetches_no_rows(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 1000_000 + + ResultSet(results, mock) + + results.fetchmany.assert_has_calls([call(size=2)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() + + +def test_resultset_autolimit_one(results): + mock = Mock() + mock.displaylimit = 10 + mock.autolimit = 1 + + rs = ResultSet(results, mock) + repr(rs) + str(rs) + rs._repr_html_() + list(rs) - result_set.fetch_results() + results.fetchmany.assert_has_calls([call(size=1)]) + results.fetchone.assert_not_called() + results.fetchall.assert_not_called() - assert len(result_set._results) == expected_results - assert len(list(result_set)) == expected_results +# TODO: try with more values of displaylimit +# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows +def test_displaylimit_message(results): + mock = Mock() + mock.displaylimit = 1 + mock.autolimit = 0 -def _get_number_of_rows_in_html_table(html): - """ - Returns the number of tags within the section - """ - pattern = r"(.*?)<\/tbody>" - tbody_content = re.findall(pattern, html, re.DOTALL)[0] - row_count = len(re.findall(r"", tbody_content)) + rs = ResultSet(results, mock) - return row_count + assert "Truncated to displaylimit of 1" in rs._repr_html_() From f5de93fabe4a0acc0658ad6f01e85cc66343ad5c Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Wed, 28 Jun 2023 23:53:02 -0600 Subject: [PATCH 08/16] cleans up tests --- src/tests/test_resultset.py | 268 ++++++------------------------------ 1 file changed, 45 insertions(+), 223 deletions(-) diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 922a4b2c1..96fa90952 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -1,9 +1,8 @@ from unittest.mock import Mock, call from sqlalchemy import create_engine -from sql.connection import Connection from pathlib import Path -from unittest.mock import Mock + import pytest import pandas as pd @@ -11,254 +10,77 @@ import sqlalchemy from sql.run import ResultSet -from sql import run as run_module import re -# @pytest.fixture -# def config(): -# config = Mock() -# config.displaylimit = 5 -# config.autolimit = 100 -# return config - - -# @pytest.fixture -# def result(): -# df = pd.DataFrame({"x": range(3)}) # noqa -# engine = sqlalchemy.create_engine("duckdb://") - -# conn = engine.connect() -# result = conn.execute(sqlalchemy.text("select * from df")) -# yield result -# conn.close() - - -# @pytest.fixture -# def result_set(result, config): -# return ResultSet(result, config).fetch_results() - - -# def test_resultset_getitem(result_set): -# assert result_set[0] == (0,) -# assert result_set[0:2] == [(0,), (1,)] - - -# def test_resultset_dict(result_set): -# assert result_set.dict() == {"x": (0, 1, 2)} - - -# def test_resultset_dicts(result_set): -# assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] +@pytest.fixture +def config(): + config = Mock() + config.displaylimit = 5 + config.autolimit = 100 + return config -# def test_resultset_dataframe(result_set, monkeypatch): -# monkeypatch.setattr(run_module.Connection, "current", Mock()) +@pytest.fixture +def result(): + df = pd.DataFrame({"x": range(3)}) # noqa + engine = sqlalchemy.create_engine("duckdb://") -# assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)})) + conn = engine.connect() + result = conn.execute(sqlalchemy.text("select * from df")) + yield result + conn.close() -# def test_resultset_polars_dataframe(result_set, monkeypatch): -# assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) +@pytest.fixture +def result_set(result, config): + return ResultSet(result, config) -# def test_resultset_csv(result_set, tmp_empty): -# result_set.csv("file.csv") +def test_resultset_getitem(result_set): + assert result_set[0] == (0,) + assert result_set[0:2] == [(0,), (1,)] -# assert Path("file.csv").read_text() == "x\n0\n1\n2\n" +def test_resultset_dict(result_set): + assert result_set.dict() == {"x": (0, 1, 2)} -# def test_resultset_str(result_set): -# assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" +def test_resultset_dicts(result_set): + assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}] -# def test_resultset_repr_html(result_set): -# assert result_set._repr_html_() == ( -# "\n \n \n " -# "\n \n \n \n " -# "\n \n \n \n " -# "\n \n \n \n " -# "\n \n
x
0
1
2
" -# ) +def test_resultset_polars_dataframe(result_set, monkeypatch): + assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)})) -# def test_resultset_config_autolimit_dict(result, config): -# config.autolimit = 1 -# resultset = ResultSet(result, config).fetch_results() -# assert resultset.dict() == {"x": (0,)} +def test_resultset_csv(result_set, tmp_empty): + result_set.csv("file.csv") -# def test_resultset_with_non_sqlalchemy_results(config): -# df = pd.DataFrame({"x": range(3)}) # noqa -# conn = Connection(engine=create_engine("duckdb://")) -# result = conn.execute("SELECT * FROM df") -# assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] + assert Path("file.csv").read_text() == "x\n0\n1\n2\n" -# def test_none_pretty(config): -# conn = Connection(engine=create_engine("sqlite://")) -# result = conn.execute("create table some_table (name, age)") -# result_set = ResultSet(result, config) -# assert result_set.pretty is None -# assert "" == str(result_set) +def test_resultset_str(result_set): + assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+" -# def test_lazy_loading(result, config): -# resultset = ResultSet(result, config) -# assert len(resultset._results) == 0 -# resultset.fetch_results() -# assert len(resultset._results) == 3 +def test_resultset_config_autolimit_dict(result, config): + config.autolimit = 1 + resultset = ResultSet(result, config) + assert resultset.dict() == {"x": (0,)} -# @pytest.mark.parametrize( -# "autolimit, expected", -# [ -# (None, 3), -# (False, 3), -# (0, 3), -# (1, 1), -# (2, 2), -# (3, 3), -# (4, 3), -# ], -# ) -# def test_lazy_loading_autolimit(result, config, autolimit, expected): -# config.autolimit = autolimit -# resultset = ResultSet(result, config) -# assert len(resultset._results) == 0 -# resultset.fetch_results() -# assert len(resultset._results) == expected +def _get_number_of_rows_in_html_table(html): + """ + Returns the number of tags within the section + """ + pattern = r"(.*?)<\/tbody>" + tbody_content = re.findall(pattern, html, re.DOTALL)[0] + row_count = len(re.findall(r"", tbody_content)) + return row_count -# @pytest.mark.parametrize( -# "displaylimit, expected", -# [ -# (0, 3), -# (1, 1), -# (2, 2), -# (3, 3), -# (4, 3), -# ], -# ) -# def test_lazy_loading_displaylimit(result, config, displaylimit, expected): -# config.displaylimit = displaylimit -# result_set = ResultSet(result, config) - -# assert len(result_set._results) == 0 -# result_set.fetch_results() -# html = result_set._repr_html_() -# row_count = _get_number_of_rows_in_html_table(html) -# assert row_count == expected - - -# @pytest.mark.parametrize( -# "displaylimit, expected_display", -# [ -# (0, 3), -# (1, 1), -# (2, 2), -# (3, 3), -# (4, 3), -# ], -# ) -# def test_lazy_loading_displaylimit_fetch_all( -# result, config, displaylimit, expected_display -# ): -# max_results_count = 3 -# config.autolimit = False -# config.displaylimit = displaylimit -# result_set = ResultSet(result, config) - -# # Initialize result_set without fetching results -# assert len(result_set._results) == 0 - -# # Fetch the min number of rows (based on configuration) -# result_set.fetch_results() - -# html = result_set._repr_html_() -# row_count = _get_number_of_rows_in_html_table(html) -# expected_results = ( -# max_results_count -# if expected_display + 1 >= max_results_count -# else expected_display + 1 -# ) - -# assert len(result_set._results) == expected_results -# assert row_count == expected_display - -# # Fetch the the rest results, but don't display them in the table -# result_set.fetch_results(fetch_all=True) - -# html = result_set._repr_html_() -# row_count = _get_number_of_rows_in_html_table(html) - -# assert len(result_set._results) == max_results_count -# assert row_count == expected_display - - -# @pytest.mark.parametrize( -# "displaylimit, expected_display", -# [ -# (0, 3), -# (1, 1), -# (2, 2), -# (3, 3), -# (4, 3), -# ], -# ) -# def test_lazy_loading_list(result, config, displaylimit, expected_display): -# max_results_count = 3 -# config.autolimit = False -# config.displaylimit = displaylimit -# result_set = ResultSet(result, config) - -# # Initialize result_set without fetching results -# assert len(result_set._results) == 0 - -# # Fetch the min number of rows (based on configuration) -# result_set.fetch_results() - -# expected_results = ( -# max_results_count -# if expected_display + 1 >= max_results_count -# else expected_display + 1 -# ) - -# assert len(result_set._results) == expected_results -# assert len(list(result_set)) == max_results_count - - -# @pytest.mark.parametrize( -# "autolimit, expected_results", -# [ -# (0, 3), -# (1, 1), -# (2, 2), -# (3, 3), -# (4, 3), -# ], -# ) -# def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): -# config.autolimit = autolimit -# result_set = ResultSet(result, config) -# assert len(result_set._results) == 0 - -# result_set.fetch_results() - -# assert len(result_set._results) == expected_results -# assert len(list(result_set)) == expected_results - - -# def _get_number_of_rows_in_html_table(html): -# """ -# Returns the number of tags within the section -# """ -# pattern = r"(.*?)<\/tbody>" -# tbody_content = re.findall(pattern, html, re.DOTALL)[0] -# row_count = len(re.findall(r"", tbody_content)) - -# return row_count # TODO: add dbapi tests @@ -407,7 +229,7 @@ def test_resultset_fetches_required_rows(results): mock.displaylimit = 3 mock.autolimit = 1000_000 - rs = ResultSet(results, mock) + ResultSet(results, mock) # rs.fetch_results() # rs._repr_html_() # str(rs) From bf6cb38296cb1f1c39114df087e8edc9c1db4d74 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 00:06:33 -0600 Subject: [PATCH 09/16] cleaning up code --- src/sql/run.py | 79 +------------------------------------------------- 1 file changed, 1 insertion(+), 78 deletions(-) diff --git a/src/sql/run.py b/src/sql/run.py index acaab205b..32c0d9731 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -262,14 +262,6 @@ def DataFrame(self, payload): # "connection_info" # ] = Connection.current._get_curr_sqlalchemy_connection_info() - # TODO: test with generic db connection - self.sqlaproxy.dialect.name - - # list(self.engine.execute(self.sql)) - - # if not self.has_more_results: - # return pd.DataFrame() - if self.did_finish_fetching(): return pd.DataFrame(self, columns=(self and self.keys) or []) @@ -440,11 +432,6 @@ def fetchmany(self, size): self.extend_results(returned) - # if "SELECT * FROM number_table" in str(self.sql): - # from ipdb import set_trace - - # set_trace() - if len(returned) < size: self.done_fetching() @@ -478,69 +465,6 @@ def _init_table(self): return pretty - def _fetch_query_results(self, fetch_all=False): - """ - Returns rows of a query result as a list of tuples. - - Parameters - ---------- - fetch_all : bool default False - Return all query rows - """ - _should_try_lazy_fetch = hasattr(self.sqlaproxy, "_soft_closed") - - _should_fetch_all = ( - (self.config.displaylimit == 0 or not self.config.displaylimit) - or fetch_all - or not _should_try_lazy_fetch - ) - - is_autolimit = ( - isinstance(self.config.autolimit, int) and self.config.autolimit > 0 - ) - is_connection_closed = ( - self.sqlaproxy._soft_closed if _should_try_lazy_fetch else False - ) - - should_return_results = is_connection_closed or ( - len(self._results) > 0 and is_autolimit - ) - - if should_return_results: - # this means we already loaded all - # the results to self._results or we use - # autolimit and shouldn't fetch more - results = self._results - else: - if is_autolimit: - results = self.sqlaproxy.fetchmany(size=self.config.autolimit) - else: - if _should_fetch_all: - all_results = self.sqlaproxy.fetchall() - results = self._results + all_results - self._results = results - else: - # operational errors are silenced! - try: - results = self.sqlaproxy.fetchmany( - size=self.config.displaylimit - ) - except Exception as e: - raise RuntimeError("Could not fetch from database") from e - - if _should_try_lazy_fetch: - # Try to fetch an extra row to find out - # if there are more results to fetch - row = self.sqlaproxy.fetchone() - if row is not None: - results += [row] - - # Check if we have more rows to show - if self.config.displaylimit > 0: - self.truncated = len(results) > self.config.displaylimit - - return results - def display_affected_rowcount(rowcount): if rowcount > 0: @@ -669,7 +593,6 @@ def select_df_type(resultset, config): return resultset.PolarsDataFrame(**config.polars_dataframe_kwargs) else: return resultset - # returning only last result, intentionally def run(conn, sql, config): @@ -716,6 +639,7 @@ def run(conn, sql, config): if not is_custom_connection: statement = sqlalchemy.sql.text(statement) + # NOTE: are thre any edge cases we should cover? if not is_select: with Session(conn.engine, expire_on_commit=False) as session: result = session.execute(statement) @@ -730,7 +654,6 @@ def run(conn, sql, config): if hasattr(result, "rowcount"): display_affected_rowcount(result.rowcount) - # resultset = ResultSet(result, config, statement, conn.engine) return select_df_type(resultset, config) From f4a2703287fc5b3bcdba2bff58bb12963e7b5636 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 08:45:10 -0600 Subject: [PATCH 10/16] removes unused session arg --- src/sql/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sql/run.py b/src/sql/run.py index 32c0d9731..229334cc7 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -641,7 +641,7 @@ def run(conn, sql, config): # NOTE: are thre any edge cases we should cover? if not is_select: - with Session(conn.engine, expire_on_commit=False) as session: + with Session(conn.engine) as session: result = session.execute(statement) resultset = ResultSet(result, config, statement, conn.engine) session.commit() From b70479385c78e288c360a55d40af194677875c58 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 10:28:50 -0600 Subject: [PATCH 11/16] fix --- src/sql/connection.py | 2 +- src/sql/inspect.py | 4 ++++ src/sql/run.py | 2 +- src/tests/test_magic_cmd.py | 18 ++++++++++++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/sql/connection.py b/src/sql/connection.py index 58352f797..f151921fe 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -653,7 +653,7 @@ def execute(self, query, with_=None): Executes SQL query on a given connection """ query = self._prepare_query(query, with_) - return self.session.execute(query) + return self.engine.execute(query) atexit.register(Connection.close_all, verbose=True) diff --git a/src/sql/inspect.py b/src/sql/inspect.py index 1d7cfc103..ada6b16b3 100644 --- a/src/sql/inspect.py +++ b/src/sql/inspect.py @@ -239,11 +239,15 @@ def __init__(self, table_name, schema=None) -> None: columns_query_result = sql.run.raw_run( Connection.current, f"SELECT * FROM {table_name} WHERE 1=0" ) + if Connection.is_custom_connection(): columns = [i[0] for i in columns_query_result.description] else: columns = columns_query_result.keys() + # TODO: abstract it internally + columns_query_result.close() + table_stats = dict({}) columns_to_include_in_report = set() columns_with_styles = [] diff --git a/src/sql/run.py b/src/sql/run.py index 229334cc7..9e492420e 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -662,7 +662,7 @@ def _first_word(sql): def raw_run(conn, sql): - return conn.session.execute(sqlalchemy.sql.text(sql)) + return conn.engine.execute(sqlalchemy.sql.text(sql)) class PrettyTable(prettytable.PrettyTable): diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index d9b7ef97b..ea524741b 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -1,11 +1,10 @@ +import sqlite3 import sys import math import pytest from IPython.core.error import UsageError from pathlib import Path -from sqlalchemy import create_engine -from sql.connection import Connection from sql.store import store from sql.inspect import _is_numeric @@ -112,8 +111,15 @@ def test_tables(ip): def test_tables_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + # TODO: why does this fail? + # ip.run_cell( + # """%%sql sqlite:///my.db + # CREATE TABLE numbers (some_number FLOAT) + # """ + # ) + + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql @@ -150,8 +156,8 @@ def test_columns(ip, cmd, cols): def test_columns_with_schema(ip, tmp_empty): - conn = Connection(engine=create_engine("sqlite:///my.db")) - conn.execute("CREATE TABLE numbers (some_number FLOAT)") + with sqlite3.connect("my.db") as conn: + conn.execute("CREATE TABLE numbers (some_number FLOAT)") ip.run_cell( """%%sql From 4c91298fbf986e711650b2b85391ae752f6787c1 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 10:32:32 -0600 Subject: [PATCH 12/16] fix --- src/sql/run.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sql/run.py b/src/sql/run.py index 9e492420e..d919c6101 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -630,7 +630,7 @@ def run(conn, sql, config): # regular query else: - # TODO: re add commmit feature + # TODO: add commit feature again # _commit(conn=conn, config=config, manual_commit=manual_commit) # manual_commit = set_autocommit(conn, config) is_custom_connection = Connection.is_custom_connection(conn) @@ -639,7 +639,7 @@ def run(conn, sql, config): if not is_custom_connection: statement = sqlalchemy.sql.text(statement) - # NOTE: are thre any edge cases we should cover? + # NOTE: are there any edge cases we should cover? if not is_select: with Session(conn.engine) as session: result = session.execute(statement) From dc2d1f545776188c0079ccf4db05fb6e859546a4 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 11:21:15 -0600 Subject: [PATCH 13/16] fixing generic driver test --- src/sql/connection.py | 5 +++++ src/sql/run.py | 12 +++++++++--- src/tests/conftest.py | 16 +++++++++++++--- src/tests/test_magic.py | 13 +++++++++++++ 4 files changed, 40 insertions(+), 6 deletions(-) diff --git a/src/sql/connection.py b/src/sql/connection.py index f151921fe..180f93d2a 100644 --- a/src/sql/connection.py +++ b/src/sql/connection.py @@ -674,11 +674,15 @@ def __init__(self, connection, engine): } ) + # TODO: need to close the cursor def execute(self, query): cur = self.engine.cursor() cur.execute(query) return cur + def commit(self): + self.engine.commit() + class CustomConnection(Connection): """ @@ -700,6 +704,7 @@ def __init__(self, payload, engine=None, alias=None): self.name = connection_name_ self.dialect = connection_name_ self.session = CustomSession(self, engine) + self.engine = self.session self.connections[alias or connection_name_] = self diff --git a/src/sql/run.py b/src/sql/run.py index d919c6101..6b1dfae2b 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -641,10 +641,16 @@ def run(conn, sql, config): # NOTE: are there any edge cases we should cover? if not is_select: - with Session(conn.engine) as session: - result = session.execute(statement) + # TODO: we should abstract this + if is_custom_connection: + result = conn.engine.execute(statement) resultset = ResultSet(result, config, statement, conn.engine) - session.commit() + conn.engine.commit() + else: + with Session(conn.engine) as session: + result = session.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) + session.commit() else: result = conn.engine.execute(statement) diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 5bceaaad8..a96e5f832 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -48,19 +48,29 @@ def clean_conns(): yield -# if we enable it, we'll have to update tests! -# because they expect the error not to be raised class TestingShell(InteractiveShell): def run_cell(self, *args, **kwargs): result = super().run_cell(*args, **kwargs) if result.error_in_exec is not None: - # raise RuntimeError("a") from result.error_in_exec raise result.error_in_exec return result +# migrate to this one +@pytest.fixture +def ip_empty_testing_shell(): + ip_session = TestingShell() + ip_session.register_magics(SqlMagic) + ip_session.register_magics(RenderMagic) + ip_session.register_magics(SqlPlotMagic) + ip_session.register_magics(SqlCmdMagic) + + yield ip_session + Connection.close_all() + + @pytest.fixture def ip_empty(): ip_session = InteractiveShell() diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 814099e90..1834b9087 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1302,3 +1302,16 @@ def test_interact_and_missing_ipywidgets_installed(ip): "%sql --interact my_variable SELECT * FROM author LIMIT {{my_variable}}" ) assert isinstance(out.error_in_exec, ModuleNotFoundError) + + +def test_generic_driver(ip_empty_testing_shell): + # ip_empty_testing_shell.run_cell("import duckdb") + # ip_empty_testing_shell.run_cell("conn = duckdb.connect(':memory:')") + ip_empty_testing_shell.run_cell("import sqlite3") + ip_empty_testing_shell.run_cell("conn = sqlite3.connect(':memory:')") + + ip_empty_testing_shell.run_cell("%sql conn") + ip_empty_testing_shell.run_cell("%sql CREATE TABLE test (a INTEGER, b INTEGER)") + ip_empty_testing_shell.run_cell("%sql INSERT INTO test VALUES (1, 2)") + result = ip_empty_testing_shell.run_cell("%sql SELECT * FROM test") + assert result.result == [(1, 2)] From e739e9ae5784c2c8ae252563237acafa1738e435 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 12:29:18 -0600 Subject: [PATCH 14/16] test fix --- src/sql/run.py | 1 - src/sql/util.py | 3 +++ src/tests/test_inspect.py | 10 +++++----- src/tests/test_magic.py | 1 + 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/sql/run.py b/src/sql/run.py index 6b1dfae2b..4c9d341c4 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -651,7 +651,6 @@ def run(conn, sql, config): result = session.execute(statement) resultset = ResultSet(result, config, statement, conn.engine) session.commit() - else: result = conn.engine.execute(statement) resultset = ResultSet(result, config, statement, conn.engine) diff --git a/src/sql/util.py b/src/sql/util.py index 19473ace6..b621f383b 100644 --- a/src/sql/util.py +++ b/src/sql/util.py @@ -81,6 +81,7 @@ def is_table_exists( ignore_error: bool, default False Avoid raising a ValueError """ + if table is None: if ignore_error: return False @@ -188,6 +189,7 @@ def _is_table_exists(table: str, conn) -> bool: """ Runs a SQL query to check if table exists """ + if not conn: conn = Connection.current @@ -201,6 +203,7 @@ def _is_table_exists(table: str, conn) -> bool: try: conn.execute(query) return True + # TODO: Add specific exception except Exception: pass diff --git a/src/tests/test_inspect.py b/src/tests/test_inspect.py index 4234d412d..5bf32fbab 100644 --- a/src/tests/test_inspect.py +++ b/src/tests/test_inspect.py @@ -224,20 +224,20 @@ def test_columns_with_missing_values( monkeypatch.setattr(inspect, "_get_inspector", lambda _: mock) ip.run_cell( - """%%sql sqlite:///another.db -CREATE TABLE IF NOT EXISTS another_table (id INT) + """%%sql duckdb:// +CREATE TABLE IF NOT EXISTS test_table (id INT) """ ) ip.run_cell( - """%%sql sqlite:///my.db -CREATE TABLE IF NOT EXISTS test_table (id INT) + """%%sql +CREATE SCHEMA another_schema; """ ) ip.run_cell( """%%sql -ATTACH DATABASE 'another.db' as 'another_schema'; +CREATE TABLE IF NOT EXISTS another_schema.another_table (id INT) """ ) diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 1834b9087..f4f3fc81f 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1305,6 +1305,7 @@ def test_interact_and_missing_ipywidgets_installed(ip): def test_generic_driver(ip_empty_testing_shell): + # TODO: add duckdb support # ip_empty_testing_shell.run_cell("import duckdb") # ip_empty_testing_shell.run_cell("conn = duckdb.connect(':memory:')") ip_empty_testing_shell.run_cell("import sqlite3") From 879afd596f47e5f810a5f92984e1d095b16b55d0 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 12:52:50 -0600 Subject: [PATCH 15/16] more fixes --- src/sql/run.py | 2 +- src/tests/integration/test_duckDB.py | 28 +++++++++------------------- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/src/sql/run.py b/src/sql/run.py index 4c9d341c4..965c01a5e 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -615,7 +615,7 @@ def run(conn, sql, config): statements = sqlparse.split(sql) - for index, statement in enumerate(statements): + for statement in statements: first_word = _first_word(statement) is_select = first_word == "select" # manual_commit = False diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py index dff8bbd1a..116bc9053 100644 --- a/src/tests/integration/test_duckDB.py +++ b/src/tests/integration/test_duckDB.py @@ -1,10 +1,11 @@ from unittest.mock import Mock import logging +from sqlalchemy import create_engine import pandas as pd import pytest -from sql import connection +from sql.connection import Connection from sql.run import ResultSet @@ -92,18 +93,16 @@ def test_multiple_statements(ip_empty, config, sql, tables): assert set(tables) == set(r[0] for r in out_tables.result._table.rows) -def test_dataframe_returned_only_if_last_statement_is_select(ip_empty): +def test_autopandas_when_last_result_is_not_a_select_statement(ip_empty): ip_empty.run_cell("%sql duckdb://") ip_empty.run_cell("%config SqlMagic.autopandas=True") - connection.Connection.connections["duckdb://"].engine.raw_connection = Mock( - side_effect=ValueError("some error") - ) out = ip_empty.run_cell( "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);" ) assert out.error_in_exec is None + assert out.result.to_dict() == dict() @pytest.mark.parametrize( @@ -132,35 +131,26 @@ def test_commits_all_statements(ip_empty, sql): def test_resultset_uses_native_duckdb_df(ip_empty): - from sqlalchemy import create_engine - from sql.connection import Connection - engine = create_engine("duckdb://") - engine.execute("CREATE TABLE a (x INT,);") engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);") sql = "SELECT * FROM a" - - # this breaks if there's an open results set - engine.execute(sql).fetchall() - results = engine.execute(sql) Connection.set(engine, displaycon=False) - results.fetchmany = Mock(wraps=results.fetchmany) + results.fetchone = Mock(side_effect=ValueError("Should not be called")) + results.fetchall = Mock(side_effect=ValueError("Should not be called")) mock = Mock() mock.displaylimit = 1 + mock.autolimit = 0 result_set = ResultSet(results, mock, sql=sql, engine=engine) - - result_set.fetch_results() - df = result_set.DataFrame() assert isinstance(df, pd.DataFrame) - assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3}} + assert df.to_dict() == {"x": {0: 10, 1: 20, 2: 30}} - results.fetchmany.assert_called_once_with(size=1) + results.fetchmany.assert_called_once_with(size=2) From a68b10a930e8aff18da52280f6d6dbc798028be8 Mon Sep 17 00:00:00 2001 From: Eduardo Blancas Date: Thu, 29 Jun 2023 12:57:51 -0600 Subject: [PATCH 16/16] sqlalchemyy 2 fix --- src/sql/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/sql/run.py b/src/sql/run.py index 965c01a5e..fa27ea24a 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -652,7 +652,11 @@ def run(conn, sql, config): resultset = ResultSet(result, config, statement, conn.engine) session.commit() else: - result = conn.engine.execute(statement) + # we need the session so sqlalchemy 2.x works + # result = conn.engine.execute(statement) + with Session(conn.engine) as session: + result = session.execute(statement) + resultset = ResultSet(result, config, statement, conn.engine) if result and config.feedback: