diff --git a/CHANGELOG.md b/CHANGELOG.md
index fb69fe39a..c803e6a37 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)
* [Fix] Refactored `ResultSet` to lazy loading (#470)
+* [Fix] Error when executing multiple SQL statements when using DuckDB with `autopandas` on (#674)
* [Fix] Removed `WITH` when a snippet does not have a dependency (#657)
## 0.7.9 (2023-06-19)
diff --git a/src/sql/connection.py b/src/sql/connection.py
index d5435055b..180f93d2a 100644
--- a/src/sql/connection.py
+++ b/src/sql/connection.py
@@ -203,6 +203,8 @@ def __init__(self, engine, alias=None):
)
self.dialect = engine.url.get_dialect()
+
+ # TODO: delete this!
self.session = self._create_session(engine, self.url)
self.connections[alias or self.url] = self
@@ -651,7 +653,7 @@ def execute(self, query, with_=None):
Executes SQL query on a given connection
"""
query = self._prepare_query(query, with_)
- return self.session.execute(query)
+ return self.engine.execute(query)
atexit.register(Connection.close_all, verbose=True)
@@ -672,11 +674,15 @@ def __init__(self, connection, engine):
}
)
+ # TODO: need to close the cursor
def execute(self, query):
cur = self.engine.cursor()
cur.execute(query)
return cur
+ def commit(self):
+ self.engine.commit()
+
class CustomConnection(Connection):
"""
@@ -698,6 +704,7 @@ def __init__(self, payload, engine=None, alias=None):
self.name = connection_name_
self.dialect = connection_name_
self.session = CustomSession(self, engine)
+ self.engine = self.session
self.connections[alias or connection_name_] = self
diff --git a/src/sql/inspect.py b/src/sql/inspect.py
index f4455370c..ada6b16b3 100644
--- a/src/sql/inspect.py
+++ b/src/sql/inspect.py
@@ -18,7 +18,7 @@ def _get_inspector(conn):
if not Connection.current:
raise exceptions.RuntimeError("No active connection")
else:
- return inspect(Connection.current.session)
+ return inspect(Connection.current.engine)
class DatabaseInspection:
@@ -239,11 +239,15 @@ def __init__(self, table_name, schema=None) -> None:
columns_query_result = sql.run.raw_run(
Connection.current, f"SELECT * FROM {table_name} WHERE 1=0"
)
+
if Connection.is_custom_connection():
columns = [i[0] for i in columns_query_result.description]
else:
columns = columns_query_result.keys()
+ # TODO: abstract it internally
+ columns_query_result.close()
+
table_stats = dict({})
columns_to_include_in_report = set()
columns_with_styles = []
diff --git a/src/sql/run.py b/src/sql/run.py
index fb8bcaff3..fa27ea24a 100644
--- a/src/sql/run.py
+++ b/src/sql/run.py
@@ -9,6 +9,7 @@
import prettytable
import sqlalchemy
+from sqlalchemy.exc import ResourceClosedError
import sqlparse
from sql.connection import Connection
from sql import exceptions, display
@@ -109,27 +110,88 @@ class ResultSet(ColumnGuesserMixin):
Can access rows listwise, or by string value of leftmost column.
"""
- def __init__(self, sqlaproxy, config):
+ def __init__(self, sqlaproxy, config, sql=None, engine=None):
self.config = config
- self.keys = {}
- self._results = []
self.truncated = False
self.sqlaproxy = sqlaproxy
+ self.sql = sql
+ self.engine = engine
+
+ self._keys = None
+ self._field_names = None
+ self._results = []
# https://peps.python.org/pep-0249/#description
self.is_dbapi_results = hasattr(sqlaproxy, "description")
- self.pretty = None
+ # NOTE: this will trigger key fetching
+ self.pretty_table = self._init_table()
+
+ self._done_fetching = False
+
+ if self.config.autolimit == 1:
+ self.fetchmany(size=1)
+ self.did_finish_fetching()
+ # EXPLAIN WHY WE NEED TWO
+ else:
+ self.fetchmany(size=2)
+
+ def extend_results(self, elements):
+ self._results.extend(elements)
+
+ # NOTE: we shouold use add_rows but there is a subclass that behaves weird
+ # so using add_row for now
+ for e in elements:
+ self.pretty_table.add_row(e)
+
+ def done_fetching(self):
+ self._done_fetching = True
+ self.sqlaproxy.close()
+
+ def did_finish_fetching(self):
+ return self._done_fetching
+
+ # NOTE: this triggers key fetching
+ @property
+ def field_names(self):
+ if self._field_names is None:
+ self._field_names = unduplicate_field_names(self.keys)
+
+ return self._field_names
+
+ # NOTE: this triggers key fetching
+ @property
+ def keys(self):
+ if self._keys is not None:
+ return self._keys
+
+ if not self.is_dbapi_results:
+ try:
+ self._keys = self.sqlaproxy.keys()
+ # sqlite raises this error when running a script that doesn't return rows
+ # e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
+ except ResourceClosedError:
+ self._keys = []
+ return self._keys
+
+ elif isinstance(self.sqlaproxy.description, Iterable):
+ self._keys = [i[0] for i in self.sqlaproxy.description]
+ else:
+ self._keys = []
+
+ return self._keys
def _repr_html_(self):
+ self.fetch_for_repr_if_needed()
+
_cell_with_spaces_pattern = re.compile(r"(
)( {2,})")
- if self.pretty:
- self.pretty.add_rows(self)
- result = self.pretty.get_html_string()
+ if self.pretty_table:
+ self.pretty_table.add_rows(self)
+ result = self.pretty_table.get_html_string()
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
- if self.truncated:
+ if self.config.displaylimit != 0:
HTML = (
'%s\n'
"Truncated to displaylimit of %d"
@@ -139,7 +201,7 @@ def _repr_html_(self):
'displaylimit' # noqa: E501
" configuration"
)
- result = HTML % (result, self.pretty.row_count)
+ result = HTML % (result, self.config.displaylimit)
return result
else:
return None
@@ -148,18 +210,18 @@ def __len__(self):
return len(self._results)
def __iter__(self):
- results = self._fetch_query_results(fetch_all=True)
+ self.fetchall()
- for result in results:
+ for result in self._results:
yield result
def __str__(self, *arg, **kwarg):
- if self.pretty:
- self.pretty.add_rows(self)
- return str(self.pretty or "")
+ self.fetch_for_repr_if_needed()
+ return str(self.pretty_table)
def __repr__(self) -> str:
- return str(self)
+ self.fetch_for_repr_if_needed()
+ return str(self.pretty_table)
def __eq__(self, another: object) -> bool:
return self._results == another
@@ -195,10 +257,28 @@ def DataFrame(self, payload):
"Returns a Pandas DataFrame instance built from the result set."
import pandas as pd
+ # TODO: re-add
+ # payload[
+ # "connection_info"
+ # ] = Connection.current._get_curr_sqlalchemy_connection_info()
+
+ if self.did_finish_fetching():
+ return pd.DataFrame(self, columns=(self and self.keys) or [])
+
+ # only run when using duckdb
+ if self.engine.dialect.name == "duckdb":
+ # why do we need this?
+ self.sqlaproxy.close()
+
+ conn_duckdb_raw = self.engine.raw_connection()
+ cursor = conn_duckdb_raw.cursor()
+ cursor.execute(str(self.sql))
+ df = cursor.df()
+ conn_duckdb_raw.close()
+ return df
+
frame = pd.DataFrame(self, columns=(self and self.keys) or [])
- payload[
- "connection_info"
- ] = Connection.current._get_curr_sqlalchemy_connection_info()
+
return frame
@telemetry.log_call("polars-data-frame")
@@ -320,9 +400,9 @@ def bar(self, key_word_sep=" ", title=None, **kwargs):
def csv(self, filename=None, **format_params):
"""Generate results in comma-separated form. Write to ``filename`` if given.
Any other parameters will be passed on to csv.writer."""
- if not self.pretty:
+ if not self.pretty_table:
return None # no results
- self.pretty.add_rows(self)
+ self.pretty_table.add_rows(self)
if filename:
encoding = format_params.get("encoding", "utf-8")
outfile = open(filename, "w", newline="", encoding=encoding)
@@ -339,104 +419,51 @@ def csv(self, filename=None, **format_params):
else:
return outfile.getvalue()
- def fetch_results(self, fetch_all=False):
- """
- Returns a limited representation of the query results.
-
- Parameters
- ----------
- fetch_all : bool default False
- Return all query rows
- """
- is_dbapi_results = self.is_dbapi_results
- sqlaproxy = self.sqlaproxy
- config = self.config
-
- if is_dbapi_results:
- should_try_fetch_results = True
- else:
- should_try_fetch_results = sqlaproxy.returns_rows
-
- if should_try_fetch_results:
- # sql alchemy results
- if not is_dbapi_results:
- self.keys = sqlaproxy.keys()
- elif isinstance(sqlaproxy.description, Iterable):
- self.keys = [i[0] for i in sqlaproxy.description]
- else:
- self.keys = []
-
- if len(self.keys) > 0:
- self._results = self._fetch_query_results(fetch_all=fetch_all)
+ def fetchmany(self, size):
+ """Fetch n results and add it to the results"""
+ if not self.did_finish_fetching():
+ try:
+ returned = self.sqlaproxy.fetchmany(size=size)
+ # sqlite raises this error when running a script that doesn't return rows
+ # e.g, 'CREATE TABLE' but others don't (e.g., duckdb)
+ except ResourceClosedError:
+ self.done_fetching()
+ return
- self.field_names = unduplicate_field_names(self.keys)
+ self.extend_results(returned)
- _style = None
+ if len(returned) < size:
+ self.done_fetching()
- self.pretty = PrettyTable(self.field_names)
+ if (
+ self.config.autolimit is not None
+ and self.config.autolimit != 0
+ and len(self._results) >= self.config.autolimit
+ ):
+ self.done_fetching()
- if isinstance(config.style, str):
- _style = prettytable.__dict__[config.style.upper()]
- self.pretty.set_style(_style)
+ def fetch_for_repr_if_needed(self):
+ if self.config.displaylimit == 0:
+ self.fetchall()
- return self
+ missing = self.config.displaylimit - len(self._results)
- def _fetch_query_results(self, fetch_all=False):
- """
- Returns rows of a query result as a list of tuples.
-
- Parameters
- ----------
- fetch_all : bool default False
- Return all query rows
- """
- sqlaproxy = self.sqlaproxy
- config = self.config
- _should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed")
-
- _should_fetch_all = (
- (config.displaylimit == 0 or not config.displaylimit)
- or fetch_all
- or not _should_try_lazy_fetch
- )
+ if missing > 0:
+ self.fetchmany(missing)
- is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0
- is_connection_closed = (
- sqlaproxy._soft_closed if _should_try_lazy_fetch else False
- )
-
- should_return_results = is_connection_closed or (
- len(self._results) > 0 and is_autolimit
- )
+ def fetchall(self):
+ if not self.did_finish_fetching():
+ self.extend_results(self.sqlaproxy.fetchall())
+ self.done_fetching()
- if should_return_results:
- # this means we already loaded all
- # the results to self._results or we use
- # autolimit and shouldn't fetch more
- results = self._results
- else:
- if is_autolimit:
- results = sqlaproxy.fetchmany(size=config.autolimit)
- else:
- if _should_fetch_all:
- all_results = sqlaproxy.fetchall()
- results = self._results + all_results
- self._results = results
- else:
- results = sqlaproxy.fetchmany(size=config.displaylimit)
+ def _init_table(self):
+ pretty = PrettyTable(self.field_names)
- if _should_try_lazy_fetch:
- # Try to fetch an extra row to find out
- # if there are more results to fetch
- row = sqlaproxy.fetchone()
- if row is not None:
- results += [row]
+ if isinstance(self.config.style, str):
+ _style = prettytable.__dict__[self.config.style.upper()]
+ pretty.set_style(_style)
- # Check if we have more rows to show
- if config.displaylimit > 0:
- self.truncated = len(results) > config.displaylimit
-
- return results
+ return pretty
def display_affected_rowcount(rowcount):
@@ -507,6 +534,7 @@ def _commit(conn, config, manual_commit):
with Session(conn.session) as session:
session.commit()
except sqlalchemy.exc.OperationalError:
+ # TODO: missing rollback here?
print("The database does not support the COMMIT command")
@@ -565,7 +593,6 @@ def select_df_type(resultset, config):
return resultset.PolarsDataFrame(**config.polars_dataframe_kwargs)
else:
return resultset
- # returning only last result, intentionally
def run(conn, sql, config):
@@ -582,17 +609,16 @@ def run(conn, sql, config):
config
Configuration object
"""
- info = conn._get_curr_sqlalchemy_connection_info()
-
- duckdb_autopandas = info and info.get("dialect") == "duckdb" and config.autopandas
-
if not sql.strip():
# returning only when sql is empty string
return "Connected: %s" % conn.name
- for statement in sqlparse.split(sql):
- first_word = sql.strip().split()[0].lower()
- manual_commit = False
+ statements = sqlparse.split(sql)
+
+ for statement in statements:
+ first_word = _first_word(statement)
+ is_select = first_word == "select"
+ # manual_commit = False
# attempting to run a transaction
if first_word == "begin":
@@ -604,41 +630,48 @@ def run(conn, sql, config):
# regular query
else:
- manual_commit = set_autocommit(conn, config)
+ # TODO: add commit feature again
+ # _commit(conn=conn, config=config, manual_commit=manual_commit)
+ # manual_commit = set_autocommit(conn, config)
is_custom_connection = Connection.is_custom_connection(conn)
# if regular sqlalchemy, pass a text object
if not is_custom_connection:
statement = sqlalchemy.sql.text(statement)
- if duckdb_autopandas:
- conn = conn.engine.raw_connection()
- cursor = conn.cursor()
- cursor.execute(str(statement))
-
+ # NOTE: are there any edge cases we should cover?
+ if not is_select:
+ # TODO: we should abstract this
+ if is_custom_connection:
+ result = conn.engine.execute(statement)
+ resultset = ResultSet(result, config, statement, conn.engine)
+ conn.engine.commit()
+ else:
+ with Session(conn.engine) as session:
+ result = session.execute(statement)
+ resultset = ResultSet(result, config, statement, conn.engine)
+ session.commit()
else:
- result = conn.session.execute(statement)
- _commit(conn=conn, config=config, manual_commit=manual_commit)
-
- if result and config.feedback:
- if hasattr(result, "rowcount"):
- display_affected_rowcount(result.rowcount)
-
- # bypass ResultSet and use duckdb's native method to return a pandas data frame
- if duckdb_autopandas:
- df = cursor.df()
- conn.close()
- return df
- else:
- resultset = ResultSet(result, config)
+ # we need the session so sqlalchemy 2.x works
+ # result = conn.engine.execute(statement)
+ with Session(conn.engine) as session:
+ result = session.execute(statement)
+
+ resultset = ResultSet(result, config, statement, conn.engine)
+
+ if result and config.feedback:
+ if hasattr(result, "rowcount"):
+ display_affected_rowcount(result.rowcount)
+
+ return select_df_type(resultset, config)
+
- # lazy load
- resultset.fetch_results()
- return select_df_type(resultset, config)
+def _first_word(sql):
+ return sql.strip().split()[0].lower()
def raw_run(conn, sql):
- return conn.session.execute(sqlalchemy.sql.text(sql))
+ return conn.engine.execute(sqlalchemy.sql.text(sql))
class PrettyTable(prettytable.PrettyTable):
diff --git a/src/sql/util.py b/src/sql/util.py
index 19473ace6..b621f383b 100644
--- a/src/sql/util.py
+++ b/src/sql/util.py
@@ -81,6 +81,7 @@ def is_table_exists(
ignore_error: bool, default False
Avoid raising a ValueError
"""
+
if table is None:
if ignore_error:
return False
@@ -188,6 +189,7 @@ def _is_table_exists(table: str, conn) -> bool:
"""
Runs a SQL query to check if table exists
"""
+
if not conn:
conn = Connection.current
@@ -201,6 +203,7 @@ def _is_table_exists(table: str, conn) -> bool:
try:
conn.execute(query)
return True
+ # TODO: Add specific exception
except Exception:
pass
diff --git a/src/tests/conftest.py b/src/tests/conftest.py
index bd2422e4f..a96e5f832 100644
--- a/src/tests/conftest.py
+++ b/src/tests/conftest.py
@@ -48,6 +48,29 @@ def clean_conns():
yield
+class TestingShell(InteractiveShell):
+ def run_cell(self, *args, **kwargs):
+ result = super().run_cell(*args, **kwargs)
+
+ if result.error_in_exec is not None:
+ raise result.error_in_exec
+
+ return result
+
+
+# migrate to this one
+@pytest.fixture
+def ip_empty_testing_shell():
+ ip_session = TestingShell()
+ ip_session.register_magics(SqlMagic)
+ ip_session.register_magics(RenderMagic)
+ ip_session.register_magics(SqlPlotMagic)
+ ip_session.register_magics(SqlCmdMagic)
+
+ yield ip_session
+ Connection.close_all()
+
+
@pytest.fixture
def ip_empty():
ip_session = InteractiveShell()
@@ -55,6 +78,7 @@ def ip_empty():
ip_session.register_magics(RenderMagic)
ip_session.register_magics(SqlPlotMagic)
ip_session.register_magics(SqlCmdMagic)
+
yield ip_session
Connection.close_all()
diff --git a/src/tests/integration/test_duckDB.py b/src/tests/integration/test_duckDB.py
index 18b47a56f..116bc9053 100644
--- a/src/tests/integration/test_duckDB.py
+++ b/src/tests/integration/test_duckDB.py
@@ -1,5 +1,13 @@
+from unittest.mock import Mock
import logging
+from sqlalchemy import create_engine
+import pandas as pd
+import pytest
+
+from sql.connection import Connection
+from sql.run import ResultSet
+
def test_auto_commit_mode_on(ip_with_duckDB, caplog):
with caplog.at_level(logging.DEBUG):
@@ -27,3 +35,122 @@ def test_auto_commit_mode_off(ip_with_duckDB, caplog):
# Check the tables is created
tables_out = ip_with_duckDB.run_cell("%sql SHOW TABLES;").result
assert any("weather" == table[0] for table in tables_out)
+
+
+@pytest.mark.parametrize(
+ "config",
+ [
+ "%config SqlMagic.autopandas = True",
+ "%config SqlMagic.autopandas = False",
+ ],
+ ids=[
+ "autopandas_on",
+ "autopandas_off",
+ ],
+)
+@pytest.mark.parametrize(
+ "sql, tables",
+ [
+ ["%sql SELECT * FROM weather; SELECT * FROM weather;", ["weather"]],
+ [
+ "%sql CREATE TABLE names (name VARCHAR,); SELECT * FROM weather;",
+ ["weather", "names"],
+ ],
+ [
+ (
+ "%sql CREATE TABLE names (city VARCHAR,);"
+ "CREATE TABLE more_names (city VARCHAR,);"
+ "INSERT INTO names VALUES ('NYC');"
+ "SELECT * FROM names UNION ALL SELECT * FROM more_names;"
+ ),
+ ["weather", "names", "more_names"],
+ ],
+ ],
+ ids=[
+ "multiple_selects",
+ "multiple_statements",
+ "multiple_tables_created",
+ ],
+)
+def test_multiple_statements(ip_empty, config, sql, tables):
+ ip_empty.run_cell("%sql duckdb://")
+ ip_empty.run_cell(config)
+
+ ip_empty.run_cell("%sql CREATE TABLE weather (city VARCHAR,);")
+ ip_empty.run_cell("%sql INSERT INTO weather VALUES ('NYC');")
+ ip_empty.run_cell("%sql SELECT * FROM weather;")
+
+ out = ip_empty.run_cell(sql)
+ out_tables = ip_empty.run_cell("%sqlcmd tables")
+
+ assert out.error_in_exec is None
+
+ if config == "%config SqlMagic.autopandas = True":
+ assert out.result.to_dict() == {"city": {0: "NYC"}}
+ else:
+ assert out.result.dict() == {"city": ("NYC",)}
+
+ assert set(tables) == set(r[0] for r in out_tables.result._table.rows)
+
+
+def test_autopandas_when_last_result_is_not_a_select_statement(ip_empty):
+ ip_empty.run_cell("%sql duckdb://")
+ ip_empty.run_cell("%config SqlMagic.autopandas=True")
+
+ out = ip_empty.run_cell(
+ "%sql CREATE TABLE a (c VARCHAR,); CREATE TABLE b (c VARCHAR,);"
+ )
+
+ assert out.error_in_exec is None
+ assert out.result.to_dict() == dict()
+
+
+@pytest.mark.parametrize(
+ "sql",
+ [
+ (
+ "%sql CREATE TABLE a (x INT,); CREATE TABLE b (x INT,); "
+ "INSERT INTO a VALUES (1,); INSERT INTO b VALUES(2,); "
+ "SELECT * FROM a UNION ALL SELECT * FROM b;"
+ ),
+ """\
+%%sql
+CREATE TABLE a (x INT,);
+CREATE TABLE b (x INT,);
+INSERT INTO a VALUES (1,);
+INSERT INTO b VALUES(2,);
+SELECT * FROM a UNION ALL SELECT * FROM b;
+""",
+ ],
+)
+def test_commits_all_statements(ip_empty, sql):
+ ip_empty.run_cell("%sql duckdb://")
+ out = ip_empty.run_cell(sql)
+ assert out.error_in_exec is None
+ assert out.result.dict() == {"x": (1, 2)}
+
+
+def test_resultset_uses_native_duckdb_df(ip_empty):
+ engine = create_engine("duckdb://")
+ engine.execute("CREATE TABLE a (x INT,);")
+ engine.execute("INSERT INTO a(x) VALUES (10),(20),(30);")
+
+ sql = "SELECT * FROM a"
+ results = engine.execute(sql)
+
+ Connection.set(engine, displaycon=False)
+ results.fetchmany = Mock(wraps=results.fetchmany)
+ results.fetchone = Mock(side_effect=ValueError("Should not be called"))
+ results.fetchall = Mock(side_effect=ValueError("Should not be called"))
+
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 0
+
+ result_set = ResultSet(results, mock, sql=sql, engine=engine)
+ df = result_set.DataFrame()
+
+ assert isinstance(df, pd.DataFrame)
+ assert df.to_dict() == {"x": {0: 10, 1: 20, 2: 30}}
+
+ results.fetchmany.assert_called_once_with(size=2)
diff --git a/src/tests/test_inspect.py b/src/tests/test_inspect.py
index 4234d412d..5bf32fbab 100644
--- a/src/tests/test_inspect.py
+++ b/src/tests/test_inspect.py
@@ -224,20 +224,20 @@ def test_columns_with_missing_values(
monkeypatch.setattr(inspect, "_get_inspector", lambda _: mock)
ip.run_cell(
- """%%sql sqlite:///another.db
-CREATE TABLE IF NOT EXISTS another_table (id INT)
+ """%%sql duckdb://
+CREATE TABLE IF NOT EXISTS test_table (id INT)
"""
)
ip.run_cell(
- """%%sql sqlite:///my.db
-CREATE TABLE IF NOT EXISTS test_table (id INT)
+ """%%sql
+CREATE SCHEMA another_schema;
"""
)
ip.run_cell(
"""%%sql
-ATTACH DATABASE 'another.db' as 'another_schema';
+CREATE TABLE IF NOT EXISTS another_schema.another_table (id INT)
"""
)
diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py
index 814099e90..f4f3fc81f 100644
--- a/src/tests/test_magic.py
+++ b/src/tests/test_magic.py
@@ -1302,3 +1302,17 @@ def test_interact_and_missing_ipywidgets_installed(ip):
"%sql --interact my_variable SELECT * FROM author LIMIT {{my_variable}}"
)
assert isinstance(out.error_in_exec, ModuleNotFoundError)
+
+
+def test_generic_driver(ip_empty_testing_shell):
+ # TODO: add duckdb support
+ # ip_empty_testing_shell.run_cell("import duckdb")
+ # ip_empty_testing_shell.run_cell("conn = duckdb.connect(':memory:')")
+ ip_empty_testing_shell.run_cell("import sqlite3")
+ ip_empty_testing_shell.run_cell("conn = sqlite3.connect(':memory:')")
+
+ ip_empty_testing_shell.run_cell("%sql conn")
+ ip_empty_testing_shell.run_cell("%sql CREATE TABLE test (a INTEGER, b INTEGER)")
+ ip_empty_testing_shell.run_cell("%sql INSERT INTO test VALUES (1, 2)")
+ result = ip_empty_testing_shell.run_cell("%sql SELECT * FROM test")
+ assert result.result == [(1, 2)]
diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py
index d9b7ef97b..ea524741b 100644
--- a/src/tests/test_magic_cmd.py
+++ b/src/tests/test_magic_cmd.py
@@ -1,11 +1,10 @@
+import sqlite3
import sys
import math
import pytest
from IPython.core.error import UsageError
from pathlib import Path
-from sqlalchemy import create_engine
-from sql.connection import Connection
from sql.store import store
from sql.inspect import _is_numeric
@@ -112,8 +111,15 @@ def test_tables(ip):
def test_tables_with_schema(ip, tmp_empty):
- conn = Connection(engine=create_engine("sqlite:///my.db"))
- conn.execute("CREATE TABLE numbers (some_number FLOAT)")
+ # TODO: why does this fail?
+ # ip.run_cell(
+ # """%%sql sqlite:///my.db
+ # CREATE TABLE numbers (some_number FLOAT)
+ # """
+ # )
+
+ with sqlite3.connect("my.db") as conn:
+ conn.execute("CREATE TABLE numbers (some_number FLOAT)")
ip.run_cell(
"""%%sql
@@ -150,8 +156,8 @@ def test_columns(ip, cmd, cols):
def test_columns_with_schema(ip, tmp_empty):
- conn = Connection(engine=create_engine("sqlite:///my.db"))
- conn.execute("CREATE TABLE numbers (some_number FLOAT)")
+ with sqlite3.connect("my.db") as conn:
+ conn.execute("CREATE TABLE numbers (some_number FLOAT)")
ip.run_cell(
"""%%sql
diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py
index 52c7f338e..96fa90952 100644
--- a/src/tests/test_resultset.py
+++ b/src/tests/test_resultset.py
@@ -1,7 +1,8 @@
+from unittest.mock import Mock, call
+
from sqlalchemy import create_engine
-from sql.connection import Connection
from pathlib import Path
-from unittest.mock import Mock
+
import pytest
import pandas as pd
@@ -9,7 +10,6 @@
import sqlalchemy
from sql.run import ResultSet
-from sql import run as run_module
import re
@@ -35,7 +35,7 @@ def result():
@pytest.fixture
def result_set(result, config):
- return ResultSet(result, config).fetch_results()
+ return ResultSet(result, config)
def test_resultset_getitem(result_set):
@@ -51,12 +51,6 @@ def test_resultset_dicts(result_set):
assert list(result_set.dicts()) == [{"x": 0}, {"x": 1}, {"x": 2}]
-def test_resultset_dataframe(result_set, monkeypatch):
- monkeypatch.setattr(run_module.Connection, "current", Mock())
-
- assert result_set.DataFrame().equals(pd.DataFrame({"x": range(3)}))
-
-
def test_resultset_polars_dataframe(result_set, monkeypatch):
assert result_set.PolarsDataFrame().frame_equal(pl.DataFrame({"x": range(3)}))
@@ -71,189 +65,276 @@ def test_resultset_str(result_set):
assert str(result_set) == "+---+\n| x |\n+---+\n| 0 |\n| 1 |\n| 2 |\n+---+"
-def test_resultset_repr_html(result_set):
- assert result_set._repr_html_() == (
- "\n \n \n "
- "x | \n \n \n \n "
- "\n 0 | \n \n \n "
- "1 | \n \n \n 2 | \n "
- " \n \n "
- )
-
-
def test_resultset_config_autolimit_dict(result, config):
config.autolimit = 1
- resultset = ResultSet(result, config).fetch_results()
+ resultset = ResultSet(result, config)
assert resultset.dict() == {"x": (0,)}
-def test_resultset_with_non_sqlalchemy_results(config):
- df = pd.DataFrame({"x": range(3)}) # noqa
- conn = Connection(engine=create_engine("duckdb://"))
- result = conn.execute("SELECT * FROM df")
- assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)]
+def _get_number_of_rows_in_html_table(html):
+ """
+ Returns the number of | tags within the
section
+ """
+ pattern = r"(.*?)<\/tbody>"
+ tbody_content = re.findall(pattern, html, re.DOTALL)[0]
+ row_count = len(re.findall(r"", tbody_content))
+ return row_count
-def test_none_pretty(config):
- conn = Connection(engine=create_engine("sqlite://"))
- result = conn.execute("create table some_table (name, age)")
- result_set = ResultSet(result, config)
- assert result_set.pretty is None
- assert "" == str(result_set)
+# TODO: add dbapi tests
-def test_lazy_loading(result, config):
- resultset = ResultSet(result, config)
- assert len(resultset._results) == 0
- resultset.fetch_results()
- assert len(resultset._results) == 3
+@pytest.fixture
+def results(ip_empty):
+ engine = create_engine("duckdb://")
-@pytest.mark.parametrize(
- "autolimit, expected",
- [
- (None, 3),
- (False, 3),
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_autolimit(result, config, autolimit, expected):
- config.autolimit = autolimit
- resultset = ResultSet(result, config)
- assert len(resultset._results) == 0
- resultset.fetch_results()
- assert len(resultset._results) == expected
+ engine.execute("CREATE TABLE a (x INT,);")
+ engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
-@pytest.mark.parametrize(
- "displaylimit, expected",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_displaylimit(result, config, displaylimit, expected):
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
+ sql = "SELECT * FROM a"
+ results = engine.execute(sql)
- assert len(result_set._results) == 0
- result_set.fetch_results()
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
- assert row_count == expected
+ results.fetchmany = Mock(wraps=results.fetchmany)
+ results.fetchall = Mock(wraps=results.fetchall)
+ results.fetchone = Mock(wraps=results.fetchone)
+ # results.fetchone = Mock(side_effect=ValueError("fetchone called"))
-@pytest.mark.parametrize(
- "displaylimit, expected_display",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_displaylimit_fetch_all(
- result, config, displaylimit, expected_display
-):
- max_results_count = 3
- config.autolimit = False
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
+ yield results
- # Initialize result_set without fetching results
- assert len(result_set._results) == 0
- # Fetch the min number of rows (based on configuration)
- result_set.fetch_results()
+@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
+def test_convert_to_dataframe(ip_empty, uri):
+ engine = create_engine(uri)
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
- expected_results = (
- max_results_count
- if expected_display + 1 >= max_results_count
- else expected_display + 1
- )
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
- assert len(result_set._results) == expected_results
- assert row_count == expected_display
+ results = engine.execute("CREATE TABLE a (x INT);")
+ # results = engine.execute("CREATE TABLE test (n INT, name TEXT)")
- # Fetch the the rest results, but don't display them in the table
- result_set.fetch_results(fetch_all=True)
+ rs = ResultSet(results, mock)
+ df = rs.DataFrame()
- html = result_set._repr_html_()
- row_count = _get_number_of_rows_in_html_table(html)
+ assert df.to_dict() == {}
- assert len(result_set._results) == max_results_count
- assert row_count == expected_display
+def test_convert_to_dataframe_2(ip_empty):
+ engine = create_engine("duckdb://")
+
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
+
+ engine.execute("CREATE TABLE a (x INT,);")
+ results = engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ rs = ResultSet(results, mock)
+ df = rs.DataFrame()
+
+ assert df.to_dict() == {"Count": {0: 5}}
+
+
+@pytest.mark.parametrize("uri", ["duckdb://", "sqlite://"])
+def test_convert_to_dataframe_3(ip_empty, uri):
+ engine = create_engine(uri)
+
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 100000
+
+ engine.execute("CREATE TABLE a (x INT);")
+ engine.execute("INSERT INTO a(x) VALUES (1),(2),(3),(4),(5);")
+ results = engine.execute("SELECT * FROM a")
+ # from ipdb import set_trace
+
+ # set_trace()
+ rs = ResultSet(results, mock, sql="SELECT * FROM a", engine=engine)
+ df = rs.DataFrame()
+
+ # TODO: check native duckdb was called if using duckb
+ assert df.to_dict() == {"x": {0: 1, 1: 2, 2: 3, 3: 4, 4: 5}}
-@pytest.mark.parametrize(
- "displaylimit, expected_display",
- [
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
- ],
-)
-def test_lazy_loading_list(result, config, displaylimit, expected_display):
- max_results_count = 3
- config.autolimit = False
- config.displaylimit = displaylimit
- result_set = ResultSet(result, config)
- # Initialize result_set without fetching results
- assert len(result_set._results) == 0
+def test_done_fetching_if_reached_autolimit(results):
+ mock = Mock()
+ mock.autolimit = 2
+ mock.displaylimit = 100
- # Fetch the min number of rows (based on configuration)
- result_set.fetch_results()
+ rs = ResultSet(results, mock)
- expected_results = (
- max_results_count
- if expected_display + 1 >= max_results_count
- else expected_display + 1
- )
+ assert rs.did_finish_fetching() is True
- assert len(result_set._results) == expected_results
- assert len(list(result_set)) == max_results_count
+
+def test_done_fetching_if_reached_autolimit_2(results):
+ mock = Mock()
+ mock.autolimit = 4
+ mock.displaylimit = 100
+
+ rs = ResultSet(results, mock)
+ list(rs)
+
+ assert rs.did_finish_fetching() is True
+
+
+@pytest.mark.parametrize("method", ["__repr__", "_repr_html_"])
+@pytest.mark.parametrize("autolimit", [1000_000, 0])
+def test_no_displaylimit(results, method, autolimit):
+ mock = Mock()
+ mock.displaylimit = 0
+ mock.autolimit = autolimit
+
+ rs = ResultSet(results, mock)
+ getattr(rs, method)()
+
+ assert rs._results == [(1,), (2,), (3,), (4,), (5,)]
+ assert rs.did_finish_fetching() is True
+
+
+def test_no_fetching_if_one_result():
+ engine = create_engine("duckdb://")
+ engine.execute("CREATE TABLE a (x INT,);")
+ engine.execute("INSERT INTO a(x) VALUES (1);")
+
+ mock = Mock()
+ mock.displaylimit = 100
+ mock.autolimit = 1000_000
+
+ results = engine.execute("SELECT * FROM a")
+ results.fetchmany = Mock(wraps=results.fetchmany)
+ results.fetchall = Mock(wraps=results.fetchall)
+ results.fetchone = Mock(wraps=results.fetchone)
+
+ rs = ResultSet(results, mock)
+
+ assert rs.did_finish_fetching() is True
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+ str(rs)
+ list(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+
+# TODO: try with other values, and also change the display limit and re-run the repr
+# TODO: try with __repr__ and __str__
+def test_resultset_fetches_required_rows(results):
+ mock = Mock()
+ mock.displaylimit = 3
+ mock.autolimit = 1000_000
+
+ ResultSet(results, mock)
+ # rs.fetch_results()
+ # rs._repr_html_()
+ # str(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_called_once_with(size=2)
+ results.fetchone.assert_not_called()
+
+
+def test_fetches_remaining_rows(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 1000_000
+
+ rs = ResultSet(results, mock)
+
+ # this will trigger fetching two
+ str(rs)
+
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+
+ # this will trigger fetching the rest
+ assert list(rs) == [(1,), (2,), (3,), (4,), (5,)]
+
+ results.fetchall.assert_called_once_with()
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+
+ rs.sqlaproxy.fetchmany = Mock(side_effect=ValueError("fetchmany called"))
+ rs.sqlaproxy.fetchall = Mock(side_effect=ValueError("fetchall called"))
+ rs.sqlaproxy.fetchone = Mock(side_effect=ValueError("fetchone called"))
+
+ # this should not trigger any more fetching
+ assert list(rs) == [(1,), (2,), (3,), (4,), (5,)]
@pytest.mark.parametrize(
- "autolimit, expected_results",
+ "method, repr_",
[
- (0, 3),
- (1, 1),
- (2, 2),
- (3, 3),
- (4, 3),
+ [
+ "__repr__",
+ "+---+\n| x |\n+---+\n| 1 |\n| 2 |\n| 3 |\n+---+",
+ ],
+ [
+ "_repr_html_",
+ "\n \n \n x | \n "
+ "
\n \n \n \n "
+ "1 | \n
\n \n "
+ "2 | \n
\n \n "
+ "3 | \n
\n \n
",
+ ],
],
)
-def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results):
- config.autolimit = autolimit
- result_set = ResultSet(result, config)
- assert len(result_set._results) == 0
+def test_resultset_fetches_required_rows_repr_html(results, method, repr_):
+ mock = Mock()
+ mock.displaylimit = 3
+ mock.autolimit = 1000_000
- result_set.fetch_results()
+ rs = ResultSet(results, mock)
+ rs_repr = getattr(rs, method)()
- assert len(result_set._results) == expected_results
- assert len(list(result_set)) == expected_results
+ assert repr_ in rs_repr
+ assert rs.did_finish_fetching() is False
+ results.fetchall.assert_not_called()
+ results.fetchmany.assert_has_calls([call(size=2), call(size=1)])
+ results.fetchone.assert_not_called()
-def _get_number_of_rows_in_html_table(html):
- """
- Returns the number of
tags within the
section
- """
- pattern = r"(.*?)<\/tbody>"
- tbody_content = re.findall(pattern, html, re.DOTALL)[0]
- row_count = len(re.findall(r"", tbody_content))
+def test_resultset_fetches_no_rows(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 1000_000
- return row_count
+ ResultSet(results, mock)
+
+ results.fetchmany.assert_has_calls([call(size=2)])
+ results.fetchone.assert_not_called()
+ results.fetchall.assert_not_called()
+
+
+def test_resultset_autolimit_one(results):
+ mock = Mock()
+ mock.displaylimit = 10
+ mock.autolimit = 1
+
+ rs = ResultSet(results, mock)
+ repr(rs)
+ str(rs)
+ rs._repr_html_()
+ list(rs)
+
+ results.fetchmany.assert_has_calls([call(size=1)])
+ results.fetchone.assert_not_called()
+ results.fetchall.assert_not_called()
+
+
+# TODO: try with more values of displaylimit
+# TODO: test some edge cases. e.g., displaylimit is set to 10 but we only have 5 rows
+def test_displaylimit_message(results):
+ mock = Mock()
+ mock.displaylimit = 1
+ mock.autolimit = 0
+
+ rs = ResultSet(results, mock)
+
+ assert "Truncated to displaylimit of 1" in rs._repr_html_()