diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 752121a90e50..e9ee15faed4f 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -33,10 +33,15 @@ class _ArrowDriverProperties_(TypedDict): - fetch_all: str # name of the method that fetches all arrow data - fetch_batches: str | None # name of the method that fetches arrow data in batches - exact_batch_size: bool | None # whether indicated batch size is respected exactly - repeat_batch_calls: bool # repeat batch calls (if False, batch call is generator) + # name of the method that fetches all arrow data; tuple form + # calls the fetch_all method with the given chunk size (int) + fetch_all: str | tuple[str, int] + # name of the method that fetches arrow data in batches + fetch_batches: str | None + # indicate whether the given batch size is respected exactly + exact_batch_size: bool | None + # repeat batch calls (if False, the batch call is a generator) + repeat_batch_calls: bool _ARROW_DRIVER_REGISTRY_: dict[str, _ArrowDriverProperties_] = { @@ -64,6 +69,13 @@ class _ArrowDriverProperties_(TypedDict): "exact_batch_size": True, "repeat_batch_calls": False, }, + "kuzu": { + # 'get_as_arrow' currently takes a mandatory chunk size + "fetch_all": ("get_as_arrow", 10_000), + "fetch_batches": None, + "exact_batch_size": None, + "repeat_batch_calls": False, + }, "snowflake": { "fetch_all": "fetch_arrow_all", "fetch_batches": "fetch_arrow_batches", @@ -153,7 +165,7 @@ def __exit__( ) -> None: # if we created it and are finished with it, we can # close the cursor (but NOT the connection) - if self.can_close_cursor: + if self.can_close_cursor and hasattr(self.cursor, "close"): self.cursor.close() def __repr__(self) -> str: @@ -169,8 +181,11 @@ def _arrow_batches( """Yield Arrow data in batches, or as a single 'fetchall' batch.""" fetch_batches = driver_properties["fetch_batches"] if not iter_batches or fetch_batches is None: - fetch_method = driver_properties["fetch_all"] - yield getattr(self.result, fetch_method)() + fetch_method, sz = driver_properties["fetch_all"], [] + if isinstance(fetch_method, tuple): + fetch_method, chunk_size = fetch_method + sz = [chunk_size] + yield getattr(self.result, fetch_method)(*sz) else: size = batch_size if driver_properties["exact_batch_size"] else None repeat_batch_calls = driver_properties["repeat_batch_calls"] diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index ea1a8a0cf6c7..5a572be8bef9 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -228,17 +228,11 @@ class SeriesBuffers(TypedDict): # minimal protocol definitions that can reasonably represent # an executable connection, cursor, or equivalent object class BasicConnection(Protocol): # noqa: D101 - def close(self) -> None: - """Close the connection.""" - def cursor(self, *args: Any, **kwargs: Any) -> Any: """Return a cursor object.""" class BasicCursor(Protocol): # noqa: D101 - def close(self) -> None: - """Close the cursor.""" - def execute(self, *args: Any, **kwargs: Any) -> Any: """Execute a query.""" diff --git a/py-polars/pyproject.toml b/py-polars/pyproject.toml index 4911687cea2e..2088b3310578 100644 --- a/py-polars/pyproject.toml +++ b/py-polars/pyproject.toml @@ -90,6 +90,7 @@ module = [ "fsspec.*", "gevent", "hvplot.*", + "kuzu", "matplotlib.*", "moto.server", "openpyxl", @@ -179,7 +180,7 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] -"tests/**/*.py" = ["D100", "D103", "B018", "FBT001"] +"tests/**/*.py" = ["D100", "D102", "D103", "B018", "FBT001"] [tool.ruff.lint.pycodestyle] max-doc-length = 88 diff --git a/py-polars/requirements-dev.txt b/py-polars/requirements-dev.txt index 4a3785e77f15..2006cab3f5d5 100644 --- a/py-polars/requirements-dev.txt +++ b/py-polars/requirements-dev.txt @@ -28,6 +28,7 @@ adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows' # TODO: Remove version constraint for connectorx when Python 3.12 is supported: # https://github.com/sfu-db/connector-x/issues/527 connectorx; python_version <= '3.11' +kuzu # Cloud cloudpickle fsspec diff --git a/py-polars/tests/unit/io/files/graph-data/follows.csv b/py-polars/tests/unit/io/files/graph-data/follows.csv new file mode 100644 index 000000000000..5ec090c283cd --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/follows.csv @@ -0,0 +1,4 @@ +Adam,Karissa,2020 +Adam,Zhang,2020 +Karissa,Zhang,2021 +Zhang,Noura,2022 diff --git a/py-polars/tests/unit/io/files/graph-data/user.csv b/py-polars/tests/unit/io/files/graph-data/user.csv new file mode 100644 index 000000000000..0421e38ee559 --- /dev/null +++ b/py-polars/tests/unit/io/files/graph-data/user.csv @@ -0,0 +1,4 @@ +Adam,30 +Karissa,40 +Zhang,50 +Noura,25 diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index 9fc68921905d..480e5522eb9d 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -116,10 +116,10 @@ def __init__( test_data=test_data, ) - def close(self) -> None: # noqa: D102 + def close(self) -> None: pass - def cursor(self) -> Any: # noqa: D102 + def cursor(self) -> Any: return self._cursor @@ -143,10 +143,10 @@ def __getattr__(self, item: str) -> Any: return self.resultset super().__getattr__(item) # type: ignore[misc] - def close(self) -> Any: # noqa: D102 + def close(self) -> Any: pass - def execute(self, query: str) -> Any: # noqa: D102 + def execute(self, query: str) -> Any: return self @@ -161,7 +161,7 @@ def __init__( self.batched = batched self.n_calls = 1 - def __call__(self, *args: Any, **kwargs: Any) -> Any: # noqa: D102 + def __call__(self, *args: Any, **kwargs: Any) -> Any: if self.repeat_batched_calls: res = self.test_data[: None if self.n_calls else 0] self.n_calls -= 1 @@ -632,3 +632,55 @@ def test_read_database_cx_credentials(uri: str) -> None: # can reasonably mitigate the issue. with pytest.raises(BaseException, match=r"fakedb://\*\*\*:\*\*\*@\w+"): pl.read_database_uri("SELECT * FROM data", uri=uri) + + +@pytest.mark.write_disk() +def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: + # validate reading from a kuzu graph database + import kuzu + + tmp_path.mkdir(exist_ok=True) + if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): + kuzu_test_db.unlink() + + test_db = str(kuzu_test_db).replace("\\", "/") + + db = kuzu.Database(test_db) + conn = kuzu.Connection(db) + conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") + conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") + + users = str(io_files_path / "graph-data" / "user.csv").replace("\\", "/") + follows = str(io_files_path / "graph-data" / "follows.csv").replace("\\", "/") + + conn.execute(f'COPY User FROM "{users}"') + conn.execute(f'COPY Follows FROM "{follows}"') + + df1 = pl.read_database( + query="MATCH (u:User) RETURN u.name, u.age", + connection=conn, + ) + assert_frame_equal( + df1, + pl.DataFrame( + { + "u.name": ["Adam", "Karissa", "Zhang", "Noura"], + "u.age": [30, 40, 50, 25], + } + ), + ) + + df2 = pl.read_database( + query="MATCH (a:User)-[f:Follows]->(b:User) RETURN a.name, f.since, b.name", + connection=conn, + ) + assert_frame_equal( + df2, + pl.DataFrame( + { + "a.name": ["Adam", "Adam", "Karissa", "Zhang"], + "f.since": [2020, 2020, 2021, 2022], + "b.name": ["Karissa", "Zhang", "Zhang", "Noura"], + } + ), + ) diff --git a/py-polars/tests/unit/utils/test_deprecation.py b/py-polars/tests/unit/utils/test_deprecation.py index 639ff7615daf..5c067404afd0 100644 --- a/py-polars/tests/unit/utils/test_deprecation.py +++ b/py-polars/tests/unit/utils/test_deprecation.py @@ -49,7 +49,7 @@ def hello(oof: str, rab: str, ham: str) -> None: ... class Foo: # noqa: D101 @deprecate_nonkeyword_arguments(allowed_args=["self", "baz"], version="0.1.2") - def bar( # noqa: D102 + def bar( self, baz: str, ham: str | None = None, foobar: str | None = None ) -> None: ...