diff --git a/doc/source/whatsnew/v1.2.1.rst b/doc/source/whatsnew/v1.2.1.rst index 769c195229bbd..e280b730679f0 100644 --- a/doc/source/whatsnew/v1.2.1.rst +++ b/doc/source/whatsnew/v1.2.1.rst @@ -15,6 +15,7 @@ including other versions of pandas. Fixed regressions ~~~~~~~~~~~~~~~~~ - The deprecated attributes ``_AXIS_NAMES`` and ``_AXIS_NUMBERS`` of :class:`DataFrame` and :class:`Series` will no longer show up in ``dir`` or ``inspect.getmembers`` calls (:issue:`38740`) +- :meth:`to_csv` created corrupted zip files when there were more rows than ``chunksize`` (issue:`38714`) - .. --------------------------------------------------------------------------- diff --git a/pandas/io/common.py b/pandas/io/common.py index c33ef9ac4ba95..642684ca61480 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -4,10 +4,10 @@ from collections import abc import dataclasses import gzip -from io import BufferedIOBase, BytesIO, RawIOBase, TextIOWrapper +from io import BufferedIOBase, BytesIO, RawIOBase, StringIO, TextIOWrapper import mmap import os -from typing import IO, Any, AnyStr, Dict, List, Mapping, Optional, Tuple, cast +from typing import IO, Any, AnyStr, Dict, List, Mapping, Optional, Tuple, Union, cast from urllib.parse import ( urljoin, urlparse as parse_url, @@ -713,17 +713,36 @@ def __init__( archive_name: Optional[str] = None, **kwargs, ): - if mode in ["wb", "rb"]: - mode = mode.replace("b", "") + mode = mode.replace("b", "") self.archive_name = archive_name + self.multiple_write_buffer: Optional[Union[StringIO, BytesIO]] = None + kwargs_zip: Dict[str, Any] = {"compression": zipfile.ZIP_DEFLATED} kwargs_zip.update(kwargs) + super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type] def write(self, data): + # buffer multiple write calls, write on flush + if self.multiple_write_buffer is None: + self.multiple_write_buffer = ( + BytesIO() if isinstance(data, bytes) else StringIO() + ) + self.multiple_write_buffer.write(data) + + def flush(self) -> None: + # write to actual handle and close write buffer + if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: + return + # ZipFile needs a non-empty string archive_name = self.archive_name or self.filename or "zip" - super().writestr(archive_name, data) + with self.multiple_write_buffer: + super().writestr(archive_name, self.multiple_write_buffer.getvalue()) + + def close(self): + self.flush() + super().close() @property def closed(self): diff --git a/pandas/tests/io/formats/test_to_csv.py b/pandas/tests/io/formats/test_to_csv.py index a9673ded7c377..6416cb93c7ff5 100644 --- a/pandas/tests/io/formats/test_to_csv.py +++ b/pandas/tests/io/formats/test_to_csv.py @@ -640,3 +640,25 @@ def test_to_csv_encoding_binary_handle(self, mode): handle.seek(0) assert handle.read().startswith(b'\xef\xbb\xbf""') + + +def test_to_csv_iterative_compression_name(compression): + # GH 38714 + df = tm.makeDataFrame() + with tm.ensure_clean() as path: + df.to_csv(path, compression=compression, chunksize=1) + tm.assert_frame_equal( + pd.read_csv(path, compression=compression, index_col=0), df + ) + + +def test_to_csv_iterative_compression_buffer(compression): + # GH 38714 + df = tm.makeDataFrame() + with io.BytesIO() as buffer: + df.to_csv(buffer, compression=compression, chunksize=1) + buffer.seek(0) + tm.assert_frame_equal( + pd.read_csv(buffer, compression=compression, index_col=0), df + ) + assert not buffer.closed