From 40558fdd9959c79293f3963c673fe0fafbdc0572 Mon Sep 17 00:00:00 2001 From: jayceslesar Date: Thu, 16 Nov 2023 18:27:22 -0500 Subject: [PATCH] feat(export): allow passing keyword arguments to PyArrow `ParquetWriter` and `CSVWriter` --- ibis/backends/base/__init__.py | 4 ++-- ibis/backends/duckdb/__init__.py | 22 ++++++++++++++++++++-- ibis/backends/tests/test_export.py | 27 +++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/ibis/backends/base/__init__.py b/ibis/backends/base/__init__.py index 704bc3e63b74..27833225820d 100644 --- a/ibis/backends/base/__init__.py +++ b/ibis/backends/base/__init__.py @@ -547,7 +547,7 @@ def to_parquet( import pyarrow.parquet as pq with expr.to_pyarrow_batches(params=params) as batch_reader: - with pq.ParquetWriter(path, batch_reader.schema) as writer: + with pq.ParquetWriter(path, batch_reader.schema, **kwargs) as writer: for batch in batch_reader: writer.write_batch(batch) @@ -582,7 +582,7 @@ def to_csv( import pyarrow.csv as pcsv with expr.to_pyarrow_batches(params=params) as batch_reader: - with pcsv.CSVWriter(path, batch_reader.schema) as writer: + with pcsv.CSVWriter(path, batch_reader.schema, **kwargs) as writer: for batch in batch_reader: writer.write_batch(batch) diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index c1570f207539..ed8429bbf379 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -98,6 +98,24 @@ class Backend(AlchemyCrossSchemaBackend, CanCreateSchema): name = "duckdb" compiler = DuckDBSQLCompiler supports_create_or_replace = True + reserved_csv_copy_args = [ + "COMPRESSION", + "FORCE_QUOTE", + "DATEFORMAT", + "DELIM", + "SEP", + "ESCAPE", + "HEADER", + "NULLSTR", + "QUOTE", + "TIMESTAMP_FORMAT" + ] + reserved_parquet_copy_args = [ + "COMPRESSION", + "ROW_GROUP_SIZE", + "ROW_GROUP_SIZE_BYTES", + "FIELD_IDS", + ] @property def settings(self) -> _Settings: @@ -1089,7 +1107,7 @@ def to_parquet( """ self._run_pre_execute_hooks(expr) query = self._to_sql(expr, params=params) - args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items())] + args = ["FORMAT 'parquet'", *(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_parquet_copy_args)] copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})" with self.begin() as con: con.exec_driver_sql(copy_cmd) @@ -1127,7 +1145,7 @@ def to_csv( args = [ "FORMAT 'csv'", f"HEADER {int(header)}", - *(f"{k.upper()} {v!r}" for k, v in kwargs.items()), + *(f"{k.upper()} {v!r}" for k, v in kwargs.items() if k.upper() in self.reserved_csv_copy_args), ] copy_cmd = f"COPY ({query}) TO {str(path)!r} ({', '.join(args)})" with self.begin() as con: diff --git a/ibis/backends/tests/test_export.py b/ibis/backends/tests/test_export.py index f4bc5b16ecd0..4d536d72e385 100644 --- a/ibis/backends/tests/test_export.py +++ b/ibis/backends/tests/test_export.py @@ -3,6 +3,7 @@ import pandas as pd import pandas.testing as tm import pyarrow as pa +import pyarrow.csv as pcsv import pytest import sqlalchemy as sa from pytest import param @@ -220,6 +221,21 @@ def test_table_to_parquet(tmp_path, backend, awards_players): backend.assert_frame_equal(awards_players.to_pandas(), df) +@pytest.mark.notimpl(["flink"]) +@pytest.mark.parametrize(("kwargs"), [({"version": "1.0"}), ({"version": "2.6"})]) +def test_table_to_parquet_writer_kwargs(kwargs, tmp_path, backend, awards_players): + outparquet = tmp_path / "out.parquet" + awards_players.to_parquet(outparquet, **kwargs) + + df = pd.read_parquet(outparquet) + + backend.assert_frame_equal(awards_players.to_pandas(), df) + + file = pa.parquet.ParquetFile(outparquet) + + assert file.metadata.format_version == kwargs["version"] + + @pytest.mark.notimpl( [ "bigquery", @@ -299,6 +315,17 @@ def test_table_to_csv(tmp_path, backend, awards_players): backend.assert_frame_equal(awards_players.to_pandas(), df) +@pytest.mark.notimpl(["flink"]) +@pytest.mark.parametrize(("kwargs", "delimiter"), [({"write_options": pcsv.WriteOptions(delimiter=";")}, ";"), ({"write_options": pcsv.WriteOptions(delimiter="\t")}, "\t")]) +def test_table_to_csv_writer_kwargs(kwargs, delimiter, tmp_path, backend, awards_players): + outcsv = tmp_path / "out.csv" + # avoid pandas NaNonense + awards_players = awards_players.select("playerID", "awardID", "yearID", "lgID") + + awards_players.to_csv(outcsv, **kwargs) + pd.read_csv(outcsv, delimiter=delimiter) + + @pytest.mark.parametrize( ("dtype", "pyarrow_dtype"), [