From a19b35ea1912d6a1f1e2aae494dbf1d5c5e37850 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 25 Feb 2024 12:45:48 +0400 Subject: [PATCH 1/6] feat(python): add "execute_options" support for `read_database_uri` --- py-polars/polars/io/database.py | 42 ++++++++++++++----- py-polars/tests/unit/io/test_database_read.py | 31 +++++++++++++- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index e9ee15faed4f..7e7de93fedaa 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -8,6 +8,7 @@ from polars._utils.deprecation import issue_deprecation_warning from polars.convert import from_arrow +from polars.dependencies import pyarrow as pa from polars.exceptions import InvalidOperationError, UnsuitableSQLError if TYPE_CHECKING: @@ -23,7 +24,6 @@ from typing_extensions import Self from polars import DataFrame - from polars.dependencies import pyarrow as pa from polars.type_aliases import ConnectionOrCursor, Cursor, DbReadEngine, SchemaDict try: @@ -522,12 +522,11 @@ def read_database( # noqa: D417 more details about using this driver (notable databases implementing Flight SQL include Dremio and InfluxDB). - * The `read_database_uri` function is likely to be noticeably faster than - `read_database` if you are using a SQLAlchemy or DBAPI2 connection, as - `connectorx` will optimise translation of the result set into Arrow format - in Rust, whereas these libraries will return row-wise data to Python *before* - we can load into Arrow. Note that you can easily determine the connection's - URI from a SQLAlchemy engine object by calling + * The `read_database_uri` function can be noticeably faster than `read_database` + if you are using a SQLAlchemy or DBAPI2 connection, as `connectorx` optimises + translation of the result set into Arrow format in Rust, whereas these libraries + will return row-wise data to Python *before* we can load into Arrow. Note that + you can determine the connection's URI from a SQLAlchemy engine object by calling `conn.engine.url.render_as_string(hide_password=False)`. * If polars has to create a cursor from your connection in order to execute the @@ -638,6 +637,7 @@ def read_database_uri( protocol: str | None = None, engine: DbReadEngine | None = None, schema_overrides: SchemaDict | None = None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: """ Read the results of a SQL query into a DataFrame, given a URI. @@ -684,6 +684,9 @@ def read_database_uri( schema_overrides A dictionary mapping column names to dtypes, used to override the schema given in the data returned by the query. + execute_options + These options will be passed to the underlying query execution method as + kwargs. Note that connectorx does not support this parameter. Notes ----- @@ -752,6 +755,9 @@ def read_database_uri( engine = "connectorx" if engine == "connectorx": + if execute_options: + msg = "the 'connectorx' engine does not support use of `execute_options`" + raise ValueError(msg) return _read_sql_connectorx( query, connection_uri=uri, @@ -765,7 +771,12 @@ def read_database_uri( if not isinstance(query, str): msg = "only a single SQL query string is accepted for adbc" raise ValueError(msg) - return _read_sql_adbc(query, uri, schema_overrides) + return _read_sql_adbc( + query, + connection_uri=uri, + schema_overrides=schema_overrides, + execute_options=execute_options, + ) else: msg = f"engine must be one of {{'connectorx', 'adbc'}}, got {engine!r}" raise ValueError(msg) @@ -805,10 +816,21 @@ def _read_sql_connectorx( def _read_sql_adbc( - query: str, connection_uri: str, schema_overrides: SchemaDict | None + query: str, + connection_uri: str, + schema_overrides: SchemaDict | None, + execute_options: dict[str, Any] | None = None, ) -> DataFrame: with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor: - cursor.execute(query) + if ( + execute_options + and (params := execute_options.get("parameters")) is not None + ): + if isinstance(params, dict): + params = pa.Table.from_pydict({k: [v] for k, v in params.items()}) + execute_options["parameters"] = params + + cursor.execute(query, **(execute_options or {})) tbl = cursor.fetch_arrow_table() return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 480e5522eb9d..e82bffefb62b 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -377,6 +377,8 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: def test_read_database_parameterised(tmp_path: Path) -> None: + supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32" + # setup underlying test data tmp_path.mkdir(exist_ok=True) create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) @@ -393,6 +395,8 @@ def test_read_database_parameterised(tmp_path: Path) -> None: FROM test_data WHERE value < {n} """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + for param, param_value in ( (":n", {"n": 0}), ("?", (0,)), @@ -403,12 +407,37 @@ def test_read_database_parameterised(tmp_path: Path) -> None: continue # alchemy session.execute() doesn't support positional params assert_frame_equal( + expected_frame, pl.read_database( query.format(n=param), connection=conn, execute_options={"parameters": param_value}, ), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), + ) + + # test URI read method (adbc only; no connectorx support for execute_options) + if supports_adbc_sqlite: + uri = alchemy_engine.url.render_as_string(hide_password=False) + assert_frame_equal( + expected_frame, + pl.read_database_uri( + query.format(n=param), + uri=uri, + engine="adbc", + execute_options={"parameters": param_value}, + ), + ) + + if supports_adbc_sqlite: + with pytest.raises( + ValueError, + match="connectorx.*does not support.*execute_options", + ): + pl.read_database_uri( + query.format(n=":n"), + uri=uri, + engine="connectorx", + execute_options={"parameters": (":n", {"n": 0})}, ) From cd2cc7e9ce30faa9f383e10945d5edbcacb6c822 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 4 Mar 2024 23:29:31 +0400 Subject: [PATCH 2/6] make the `sqlite` test db setup a fixture --- py-polars/tests/unit/io/test_database_read.py | 59 ++++++++----------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index e82bffefb62b..5ab6965be40e 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,19 +9,17 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple -import pytest -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql.expression import cast as alchemy_cast - import polars as pl +import pytest from polars.exceptions import UnsuitableSQLError from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql.expression import cast as alchemy_cast if TYPE_CHECKING: import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -34,24 +32,26 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: with suppress(ModuleNotFoundError): # not available on 3.8/windows from adbc_driver_sqlite.dbapi import connect + args = [str(a) if isinstance(a, Path) else a for a in args] return connect(*args, **kwargs) -def create_temp_sqlite_db(test_db: str) -> None: - Path(test_db).unlink(missing_ok=True) +@pytest.fixture() +def tmp_sqlite_db(tmp_path: Path) -> Path: + test_db = tmp_path / "test.db" + test_db.unlink(missing_ok=True) def convert_date(val: bytes) -> date: """Convert ISO 8601 date to datetime.date object.""" return date.fromisoformat(val.decode()) - sqlite3.register_converter("date", convert_date) - # NOTE: at the time of writing adcb/connectorx have weak SQLite support (poor or # no bool/date/datetime dtypes, for example) and there is a bug in connectorx that # causes float rounding < py 3.11, hence we are only testing/storing simple values # in this test db for now. as support improves, we can add/test additional dtypes). - + sqlite3.register_converter("date", convert_date) conn = sqlite3.connect(test_db) + # ┌─────┬───────┬───────┬────────────┐ # │ id ┆ name ┆ value ┆ date │ # │ --- ┆ --- ┆ --- ┆ --- │ @@ -62,17 +62,19 @@ def convert_date(val: bytes) -> date: # └─────┴───────┴───────┴────────────┘ conn.executescript( """ - CREATE TABLE test_data ( + CREATE TABLE IF NOT EXISTS test_data ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, value FLOAT, date DATE ); - INSERT INTO test_data(name,value,date) - VALUES ('misc',100.0,'2020-01-01'), ('other',-99.5,'2021-12-31'); + REPLACE INTO test_data(name,value,date) + VALUES ('misc',100.0,'2020-01-01'), + ('other',-99.5,'2021-12-31'); """ ) conn.close() + return test_db class DatabaseReadTestParams(NamedTuple): @@ -314,22 +316,19 @@ def test_read_database( schema_overrides: SchemaDict | None, batch_size: int | None, tmp_path: Path, + tmp_sqlite_db: Path, ) -> None: - tmp_path.mkdir(exist_ok=True) - test_db = str(tmp_path / "test.db") - create_temp_sqlite_db(test_db) - if read_method == "read_database_uri": # instantiate the connection ourselves, using connectorx/adbc df = pl.read_database_uri( - uri=f"sqlite:///{test_db}", + uri=f"sqlite:///{tmp_sqlite_db}", query="SELECT * FROM test_data", engine=str(connect_using), # type: ignore[arg-type] schema_overrides=schema_overrides, ) elif "adbc" in os.environ["PYTEST_CURRENT_TEST"]: # externally instantiated adbc connections - with connect_using(test_db) as conn, conn.cursor(): + with connect_using(tmp_sqlite_db) as conn, conn.cursor(): df = pl.read_database( connection=conn, query="SELECT * FROM test_data", @@ -339,7 +338,7 @@ def test_read_database( else: # other user-supplied connections df = pl.read_database( - connection=connect_using(test_db), + connection=connect_using(tmp_sqlite_db), query="SELECT * FROM test_data WHERE name NOT LIKE '%polars%'", schema_overrides=schema_overrides, batch_size=batch_size, @@ -350,13 +349,9 @@ def test_read_database( assert df["date"].to_list() == expected_dates -def test_read_database_alchemy_selectable(tmp_path: Path) -> None: - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - +def test_read_database_alchemy_selectable(tmp_path: Path, tmp_sqlite_db: Path) -> None: # various flavours of alchemy connection - alchemy_engine = create_engine(f"sqlite:///{test_db}") + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() @@ -376,18 +371,14 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: ) -def test_read_database_parameterised(tmp_path: Path) -> None: +def test_read_database_parameterised(tmp_path: Path, tmp_sqlite_db: Path) -> None: supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32" - # setup underlying test data - tmp_path.mkdir(exist_ok=True) - create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - alchemy_engine = create_engine(f"sqlite:///{test_db}") - # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs - raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + raw_conn: ConnectionOrCursor = sqlite3.connect(tmp_sqlite_db) # establish parameterised queries and validate usage query = """ From 73715a4a74ed615bbcf862ea452c513d8ebdc25a Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 4 Mar 2024 23:30:45 +0400 Subject: [PATCH 3/6] remove extraneous `tmp_path` refs --- py-polars/tests/unit/io/test_database_read.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 5ab6965be40e..311a4d3362f9 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -315,7 +315,6 @@ def test_read_database( expected_dates: list[date | str], schema_overrides: SchemaDict | None, batch_size: int | None, - tmp_path: Path, tmp_sqlite_db: Path, ) -> None: if read_method == "read_database_uri": @@ -349,7 +348,7 @@ def test_read_database( assert df["date"].to_list() == expected_dates -def test_read_database_alchemy_selectable(tmp_path: Path, tmp_sqlite_db: Path) -> None: +def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: # various flavours of alchemy connection alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() @@ -371,7 +370,7 @@ def test_read_database_alchemy_selectable(tmp_path: Path, tmp_sqlite_db: Path) - ) -def test_read_database_parameterised(tmp_path: Path, tmp_sqlite_db: Path) -> None: +def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32" # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs @@ -618,7 +617,6 @@ def test_read_database_exceptions( engine: DbReadEngine | None, execute_options: dict[str, Any] | None, kwargs: dict[str, Any] | None, - tmp_path: Path, ) -> None: if read_method == "read_database_uri": conn = f"{protocol}://test" if isinstance(protocol, str) else protocol From 7878a6bb969d7d5e7ef0dfbff49d0949f1b035da Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 4 Mar 2024 23:41:21 +0400 Subject: [PATCH 4/6] lint --- py-polars/tests/unit/io/test_database_read.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 311a4d3362f9..b49a714edf0c 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,17 +9,19 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple -import polars as pl import pytest -from polars.exceptions import UnsuitableSQLError -from polars.io.database import _ARROW_DRIVER_REGISTRY_ -from polars.testing import assert_frame_equal from sqlalchemy import Integer, MetaData, Table, create_engine, func, select from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast +import polars as pl +from polars.exceptions import UnsuitableSQLError +from polars.io.database import _ARROW_DRIVER_REGISTRY_ +from polars.testing import assert_frame_equal + if TYPE_CHECKING: import pyarrow as pa + from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, From 1c4cc90c8c3c8d7b418df65660f2d690a40c863d Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 4 Mar 2024 23:52:14 +0400 Subject: [PATCH 5/6] `parametrize` the uri/parameters test --- py-polars/tests/unit/io/test_database_read.py | 74 +++++++++++++------ 1 file changed, 50 insertions(+), 24 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index b49a714edf0c..2e2a236831d8 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -34,7 +34,7 @@ def adbc_sqlite_connect(*args: Any, **kwargs: Any) -> Any: with suppress(ModuleNotFoundError): # not available on 3.8/windows from adbc_driver_sqlite.dbapi import connect - args = [str(a) if isinstance(a, Path) else a for a in args] + args = tuple(str(a) if isinstance(a, Path) else a for a in args) return connect(*args, **kwargs) @@ -373,8 +373,6 @@ def test_read_database_alchemy_selectable(tmp_sqlite_db: Path) -> None: def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: - supports_adbc_sqlite = sys.version_info >= (3, 9) and sys.platform != "win32" - # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() @@ -407,30 +405,58 @@ def test_read_database_parameterised(tmp_sqlite_db: Path) -> None: ), ) - # test URI read method (adbc only; no connectorx support for execute_options) - if supports_adbc_sqlite: - uri = alchemy_engine.url.render_as_string(hide_password=False) - assert_frame_equal( - expected_frame, - pl.read_database_uri( - query.format(n=param), - uri=uri, - engine="adbc", - execute_options={"parameters": param_value}, - ), - ) - if supports_adbc_sqlite: - with pytest.raises( - ValueError, - match="connectorx.*does not support.*execute_options", - ): +@pytest.mark.parametrize( + ("param", "param_value"), + [ + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ], +) +@pytest.mark.skipif( + sys.version_info < (3, 9) or sys.platform == "win32", + reason="adbc_driver_sqlite not available on py3.8/windows", +) +def test_read_database_parameterised_uri( + param: str, param_value: Any, tmp_sqlite_db: Path +) -> None: + alchemy_engine = create_engine(f"sqlite:///{tmp_sqlite_db}") + uri = alchemy_engine.url.render_as_string(hide_password=False) + query = """ + SELECT CAST(STRFTIME('%Y',"date") AS INT) as "year", name, value + FROM test_data + WHERE value < {n} + """ + expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) + + for param, param_value in ( + (":n", {"n": 0}), + ("?", (0,)), + ("?", [0]), + ): + # test URI read method (adbc only) + assert_frame_equal( + expected_frame, pl.read_database_uri( - query.format(n=":n"), + query.format(n=param), uri=uri, - engine="connectorx", - execute_options={"parameters": (":n", {"n": 0})}, - ) + engine="adbc", + execute_options={"parameters": param_value}, + ), + ) + + # no connectorx support for execute_options + with pytest.raises( + ValueError, + match="connectorx.*does not support.*execute_options", + ): + pl.read_database_uri( + query.format(n=":n"), + uri=uri, + engine="connectorx", + execute_options={"parameters": (":n", {"n": 0})}, + ) @pytest.mark.parametrize( From 92a5151787c330040e5a1ab76e1930b308d76c99 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 7 Mar 2024 17:57:19 +0400 Subject: [PATCH 6/6] don't translate `dict` parameter input to `pyarrow.Table` (for `adbc`) --- py-polars/polars/io/database.py | 11 ++--------- py-polars/tests/unit/io/test_database_read.py | 5 ++--- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 7e7de93fedaa..cf5dbcf7b35f 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -8,12 +8,13 @@ from polars._utils.deprecation import issue_deprecation_warning from polars.convert import from_arrow -from polars.dependencies import pyarrow as pa from polars.exceptions import InvalidOperationError, UnsuitableSQLError if TYPE_CHECKING: from types import TracebackType + import pyarrow as pa + if sys.version_info >= (3, 10): from typing import TypeAlias else: @@ -822,14 +823,6 @@ def _read_sql_adbc( execute_options: dict[str, Any] | None = None, ) -> DataFrame: with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor: - if ( - execute_options - and (params := execute_options.get("parameters")) is not None - ): - if isinstance(params, dict): - params = pa.Table.from_pydict({k: [v] for k, v in params.items()}) - execute_options["parameters"] = params - cursor.execute(query, **(execute_options or {})) tbl = cursor.fetch_arrow_table() return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 2e2a236831d8..c9d66347eb67 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,6 +9,7 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple +import pyarrow as pa import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select from sqlalchemy.orm import sessionmaker @@ -20,8 +21,6 @@ from polars.testing import assert_frame_equal if TYPE_CHECKING: - import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -431,7 +430,7 @@ def test_read_database_parameterised_uri( expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}) for param, param_value in ( - (":n", {"n": 0}), + (":n", pa.Table.from_pydict({"n": [0]})), ("?", (0,)), ("?", [0]), ):