Skip to content

Commit

Permalink
feat(python): add "execute_options" support for read_database_uri
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Mar 4, 2024
1 parent baacf3d commit f738c90
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
42 changes: 32 additions & 10 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from polars._utils.deprecation import issue_deprecation_warning
from polars.convert import from_arrow
from polars.dependencies import pyarrow as pa
from polars.exceptions import InvalidOperationError, UnsuitableSQLError

if TYPE_CHECKING:
Expand All @@ -23,7 +24,6 @@
from typing_extensions import Self

from polars import DataFrame
from polars.dependencies import pyarrow as pa
from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict

try:
Expand Down Expand Up @@ -522,12 +522,11 @@ def read_database( # noqa: D417
more details about using this driver (notable databases implementing Flight SQL
include Dremio and InfluxDB).
* The `read_database_uri` function is likely to be noticeably faster than
`read_database` if you are using a SQLAlchemy or DBAPI2 connection, as
`connectorx` will optimise translation of the result set into Arrow format
in Rust, whereas these libraries will return row-wise data to Python *before*
we can load into Arrow. Note that you can easily determine the connection's
URI from a SQLAlchemy engine object by calling
* The `read_database_uri` function can be noticeably faster than `read_database`
if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises
translation of the result set into Arrow format in Rust, whereas these libraries
will return row-wise data to Python *before* we can load into Arrow. Note that
you can determine the connection's URI from a SQLAlchemy engine object by calling
`conn.engine.url.render_as_string(hide_password=False)`.
* If polars has to create a cursor from your connection in order to execute the
Expand Down Expand Up @@ -638,6 +637,7 @@ def read_database_uri(
protocol: str | None = None,
engine: DbReadEngine | None = None,
schema_overrides: SchemaDict | None = None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame:
"""
Read the results of a SQL query into a DataFrame, given a URI.
Expand Down Expand Up @@ -684,6 +684,9 @@ def read_database_uri(
schema_overrides
A dictionary mapping column names to dtypes, used to override the schema
given in the data returned by the query.
execute_options
These options will be passed to the underlying query execution method as
kwargs. Note that connectorx does not support this parameter.
Notes
-----
Expand Down Expand Up @@ -752,6 +755,9 @@ def read_database_uri(
engine = "connectorx"

if engine == "connectorx":
if execute_options:
msg = "the 'connectorx' engine does not support use of `execute_options`"
raise ValueError(msg)
return _read_sql_connectorx(
query,
connection_uri=uri,
Expand All @@ -765,7 +771,12 @@ def read_database_uri(
if not isinstance(query, str):
msg = "only a single SQL query string is accepted for adbc"
raise ValueError(msg)
return _read_sql_adbc(query, uri, schema_overrides)
return _read_sql_adbc(
query,
connection_uri=uri,
schema_overrides=schema_overrides,
execute_options=execute_options,
)
else:
msg = f"engine must be one of {{'connectorx', 'adbc'}}, got {engine!r}"
raise ValueError(msg)
Expand Down Expand Up @@ -805,10 +816,21 @@ def _read_sql_connectorx(


def _read_sql_adbc(
query: str, connection_uri: str, schema_overrides: SchemaDict | None
query: str,
connection_uri: str,
schema_overrides: SchemaDict | None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame:
with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor:
cursor.execute(query)
if (
execute_options
and (params := execute_options.get("parameters")) is not None
):
if isinstance(params, dict):
params = pa.Table.from_pydict({k: [v] for k, v in params.items()})
execute_options["parameters"] = params

cursor.execute(query, **(execute_options or {}))
tbl = cursor.fetch_arrow_table()
return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value]

Expand Down
31 changes: 30 additions & 1 deletion py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None:


def test_read_database_parameterised(tmp_path: Path) -> None:
supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32"

# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))
Expand All @@ -393,6 +395,8 @@ def test_read_database_parameterised(tmp_path: Path) -> None:
FROM test_data
WHERE value < {n}
"""
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

for param, param_value in (
(":n", {"n": 0}),
("?", (0,)),
Expand All @@ -403,12 +407,37 @@ def test_read_database_parameterised(tmp_path: Path) -> None:
continue # alchemy session.execute() doesn't support positional params

assert_frame_equal(
expected_frame,
pl.read_database(
query.format(n=param),
connection=conn,
execute_options={"parameters": param_value},
),
pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}),
)

# test URI read method (adbc only; no connectorx support for execute_options)
if supports_adbc_sqlite:
uri = alchemy_engine.url.render_as_string(hide_password=False)
assert_frame_equal(
expected_frame,
pl.read_database_uri(
query.format(n=param),
uri=uri,
engine="adbc",
execute_options={"parameters": param_value},
),
)

if supports_adbc_sqlite:
with pytest.raises(
ValueError,
match="connectorx.*does not support.*execute_options",
):
pl.read_database_uri(
query.format(n=":n"),
uri=uri,
engine="connectorx",
execute_options={"parameters": (":n", {"n": 0})},
)


Expand Down

0 comments on commit f738c90

Please sign in to comment.