Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add "execute_options" support for read_database_uri #14682

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from types import TracebackType

import pyarrow as pa

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand All @@ -23,7 +25,6 @@
from typing_extensions import Self

from polars import DataFrame
from polars.dependencies import pyarrow as pa
from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict

try:
Expand Down Expand Up @@ -522,12 +523,11 @@ def read_database( # noqa: D417
more details about using this driver (notable databases implementing Flight SQL
include Dremio and InfluxDB).

* The `read_database_uri` function is likely to be noticeably faster than
`read_database` if you are using a SQLAlchemy or DBAPI2 connection, as
`connectorx` will optimise translation of the result set into Arrow format
in Rust, whereas these libraries will return row-wise data to Python *before*
we can load into Arrow. Note that you can easily determine the connection's
URI from a SQLAlchemy engine object by calling
* The `read_database_uri` function can be noticeably faster than `read_database`
if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises
translation of the result set into Arrow format in Rust, whereas these libraries
will return row-wise data to Python *before* we can load into Arrow. Note that
you can determine the connection's URI from a SQLAlchemy engine object by calling
`conn.engine.url.render_as_string(hide_password=False)`.

* If polars has to create a cursor from your connection in order to execute the
Expand Down Expand Up @@ -638,6 +638,7 @@ def read_database_uri(
protocol: str | None = None,
engine: DbReadEngine | None = None,
schema_overrides: SchemaDict | None = None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame:
"""
Read the results of a SQL query into a DataFrame, given a URI.
Expand Down Expand Up @@ -684,6 +685,9 @@ def read_database_uri(
schema_overrides
A dictionary mapping column names to dtypes, used to override the schema
given in the data returned by the query.
execute_options
These options will be passed to the underlying query execution method as
kwargs. Note that connectorx does not support this parameter.

Notes
-----
Expand Down Expand Up @@ -752,6 +756,9 @@ def read_database_uri(
engine = "connectorx"

if engine == "connectorx":
if execute_options:
msg = "the 'connectorx' engine does not support use of `execute_options`"
raise ValueError(msg)
return _read_sql_connectorx(
query,
connection_uri=uri,
Expand All @@ -765,7 +772,12 @@ def read_database_uri(
if not isinstance(query, str):
msg = "only a single SQL query string is accepted for adbc"
raise ValueError(msg)
return _read_sql_adbc(query, uri, schema_overrides)
return _read_sql_adbc(
query,
connection_uri=uri,
schema_overrides=schema_overrides,
execute_options=execute_options,
)
else:
msg = f"engine must be one of {{'connectorx', 'adbc'}}, got {engine!r}"
raise ValueError(msg)
Expand Down Expand Up @@ -805,10 +817,13 @@ def _read_sql_connectorx(


def _read_sql_adbc(
query: str, connection_uri: str, schema_overrides: SchemaDict | None
query: str,
connection_uri: str,
schema_overrides: SchemaDict | None,
execute_options: dict[str, Any] | None = None,
) -> DataFrame:
with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor:
cursor.execute(query)
cursor.execute(query, **(execute_options or {}))
tbl = cursor.fetch_arrow_table()
return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value]

Expand Down
111 changes: 78 additions & 33 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import GeneratorType
from typing import TYPE_CHECKING, Any, NamedTuple

import pyarrow as pa
import pytest
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.orm import sessionmaker
Expand All @@ -20,8 +21,6 @@
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
import pyarrow as pa

from polars.type_aliases import (
ConnectionOrCursor,
DbReadEngine,
Expand All @@ -34,24 +33,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any:
with suppress(ModuleNotFoundError): # not available on 3.8/windows
from adbc_driver_sqlite.dbapi import connect

args = tuple(str(a) if isinstance(a, Path) else a for a in args)
return connect(*args, **kwargs)


def create_temp_sqlite_db(test_db: str) -> None:
Path(test_db).unlink(missing_ok=True)
@pytest.fixture()
def tmp_sqlite_db(tmp_path: Path) -> Path:
test_db = tmp_path / "test.db"
test_db.unlink(missing_ok=True)

def convert_date(val: bytes) -> date:
"""Convert ISO 8601 date to datetime.date object."""
return date.fromisoformat(val.decode())

sqlite3.register_converter("date", convert_date)

# NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or
# no bool/date/datetime dtypes, for example) and there is a bug in connectorx that
# causes float rounding < py 3.11, hence we are only testing/storing simple values
# in this test db for now. as support improves, we can add/test additional dtypes).

sqlite3.register_converter("date", convert_date)
conn = sqlite3.connect(test_db)

# ┌─────┬───────┬───────┬────────────┐
# │ id ┆ name ┆ value ┆ date │
# │ --- ┆ --- ┆ --- ┆ --- │
Expand All @@ -62,17 +63,19 @@ def convert_date(val: bytes) -> date:
# └─────┴───────┴───────┴────────────┘
conn.executescript(
"""
CREATE TABLE test_data (
CREATE TABLE IF NOT EXISTS test_data (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
value FLOAT,
date DATE
);
INSERT INTO test_data(name,value,date)
VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31');
REPLACE INTO test_data(name,value,date)
VALUES ('misc',100.0,'2020-01-01'),
('other',-99.5,'2021-12-31');
"""
)
conn.close()
return test_db


class DatabaseReadTestParams(NamedTuple):
Expand Down Expand Up @@ -313,23 +316,19 @@ def test_read_database(
expected_dates: list[date | str],
schema_overrides: SchemaDict | None,
batch_size: int | None,
tmp_path: Path,
tmp_sqlite_db: Path,
) -> None:
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / "test.db")
create_temp_sqlite_db(test_db)

if read_method == "read_database_uri":
# instantiate the connection ourselves, using connectorx/adbc
df = pl.read_database_uri(
uri=f"sqlite:///{test_db}",
uri=f"sqlite:///{tmp_sqlite_db}",
query="SELECT * FROM test_data",
engine=str(connect_using), # type: ignore[arg-type]
schema_overrides=schema_overrides,
)
elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]:
# externally instantiated adbc connections
with connect_using(test_db) as conn, conn.cursor():
with connect_using(tmp_sqlite_db) as conn, conn.cursor():
df = pl.read_database(
connection=conn,
query="SELECT * FROM test_data",
Expand All @@ -339,7 +338,7 @@ def test_read_database(
else:
# other user-supplied connections
df = pl.read_database(
connection=connect_using(test_db),
connection=connect_using(tmp_sqlite_db),
query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'",
schema_overrides=schema_overrides,
batch_size=batch_size,
Expand All @@ -350,13 +349,9 @@ def test_read_database(
assert df["date"].to_list() == expected_dates


def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))

def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None:
# various flavours of alchemy connection
alchemy_engine = create_engine(f"sqlite:///{test_db}")
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()

Expand All @@ -376,23 +371,21 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
)


def test_read_database_parameterised(tmp_path: Path) -> None:
# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))
alchemy_engine = create_engine(f"sqlite:///{test_db}")

def test_read_database_parameterised(tmp_sqlite_db: Path) -> None:
# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
raw_conn: ConnectionOrCursor = sqlite3.connect(test_db)
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db)

# establish parameterised queries and validate usage
query = """
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
FROM test_data
WHERE value < {n}
"""
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

for param, param_value in (
(":n", {"n": 0}),
("?", (0,)),
Expand All @@ -403,15 +396,68 @@ def test_read_database_parameterised(tmp_path: Path) -> None:
continue # alchemy session.execute() doesn't support positional params

assert_frame_equal(
expected_frame,
pl.read_database(
query.format(n=param),
connection=conn,
execute_options={"parameters": param_value},
),
pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}),
)


@pytest.mark.parametrize(
("param", "param_value"),
[
(":n", {"n": 0}),
("?", (0,)),
("?", [0]),
],
)
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available on py3.8/windows",
)
def test_read_database_parameterised_uri(
param: str, param_value: Any, tmp_sqlite_db: Path
) -> None:
alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}")
uri = alchemy_engine.url.render_as_string(hide_password=False)
query = """
SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value
FROM test_data
WHERE value < {n}
"""
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

for param, param_value in (
(":n", pa.Table.from_pydict({"n": [0]})),
("?", (0,)),
("?", [0]),
):
# test URI read method (adbc only)
assert_frame_equal(
expected_frame,
pl.read_database_uri(
query.format(n=param),
uri=uri,
engine="adbc",
execute_options={"parameters": param_value},
),
)

# no connectorx support for execute_options
with pytest.raises(
ValueError,
match="connectorx.*does not support.*execute_options",
):
pl.read_database_uri(
query.format(n=":n"),
uri=uri,
engine="connectorx",
execute_options={"parameters": (":n", {"n": 0})},
)


@pytest.mark.parametrize(
("driver", "batch_size", "iter_batches", "expected_call"),
[
Expand Down Expand Up @@ -598,7 +644,6 @@ def test_read_database_exceptions(
engine: DbReadEngine | None,
execute_options: dict[str, Any] | None,
kwargs: dict[str, Any] | None,
tmp_path: Path,
) -> None:
if read_method == "read_database_uri":
conn = f"{protocol}://test" if isinstance(protocol, str) else protocol
Expand Down
Loading