From 01e323d46a94edb503178043b1803792101bd144 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 17 Feb 2024 13:25:41 +0400 Subject: [PATCH] feat(python): improve `read_database` interop with sqlalchemy `Session` connections --- py-polars/polars/io/database.py | 52 +++++++++++++------ py-polars/polars/type_aliases.py | 3 +- py-polars/tests/unit/io/test_database_read.py | 29 +++++++---- 3 files changed, 57 insertions(+), 27 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 17e020cf7a3e..5bae271947fd 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -279,12 +279,15 @@ def _from_rows( from polars import DataFrame if hasattr(self.result, "fetchall"): - description = ( - self.result.cursor.description - if self.driver_name == "sqlalchemy" - else self.result.description - ) - column_names = [desc[0] for desc in description] + if self.driver_name == "sqlalchemy": + column_names = ( + list(self.result._metadata.keys) + if hasattr(self.result, "_metadata") + else [desc[0] for desc in self.result.cursor.description] + ) + else: + column_names = [desc[0] for desc in self.result.description] + frames = ( DataFrame( data=rows, @@ -318,18 +321,33 @@ def execute( options = options or {} cursor_execute = self.cursor.execute - if self.driver_name == "sqlalchemy" and isinstance(query, str): - params = options.get("parameters") - if isinstance(params, Sequence) and hasattr(self.cursor, "exec_driver_sql"): - cursor_execute = self.cursor.exec_driver_sql - if isinstance(params, list) and not all( - isinstance(p, (dict, tuple)) for p in params + if self.driver_name == "sqlalchemy": + from sqlalchemy.orm import Session + + param_key = "parameters" + if ( + isinstance(self.cursor, Session) + and "parameters" in options + and "params" not in options + ): + options = options.copy() + options["params"] = options.pop("parameters") + param_key = "params" + + if isinstance(query, str): + params = options.get(param_key) + if isinstance(params, Sequence) and hasattr( + self.cursor, "exec_driver_sql" ): - options["parameters"] = tuple(params) - else: - from sqlalchemy.sql import text - - query = text(query) # type: ignore[assignment] + cursor_execute = self.cursor.exec_driver_sql + if isinstance(params, list) and not all( + isinstance(p, (dict, tuple)) for p in params + ): + options[param_key] = tuple(params) + else: + from sqlalchemy.sql import text + + query = text(query) # type: ignore[assignment] # note: some cursor execute methods (eg: sqlite3) only take positional # params, hence the slightly convoluted resolution of the 'options' dict diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index eee5e670d8a6..ea1a8a0cf6c7 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -23,6 +23,7 @@ import sys from sqlalchemy import Engine + from sqlalchemy.orm import Session from polars import DataFrame, Expr, LazyFrame, Series from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType @@ -250,4 +251,4 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: """Fetch results in batches.""" -ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine"] +ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 7b4ad1d8bc32..9fc68921905d 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -11,6 +11,7 @@ import pytest from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast import polars as pl @@ -353,8 +354,13 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: # setup underlying test data tmp_path.mkdir(exist_ok=True) create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) - conn = create_engine(f"sqlite:///{test_db}") - t = Table("test_data", MetaData(), autoload_with=conn) + + # various flavours of alchemy connection + alchemy_engine = create_engine(f"sqlite:///{test_db}") + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + + t = Table("test_data", MetaData(), autoload_with=alchemy_engine) # establish sqlalchemy "selectable" and validate usage selectable_query = select( @@ -363,21 +369,23 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None: t.c.value, ).where(t.c.value < 0) - assert_frame_equal( - pl.read_database(selectable_query, connection=conn.connect()), - pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), - ) + for conn in (alchemy_session, alchemy_engine, alchemy_conn): + assert_frame_equal( + pl.read_database(selectable_query, connection=conn), + pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}), + ) def test_read_database_parameterised(tmp_path: Path) -> None: # setup underlying test data tmp_path.mkdir(exist_ok=True) create_temp_sqlite_db(test_db := str(tmp_path / "test.db")) + alchemy_engine = create_engine(f"sqlite:///{test_db}") # raw cursor "execute" only takes positional params, alchemy cursor takes kwargs raw_conn: ConnectionOrCursor = sqlite3.connect(test_db) - alchemy_conn: ConnectionOrCursor = create_engine(f"sqlite:///{test_db}").connect() - test_conns = (alchemy_conn, raw_conn) + alchemy_conn: ConnectionOrCursor = alchemy_engine.connect() + alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)() # establish parameterised queries and validate usage query = """ @@ -390,7 +398,10 @@ def test_read_database_parameterised(tmp_path: Path) -> None: ("?", (0,)), ("?", [0]), ): - for conn in test_conns: + for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn): + if alchemy_session is conn and param == "?": + continue # alchemy session.execute() doesn't support positional params + assert_frame_equal( pl.read_database( query.format(n=param),