Skip to content

Commit

Permalink
ENH: Add compression to stata writers (#34013)
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed May 12, 2020
1 parent f5ea8ca commit 5724728
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 9 deletions.
7 changes: 6 additions & 1 deletion doc/source/whatsnew/v1.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ Other enhancements
- The ``ExtensionArray`` class has now an :meth:`~pandas.arrays.ExtensionArray.equals`
method, similarly to :meth:`Series.equals` (:issue:`27081`).
- The minimum suppported dta version has increased to 105 in :meth:`~pandas.io.stata.read_stata` and :class:`~pandas.io.stata.StataReader` (:issue:`26667`).

- :meth:`~pandas.core.frame.DataFrame.to_stata` supports compression using the ``compression``
keyword argument. Compression can either be inferred or explicitly set using a string or a
dictionary containing both the method and any additional arguments that are passed to the
compression library. Compression was also added to the low-level Stata-file writers
:class:`~pandas.io.stata.StataWriter`, :class:`~pandas.io.stata.StataWriter117`,
and :class:`~pandas.io.stata.StataWriterUTF8` (:issue:`26599`).

.. ---------------------------------------------------------------------------
Expand Down
16 changes: 16 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Iterable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -2015,6 +2016,7 @@ def to_stata(
variable_labels: Optional[Dict[Label, str]] = None,
version: Optional[int] = 114,
convert_strl: Optional[Sequence[Label]] = None,
compression: Union[str, Mapping[str, str], None] = "infer",
) -> None:
"""
Export DataFrame object to Stata dta format.
Expand Down Expand Up @@ -2078,6 +2080,19 @@ def to_stata(
.. versionadded:: 0.23.0
compression : str or dict, default 'infer'
For on-the-fly compression of the output dta. If string, specifies
compression mode. If dict, value at key 'method' specifies
compression mode. Compression mode must be one of {'infer', 'gzip',
'bz2', 'zip', 'xz', None}. If compression mode is 'infer' and
`fname` is path-like, then detect compression from the following
extensions: '.gz', '.bz2', '.zip', or '.xz' (otherwise no
compression). If dict and compression mode is one of {'zip',
'gzip', 'bz2'}, or inferred as one of the above, other entries
passed as additional compression options.
.. versionadded:: 1.1.0
Raises
------
NotImplementedError
Expand Down Expand Up @@ -2133,6 +2148,7 @@ def to_stata(
data_label=data_label,
write_index=write_index,
variable_labels=variable_labels,
compression=compression,
**kwargs,
)
writer.write_file()
Expand Down
115 changes: 107 additions & 8 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@
from pathlib import Path
import struct
import sys
from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union
from typing import (
Any,
AnyStr,
BinaryIO,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import warnings

from dateutil.relativedelta import relativedelta
Expand Down Expand Up @@ -47,7 +58,13 @@
from pandas.core.indexes.base import Index
from pandas.core.series import Series

from pandas.io.common import get_filepath_or_buffer, stringify_path
from pandas.io.common import (
get_compression_method,
get_filepath_or_buffer,
get_handle,
infer_compression,
stringify_path,
)

_version_error = (
"Version of given Stata file is {version}. pandas supports importing "
Expand Down Expand Up @@ -1854,13 +1871,18 @@ def read_stata(
return data


def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
def _open_file_binary_write(
fname: FilePathOrBuffer, compression: Union[str, Mapping[str, str], None],
) -> Tuple[BinaryIO, bool, Optional[Union[str, Mapping[str, str]]]]:
"""
Open a binary file or no-op if file-like.
Parameters
----------
fname : string path, path object or buffer
The file name or buffer.
compression : {str, dict, None}
The compression method to use.
Returns
-------
Expand All @@ -1871,9 +1893,21 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]:
"""
if hasattr(fname, "write"):
# See https://github.com/python/mypy/issues/1424 for hasattr challenges
return fname, False # type: ignore
return fname, False, None # type: ignore
elif isinstance(fname, (str, Path)):
return open(fname, "wb"), True
# Extract compression mode as given, if dict
compression_typ, compression_args = get_compression_method(compression)
compression_typ = infer_compression(fname, compression_typ)
path_or_buf, _, compression_typ, _ = get_filepath_or_buffer(
fname, compression=compression_typ
)
if compression_typ is not None:
compression = compression_args
compression["method"] = compression_typ
else:
compression = None
f, _ = get_handle(path_or_buf, "wb", compression=compression, is_text=False)
return f, True, compression
else:
raise TypeError("fname must be a binary file, buffer or path-like.")

Expand Down Expand Up @@ -2050,6 +2084,17 @@ class StataWriter(StataParser):
variable_labels : dict
Dictionary containing columns as keys and variable labels as values.
Each label must be 80 characters or smaller.
compression : str or dict, default 'infer'
For on-the-fly compression of the output dta. If string, specifies
compression mode. If dict, value at key 'method' specifies compression
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
then detect compression from the following extensions: '.gz', '.bz2',
'.zip', or '.xz' (otherwise no compression). If dict and compression
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
other entries passed as additional compression options.
.. versionadded:: 1.1.0
Returns
-------
Expand All @@ -2074,7 +2119,12 @@ class StataWriter(StataParser):
>>> writer = StataWriter('./data_file.dta', data)
>>> writer.write_file()
Or with dates
Directly write a zip file
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
>>> writer = StataWriter('./data_file.zip', data, compression=compression)
>>> writer.write_file()
Save a DataFrame with dates
>>> from datetime import datetime
>>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date'])
>>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'})
Expand All @@ -2094,6 +2144,7 @@ def __init__(
time_stamp: Optional[datetime.datetime] = None,
data_label: Optional[str] = None,
variable_labels: Optional[Dict[Label, str]] = None,
compression: Union[str, Mapping[str, str], None] = "infer",
):
super().__init__()
self._convert_dates = {} if convert_dates is None else convert_dates
Expand All @@ -2102,6 +2153,8 @@ def __init__(
self._data_label = data_label
self._variable_labels = variable_labels
self._own_file = True
self._compression = compression
self._output_file: Optional[BinaryIO] = None
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)

Expand Down Expand Up @@ -2389,7 +2442,12 @@ def _encode_strings(self) -> None:
self.data[col] = encoded

def write_file(self) -> None:
self._file, self._own_file = _open_file_binary_write(self._fname)
self._file, self._own_file, compression = _open_file_binary_write(
self._fname, self._compression
)
if compression is not None:
self._output_file = self._file
self._file = BytesIO()
try:
self._write_header(data_label=self._data_label, time_stamp=self._time_stamp)
self._write_map()
Expand Down Expand Up @@ -2434,6 +2492,12 @@ def _close(self) -> None:
"""
# Some file-like objects might not support flush
assert self._file is not None
if self._output_file is not None:
assert isinstance(self._file, BytesIO)
bio = self._file
bio.seek(0)
self._file = self._output_file
self._file.write(bio.read())
try:
self._file.flush()
except AttributeError:
Expand Down Expand Up @@ -2898,6 +2962,17 @@ class StataWriter117(StataWriter):
Smaller columns can be converted by including the column name. Using
StrLs can reduce output file size when strings are longer than 8
characters, and either frequently repeated or sparse.
compression : str or dict, default 'infer'
For on-the-fly compression of the output dta. If string, specifies
compression mode. If dict, value at key 'method' specifies compression
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
then detect compression from the following extensions: '.gz', '.bz2',
'.zip', or '.xz' (otherwise no compression). If dict and compression
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
other entries passed as additional compression options.
.. versionadded:: 1.1.0
Returns
-------
Expand All @@ -2923,8 +2998,12 @@ class StataWriter117(StataWriter):
>>> writer = StataWriter117('./data_file.dta', data)
>>> writer.write_file()
Or with long strings stored in strl format
Directly write a zip file
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
>>> writer = StataWriter117('./data_file.zip', data, compression=compression)
>>> writer.write_file()
Or with long strings stored in strl format
>>> data = pd.DataFrame([['A relatively long string'], [''], ['']],
... columns=['strls'])
>>> writer = StataWriter117('./data_file_with_long_strings.dta', data,
Expand All @@ -2946,6 +3025,7 @@ def __init__(
data_label: Optional[str] = None,
variable_labels: Optional[Dict[Label, str]] = None,
convert_strl: Optional[Sequence[Label]] = None,
compression: Union[str, Mapping[str, str], None] = "infer",
):
# Copy to new list since convert_strl might be modified later
self._convert_strl: List[Label] = []
Expand All @@ -2961,6 +3041,7 @@ def __init__(
time_stamp=time_stamp,
data_label=data_label,
variable_labels=variable_labels,
compression=compression,
)
self._map: Dict[str, int] = {}
self._strl_blob = b""
Expand Down Expand Up @@ -3281,6 +3362,17 @@ class StataWriterUTF8(StataWriter117):
The dta version to use. By default, uses the size of data to determine
the version. 118 is used if data.shape[1] <= 32767, and 119 is used
for storing larger DataFrames.
compression : str or dict, default 'infer'
For on-the-fly compression of the output dta. If string, specifies
compression mode. If dict, value at key 'method' specifies compression
mode. Compression mode must be one of {'infer', 'gzip', 'bz2', 'zip',
'xz', None}. If compression mode is 'infer' and `fname` is path-like,
then detect compression from the following extensions: '.gz', '.bz2',
'.zip', or '.xz' (otherwise no compression). If dict and compression
mode is one of {'zip', 'gzip', 'bz2'}, or inferred as one of the above,
other entries passed as additional compression options.
.. versionadded:: 1.1.0
Returns
-------
Expand Down Expand Up @@ -3308,6 +3400,11 @@ class StataWriterUTF8(StataWriter117):
>>> writer = StataWriterUTF8('./data_file.dta', data)
>>> writer.write_file()
Directly write a zip file
>>> compression = {"method": "zip", "archive_name": "data_file.dta"}
>>> writer = StataWriterUTF8('./data_file.zip', data, compression=compression)
>>> writer.write_file()
Or with long strings stored in strl format
>>> data = pd.DataFrame([['ᴀ relatively long ŝtring'], [''], ['']],
Expand All @@ -3331,6 +3428,7 @@ def __init__(
variable_labels: Optional[Dict[Label, str]] = None,
convert_strl: Optional[Sequence[Label]] = None,
version: Optional[int] = None,
compression: Union[str, Mapping[str, str], None] = "infer",
):
if version is None:
version = 118 if data.shape[1] <= 32767 else 119
Expand All @@ -3352,6 +3450,7 @@ def __init__(
data_label=data_label,
variable_labels=variable_labels,
convert_strl=convert_strl,
compression=compression,
)
# Override version set in StataWriter117 init
self._dta_version = version
Expand Down
60 changes: 60 additions & 0 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import bz2
import datetime as dt
from datetime import datetime
import gzip
import io
import lzma
import os
import struct
import warnings
import zipfile

import numpy as np
import pytest
Expand Down Expand Up @@ -1863,3 +1866,60 @@ def test_backward_compat(version, datapath):
expected = pd.read_stata(ref)
old_dta = pd.read_stata(old)
tm.assert_frame_equal(old_dta, expected, check_dtype=False)


@pytest.mark.parametrize("version", [114, 117, 118, 119, None])
@pytest.mark.parametrize("use_dict", [True, False])
@pytest.mark.parametrize("infer", [True, False])
def test_compression(compression, version, use_dict, infer):
file_name = "dta_inferred_compression.dta"
if compression:
file_ext = "gz" if compression == "gzip" and not use_dict else compression
file_name += f".{file_ext}"
compression_arg = compression
if infer:
compression_arg = "infer"
if use_dict:
compression_arg = {"method": compression}

df = DataFrame(np.random.randn(10, 2), columns=list("AB"))
df.index.name = "index"
with tm.ensure_clean(file_name) as path:
df.to_stata(path, version=version, compression=compression_arg)
if compression == "gzip":
with gzip.open(path, "rb") as comp:
fp = io.BytesIO(comp.read())
elif compression == "zip":
with zipfile.ZipFile(path, "r") as comp:
fp = io.BytesIO(comp.read(comp.filelist[0]))
elif compression == "bz2":
with bz2.open(path, "rb") as comp:
fp = io.BytesIO(comp.read())
elif compression == "xz":
with lzma.open(path, "rb") as comp:
fp = io.BytesIO(comp.read())
elif compression is None:
fp = path
reread = read_stata(fp, index_col="index")
tm.assert_frame_equal(reread, df)


@pytest.mark.parametrize("method", ["zip", "infer"])
@pytest.mark.parametrize("file_ext", [None, "dta", "zip"])
def test_compression_dict(method, file_ext):
file_name = f"test.{file_ext}"
archive_name = "test.dta"
df = DataFrame(np.random.randn(10, 2), columns=list("AB"))
df.index.name = "index"
with tm.ensure_clean(file_name) as path:
compression = {"method": method, "archive_name": archive_name}
df.to_stata(path, compression=compression)
if method == "zip" or file_ext == "zip":
zp = zipfile.ZipFile(path, "r")
assert len(zp.filelist) == 1
assert zp.filelist[0].filename == archive_name
fp = io.BytesIO(zp.read(zp.filelist[0]))
else:
fp = path
reread = read_stata(fp, index_col="index")
tm.assert_frame_equal(reread, df)

0 comments on commit 5724728

Please sign in to comment.