Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored ResultSet to lazy loading #624

Merged
merged 10 commits into from
Jun 22, 2023
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## 0.7.10dev

* [Fix] Refactored `ResultSet` to lazy loading (#470)

## 0.7.9 (2023-06-19)

* [Feature] Modified `histogram` command to support data with NULL values ([#176](https://github.com/ploomber/jupysql/issues/176))
Expand Down
147 changes: 112 additions & 35 deletions src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,41 +113,14 @@ def __init__(self, sqlaproxy, config):
self.config = config
self.keys = {}
self._results = []
self.truncated = False
self.sqlaproxy = sqlaproxy

# https://peps.python.org/pep-0249/#description
is_dbapi_results = hasattr(sqlaproxy, "description")
self.is_dbapi_results = hasattr(sqlaproxy, "description")

self.pretty = None

if is_dbapi_results:
should_try_fetch_results = True
else:
should_try_fetch_results = sqlaproxy.returns_rows

if should_try_fetch_results:
# sql alchemy results
if not is_dbapi_results:
self.keys = sqlaproxy.keys()
elif isinstance(sqlaproxy.description, Iterable):
self.keys = [i[0] for i in sqlaproxy.description]
else:
self.keys = []

if len(self.keys) > 0:
if isinstance(config.autolimit, int) and config.autolimit > 0:
self._results = sqlaproxy.fetchmany(size=config.autolimit)
else:
self._results = sqlaproxy.fetchall()

self.field_names = unduplicate_field_names(self.keys)

_style = None

self.pretty = PrettyTable(self.field_names)

if isinstance(config.style, str):
_style = prettytable.__dict__[config.style.upper()]
self.pretty.set_style(_style)

def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
if self.pretty:
Expand All @@ -156,17 +129,17 @@ def _repr_html_(self):
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
if len(self) > self.pretty.row_count:
if self.truncated:
HTML = (
'%s\n<span style="font-style:italic;text-align:center;">'
"%d rows, truncated to displaylimit of %d</span>"
"Truncated to displaylimit of %d</span>"
"<br>"
'<span style="font-style:italic;text-align:center;">'
"If you want to see more, please visit "
'<a href="https://jupysql.ploomber.io/en/latest/api/configuration.html#displaylimit">displaylimit</a>' # noqa: E501
" configuration</span>"
)
result = HTML % (result, len(self), self.pretty.row_count)
result = HTML % (result, self.pretty.row_count)
return result
else:
return None
Expand All @@ -175,7 +148,9 @@ def __len__(self):
return len(self._results)

def __iter__(self):
for result in self._results:
results = self._fetch_query_results(fetch_all=True)

for result in results:
yield result

def __str__(self, *arg, **kwarg):
Expand Down Expand Up @@ -364,6 +339,105 @@ def csv(self, filename=None, **format_params):
else:
return outfile.getvalue()

def fetch_results(self, fetch_all=False):
"""
Returns a limited representation of the query results.

Parameters
----------
fetch_all : bool default False
Return all query rows
"""
is_dbapi_results = self.is_dbapi_results
sqlaproxy = self.sqlaproxy
config = self.config

if is_dbapi_results:
should_try_fetch_results = True
else:
should_try_fetch_results = sqlaproxy.returns_rows

if should_try_fetch_results:
# sql alchemy results
if not is_dbapi_results:
self.keys = sqlaproxy.keys()
elif isinstance(sqlaproxy.description, Iterable):
self.keys = [i[0] for i in sqlaproxy.description]
else:
self.keys = []

if len(self.keys) > 0:
self._results = self._fetch_query_results(fetch_all=fetch_all)

self.field_names = unduplicate_field_names(self.keys)

_style = None

self.pretty = PrettyTable(self.field_names)

if isinstance(config.style, str):
_style = prettytable.__dict__[config.style.upper()]
self.pretty.set_style(_style)

return self

def _fetch_query_results(self, fetch_all=False):
"""
Returns rows of a query result as a list of tuples.

Parameters
----------
fetch_all : bool default False
Return all query rows
"""
sqlaproxy = self.sqlaproxy
config = self.config
_should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed")

_should_fetch_all = (
(config.displaylimit == 0 or not config.displaylimit)
or fetch_all
or not _should_try_lazy_fetch
)

is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0
is_connection_closed = (
sqlaproxy._soft_closed if _should_try_lazy_fetch else False
)

should_return_results = is_connection_closed or (
len(self._results) > 0 and is_autolimit
)

if should_return_results:
# this means we already loaded all
# the results to self._results or we use
# autolimit and shouldn't fetch more
results = self._results
else:
if is_autolimit:
results = sqlaproxy.fetchmany(size=config.autolimit)
else:
if _should_fetch_all:
all_results = sqlaproxy.fetchall()
results = self._results + all_results
self._results = results
else:
results = sqlaproxy.fetchmany(size=config.displaylimit)

if _should_try_lazy_fetch:
# Try to fetch an extra row to find out
# if there are more results to fetch
row = sqlaproxy.fetchone()
if row is not None:
results += [row]

# Check if we have more rows to show
if config.displaylimit > 0:
self.truncated = len(results) > config.displaylimit

return results


def display_affected_rowcount(rowcount):
if rowcount > 0:
Expand Down Expand Up @@ -557,6 +631,9 @@ def run(conn, sql, config):
return df
else:
resultset = ResultSet(result, config)

# lazy load
resultset.fetch_results()
return select_df_type(resultset, config)


Expand Down
2 changes: 2 additions & 0 deletions src/tests/integration/test_generic_db_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def test_create_table_with_indexed_df(
ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db)
# Clean up

ip_with_dynamic_db.run_cell("%config SqlMagic.displaylimit = 0")

ip_with_dynamic_db.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
)
Expand Down
3 changes: 3 additions & 0 deletions src/tests/integration/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_cte(ip_with_MSSQL, test_table_name_dict):

def test_create_table_with_indexed_df(ip_with_MSSQL, test_table_name_dict):
# MSSQL gives error if DB doesn't exist

ip_with_MSSQL.run_cell("%config SqlMagic.displaylimit = 0")

try:
ip_with_MSSQL.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
Expand Down
2 changes: 2 additions & 0 deletions src/tests/integration/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def test_query_count(ip_with_oracle, test_table_name_dict):

@pytest.mark.xfail(reason="Some issue with checking isidentifier part in persist")
def test_create_table_with_indexed_df(ip_with_oracle, test_table_name_dict):
ip_with_oracle.run_cell("%config SqlMagic.displaylimit = 0")

# Prepare DF
ip_with_oracle.run_cell(
f"""results = %sql SELECT * FROM {test_table_name_dict['taxi']} \
Expand Down
11 changes: 4 additions & 7 deletions src/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def test_displaylimit_default(ip):
ip.run_cell("%sql INSERT INTO number_table VALUES (4, 3)")

out = runsql(ip, "SELECT * FROM number_table;")
assert "truncated to displaylimit of 10" in out._repr_html_()
assert "Truncated to displaylimit of 10" in out._repr_html_()


def test_displaylimit(ip):
Expand All @@ -545,7 +545,7 @@ def test_displaylimit_enabled_truncated_length(ip, config_value, expected_length

ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
assert f"truncated to displaylimit of {expected_length}" in out._repr_html_()
assert f"Truncated to displaylimit of {expected_length}" in out._repr_html_()


@pytest.mark.parametrize("config_value", [(None), (0)])
Expand All @@ -559,7 +559,7 @@ def test_displaylimit_enabled_no_limit(

ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
assert "truncated to displaylimit of " not in out._repr_html_()
assert "Truncated to displaylimit of " not in out._repr_html_()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -613,10 +613,7 @@ def test_displaylimit_with_conditional_clause(
out = runsql(ip, query_clause)

if expected_truncated_length:
assert (
f"{expected_truncated_length} rows, truncated to displaylimit of 10"
in out._repr_html_()
)
assert "Truncated to displaylimit of 10" in out._repr_html_()


def test_column_local_vars(ip):
Expand Down
Loading