Skip to content

Commit

Permalink
don't translate dict parameter input to pyarrow.Table (for adbc)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Mar 7, 2024
1 parent 1c4cc90 commit 92a5151
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
11 changes: 2 additions & 9 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

from polars._utils.deprecation import issue_deprecation_warning
from polars.convert import from_arrow
from polars.dependencies import pyarrow as pa
from polars.exceptions import InvalidOperationError, UnsuitableSQLError

if TYPE_CHECKING:
from types import TracebackType

import pyarrow as pa

if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
Expand Down Expand Up @@ -822,14 +823,6 @@ def _read_sql_adbc(
execute_options: dict[str, Any] | None = None,
) -> DataFrame:
with _open_adbc_connection(connection_uri) as conn, conn.cursor() as cursor:
if (
execute_options
and (params := execute_options.get("parameters")) is not None
):
if isinstance(params, dict):
params = pa.Table.from_pydict({k: [v] for k, v in params.items()})
execute_options["parameters"] = params

cursor.execute(query, **(execute_options or {}))
tbl = cursor.fetch_arrow_table()
return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value]
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from types import GeneratorType
from typing import TYPE_CHECKING, Any, NamedTuple

import pyarrow as pa
import pytest
from sqlalchemy import Integer, MetaData, Table, create_engine, func, select
from sqlalchemy.orm import sessionmaker
Expand All @@ -20,8 +21,6 @@
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
import pyarrow as pa

from polars.type_aliases import (
ConnectionOrCursor,
DbReadEngine,
Expand Down Expand Up @@ -431,7 +430,7 @@ def test_read_database_parameterised_uri(
expected_frame = pl.DataFrame({"year": [2021], "name": ["other"], "value": [-99.5]})

for param, param_value in (
(":n", {"n": 0}),
(":n", pa.Table.from_pydict({"n": [0]})),
("?", (0,)),
("?", [0]),
):
Expand Down

0 comments on commit 92a5151

Please sign in to comment.