diff --git a/CHANGELOG.md b/CHANGELOG.md
index 26faf9020..57cf4740b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,8 @@
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)
+* [Fix] Refactored `ResultSet` to lazy loading (#470)
+
## 0.7.9 (2023-06-19)
* [Feature] Modified `histogram` command to support data with NULL values ([#176](https://github.com/ploomber/jupysql/issues/176))
diff --git a/src/sql/run.py b/src/sql/run.py
index 9fa52b7a6..fb8bcaff3 100644
--- a/src/sql/run.py
+++ b/src/sql/run.py
@@ -113,41 +113,14 @@ def __init__(self, sqlaproxy, config):
self.config = config
self.keys = {}
self._results = []
+ self.truncated = False
+ self.sqlaproxy = sqlaproxy
+
# https://peps.python.org/pep-0249/#description
- is_dbapi_results = hasattr(sqlaproxy, "description")
+ self.is_dbapi_results = hasattr(sqlaproxy, "description")
self.pretty = None
- if is_dbapi_results:
- should_try_fetch_results = True
- else:
- should_try_fetch_results = sqlaproxy.returns_rows
-
- if should_try_fetch_results:
- # sql alchemy results
- if not is_dbapi_results:
- self.keys = sqlaproxy.keys()
- elif isinstance(sqlaproxy.description, Iterable):
- self.keys = [i[0] for i in sqlaproxy.description]
- else:
- self.keys = []
-
- if len(self.keys) > 0:
- if isinstance(config.autolimit, int) and config.autolimit > 0:
- self._results = sqlaproxy.fetchmany(size=config.autolimit)
- else:
- self._results = sqlaproxy.fetchall()
-
- self.field_names = unduplicate_field_names(self.keys)
-
- _style = None
-
- self.pretty = PrettyTable(self.field_names)
-
- if isinstance(config.style, str):
- _style = prettytable.__dict__[config.style.upper()]
- self.pretty.set_style(_style)
-
def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r"(
)( {2,})")
if self.pretty:
@@ -156,17 +129,17 @@ def _repr_html_(self):
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
- if len(self) > self.pretty.row_count:
+ if self.truncated:
HTML = (
'%s\n'
- "%d rows, truncated to displaylimit of %d"
+ "Truncated to displaylimit of %d"
" "
''
"If you want to see more, please visit "
'displaylimit' # noqa: E501
" configuration"
)
- result = HTML % (result, len(self), self.pretty.row_count)
+ result = HTML % (result, self.pretty.row_count)
return result
else:
return None
@@ -175,7 +148,9 @@ def __len__(self):
return len(self._results)
def __iter__(self):
- for result in self._results:
+ results = self._fetch_query_results(fetch_all=True)
+
+ for result in results:
yield result
def __str__(self, *arg, **kwarg):
@@ -364,6 +339,105 @@ def csv(self, filename=None, **format_params):
else:
return outfile.getvalue()
+ def fetch_results(self, fetch_all=False):
+ """
+ Returns a limited representation of the query results.
+
+ Parameters
+ ----------
+ fetch_all : bool default False
+ Return all query rows
+ """
+ is_dbapi_results = self.is_dbapi_results
+ sqlaproxy = self.sqlaproxy
+ config = self.config
+
+ if is_dbapi_results:
+ should_try_fetch_results = True
+ else:
+ should_try_fetch_results = sqlaproxy.returns_rows
+
+ if should_try_fetch_results:
+ # sql alchemy results
+ if not is_dbapi_results:
+ self.keys = sqlaproxy.keys()
+ elif isinstance(sqlaproxy.description, Iterable):
+ self.keys = [i[0] for i in sqlaproxy.description]
+ else:
+ self.keys = []
+
+ if len(self.keys) > 0:
+ self._results = self._fetch_query_results(fetch_all=fetch_all)
+
+ self.field_names = unduplicate_field_names(self.keys)
+
+ _style = None
+
+ self.pretty = PrettyTable(self.field_names)
+
+ if isinstance(config.style, str):
+ _style = prettytable.__dict__[config.style.upper()]
+ self.pretty.set_style(_style)
+
+ return self
+
+ def _fetch_query_results(self, fetch_all=False):
+ """
+ Returns rows of a query result as a list of tuples.
+
+ Parameters
+ ----------
+ fetch_all : bool default False
+ Return all query rows
+ """
+ sqlaproxy = self.sqlaproxy
+ config = self.config
+ _should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed")
+
+ _should_fetch_all = (
+ (config.displaylimit == 0 or not config.displaylimit)
+ or fetch_all
+ or not _should_try_lazy_fetch
+ )
+
+ is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0
+ is_connection_closed = (
+ sqlaproxy._soft_closed if _should_try_lazy_fetch else False
+ )
+
+ should_return_results = is_connection_closed or (
+ len(self._results) > 0 and is_autolimit
+ )
+
+ if should_return_results:
+ # this means we already loaded all
+ # the results to self._results or we use
+ # autolimit and shouldn't fetch more
+ results = self._results
+ else:
+ if is_autolimit:
+ results = sqlaproxy.fetchmany(size=config.autolimit)
+ else:
+ if _should_fetch_all:
+ all_results = sqlaproxy.fetchall()
+ results = self._results + all_results
+ self._results = results
+ else:
+ results = sqlaproxy.fetchmany(size=config.displaylimit)
+
+ if _should_try_lazy_fetch:
+ # Try to fetch an extra row to find out
+ # if there are more results to fetch
+ row = sqlaproxy.fetchone()
+ if row is not None:
+ results += [row]
+
+ # Check if we have more rows to show
+ if config.displaylimit > 0:
+ self.truncated = len(results) > config.displaylimit
+
+ return results
+
def display_affected_rowcount(rowcount):
if rowcount > 0:
@@ -557,6 +631,9 @@ def run(conn, sql, config):
return df
else:
resultset = ResultSet(result, config)
+
+ # lazy load
+ resultset.fetch_results()
return select_df_type(resultset, config)
diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py
index c49a9c076..6c25a9110 100644
--- a/src/tests/integration/test_generic_db_operations.py
+++ b/src/tests/integration/test_generic_db_operations.py
@@ -84,6 +84,8 @@ def test_create_table_with_indexed_df(
ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db)
# Clean up
+ ip_with_dynamic_db.run_cell("%config SqlMagic.displaylimit = 0")
+
ip_with_dynamic_db.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
)
diff --git a/src/tests/integration/test_mssql.py b/src/tests/integration/test_mssql.py
index 4bd8a3f72..ab8e07ae9 100644
--- a/src/tests/integration/test_mssql.py
+++ b/src/tests/integration/test_mssql.py
@@ -30,6 +30,9 @@ def test_cte(ip_with_MSSQL, test_table_name_dict):
def test_create_table_with_indexed_df(ip_with_MSSQL, test_table_name_dict):
# MSSQL gives error if DB doesn't exist
+
+ ip_with_MSSQL.run_cell("%config SqlMagic.displaylimit = 0")
+
try:
ip_with_MSSQL.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
diff --git a/src/tests/integration/test_oracle.py b/src/tests/integration/test_oracle.py
index cb4c47698..d80935480 100644
--- a/src/tests/integration/test_oracle.py
+++ b/src/tests/integration/test_oracle.py
@@ -17,6 +17,8 @@ def test_query_count(ip_with_oracle, test_table_name_dict):
@pytest.mark.xfail(reason="Some issue with checking isidentifier part in persist")
def test_create_table_with_indexed_df(ip_with_oracle, test_table_name_dict):
+ ip_with_oracle.run_cell("%config SqlMagic.displaylimit = 0")
+
# Prepare DF
ip_with_oracle.run_cell(
f"""results = %sql SELECT * FROM {test_table_name_dict['taxi']} \
diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py
index 56c18e65d..814099e90 100644
--- a/src/tests/test_magic.py
+++ b/src/tests/test_magic.py
@@ -556,7 +556,7 @@ def test_displaylimit_default(ip):
ip.run_cell("%sql INSERT INTO number_table VALUES (4, 3)")
out = runsql(ip, "SELECT * FROM number_table;")
- assert "truncated to displaylimit of 10" in out._repr_html_()
+ assert "Truncated to displaylimit of 10" in out._repr_html_()
def test_displaylimit(ip):
@@ -577,7 +577,7 @@ def test_displaylimit_enabled_truncated_length(ip, config_value, expected_length
ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
- assert f"truncated to displaylimit of {expected_length}" in out._repr_html_()
+ assert f"Truncated to displaylimit of {expected_length}" in out._repr_html_()
@pytest.mark.parametrize("config_value", [(None), (0)])
@@ -591,7 +591,7 @@ def test_displaylimit_enabled_no_limit(
ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
- assert "truncated to displaylimit of " not in out._repr_html_()
+ assert "Truncated to displaylimit of " not in out._repr_html_()
@pytest.mark.parametrize(
@@ -645,10 +645,7 @@ def test_displaylimit_with_conditional_clause(
out = runsql(ip, query_clause)
if expected_truncated_length:
- assert (
- f"{expected_truncated_length} rows, truncated to displaylimit of 10"
- in out._repr_html_()
- )
+ assert "Truncated to displaylimit of 10" in out._repr_html_()
def test_column_local_vars(ip):
diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py
index 07fc325a1..52c7f338e 100644
--- a/src/tests/test_resultset.py
+++ b/src/tests/test_resultset.py
@@ -11,6 +11,8 @@
from sql.run import ResultSet
from sql import run as run_module
+import re
+
@pytest.fixture
def config():
@@ -33,7 +35,7 @@ def result():
@pytest.fixture
def result_set(result, config):
- return ResultSet(result, config)
+ return ResultSet(result, config).fetch_results()
def test_resultset_getitem(result_set):
@@ -81,15 +83,15 @@ def test_resultset_repr_html(result_set):
def test_resultset_config_autolimit_dict(result, config):
config.autolimit = 1
-
- assert ResultSet(result, config).dict() == {"x": (0,)}
+ resultset = ResultSet(result, config).fetch_results()
+ assert resultset.dict() == {"x": (0,)}
def test_resultset_with_non_sqlalchemy_results(config):
df = pd.DataFrame({"x": range(3)}) # noqa
conn = Connection(engine=create_engine("duckdb://"))
result = conn.execute("SELECT * FROM df")
- assert ResultSet(result, config) == [(0,), (1,), (2,)]
+ assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)]
def test_none_pretty(config):
@@ -98,3 +100,160 @@ def test_none_pretty(config):
result_set = ResultSet(result, config)
assert result_set.pretty is None
assert "" == str(result_set)
+
+
+def test_lazy_loading(result, config):
+ resultset = ResultSet(result, config)
+ assert len(resultset._results) == 0
+ resultset.fetch_results()
+ assert len(resultset._results) == 3
+
+
+@pytest.mark.parametrize(
+ "autolimit, expected",
+ [
+ (None, 3),
+ (False, 3),
+ (0, 3),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 3),
+ ],
+)
+def test_lazy_loading_autolimit(result, config, autolimit, expected):
+ config.autolimit = autolimit
+ resultset = ResultSet(result, config)
+ assert len(resultset._results) == 0
+ resultset.fetch_results()
+ assert len(resultset._results) == expected
+
+
+@pytest.mark.parametrize(
+ "displaylimit, expected",
+ [
+ (0, 3),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 3),
+ ],
+)
+def test_lazy_loading_displaylimit(result, config, displaylimit, expected):
+ config.displaylimit = displaylimit
+ result_set = ResultSet(result, config)
+
+ assert len(result_set._results) == 0
+ result_set.fetch_results()
+ html = result_set._repr_html_()
+ row_count = _get_number_of_rows_in_html_table(html)
+ assert row_count == expected
+
+
+@pytest.mark.parametrize(
+ "displaylimit, expected_display",
+ [
+ (0, 3),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 3),
+ ],
+)
+def test_lazy_loading_displaylimit_fetch_all(
+ result, config, displaylimit, expected_display
+):
+ max_results_count = 3
+ config.autolimit = False
+ config.displaylimit = displaylimit
+ result_set = ResultSet(result, config)
+
+ # Initialize result_set without fetching results
+ assert len(result_set._results) == 0
+
+ # Fetch the min number of rows (based on configuration)
+ result_set.fetch_results()
+
+ html = result_set._repr_html_()
+ row_count = _get_number_of_rows_in_html_table(html)
+ expected_results = (
+ max_results_count
+ if expected_display + 1 >= max_results_count
+ else expected_display + 1
+ )
+
+ assert len(result_set._results) == expected_results
+ assert row_count == expected_display
+
+ # Fetch the the rest results, but don't display them in the table
+ result_set.fetch_results(fetch_all=True)
+
+ html = result_set._repr_html_()
+ row_count = _get_number_of_rows_in_html_table(html)
+
+ assert len(result_set._results) == max_results_count
+ assert row_count == expected_display
+
+
+@pytest.mark.parametrize(
+ "displaylimit, expected_display",
+ [
+ (0, 3),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 3),
+ ],
+)
+def test_lazy_loading_list(result, config, displaylimit, expected_display):
+ max_results_count = 3
+ config.autolimit = False
+ config.displaylimit = displaylimit
+ result_set = ResultSet(result, config)
+
+ # Initialize result_set without fetching results
+ assert len(result_set._results) == 0
+
+ # Fetch the min number of rows (based on configuration)
+ result_set.fetch_results()
+
+ expected_results = (
+ max_results_count
+ if expected_display + 1 >= max_results_count
+ else expected_display + 1
+ )
+
+ assert len(result_set._results) == expected_results
+ assert len(list(result_set)) == max_results_count
+
+
+@pytest.mark.parametrize(
+ "autolimit, expected_results",
+ [
+ (0, 3),
+ (1, 1),
+ (2, 2),
+ (3, 3),
+ (4, 3),
+ ],
+)
+def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results):
+ config.autolimit = autolimit
+ result_set = ResultSet(result, config)
+ assert len(result_set._results) == 0
+
+ result_set.fetch_results()
+
+ assert len(result_set._results) == expected_results
+ assert len(list(result_set)) == expected_results
+
+
+def _get_number_of_rows_in_html_table(html):
+ """
+ Returns the number of | tags within the
section
+ """
+ pattern = r"(.*?)<\/tbody>"
+ tbody_content = re.findall(pattern, html, re.DOTALL)[0]
+ row_count = len(re.findall(r"", tbody_content))
+
+ return row_count
diff --git a/src/tests/test_run.py b/src/tests/test_run.py
index c3197dba0..ebfd97214 100644
--- a/src/tests/test_run.py
+++ b/src/tests/test_run.py
@@ -74,6 +74,10 @@ def DataFrame(cls):
def PolarsDataFrame(cls):
return polars.DataFrame()
+ @classmethod
+ def fetch_results(self, fetch_all=False):
+ pass
+
return ResultSet