diff --git a/CHANGELOG.md b/CHANGELOG.md index 26faf9020..57cf4740b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ * [Doc] Document --persist-replace in API section (#539) * [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631) +* [Fix] Refactored `ResultSet` to lazy loading (#470) + ## 0.7.9 (2023-06-19) * [Feature] Modified `histogram` command to support data with NULL values ([#176](https://github.com/ploomber/jupysql/issues/176)) diff --git a/src/sql/run.py b/src/sql/run.py index 9fa52b7a6..fb8bcaff3 100644 --- a/src/sql/run.py +++ b/src/sql/run.py @@ -113,41 +113,14 @@ def __init__(self, sqlaproxy, config): self.config = config self.keys = {} self._results = [] + self.truncated = False + self.sqlaproxy = sqlaproxy + # https://peps.python.org/pep-0249/#description - is_dbapi_results = hasattr(sqlaproxy, "description") + self.is_dbapi_results = hasattr(sqlaproxy, "description") self.pretty = None - if is_dbapi_results: - should_try_fetch_results = True - else: - should_try_fetch_results = sqlaproxy.returns_rows - - if should_try_fetch_results: - # sql alchemy results - if not is_dbapi_results: - self.keys = sqlaproxy.keys() - elif isinstance(sqlaproxy.description, Iterable): - self.keys = [i[0] for i in sqlaproxy.description] - else: - self.keys = [] - - if len(self.keys) > 0: - if isinstance(config.autolimit, int) and config.autolimit > 0: - self._results = sqlaproxy.fetchmany(size=config.autolimit) - else: - self._results = sqlaproxy.fetchall() - - self.field_names = unduplicate_field_names(self.keys) - - _style = None - - self.pretty = PrettyTable(self.field_names) - - if isinstance(config.style, str): - _style = prettytable.__dict__[config.style.upper()] - self.pretty.set_style(_style) - def _repr_html_(self): _cell_with_spaces_pattern = re.compile(r"()( {2,})") if self.pretty: @@ -156,17 +129,17 @@ def _repr_html_(self): # to create clickable links result = html.unescape(result) result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result) - if len(self) > self.pretty.row_count: + if self.truncated: HTML = ( '%s\n' - "%d rows, truncated to displaylimit of %d" + "Truncated to displaylimit of %d" "
" '' "If you want to see more, please visit " 'displaylimit' # noqa: E501 " configuration" ) - result = HTML % (result, len(self), self.pretty.row_count) + result = HTML % (result, self.pretty.row_count) return result else: return None @@ -175,7 +148,9 @@ def __len__(self): return len(self._results) def __iter__(self): - for result in self._results: + results = self._fetch_query_results(fetch_all=True) + + for result in results: yield result def __str__(self, *arg, **kwarg): @@ -364,6 +339,105 @@ def csv(self, filename=None, **format_params): else: return outfile.getvalue() + def fetch_results(self, fetch_all=False): + """ + Returns a limited representation of the query results. + + Parameters + ---------- + fetch_all : bool default False + Return all query rows + """ + is_dbapi_results = self.is_dbapi_results + sqlaproxy = self.sqlaproxy + config = self.config + + if is_dbapi_results: + should_try_fetch_results = True + else: + should_try_fetch_results = sqlaproxy.returns_rows + + if should_try_fetch_results: + # sql alchemy results + if not is_dbapi_results: + self.keys = sqlaproxy.keys() + elif isinstance(sqlaproxy.description, Iterable): + self.keys = [i[0] for i in sqlaproxy.description] + else: + self.keys = [] + + if len(self.keys) > 0: + self._results = self._fetch_query_results(fetch_all=fetch_all) + + self.field_names = unduplicate_field_names(self.keys) + + _style = None + + self.pretty = PrettyTable(self.field_names) + + if isinstance(config.style, str): + _style = prettytable.__dict__[config.style.upper()] + self.pretty.set_style(_style) + + return self + + def _fetch_query_results(self, fetch_all=False): + """ + Returns rows of a query result as a list of tuples. + + Parameters + ---------- + fetch_all : bool default False + Return all query rows + """ + sqlaproxy = self.sqlaproxy + config = self.config + _should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed") + + _should_fetch_all = ( + (config.displaylimit == 0 or not config.displaylimit) + or fetch_all + or not _should_try_lazy_fetch + ) + + is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0 + is_connection_closed = ( + sqlaproxy._soft_closed if _should_try_lazy_fetch else False + ) + + should_return_results = is_connection_closed or ( + len(self._results) > 0 and is_autolimit + ) + + if should_return_results: + # this means we already loaded all + # the results to self._results or we use + # autolimit and shouldn't fetch more + results = self._results + else: + if is_autolimit: + results = sqlaproxy.fetchmany(size=config.autolimit) + else: + if _should_fetch_all: + all_results = sqlaproxy.fetchall() + results = self._results + all_results + self._results = results + else: + results = sqlaproxy.fetchmany(size=config.displaylimit) + + if _should_try_lazy_fetch: + # Try to fetch an extra row to find out + # if there are more results to fetch + row = sqlaproxy.fetchone() + if row is not None: + results += [row] + + # Check if we have more rows to show + if config.displaylimit > 0: + self.truncated = len(results) > config.displaylimit + + return results + def display_affected_rowcount(rowcount): if rowcount > 0: @@ -557,6 +631,9 @@ def run(conn, sql, config): return df else: resultset = ResultSet(result, config) + + # lazy load + resultset.fetch_results() return select_df_type(resultset, config) diff --git a/src/tests/integration/test_generic_db_operations.py b/src/tests/integration/test_generic_db_operations.py index c49a9c076..6c25a9110 100644 --- a/src/tests/integration/test_generic_db_operations.py +++ b/src/tests/integration/test_generic_db_operations.py @@ -84,6 +84,8 @@ def test_create_table_with_indexed_df( ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db) # Clean up + ip_with_dynamic_db.run_cell("%config SqlMagic.displaylimit = 0") + ip_with_dynamic_db.run_cell( f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}" ) diff --git a/src/tests/integration/test_mssql.py b/src/tests/integration/test_mssql.py index 4bd8a3f72..ab8e07ae9 100644 --- a/src/tests/integration/test_mssql.py +++ b/src/tests/integration/test_mssql.py @@ -30,6 +30,9 @@ def test_cte(ip_with_MSSQL, test_table_name_dict): def test_create_table_with_indexed_df(ip_with_MSSQL, test_table_name_dict): # MSSQL gives error if DB doesn't exist + + ip_with_MSSQL.run_cell("%config SqlMagic.displaylimit = 0") + try: ip_with_MSSQL.run_cell( f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}" diff --git a/src/tests/integration/test_oracle.py b/src/tests/integration/test_oracle.py index cb4c47698..d80935480 100644 --- a/src/tests/integration/test_oracle.py +++ b/src/tests/integration/test_oracle.py @@ -17,6 +17,8 @@ def test_query_count(ip_with_oracle, test_table_name_dict): @pytest.mark.xfail(reason="Some issue with checking isidentifier part in persist") def test_create_table_with_indexed_df(ip_with_oracle, test_table_name_dict): + ip_with_oracle.run_cell("%config SqlMagic.displaylimit = 0") + # Prepare DF ip_with_oracle.run_cell( f"""results = %sql SELECT * FROM {test_table_name_dict['taxi']} \ diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 56c18e65d..814099e90 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -556,7 +556,7 @@ def test_displaylimit_default(ip): ip.run_cell("%sql INSERT INTO number_table VALUES (4, 3)") out = runsql(ip, "SELECT * FROM number_table;") - assert "truncated to displaylimit of 10" in out._repr_html_() + assert "Truncated to displaylimit of 10" in out._repr_html_() def test_displaylimit(ip): @@ -577,7 +577,7 @@ def test_displaylimit_enabled_truncated_length(ip, config_value, expected_length ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}") out = runsql(ip, "SELECT * FROM number_table;") - assert f"truncated to displaylimit of {expected_length}" in out._repr_html_() + assert f"Truncated to displaylimit of {expected_length}" in out._repr_html_() @pytest.mark.parametrize("config_value", [(None), (0)]) @@ -591,7 +591,7 @@ def test_displaylimit_enabled_no_limit( ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}") out = runsql(ip, "SELECT * FROM number_table;") - assert "truncated to displaylimit of " not in out._repr_html_() + assert "Truncated to displaylimit of " not in out._repr_html_() @pytest.mark.parametrize( @@ -645,10 +645,7 @@ def test_displaylimit_with_conditional_clause( out = runsql(ip, query_clause) if expected_truncated_length: - assert ( - f"{expected_truncated_length} rows, truncated to displaylimit of 10" - in out._repr_html_() - ) + assert "Truncated to displaylimit of 10" in out._repr_html_() def test_column_local_vars(ip): diff --git a/src/tests/test_resultset.py b/src/tests/test_resultset.py index 07fc325a1..52c7f338e 100644 --- a/src/tests/test_resultset.py +++ b/src/tests/test_resultset.py @@ -11,6 +11,8 @@ from sql.run import ResultSet from sql import run as run_module +import re + @pytest.fixture def config(): @@ -33,7 +35,7 @@ def result(): @pytest.fixture def result_set(result, config): - return ResultSet(result, config) + return ResultSet(result, config).fetch_results() def test_resultset_getitem(result_set): @@ -81,15 +83,15 @@ def test_resultset_repr_html(result_set): def test_resultset_config_autolimit_dict(result, config): config.autolimit = 1 - - assert ResultSet(result, config).dict() == {"x": (0,)} + resultset = ResultSet(result, config).fetch_results() + assert resultset.dict() == {"x": (0,)} def test_resultset_with_non_sqlalchemy_results(config): df = pd.DataFrame({"x": range(3)}) # noqa conn = Connection(engine=create_engine("duckdb://")) result = conn.execute("SELECT * FROM df") - assert ResultSet(result, config) == [(0,), (1,), (2,)] + assert ResultSet(result, config).fetch_results() == [(0,), (1,), (2,)] def test_none_pretty(config): @@ -98,3 +100,160 @@ def test_none_pretty(config): result_set = ResultSet(result, config) assert result_set.pretty is None assert "" == str(result_set) + + +def test_lazy_loading(result, config): + resultset = ResultSet(result, config) + assert len(resultset._results) == 0 + resultset.fetch_results() + assert len(resultset._results) == 3 + + +@pytest.mark.parametrize( + "autolimit, expected", + [ + (None, 3), + (False, 3), + (0, 3), + (1, 1), + (2, 2), + (3, 3), + (4, 3), + ], +) +def test_lazy_loading_autolimit(result, config, autolimit, expected): + config.autolimit = autolimit + resultset = ResultSet(result, config) + assert len(resultset._results) == 0 + resultset.fetch_results() + assert len(resultset._results) == expected + + +@pytest.mark.parametrize( + "displaylimit, expected", + [ + (0, 3), + (1, 1), + (2, 2), + (3, 3), + (4, 3), + ], +) +def test_lazy_loading_displaylimit(result, config, displaylimit, expected): + config.displaylimit = displaylimit + result_set = ResultSet(result, config) + + assert len(result_set._results) == 0 + result_set.fetch_results() + html = result_set._repr_html_() + row_count = _get_number_of_rows_in_html_table(html) + assert row_count == expected + + +@pytest.mark.parametrize( + "displaylimit, expected_display", + [ + (0, 3), + (1, 1), + (2, 2), + (3, 3), + (4, 3), + ], +) +def test_lazy_loading_displaylimit_fetch_all( + result, config, displaylimit, expected_display +): + max_results_count = 3 + config.autolimit = False + config.displaylimit = displaylimit + result_set = ResultSet(result, config) + + # Initialize result_set without fetching results + assert len(result_set._results) == 0 + + # Fetch the min number of rows (based on configuration) + result_set.fetch_results() + + html = result_set._repr_html_() + row_count = _get_number_of_rows_in_html_table(html) + expected_results = ( + max_results_count + if expected_display + 1 >= max_results_count + else expected_display + 1 + ) + + assert len(result_set._results) == expected_results + assert row_count == expected_display + + # Fetch the the rest results, but don't display them in the table + result_set.fetch_results(fetch_all=True) + + html = result_set._repr_html_() + row_count = _get_number_of_rows_in_html_table(html) + + assert len(result_set._results) == max_results_count + assert row_count == expected_display + + +@pytest.mark.parametrize( + "displaylimit, expected_display", + [ + (0, 3), + (1, 1), + (2, 2), + (3, 3), + (4, 3), + ], +) +def test_lazy_loading_list(result, config, displaylimit, expected_display): + max_results_count = 3 + config.autolimit = False + config.displaylimit = displaylimit + result_set = ResultSet(result, config) + + # Initialize result_set without fetching results + assert len(result_set._results) == 0 + + # Fetch the min number of rows (based on configuration) + result_set.fetch_results() + + expected_results = ( + max_results_count + if expected_display + 1 >= max_results_count + else expected_display + 1 + ) + + assert len(result_set._results) == expected_results + assert len(list(result_set)) == max_results_count + + +@pytest.mark.parametrize( + "autolimit, expected_results", + [ + (0, 3), + (1, 1), + (2, 2), + (3, 3), + (4, 3), + ], +) +def test_lazy_loading_autolimit_list(result, config, autolimit, expected_results): + config.autolimit = autolimit + result_set = ResultSet(result, config) + assert len(result_set._results) == 0 + + result_set.fetch_results() + + assert len(result_set._results) == expected_results + assert len(list(result_set)) == expected_results + + +def _get_number_of_rows_in_html_table(html): + """ + Returns the number of tags within the section + """ + pattern = r"(.*?)<\/tbody>" + tbody_content = re.findall(pattern, html, re.DOTALL)[0] + row_count = len(re.findall(r"", tbody_content)) + + return row_count diff --git a/src/tests/test_run.py b/src/tests/test_run.py index c3197dba0..ebfd97214 100644 --- a/src/tests/test_run.py +++ b/src/tests/test_run.py @@ -74,6 +74,10 @@ def DataFrame(cls): def PolarsDataFrame(cls): return polars.DataFrame() + @classmethod + def fetch_results(self, fetch_all=False): + pass + return ResultSet