diff --git a/pandas-stubs/io/stata.pyi b/pandas-stubs/io/stata.pyi index 42342928..2b7de0a6 100644 --- a/pandas-stubs/io/stata.pyi +++ b/pandas-stubs/io/stata.pyi @@ -6,6 +6,7 @@ from typing import ( Hashable, Literal, Sequence, + overload, ) import numpy as np @@ -22,6 +23,7 @@ from pandas._typing import ( WriteBuffer, ) +@overload def read_stata( path: FilePath | ReadBuffer[bytes], convert_dates: bool = ..., @@ -32,57 +34,46 @@ def read_stata( columns: list[HashableT] | None = ..., order_categoricals: bool = ..., chunksize: int | None = ..., - iterator: bool = ..., + *, + iterator: Literal[True], compression: CompressionOptions = ..., storage_options: StorageOptions = ..., -) -> DataFrame | StataReader: ... - -stata_epoch: datetime.datetime = ... -excessive_string_length_error: str +) -> StataReader: ... +@overload +def read_stata( + path: FilePath | ReadBuffer[bytes], + convert_dates: bool, + convert_categoricals: bool, + index_col: str | None, + convert_missing: bool, + preserve_dtypes: bool, + columns: list[HashableT] | None, + order_categoricals: bool, + chunksize: int | None, + iterator: Literal[True], + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., +) -> StataReader: ... +@overload +def read_stata( + path: FilePath | ReadBuffer[bytes], + convert_dates: bool = ..., + convert_categoricals: bool = ..., + index_col: str | None = ..., + convert_missing: bool = ..., + preserve_dtypes: bool = ..., + columns: list[HashableT] | None = ..., + order_categoricals: bool = ..., + chunksize: int | None = ..., + iterator: Literal[False] = ..., + compression: CompressionOptions = ..., + storage_options: StorageOptions = ..., +) -> DataFrame: ... class PossiblePrecisionLoss(Warning): ... - -precision_loss_doc: str - class ValueLabelTypeMismatch(Warning): ... - -value_label_mismatch_doc: str - class InvalidColumnName(Warning): ... -invalid_name_doc: str - -class StataValueLabel: - labname: Hashable = ... - value_labels: list[tuple[float, str]] = ... - text_len: int = ... - off: npt.NDArray[np.int32] = ... - val: npt.NDArray[np.int32] = ... - txt: list[bytes] = ... - n: int = ... - len: int = ... - def __init__( - self, catarray: pd.Series, encoding: Literal["latin-1", "utf-8"] = ... - ) -> None: ... - def generate_value_label(self, byteorder: str) -> bytes: ... - -class StataMissingValue: - MISSING_VALUES: dict[float, str] = ... - bases: tuple[int, int, int] = ... - float32_base: bytes = ... - increment: int = ... - int_value: int = ... - float64_base: bytes = ... - BASE_MISSING_VALUES: dict[str, int] = ... - def __init__(self, value: float) -> None: ... - def __eq__(self, other: object) -> bool: ... - @property - def string(self) -> str: ... - @property - def value(self) -> float: ... - @classmethod - def get_base_missing_value(cls, dtype): ... - class StataParser: DTYPE_MAP: dict[int, np.dtype] = ... DTYPE_MAP_XML: dict[int, np.dtype] = ... @@ -160,19 +151,6 @@ class StataWriter(StataParser): ) -> None: ... def write_file(self) -> None: ... -class StataStrLWriter: - df: DataFrame = ... - columns: Sequence[str] = ... - def __init__( - self, - df: DataFrame, - columns: Sequence[str], - version: int = ..., - byteorder: str | None = ..., - ) -> None: ... - def generate_table(self) -> tuple[dict[str, tuple[int, int]], DataFrame]: ... - def generate_blob(self, gso_table: dict[str, tuple[int, int]]) -> bytes: ... - class StataWriter117(StataWriter): def __init__( self, diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 00000000..1837c96d --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from contextlib import contextmanager +from pathlib import Path +import tempfile +from typing import ( + IO, + Any, +) +import uuid + +import pandas as pd +from pandas import DataFrame +from typing_extensions import assert_type + +from tests import check + +from pandas.io.stata import ( + StataReader, + read_stata, +) + +DF = DataFrame({"a": [1, 2, 3], "b": [0.0, 0.0, 0.0]}) + + +@contextmanager +def ensure_clean(filename=None, return_filelike: bool = False, **kwargs: Any): + """ + Gets a temporary path and agrees to remove on close. + This implementation does not use tempfile.mkstemp to avoid having a file handle. + If the code using the returned path wants to delete the file itself, windows + requires that no program has a file handle to it. + Parameters + ---------- + filename : str (optional) + suffix of the created file. + return_filelike : bool (default False) + if True, returns a file-like which is *always* cleaned. Necessary for + savefig and other functions which want to append extensions. + **kwargs + Additional keywords are passed to open(). + """ + folder = Path(tempfile.gettempdir()) + + if filename is None: + filename = "" + filename = str(uuid.uuid4()) + filename + path = folder / filename + + path.touch() + + handle_or_str: str | IO = str(path) + if return_filelike: + kwargs.setdefault("mode", "w+b") + handle_or_str = open(path, **kwargs) + + try: + yield handle_or_str + finally: + if not isinstance(handle_or_str, str): + handle_or_str.close() + if path.is_file(): + path.unlink() + + +def test_read_stata_df(): + with ensure_clean() as path: + DF.to_stata(path) + check(assert_type(read_stata(path), pd.DataFrame), pd.DataFrame) + + +def test_read_stata_iterator_positional(): + with ensure_clean() as path: + str_path = str(path) + DF.to_stata(str_path) + check( + assert_type( + read_stata( + str_path, False, False, None, False, False, None, False, 2, True + ), + StataReader, + ), + StataReader, + ) + + +def test_read_stata_iterator(): + with ensure_clean() as path: + str_path = str(path) + DF.to_stata(str_path) + check( + assert_type(read_stata(str_path, iterator=True), StataReader), StataReader + )