From 3eb13f750a5cc9eb91916fa2ddb7685c3b9a88c3 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Wed, 24 Jul 2024 18:16:08 +0400 Subject: [PATCH] fix(python): Fix bool/string usage for "column_totals" in `write_excel` --- py-polars/polars/_typing.py | 2 +- .../polars/io/spreadsheet/_write_utils.py | 27 ++++++++++--------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/py-polars/polars/_typing.py b/py-polars/polars/_typing.py index 1c16e21eb637..2e9ab8f3fad1 100644 --- a/py-polars/polars/_typing.py +++ b/py-polars/polars/_typing.py @@ -193,7 +193,7 @@ ] ColumnTotalsDefinition: TypeAlias = Union[ # dict of colname(s) to str, a collection of str, or a boolean - Mapping[Union[str, Collection[str]], str], + Mapping[Union[ColumnNameOrSelector, Tuple[ColumnNameOrSelector]], str], Sequence[str], bool, ] diff --git a/py-polars/polars/io/spreadsheet/_write_utils.py b/py-polars/polars/io/spreadsheet/_write_utils.py index 505cb4c9365f..ab97658aa206 100644 --- a/py-polars/polars/io/spreadsheet/_write_utils.py +++ b/py-polars/polars/io/spreadsheet/_write_utils.py @@ -17,7 +17,7 @@ from polars.datatypes.group import FLOAT_DTYPES, INTEGER_DTYPES from polars.dependencies import json from polars.exceptions import DuplicateError -from polars.selectors import _expand_selector_dicts, _expand_selectors +from polars.selectors import _expand_selector_dicts, _expand_selectors, numeric if TYPE_CHECKING: from typing import Literal @@ -346,24 +346,30 @@ def _map_str(s: Series) -> Series: if cast_cols: df = df.with_columns(cast_cols) + # expand/normalise column totals + if column_totals is True: + column_totals = {numeric(): "sum"} + elif isinstance(column_totals, str): + column_totals = {numeric(): column_totals.lower()} + column_totals = _unpack_multi_column_dict( # type: ignore[assignment] _expand_selector_dicts(df, column_totals, expand_keys=True, expand_values=False) if isinstance(column_totals, dict) else _expand_selectors(df, column_totals) ) - column_formats = _unpack_multi_column_dict( # type: ignore[assignment] - _expand_selector_dicts( - df, column_formats, expand_keys=True, expand_values=False, tuple_keys=True - ) - ) - - # normalise column totals column_total_funcs = ( {col: "sum" for col in column_totals} if isinstance(column_totals, Sequence) else (column_totals.copy() if isinstance(column_totals, dict) else {}) ) + # expand/normalise column formats + column_formats = _unpack_multi_column_dict( # type: ignore[assignment] + _expand_selector_dicts( + df, column_formats, expand_keys=True, expand_values=False, tuple_keys=True + ) + ) + # normalise row totals if not row_totals: row_total_funcs = {} @@ -444,11 +450,6 @@ def _map_str(s: Series) -> Series: if base_type in dtype_formats: fmt = dtype_formats.get(tp, dtype_formats[base_type]) column_formats.setdefault(col, fmt) - if base_type.is_numeric(): - if column_totals is True: - column_total_funcs.setdefault(col, "sum") - elif isinstance(column_totals, str): - column_total_funcs.setdefault(col, column_totals.lower()) if col not in column_formats: column_formats[col] = fmt_default