Skip to content

Commit

Permalink
feat(python): improve read_database interop with sqlalchemy `Sessio…
Browse files Browse the repository at this point in the history
…n` connections (#14557)
  • Loading branch information
alexander-beedie committed Feb 18, 2024
1 parent b50831a commit e40140f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 27 deletions.
52 changes: 35 additions & 17 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,15 @@ def _from_rows(
from polars import DataFrame

if hasattr(self.result, "fetchall"):
description = (
self.result.cursor.description
if self.driver_name == "sqlalchemy"
else self.result.description
)
column_names = [desc[0] for desc in description]
if self.driver_name == "sqlalchemy":
column_names = (
list(self.result._metadata.keys)
if hasattr(self.result, "_metadata")
else [desc[0] for desc in self.result.cursor.description]
)
else:
column_names = [desc[0] for desc in self.result.description]

frames = (
DataFrame(
data=rows,
Expand Down Expand Up @@ -318,18 +321,33 @@ def execute(
options = options or {}
cursor_execute = self.cursor.execute

if self.driver_name == "sqlalchemy" and isinstance(query, str):
params = options.get("parameters")
if isinstance(params, Sequence) and hasattr(self.cursor, "exec_driver_sql"):
cursor_execute = self.cursor.exec_driver_sql
if isinstance(params, list) and not all(
isinstance(p, (dict, tuple)) for p in params
if self.driver_name == "sqlalchemy":
from sqlalchemy.orm import Session

param_key = "parameters"
if (
isinstance(self.cursor, Session)
and "parameters" in options
and "params" not in options
):
options = options.copy()
options["params"] = options.pop("parameters")
param_key = "params"

if isinstance(query, str):
params = options.get(param_key)
if isinstance(params, Sequence) and hasattr(
self.cursor, "exec_driver_sql"
):
options["parameters"] = tuple(params)
else:
from sqlalchemy.sql import text

query = text(query) # type: ignore[assignment]
cursor_execute = self.cursor.exec_driver_sql
if isinstance(params, list) and not all(
isinstance(p, (dict, tuple)) for p in params
):
options[param_key] = tuple(params)
else:
from sqlalchemy.sql import text

query = text(query) # type: ignore[assignment]

# note: some cursor execute methods (eg: sqlite3) only take positional
# params, hence the slightly convoluted resolution of the 'options' dict
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys

from sqlalchemy import Engine
from sqlalchemy.orm import Session

from polars import DataFrame, Expr, LazyFrame, Series
from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType
Expand Down Expand Up @@ -250,4 +251,4 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
"""Fetch results in batches."""


ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine"]
ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"]
29 changes: 20 additions & 9 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.expression import cast as alchemy_cast

import polars as pl
Expand Down Expand Up @@ -353,8 +354,13 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))
conn = create_engine(f"sqlite:///{test_db}")
t = Table("test_data", MetaData(), autoload_with=conn)

# various flavours of alchemy connection
alchemy_engine = create_engine(f"sqlite:///{test_db}")
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()

t = Table("test_data", MetaData(), autoload_with=alchemy_engine)

# establish sqlalchemy "selectable" and validate usage
selectable_query = select(
Expand All @@ -363,21 +369,23 @@ def test_read_database_alchemy_selectable(tmp_path: Path) -> None:
t.c.value,
).where(t.c.value < 0)

assert_frame_equal(
pl.read_database(selectable_query, connection=conn.connect()),
pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}),
)
for conn in (alchemy_session, alchemy_engine, alchemy_conn):
assert_frame_equal(
pl.read_database(selectable_query, connection=conn),
pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]}),
)


def test_read_database_parameterised(tmp_path: Path) -> None:
# setup underlying test data
tmp_path.mkdir(exist_ok=True)
create_temp_sqlite_db(test_db := str(tmp_path / "test.db"))
alchemy_engine = create_engine(f"sqlite:///{test_db}")

# raw cursor "execute" only takes positional params, alchemy cursor takes kwargs
raw_conn: ConnectionOrCursor = sqlite3.connect(test_db)
alchemy_conn: ConnectionOrCursor = create_engine(f"sqlite:///{test_db}").connect()
test_conns = (alchemy_conn, raw_conn)
alchemy_conn: ConnectionOrCursor = alchemy_engine.connect()
alchemy_session: ConnectionOrCursor = sessionmaker(bind=alchemy_engine)()

# establish parameterised queries and validate usage
query = """
Expand All @@ -390,7 +398,10 @@ def test_read_database_parameterised(tmp_path: Path) -> None:
("?", (0,)),
("?", [0]),
):
for conn in test_conns:
for conn in (alchemy_session, alchemy_engine, alchemy_conn, raw_conn):
if alchemy_session is conn and param == "?":
continue # alchemy session.execute() doesn't support positional params

assert_frame_equal(
pl.read_database(
query.format(n=param),
Expand Down

0 comments on commit e40140f

Please sign in to comment.