From 57247284e9f4418889b0e3f6de539ac51e89b31e Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Tue, 12 May 2020 17:10:35 +0100 Subject: [PATCH] ENH: Add compression to stata writers (#34013) --- doc/source/whatsnew/v1.1.0.rst | 7 +- pandas/core/frame.py | 16 +++++ pandas/io/stata.py | 115 ++++++++++++++++++++++++++++++--- pandas/tests/io/test_stata.py | 60 +++++++++++++++++ 4 files changed, 189 insertions(+), 9 deletions(-) diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 8526ccb57396f..9f0aaecacd383 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -227,7 +227,12 @@ Other enhancements - The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals` method, similarly to :meth:`Series.equals` (:issue:`27081`). - The minimum suppported dta version has increased to 105 in :meth:`~pandas.io.stata.read_stata` and :class:`~pandas.io.stata.StataReader` (:issue:`26667`). - +- :meth:`~pandas.core.frame.DataFrame.to_stata` supports compression using the ``compression`` + keyword argument. Compression can either be inferred or explicitly set using a string or a + dictionary containing both the method and any additional arguments that are passed to the + compression library. Compression was also added to the low-level Stata-file writers + :class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`, + and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`26599`). .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 102b305cc7a99..445d168ff875d 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -25,6 +25,7 @@ Iterable, Iterator, List, + Mapping, Optional, Sequence, Set, @@ -2015,6 +2016,7 @@ def to_stata( variable_labels: Optional[Dict[Label, str]] = None, version: Optional[int] = 114, convert_strl: Optional[Sequence[Label]] = None, + compression: Union[str, Mapping[str, str], None] = "infer", ) -> None: """ Export DataFrame object to Stata dta format. @@ -2078,6 +2080,19 @@ def to_stata( .. versionadded:: 0.23.0 + compression : str or dict, default 'infer' + For on-the-fly compression of the output dta. If string, specifies + compression mode. If dict, value at key 'method' specifies + compression mode. Compression mode must be one of {'infer', 'gzip', + 'bz2', 'zip', 'xz', None}. If compression mode is 'infer' and + `fname` is path-like, then detect compression from the following + extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise no + compression). If dict and compression mode is one of {'zip', + 'gzip', 'bz2'}, or inferred as one of the above, other entries + passed as additional compression options. + + .. versionadded:: 1.1.0 + Raises ------ NotImplementedError @@ -2133,6 +2148,7 @@ def to_stata( data_label=data_label, write_index=write_index, variable_labels=variable_labels, + compression=compression, **kwargs, ) writer.write_file() diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 789e08d0652c9..fe8dcf1bdb9aa 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -16,7 +16,18 @@ from pathlib import Path import struct import sys -from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + AnyStr, + BinaryIO, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) import warnings from dateutil.relativedelta import relativedelta @@ -47,7 +58,13 @@ from pandas.core.indexes.base import Index from pandas.core.series import Series -from pandas.io.common import get_filepath_or_buffer, stringify_path +from pandas.io.common import ( + get_compression_method, + get_filepath_or_buffer, + get_handle, + infer_compression, + stringify_path, +) _version_error = ( "Version of given Stata file is {version}. pandas supports importing " @@ -1854,13 +1871,18 @@ def read_stata( return data -def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]: +def _open_file_binary_write( + fname: FilePathOrBuffer, compression: Union[str, Mapping[str, str], None], +) -> Tuple[BinaryIO, bool, Optional[Union[str, Mapping[str, str]]]]: """ Open a binary file or no-op if file-like. Parameters ---------- fname : string path, path object or buffer + The file name or buffer. + compression : {str, dict, None} + The compression method to use. Returns ------- @@ -1871,9 +1893,21 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]: """ if hasattr(fname, "write"): # See https://github.com/python/mypy/issues/1424 for hasattr challenges - return fname, False # type: ignore + return fname, False, None # type: ignore elif isinstance(fname, (str, Path)): - return open(fname, "wb"), True + # Extract compression mode as given, if dict + compression_typ, compression_args = get_compression_method(compression) + compression_typ = infer_compression(fname, compression_typ) + path_or_buf, _, compression_typ, _ = get_filepath_or_buffer( + fname, compression=compression_typ + ) + if compression_typ is not None: + compression = compression_args + compression["method"] = compression_typ + else: + compression = None + f, _ = get_handle(path_or_buf, "wb", compression=compression, is_text=False) + return f, True, compression else: raise TypeError("fname must be a binary file, buffer or path-like.") @@ -2050,6 +2084,17 @@ class StataWriter(StataParser): variable_labels : dict Dictionary containing columns as keys and variable labels as values. Each label must be 80 characters or smaller. + compression : str or dict, default 'infer' + For on-the-fly compression of the output dta. If string, specifies + compression mode. If dict, value at key 'method' specifies compression + mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip', + 'xz', None}. If compression mode is 'infer' and `fname` is path-like, + then detect compression from the following extensions: '.gz', '.bz2', + '.zip', or '.xz' (otherwise no compression). If dict and compression + mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above, + other entries passed as additional compression options. + + .. versionadded:: 1.1.0 Returns ------- @@ -2074,7 +2119,12 @@ class StataWriter(StataParser): >>> writer = StataWriter('./data_file.dta', data) >>> writer.write_file() - Or with dates + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = StataWriter('./data_file.zip', data, compression=compression) + >>> writer.write_file() + + Save a DataFrame with dates >>> from datetime import datetime >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date']) >>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'}) @@ -2094,6 +2144,7 @@ def __init__( time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, variable_labels: Optional[Dict[Label, str]] = None, + compression: Union[str, Mapping[str, str], None] = "infer", ): super().__init__() self._convert_dates = {} if convert_dates is None else convert_dates @@ -2102,6 +2153,8 @@ def __init__( self._data_label = data_label self._variable_labels = variable_labels self._own_file = True + self._compression = compression + self._output_file: Optional[BinaryIO] = None # attach nobs, nvars, data, varlist, typlist self._prepare_pandas(data) @@ -2389,7 +2442,12 @@ def _encode_strings(self) -> None: self.data[col] = encoded def write_file(self) -> None: - self._file, self._own_file = _open_file_binary_write(self._fname) + self._file, self._own_file, compression = _open_file_binary_write( + self._fname, self._compression + ) + if compression is not None: + self._output_file = self._file + self._file = BytesIO() try: self._write_header(data_label=self._data_label, time_stamp=self._time_stamp) self._write_map() @@ -2434,6 +2492,12 @@ def _close(self) -> None: """ # Some file-like objects might not support flush assert self._file is not None + if self._output_file is not None: + assert isinstance(self._file, BytesIO) + bio = self._file + bio.seek(0) + self._file = self._output_file + self._file.write(bio.read()) try: self._file.flush() except AttributeError: @@ -2898,6 +2962,17 @@ class StataWriter117(StataWriter): Smaller columns can be converted by including the column name. Using StrLs can reduce output file size when strings are longer than 8 characters, and either frequently repeated or sparse. + compression : str or dict, default 'infer' + For on-the-fly compression of the output dta. If string, specifies + compression mode. If dict, value at key 'method' specifies compression + mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip', + 'xz', None}. If compression mode is 'infer' and `fname` is path-like, + then detect compression from the following extensions: '.gz', '.bz2', + '.zip', or '.xz' (otherwise no compression). If dict and compression + mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above, + other entries passed as additional compression options. + + .. versionadded:: 1.1.0 Returns ------- @@ -2923,8 +2998,12 @@ class StataWriter117(StataWriter): >>> writer = StataWriter117('./data_file.dta', data) >>> writer.write_file() - Or with long strings stored in strl format + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = StataWriter117('./data_file.zip', data, compression=compression) + >>> writer.write_file() + Or with long strings stored in strl format >>> data = pd.DataFrame([['A relatively long string'], [''], ['']], ... columns=['strls']) >>> writer = StataWriter117('./data_file_with_long_strings.dta', data, @@ -2946,6 +3025,7 @@ def __init__( data_label: Optional[str] = None, variable_labels: Optional[Dict[Label, str]] = None, convert_strl: Optional[Sequence[Label]] = None, + compression: Union[str, Mapping[str, str], None] = "infer", ): # Copy to new list since convert_strl might be modified later self._convert_strl: List[Label] = [] @@ -2961,6 +3041,7 @@ def __init__( time_stamp=time_stamp, data_label=data_label, variable_labels=variable_labels, + compression=compression, ) self._map: Dict[str, int] = {} self._strl_blob = b"" @@ -3281,6 +3362,17 @@ class StataWriterUTF8(StataWriter117): The dta version to use. By default, uses the size of data to determine the version. 118 is used if data.shape[1] <= 32767, and 119 is used for storing larger DataFrames. + compression : str or dict, default 'infer' + For on-the-fly compression of the output dta. If string, specifies + compression mode. If dict, value at key 'method' specifies compression + mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip', + 'xz', None}. If compression mode is 'infer' and `fname` is path-like, + then detect compression from the following extensions: '.gz', '.bz2', + '.zip', or '.xz' (otherwise no compression). If dict and compression + mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above, + other entries passed as additional compression options. + + .. versionadded:: 1.1.0 Returns ------- @@ -3308,6 +3400,11 @@ class StataWriterUTF8(StataWriter117): >>> writer = StataWriterUTF8('./data_file.dta', data) >>> writer.write_file() + Directly write a zip file + >>> compression = {"method": "zip", "archive_name": "data_file.dta"} + >>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression) + >>> writer.write_file() + Or with long strings stored in strl format >>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']], @@ -3331,6 +3428,7 @@ def __init__( variable_labels: Optional[Dict[Label, str]] = None, convert_strl: Optional[Sequence[Label]] = None, version: Optional[int] = None, + compression: Union[str, Mapping[str, str], None] = "infer", ): if version is None: version = 118 if data.shape[1] <= 32767 else 119 @@ -3352,6 +3450,7 @@ def __init__( data_label=data_label, variable_labels=variable_labels, convert_strl=convert_strl, + compression=compression, ) # Override version set in StataWriter117 init self._dta_version = version diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index 783e06c9b7f2e..e670e0cdf2ade 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -1,10 +1,13 @@ +import bz2 import datetime as dt from datetime import datetime import gzip import io +import lzma import os import struct import warnings +import zipfile import numpy as np import pytest @@ -1863,3 +1866,60 @@ def test_backward_compat(version, datapath): expected = pd.read_stata(ref) old_dta = pd.read_stata(old) tm.assert_frame_equal(old_dta, expected, check_dtype=False) + + +@pytest.mark.parametrize("version", [114, 117, 118, 119, None]) +@pytest.mark.parametrize("use_dict", [True, False]) +@pytest.mark.parametrize("infer", [True, False]) +def test_compression(compression, version, use_dict, infer): + file_name = "dta_inferred_compression.dta" + if compression: + file_ext = "gz" if compression == "gzip" and not use_dict else compression + file_name += f".{file_ext}" + compression_arg = compression + if infer: + compression_arg = "infer" + if use_dict: + compression_arg = {"method": compression} + + df = DataFrame(np.random.randn(10, 2), columns=list("AB")) + df.index.name = "index" + with tm.ensure_clean(file_name) as path: + df.to_stata(path, version=version, compression=compression_arg) + if compression == "gzip": + with gzip.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "zip": + with zipfile.ZipFile(path, "r") as comp: + fp = io.BytesIO(comp.read(comp.filelist[0])) + elif compression == "bz2": + with bz2.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression == "xz": + with lzma.open(path, "rb") as comp: + fp = io.BytesIO(comp.read()) + elif compression is None: + fp = path + reread = read_stata(fp, index_col="index") + tm.assert_frame_equal(reread, df) + + +@pytest.mark.parametrize("method", ["zip", "infer"]) +@pytest.mark.parametrize("file_ext", [None, "dta", "zip"]) +def test_compression_dict(method, file_ext): + file_name = f"test.{file_ext}" + archive_name = "test.dta" + df = DataFrame(np.random.randn(10, 2), columns=list("AB")) + df.index.name = "index" + with tm.ensure_clean(file_name) as path: + compression = {"method": method, "archive_name": archive_name} + df.to_stata(path, compression=compression) + if method == "zip" or file_ext == "zip": + zp = zipfile.ZipFile(path, "r") + assert len(zp.filelist) == 1 + assert zp.filelist[0].filename == archive_name + fp = io.BytesIO(zp.read(zp.filelist[0])) + else: + fp = path + reread = read_stata(fp, index_col="index") + tm.assert_frame_equal(reread, df)