Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support use of KùzuDB via pl.read_database #14822

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@


class _ArrowDriverProperties_(TypedDict):
fetch_all: str # name of the method that fetches all arrow data
fetch_batches: str | None # name of the method that fetches arrow data in batches
exact_batch_size: bool | None # whether indicated batch size is respected exactly
repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator)
# name of the method that fetches all arrow data; tuple form
# calls the fetch_all method with the given chunk size (int)
fetch_all: str | tuple[str, int]
# name of the method that fetches arrow data in batches
fetch_batches: str | None
# indicate whether the given batch size is respected exactly
exact_batch_size: bool | None
# repeat batch calls (if False, the batch call is a generator)
repeat_batch_calls: bool


_ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = {
Expand Down Expand Up @@ -64,6 +69,13 @@ class _ArrowDriverProperties_(TypedDict):
"exact_batch_size": True,
"repeat_batch_calls": False,
},
"kuzu": {
# 'get_as_arrow' currently takes a mandatory chunk size
"fetch_all": ("get_as_arrow", 10_000),
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
"fetch_batches": None,
"exact_batch_size": None,
"repeat_batch_calls": False,
},
"snowflake": {
"fetch_all": "fetch_arrow_all",
"fetch_batches": "fetch_arrow_batches",
Expand Down Expand Up @@ -153,7 +165,7 @@ def __exit__(
) -> None:
# if we created it and are finished with it, we can
# close the cursor (but NOT the connection)
if self.can_close_cursor:
if self.can_close_cursor and hasattr(self.cursor, "close"):
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
self.cursor.close()

def __repr__(self) -> str:
Expand All @@ -169,8 +181,11 @@ def _arrow_batches(
"""Yield Arrow data in batches, or as a single 'fetchall' batch."""
fetch_batches = driver_properties["fetch_batches"]
if not iter_batches or fetch_batches is None:
fetch_method = driver_properties["fetch_all"]
yield getattr(self.result, fetch_method)()
fetch_method, sz = driver_properties["fetch_all"], []
if isinstance(fetch_method, tuple):
fetch_method, chunk_size = fetch_method
sz = [chunk_size]
yield getattr(self.result, fetch_method)(*sz)
else:
size = batch_size if driver_properties["exact_batch_size"] else None
repeat_batch_calls = driver_properties["repeat_batch_calls"]
Expand Down
6 changes: 0 additions & 6 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,17 +228,11 @@ class SeriesBuffers(TypedDict):
# minimal protocol definitions that can reasonably represent
# an executable connection, cursor, or equivalent object
class BasicConnection(Protocol): # noqa: D101
def close(self) -> None:
"""Close the connection."""

def cursor(self, *args: Any, **kwargs: Any) -> Any:
"""Return a cursor object."""


class BasicCursor(Protocol): # noqa: D101
def close(self) -> None:
"""Close the cursor."""

def execute(self, *args: Any, **kwargs: Any) -> Any:
"""Execute a query."""

Expand Down
3 changes: 2 additions & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ module = [
"fsspec.*",
"gevent",
"hvplot.*",
"kuzu",
"matplotlib.*",
"moto.server",
"openpyxl",
Expand Down Expand Up @@ -179,7 +180,7 @@ ignore = [
]

[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"]
"tests/**/*.py" = ["D100", "D102", "D103", "B018", "FBT001"]
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved

[tool.ruff.lint.pycodestyle]
max-doc-length = 88
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
# TODO: Remove version constraint for connectorx when Python 3.12 is supported:
# https://github.com/sfu-db/connector-x/issues/527
connectorx; python_version <= '3.11'
kuzu
# Cloud
cloudpickle
fsspec
Expand Down
4 changes: 4 additions & 0 deletions py-polars/tests/unit/io/files/graph-data/follows.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Adam,Karissa,2020
Adam,Zhang,2020
Karissa,Zhang,2021
Zhang,Noura,2022
4 changes: 4 additions & 0 deletions py-polars/tests/unit/io/files/graph-data/user.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Adam,30
Karissa,40
Zhang,50
Noura,25
62 changes: 57 additions & 5 deletions py-polars/tests/unit/io/test_database_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(
test_data=test_data,
)

def close(self) -> None: # noqa: D102
def close(self) -> None:
pass

def cursor(self) -> Any: # noqa: D102
def cursor(self) -> Any:
return self._cursor


Expand All @@ -143,10 +143,10 @@ def __getattr__(self, item: str) -> Any:
return self.resultset
super().__getattr__(item) # type: ignore[misc]

def close(self) -> Any: # noqa: D102
def close(self) -> Any:
pass

def execute(self, query: str) -> Any: # noqa: D102
def execute(self, query: str) -> Any:
return self


Expand All @@ -161,7 +161,7 @@ def __init__(
self.batched = batched
self.n_calls = 1

def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if self.repeat_batched_calls:
res = self.test_data[: None if self.n_calls else 0]
self.n_calls -= 1
Expand Down Expand Up @@ -632,3 +632,55 @@ def test_read_database_cx_credentials(uri: str) -> None:
# can reasonably mitigate the issue.
with pytest.raises(BaseException, match=r"fakedb://\*\*\*:\*\*\*@\w+"):
pl.read_database_uri("SELECT * FROM data", uri=uri)


@pytest.mark.write_disk()
def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None:
# validate reading from a kuzu graph database
import kuzu

tmp_path.mkdir(exist_ok=True)
if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists():
kuzu_test_db.unlink()

test_db = str(kuzu_test_db).replace("\\", "/")

db = kuzu.Database(test_db)
conn = kuzu.Connection(db)
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")

users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/")
follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/")

conn.execute(f'COPY User FROM "{users}"')
conn.execute(f'COPY Follows FROM "{follows}"')

df1 = pl.read_database(
query="MATCH (u:User) RETURN u.name, u.age",
connection=conn,
)
assert_frame_equal(
df1,
pl.DataFrame(
{
"u.name": ["Adam", "Karissa", "Zhang", "Noura"],
"u.age": [30, 40, 50, 25],
}
),
)

df2 = pl.read_database(
query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name",
connection=conn,
)
assert_frame_equal(
df2,
pl.DataFrame(
{
"a.name": ["Adam", "Adam", "Karissa", "Zhang"],
"f.since": [2020, 2020, 2021, 2022],
"b.name": ["Karissa", "Zhang", "Zhang", "Noura"],
}
),
)
2 changes: 1 addition & 1 deletion py-polars/tests/unit/utils/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def hello(oof: str, rab: str, ham: str) -> None: ...

class Foo: # noqa: D101
@deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2")
def bar( # noqa: D102
def bar(
self, baz: str, ham: str | None = None, foobar: str | None = None
) -> None: ...

Expand Down
Loading