From 21e4a27552fb7a4f265280d4c5f39a93085dc441 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sun, 3 Mar 2024 01:49:24 +0400 Subject: [PATCH] fix windows test, add inline comment, further streamline --- py-polars/polars/io/database.py | 14 +++++++------- py-polars/tests/unit/io/test_database_read.py | 6 +++++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/io/database.py b/py-polars/polars/io/database.py index 8d74290dfb781..e9ee15faed4ff 100644 --- a/py-polars/polars/io/database.py +++ b/py-polars/polars/io/database.py @@ -34,7 +34,7 @@ class _ArrowDriverProperties_(TypedDict): # name of the method that fetches all arrow data; tuple form - # calls the fetch_all method with the give chunk size + # calls the fetch_all method with the given chunk size (int) fetch_all: str | tuple[str, int] # name of the method that fetches arrow data in batches fetch_batches: str | None @@ -70,6 +70,7 @@ class _ArrowDriverProperties_(TypedDict): "repeat_batch_calls": False, }, "kuzu": { + # 'get_as_arrow' currently takes a mandatory chunk size "fetch_all": ("get_as_arrow", 10_000), "fetch_batches": None, "exact_batch_size": None, @@ -180,12 +181,11 @@ def _arrow_batches( """Yield Arrow data in batches, or as a single 'fetchall' batch.""" fetch_batches = driver_properties["fetch_batches"] if not iter_batches or fetch_batches is None: - fetch_method = driver_properties["fetch_all"] - if not isinstance(fetch_method, tuple): - yield getattr(self.result, fetch_method)() - else: - fetch_method, sz = fetch_method - yield getattr(self.result, fetch_method)(sz) + fetch_method, sz = driver_properties["fetch_all"], [] + if isinstance(fetch_method, tuple): + fetch_method, chunk_size = fetch_method + sz = [chunk_size] + yield getattr(self.result, fetch_method)(*sz) else: size = batch_size if driver_properties["exact_batch_size"] else None repeat_batch_calls = driver_properties["repeat_batch_calls"] diff --git a/py-polars/tests/unit/io/test_database_read.py b/py-polars/tests/unit/io/test_database_read.py index fd4f07f5ebd51..3db0285f4ce17 100644 --- a/py-polars/tests/unit/io/test_database_read.py +++ b/py-polars/tests/unit/io/test_database_read.py @@ -643,7 +643,11 @@ def test_read_kuzu_graph_database(tmp_path: Path, io_files_path: Path) -> None: if (kuzu_test_db := (tmp_path / "kuzu_test.db")).exists(): kuzu_test_db.unlink() - db = kuzu.Database(str(kuzu_test_db)) + test_db = str(kuzu_test_db) + if sys.platform == "win32": + test_db = test_db.replace("\\", "/") + + db = kuzu.Database(test_db) conn = kuzu.Connection(db) conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")