diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index e9ee15faed4f..cf5dbcf7b35f 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -13,6 +13,8 @@ if TYPE_CHECKING: from types import TracebackType + import pyarrow as pa + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -23,7 +25,6 @@ from typing_extensions import Self from polars import DataFrame - from polars.dependencies import pyarrow as pa from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict try: @@ -522,12 +523,11 @@ def read_database( # noqa: D417 more details about using this driver (notable databases implementing Flight SQL include Dremio and InfluxDB). - * The `read_database_uri` function is likely to be noticeably faster than - `read_database` if you are using a SQLAlchemy or DBAPI2 connection, as - `connectorx` will optimise translation of the result set into Arrow format - in Rust, whereas these libraries will return row-wise data to Python *before* - we can load into Arrow. Note that you can easily determine the connection's - URI from a SQLAlchemy engine object by calling + * The `read_database_uri` function can be noticeably faster than `read_database` + if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises + translation of the result set into Arrow format in Rust, whereas these libraries + will return row-wise data to Python *before* we can load into Arrow. Note that + you can determine the connection's URI from a SQLAlchemy engine object by calling `conn.engine.url.render_as_string(hide_password=False)`. * If polars has to create a cursor from your connection in order to execute the @@ -638,6 +638,7 @@ def read_database_uri( protocol: str | None = None, engine: DbReadEngine | None = None, schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: """ Read the results of a SQL query into a DataFrame, given a URI. @@ -684,6 +685,9 @@ def read_database_uri( schema_overrides A dictionary mapping column names to dtypes, used to override the schema given in the data returned by the query. + execute_options + These options will be passed to the underlying query execution method as + kwargs. Note that connectorx does not support this parameter. Notes ----- @@ -752,6 +756,9 @@ def read_database_uri( engine = "connectorx" if engine == "connectorx": + if execute_options: + msg = "the 'connectorx' engine does not support use of `execute_options`" + raise ValueError(msg) return _read_sql_connectorx( query, connection_uri=uri, @@ -765,7 +772,12 @@ def read_database_uri( if not isinstance(query, str): msg = "only a single SQL query string is accepted for adbc" raise ValueError(msg) - return _read_sql_adbc(query, uri, schema_overrides) + return _read_sql_adbc( + query, + connection_uri=uri, + schema_overrides=schema_overrides, + execute_options=execute_options, + ) else: msg = f"engine must be one of {{'connectorx', 'adbc'}}, got {engine!r}" raise ValueError(msg) @@ -805,10 +817,13 @@ def _read_sql_connectorx( def _read_sql_adbc( - query: str, connection_uri: str, schema_overrides: SchemaDict | None + query: str, + connection_uri: str, + schema_overrides: SchemaDict | None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor: - cursor.execute(query) + cursor.execute(query, **(execute_options or {})) tbl = cursor.fetch_arrow_table() return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 480e5522eb9d..c9d66347eb67 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,6 +9,7 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple +import pyarrow as pa import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select from sqlalchemy.orm import sessionmaker @@ -20,8 +21,6 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: - import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -34,24 +33,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: with suppress(ModuleNotFoundError): # not available on 3.8/windows from adbc_driver_sqlite.dbapi import connect + args = tuple(str(a) if isinstance(a, Path) else a for a in args) return connect(*args, **kwargs) -def create_temp_sqlite_db(test_db: str) -> None: - Path(test_db).unlink(missing_ok=True) +@pytest.fixture() +def tmp_sqlite_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test.db" + test_db.unlink(missing_ok=True) def convert_date(val: bytes) -> date: """Convert ISO 8601 date to datetime.date object.""" return date.fromisoformat(val.decode()) - sqlite3.register_converter("date", convert_date) - # NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or # no bool/date/datetime dtypes, for example) and there is a bug in connectorx that # causes float rounding < py 3.11, hence we are only testing/storing simple values # in this test db for now. as support improves, we can add/test additional dtypes). - + sqlite3.register_converter("date", convert_date) conn = sqlite3.connect(test_db) + # ┌─────┬───────┬───────┬────────────┐ # │ id ┆ name ┆ value ┆ date │ # │ --- ┆ --- ┆ --- ┆ --- │ @@ -62,17 +63,19 @@ def convert_date(val: bytes) -> date: # └─────┴───────┴───────┴────────────┘ conn.executescript( """ - CREATE TABLE test_data ( + CREATE TABLE IF NOT EXISTS test_data ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, value FLOAT, date DATE ); - INSERT INTO test_data(name,value,date) - VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31'); + REPLACE INTO test_data(name,value,date) + VALUES ('misc',100.0,'2020-01-01'), + ('other',-99.5,'2021-12-31'); """ ) conn.close() + return test_db class DatabaseReadTestParams(NamedTuple): @@ -313,23 +316,19 @@ def test_read_database( expected_dates: list[date | str], schema_overrides: SchemaDict | None, batch_size: int | None, - tmp_path: Path, + tmp_sqlite_db: Path, ) -> None: - tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / "test.db") - create_temp_sqlite_db(test_db) - if read_method == "read_database_uri": # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( - uri=f"sqlite:///{test_db}", + uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", engine=str(connect_using), # type: ignore[arg-type] schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: # externally instantiated adbc connections - with connect_using(test_db) as conn, conn.cursor(): + with connect_using(tmp_sqlite_db) as conn, conn.cursor(): df = pl.read_database( connection=conn, query="SELECT * FROM test_data", @@ -339,7 +338,7 @@ def test_read_database( else: # other user-supplied connections df = pl.read_database( - connection=connect_using(test_db), + connection=connect_using(tmp_sqlite_db), query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'", schema_overrides=schema_overrides, batch_size=batch_size, @@ -350,13 +349,9 @@ def test_read_database( assert df["date"].to_list() == expected_dates -def test_read_database_alchemy_selectable(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - +def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: # various flavours of alchemy connection - alchemy_engine = create_engine(f"sqlite:///{test_db}") + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() @@ -376,16 +371,12 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: ) -def test_read_database_parameterised(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - alchemy_engine = create_engine(f"sqlite:///{test_db}") - +def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs - raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db) # establish parameterised queries and validate usage query = """ @@ -393,6 +384,8 @@ def test_read_database_parameterised(tmp_path: Path) -> None: FROM test_data WHERE value < {n} """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + for param, param_value in ( (":n", {"n": 0}), ("?", (0,)), @@ -403,15 +396,68 @@ def test_read_database_parameterised(tmp_path: Path) -> None: continue # alchemy session.execute() doesn't support positional params assert_frame_equal( + expected_frame, pl.read_database( query.format(n=param), connection=conn, execute_options={"parameters": param_value}, ), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), ) +@pytest.mark.parametrize( + ("param", "param_value"), + [ + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ], +) +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available on py3.8/windows", +) +def test_read_database_parameterised_uri( + param: str, param_value: Any, tmp_sqlite_db: Path +) -> None: + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + uri = alchemy_engine.url.render_as_string(hide_password=False) + query = """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < {n} + """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for param, param_value in ( + (":n", pa.Table.from_pydict({"n": [0]})), + ("?", (0,)), + ("?", [0]), + ): + # test URI read method (adbc only) + assert_frame_equal( + expected_frame, + pl.read_database_uri( + query.format(n=param), + uri=uri, + engine="adbc", + execute_options={"parameters": param_value}, + ), + ) + + # no connectorx support for execute_options + with pytest.raises( + ValueError, + match="connectorx.*does not support.*execute_options", + ): + pl.read_database_uri( + query.format(n=":n"), + uri=uri, + engine="connectorx", + execute_options={"parameters": (":n", {"n": 0})}, + ) + + @pytest.mark.parametrize( ("driver", "batch_size", "iter_batches", "expected_call"), [ @@ -598,7 +644,6 @@ def test_read_database_exceptions( engine: DbReadEngine | None, execute_options: dict[str, Any] | None, kwargs: dict[str, Any] | None, - tmp_path: Path, ) -> None: if read_method == "read_database_uri": conn = f"{protocol}://test" if isinstance(protocol, str) else protocol