From 8cd533165bcbd07a03b5a74ef1979f53762c1082 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 2 Mar 2024 00:46:15 +0400 Subject: [PATCH 1/4] feat(python): support use of the "kuzu" graph db via `pl.read_database` --- py-polars/polars/io/database.py | 27 +++++++-- py-polars/polars/type_aliases.py | 6 -- py-polars/pyproject.toml | 3 +- py-polars/requirements-dev.txt | 1 + .../unit/io/files/graph-data/follows.csv | 4 ++ .../tests/unit/io/files/graph-data/user.csv | 4 ++ py-polars/tests/unit/io/test_database_read.py | 59 +++++++++++++++++-- .../tests/unit/utils/test_deprecation.py | 2 +- 8 files changed, 87 insertions(+), 19 deletions(-) create mode 100644 py-polars/tests/unit/io/files/graph-data/follows.csv create mode 100644 py-polars/tests/unit/io/files/graph-data/user.csv diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 752121a90e50..8d74290dfb78 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -33,10 +33,15 @@ class _ArrowDriverProperties_(TypedDict): - fetch_all: str # name of the method that fetches all arrow data - fetch_batches: str | None # name of the method that fetches arrow data in batches - exact_batch_size: bool | None # whether indicated batch size is respected exactly - repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) + # name of the method that fetches all arrow data; tuple form + # calls the fetch_all method with the give chunk size + fetch_all: str | tuple[str, int] + # name of the method that fetches arrow data in batches + fetch_batches: str | None + # indicate whether the given batch size is respected exactly + exact_batch_size: bool | None + # repeat batch calls (if False, the batch call is a generator) + repeat_batch_calls: bool _ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { @@ -64,6 +69,12 @@ class _ArrowDriverProperties_(TypedDict): "exact_batch_size": True, "repeat_batch_calls": False, }, + "kuzu": { + "fetch_all": ("get_as_arrow", 10_000), + "fetch_batches": None, + "exact_batch_size": None, + "repeat_batch_calls": False, + }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", @@ -153,7 +164,7 @@ def __exit__( ) -> None: # if we created it and are finished with it, we can # close the cursor (but NOT the connection) - if self.can_close_cursor: + if self.can_close_cursor and hasattr(self.cursor, "close"): self.cursor.close() def __repr__(self) -> str: @@ -170,7 +181,11 @@ def _arrow_batches( fetch_batches = driver_properties["fetch_batches"] if not iter_batches or fetch_batches is None: fetch_method = driver_properties["fetch_all"] - yield getattr(self.result, fetch_method)() + if not isinstance(fetch_method, tuple): + yield getattr(self.result, fetch_method)() + else: + fetch_method, sz = fetch_method + yield getattr(self.result, fetch_method)(sz) else: size = batch_size if driver_properties["exact_batch_size"] else None repeat_batch_calls = driver_properties["repeat_batch_calls"] diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index ea1a8a0cf6c7..5a572be8bef9 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -228,17 +228,11 @@ class SeriesBuffers(TypedDict): # minimal protocol definitions that can reasonably represent # an executable connection, cursor, or equivalent object class BasicConnection(Protocol): # noqa: D101 - def close(self) -> None: - """Close the connection.""" - def cursor(self, *args: Any, **kwargs: Any) -> Any: """Return a cursor object.""" class BasicCursor(Protocol): # noqa: D101 - def close(self) -> None: - """Close the cursor.""" - def execute(self, *args: Any, **kwargs: Any) -> Any: """Execute a query.""" diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 4911687cea2e..2088b3310578 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -90,6 +90,7 @@ module = [ "fsspec.*", "gevent", "hvplot.*", + "kuzu", "matplotlib.*", "moto.server", "openpyxl", @@ -179,7 +180,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"] +"tests/**/*.py" = ["D100", "D102", "D103", "B018", "FBT001"] [tool.ruff.lint.pycodestyle] max-doc-length = 88 diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 4a3785e77f15..2006cab3f5d5 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -28,6 +28,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' # TODO: Remove version constraint for connectorx when Python 3.12 is supported: # https://github.com/sfu-db/connector-x/issues/527 connectorx; python_version <= '3.11' +kuzu # Cloud cloudpickle fsspec diff --git a/py-polars/tests/unit/io/files/graph-data/follows.csv b/py-polars/tests/unit/io/files/graph-data/follows.csv new file mode 100644 index 000000000000..5ec090c283cd --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/follows.csv @@ -0,0 +1,4 @@ +Adam,Karissa,2020 +Adam,Zhang,2020 +Karissa,Zhang,2021 +Zhang,Noura,2022 diff --git a/py-polars/tests/unit/io/files/graph-data/user.csv b/py-polars/tests/unit/io/files/graph-data/user.csv new file mode 100644 index 000000000000..0421e38ee559 --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/user.csv @@ -0,0 +1,4 @@ +Adam,30 +Karissa,40 +Zhang,50 +Noura,25 diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 9fc68921905d..fd4f07f5ebd5 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -116,10 +116,10 @@ def __init__( test_data=test_data, ) - def close(self) -> None: # noqa: D102 + def close(self) -> None: pass - def cursor(self) -> Any: # noqa: D102 + def cursor(self) -> Any: return self._cursor @@ -143,10 +143,10 @@ def __getattr__(self, item: str) -> Any: return self.resultset super().__getattr__(item) # type: ignore[misc] - def close(self) -> Any: # noqa: D102 + def close(self) -> Any: pass - def execute(self, query: str) -> Any: # noqa: D102 + def execute(self, query: str) -> Any: return self @@ -161,7 +161,7 @@ def __init__( self.batched = batched self.n_calls = 1 - def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 + def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.repeat_batched_calls: res = self.test_data[: None if self.n_calls else 0] self.n_calls -= 1 @@ -632,3 +632,52 @@ def test_read_database_cx_credentials(uri: str) -> None: # can reasonably mitigate the issue. with pytest.raises(BaseException, match=r"fakedb://\*\*\*:\*\*\*@\w+"): pl.read_database_uri("SELECT * FROM data", uri=uri) + + +@pytest.mark.write_disk() +def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: + # validate reading from a kuzu graph database + import kuzu + + tmp_path.mkdir(exist_ok=True) + if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): + kuzu_test_db.unlink() + + db = kuzu.Database(str(kuzu_test_db)) + conn = kuzu.Connection(db) + conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") + conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") + + users = io_files_path / "graph-data" / "user.csv" + follows = io_files_path / "graph-data" / "follows.csv" + conn.execute(f'COPY User FROM "{users}"') + conn.execute(f'COPY Follows FROM "{follows}"') + + df1 = pl.read_database( + query="MATCH (u:User) RETURN u.name, u.age", + connection=conn, + ) + assert_frame_equal( + df1, + pl.DataFrame( + { + "u.name": ["Adam", "Karissa", "Zhang", "Noura"], + "u.age": [30, 40, 50, 25], + } + ), + ) + + df2 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name", + connection=conn, + ) + assert_frame_equal( + df2, + pl.DataFrame( + { + "a.name": ["Adam", "Adam", "Karissa", "Zhang"], + "f.since": [2020, 2020, 2021, 2022], + "b.name": ["Karissa", "Zhang", "Zhang", "Noura"], + } + ), + ) diff --git a/py-polars/tests/unit/utils/test_deprecation.py b/py-polars/tests/unit/utils/test_deprecation.py index 639ff7615daf..5c067404afd0 100644 --- a/py-polars/tests/unit/utils/test_deprecation.py +++ b/py-polars/tests/unit/utils/test_deprecation.py @@ -49,7 +49,7 @@ def hello(oof: str, rab: str, ham: str) -> None: ... class Foo: # noqa: D101 @deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2") - def bar( # noqa: D102 + def bar( self, baz: str, ham: str | None = None, foobar: str | None = None ) -> None: ... From 1a79ee20531de9cbae8be63e130646eb0d7f28c4 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 3 Mar 2024 01:49:24 +0400 Subject: [PATCH 2/4] add inline comment, further streamline --- py-polars/polars/io/database.py | 14 +++++++------- py-polars/tests/unit/io/test_database_read.py | 6 +++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 8d74290dfb78..e9ee15faed4f 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -34,7 +34,7 @@ class _ArrowDriverProperties_(TypedDict): # name of the method that fetches all arrow data; tuple form - # calls the fetch_all method with the give chunk size + # calls the fetch_all method with the given chunk size (int) fetch_all: str | tuple[str, int] # name of the method that fetches arrow data in batches fetch_batches: str | None @@ -70,6 +70,7 @@ class _ArrowDriverProperties_(TypedDict): "repeat_batch_calls": False, }, "kuzu": { + # 'get_as_arrow' currently takes a mandatory chunk size "fetch_all": ("get_as_arrow", 10_000), "fetch_batches": None, "exact_batch_size": None, @@ -180,12 +181,11 @@ def _arrow_batches( """Yield Arrow data in batches, or as a single 'fetchall' batch.""" fetch_batches = driver_properties["fetch_batches"] if not iter_batches or fetch_batches is None: - fetch_method = driver_properties["fetch_all"] - if not isinstance(fetch_method, tuple): - yield getattr(self.result, fetch_method)() - else: - fetch_method, sz = fetch_method - yield getattr(self.result, fetch_method)(sz) + fetch_method, sz = driver_properties["fetch_all"], [] + if isinstance(fetch_method, tuple): + fetch_method, chunk_size = fetch_method + sz = [chunk_size] + yield getattr(self.result, fetch_method)(*sz) else: size = batch_size if driver_properties["exact_batch_size"] else None repeat_batch_calls = driver_properties["repeat_batch_calls"] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index fd4f07f5ebd5..3db0285f4ce1 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -643,7 +643,11 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): kuzu_test_db.unlink() - db = kuzu.Database(str(kuzu_test_db)) + test_db = str(kuzu_test_db) + if sys.platform == "win32": + test_db = test_db.replace("\\", "/") + + db = kuzu.Database(test_db) conn = kuzu.Connection(db) conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") From 0803fea32237f3675fa380195330740edd666d7b Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 3 Mar 2024 10:50:37 +0400 Subject: [PATCH 3/4] fix windows test/paths --- py-polars/tests/unit/io/test_database_read.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 3db0285f4ce1..635d391f87ad 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,19 +9,17 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple -import pytest -from sqlalchemy import Integer, MetaData, Table, create_engine, func, select -from sqlalchemy.orm import sessionmaker -from sqlalchemy.sql.expression import cast as alchemy_cast - import polars as pl +import pytest from polars.exceptions import UnsuitableSQLError from polars.io.database import _ARROW_DRIVER_REGISTRY_ from polars.testing import assert_frame_equal +from sqlalchemy import Integer, MetaData, Table, create_engine, func, select +from sqlalchemy.orm import sessionmaker +from sqlalchemy.sql.expression import cast as alchemy_cast if TYPE_CHECKING: import pyarrow as pa - from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -643,17 +641,15 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): kuzu_test_db.unlink() - test_db = str(kuzu_test_db) - if sys.platform == "win32": - test_db = test_db.replace("\\", "/") + test_db = str(kuzu_test_db).replace("\\", "/") db = kuzu.Database(test_db) conn = kuzu.Connection(db) conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") - users = io_files_path / "graph-data" / "user.csv" - follows = io_files_path / "graph-data" / "follows.csv" + users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/") + follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/") conn.execute(f'COPY User FROM "{users}"') conn.execute(f'COPY Follows FROM "{follows}"') From c89d9a5f516ffab08949c49dc32f51484b0b2b54 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 3 Mar 2024 11:08:08 +0400 Subject: [PATCH 4/4] lint --- py-polars/tests/unit/io/test_database_read.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 635d391f87ad..480e5522eb9d 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -9,17 +9,19 @@ from types import GeneratorType from typing import TYPE_CHECKING, Any, NamedTuple -import polars as pl import pytest -from polars.exceptions import UnsuitableSQLError -from polars.io.database import _ARROW_DRIVER_REGISTRY_ -from polars.testing import assert_frame_equal from sqlalchemy import Integer, MetaData, Table, create_engine, func, select from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.expression import cast as alchemy_cast +import polars as pl +from polars.exceptions import UnsuitableSQLError +from polars.io.database import _ARROW_DRIVER_REGISTRY_ +from polars.testing import assert_frame_equal + if TYPE_CHECKING: import pyarrow as pa + from polars.type_aliases import ( ConnectionOrCursor, DbReadEngine, @@ -650,6 +652,7 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/") follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/") + conn.execute(f'COPY User FROM "{users}"') conn.execute(f'COPY Follows FROM "{follows}"')