From 6253355d58eb2df69e16693840caa743c884089f Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Sat, 27 Jan 2024 15:28:47 +0400 Subject: [PATCH] fix(python): multiple `read_excel` updates --- crates/polars-lazy/src/frame/mod.rs | 7 +- py-polars/polars/io/spreadsheet/functions.py | 131 +++++++++++-------- py-polars/tests/unit/io/test_spreadsheet.py | 110 +++++++++++----- py-polars/tests/unit/sql/test_temporal.py | 26 +++- 4 files changed, 183 insertions(+), 91 deletions(-) diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 278590eacf34..342ced25c82d 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -502,7 +502,12 @@ impl LazyFrame { } }) .collect(); - self.with_columns(cast_cols) + + if cast_cols.is_empty() { + self.clone() + } else { + self.with_columns(cast_cols) + } } /// Cast all frame columns to the given dtype, resulting in a new LazyFrame diff --git a/py-polars/polars/io/spreadsheet/functions.py b/py-polars/polars/io/spreadsheet/functions.py index c731a207a94e..3c6767b1f5eb 100644 --- a/py-polars/polars/io/spreadsheet/functions.py +++ b/py-polars/polars/io/spreadsheet/functions.py @@ -10,7 +10,15 @@ import polars._reexport as pl from polars import functions as F -from polars.datatypes import FLOAT_DTYPES, Date, Datetime, Int64, Null, String +from polars.datatypes import ( + FLOAT_DTYPES, + NUMERIC_DTYPES, + Date, + Datetime, + Int64, + Null, + String, +) from polars.dependencies import import_optional from polars.exceptions import NoDataError, ParameterCollisionError from polars.io._utils import _looks_like_url, _process_file_url @@ -32,7 +40,7 @@ def read_excel( sheet_name: str, engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -47,7 +55,7 @@ def read_excel( sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -62,7 +70,7 @@ def read_excel( sheet_name: str, engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> NoReturn: @@ -79,7 +87,7 @@ def read_excel( sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: @@ -94,7 +102,7 @@ def read_excel( sheet_name: None = ..., engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> pl.DataFrame: @@ -109,7 +117,7 @@ def read_excel( sheet_name: list[str] | tuple[str], engine: ExcelSpreadsheetEngine | None = ..., engine_options: dict[str, Any] | None = ..., - read_csv_options: dict[str, Any] | None = ..., + read_options: dict[str, Any] | None = ..., schema_overrides: SchemaDict | None = ..., raise_if_empty: bool = ..., ) -> dict[str, pl.DataFrame]: @@ -117,6 +125,7 @@ def read_excel( @deprecate_renamed_parameter("xlsx2csv_options", "engine_options", version="0.20.6") +@deprecate_renamed_parameter("read_csv_options", "read_options", version="0.20.7") def read_excel( source: str | BytesIO | Path | BinaryIO | bytes, *, @@ -124,7 +133,7 @@ def read_excel( sheet_name: str | list[str] | tuple[str] | None = None, engine: ExcelSpreadsheetEngine | None = None, engine_options: dict[str, Any] | None = None, - read_csv_options: dict[str, Any] | None = None, + read_options: dict[str, Any] | None = None, schema_overrides: SchemaDict | None = None, raise_if_empty: bool = True, ) -> pl.DataFrame | dict[str, pl.DataFrame]: @@ -157,7 +166,7 @@ def read_excel( * "xlsx2csv": converts the data to an in-memory CSV before using the native polars `read_csv` method to parse the result. You can pass `engine_options` - and `read_csv_options` to refine the conversion. + and `read_options` to refine the conversion. * "openpyxl": this engine is significantly slower than `xlsx2csv` but supports additional automatic type inference; potentially useful if you are otherwise unable to parse your sheet with the (default) `xlsx2csv` engine in @@ -170,13 +179,19 @@ def read_excel( other options, using the `fastexcel` module to bind calamine. engine_options - Extra options passed to the underlying engine's Workbook-reading constructor. - For example, if using `xlsx2csv` you could pass `{"skip_empty_lines": True}`. - read_csv_options - Extra options passed to :func:`read_csv` for parsing the CSV file returned by - `xlsx2csv.Xlsx2csv().convert()`. This option is *only* applicable when using - the `xlsx2csv` engine. For example, you could pass ``{"has_header": False, - "new_columns": ["a", "b", "c"], "infer_schema_length": None}`` + Additional options passed to the underlying engine's primary parsing + constructor (given below), if supported: + + * "xlsx2csv": `Xlsx2csv` + * "openpyxl": `load_workbook` + * "pyxlsb": `open_workbook` + * "calamine": `n/a` + + read_options + Extra options passed to the function that reads the sheet data (for example, + the `read_csv` method if using the "xlsx2csv" engine, to which you could + pass ``{"infer_schema_length": None}``, or the `load_sheet_by_name` method + if using the "calamine" engine. schema_overrides Support type specification or override of one or more columns. raise_if_empty @@ -187,7 +202,7 @@ def read_excel( ----- When using the default `xlsx2csv` engine the target Excel sheet is first converted to CSV using `xlsx2csv.Xlsx2csv(source).convert()` and then parsed with Polars' - :func:`read_csv` function. You can pass additional options to `read_csv_options` + :func:`read_csv` function. You can pass additional options to `read_options` to influence this part of the parsing pipeline. Returns @@ -209,13 +224,13 @@ def read_excel( Read table data from sheet 3 in an Excel workbook as a DataFrame while skipping empty lines in the sheet. As sheet 3 does not have a header row and the default engine is `xlsx2csv` you can pass the necessary additional settings for this - to the "read_csv_options" parameter; these will be passed to :func:`read_csv`. + to the "read_options" parameter; these will be passed to :func:`read_csv`. >>> pl.read_excel( ... source="test.xlsx", ... sheet_id=3, ... engine_options={"skip_empty_lines": True}, - ... read_csv_options={"has_header": False, "new_columns": ["a", "b", "c"]}, + ... read_options={"has_header": False, "new_columns": ["a", "b", "c"]}, ... ) # doctest: +SKIP If the correct datatypes can't be determined you can use `schema_overrides` and/or @@ -227,14 +242,14 @@ def read_excel( >>> pl.read_excel( ... source="test.xlsx", - ... read_csv_options={"infer_schema_length": 1000}, + ... read_options={"infer_schema_length": 1000}, ... schema_overrides={"dt": pl.Date}, ... ) # doctest: +SKIP The `openpyxl` package can also be used to parse Excel data; it has slightly better default type detection, but is slower than `xlsx2csv`. If you have a sheet that is better read using this package you can set the engine as "openpyxl" (if you - use this engine then `read_csv_options` cannot be set). + use this engine then `read_options` cannot be set). >>> pl.read_excel( ... source="test.xlsx", @@ -242,17 +257,13 @@ def read_excel( ... schema_overrides={"dt": pl.Datetime, "value": pl.Int32}, ... ) # doctest: +SKIP """ - if engine and engine != "xlsx2csv" and read_csv_options: - msg = f"cannot specify `read_csv_options` when engine={engine!r}" - raise ValueError(msg) - return _read_spreadsheet( sheet_id, sheet_name, source=source, engine=engine, engine_options=engine_options, - read_csv_options=read_csv_options, + read_options=read_options, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) @@ -390,7 +401,7 @@ def read_ods( source=source, engine="ods", engine_options={}, - read_csv_options={}, + read_options={}, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) @@ -402,7 +413,7 @@ def _read_spreadsheet( source: str | BytesIO | Path | BinaryIO | bytes, engine: ExcelSpreadsheetEngine | Literal["ods"] | None, engine_options: dict[str, Any] | None = None, - read_csv_options: dict[str, Any] | None = None, + read_options: dict[str, Any] | None = None, schema_overrides: SchemaDict | None = None, *, raise_if_empty: bool = True, @@ -429,8 +440,8 @@ def _read_spreadsheet( name: reader_fn( parser=parser, sheet_name=name, - read_csv_options=read_csv_options, schema_overrides=schema_overrides, + read_options=(read_options or {}), raise_if_empty=raise_if_empty, ) for name in sheet_names @@ -571,7 +582,7 @@ def _initialise_spreadsheet_parser( def _csv_buffer_to_frame( csv: StringIO, separator: str, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -587,23 +598,23 @@ def _csv_buffer_to_frame( raise NoDataError(msg) return pl.DataFrame() - if read_csv_options is None: - read_csv_options = {} + if read_options is None: + read_options = {} if schema_overrides: - if (csv_dtypes := read_csv_options.get("dtypes", {})) and set( + if (csv_dtypes := read_options.get("dtypes", {})) and set( csv_dtypes ).intersection(schema_overrides): - msg = "cannot specify columns in both `schema_overrides` and `read_csv_options['dtypes']`" + msg = "cannot specify columns in both `schema_overrides` and `read_options['dtypes']`" raise ParameterCollisionError(msg) - read_csv_options = read_csv_options.copy() - read_csv_options["dtypes"] = {**csv_dtypes, **schema_overrides} + read_options = read_options.copy() + read_options["dtypes"] = {**csv_dtypes, **schema_overrides} # otherwise rewind the buffer and parse as csv csv.seek(0) df = read_csv( csv, separator=separator, - **read_csv_options, + **read_options, ) return _drop_null_data(df, raise_if_empty=raise_if_empty) @@ -616,7 +627,14 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: # will be named as "_duplicated_{n}" (or "__UNNAMED__{n}" from calamine) if col_name == "" or re.match(r"(_duplicated_|__UNNAMED__)\d+$", col_name): col = df[col_name] - if col.dtype == Null or col.null_count() == len(df): + if ( + col.dtype == Null + or col.null_count() == len(df) + or ( + col.dtype in NUMERIC_DTYPES + and col.replace(0, None).null_count() == len(df) + ) + ): null_cols.append(col_name) if null_cols: df = df.drop(*null_cols) @@ -637,7 +655,7 @@ def _drop_null_data(df: pl.DataFrame, *, raise_if_empty: bool) -> pl.DataFrame: def _read_spreadsheet_ods( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -705,7 +723,7 @@ def _read_spreadsheet_ods( def _read_spreadsheet_openpyxl( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -753,12 +771,12 @@ def _read_spreadsheet_openpyxl( def _read_spreadsheet_calamine( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, ) -> pl.DataFrame: - ws = parser.load_sheet_by_name(sheet_name) + ws = parser.load_sheet_by_name(sheet_name, **read_options) df = ws.to_polars() if schema_overrides: @@ -766,26 +784,27 @@ def _read_spreadsheet_calamine( df = _drop_null_data(df, raise_if_empty=raise_if_empty) - # calamine may read integer data as float; cast back to int where possible. - # do a similar downcast check for datetime -> date dtypes. + # refine dtypes type_checks = [] for c, dtype in df.schema.items(): + # may read integer data as float; cast back to int where possible. if dtype in FLOAT_DTYPES: - check_cast = [F.col(c).floor().eq_missing(F.col(c)), F.col(c).cast(Int64)] + check_cast = [F.col(c).floor().eq(F.col(c)), F.col(c).cast(Int64)] type_checks.append(check_cast) + # do a similar check for datetime columns that have only 00:00:00 times. elif dtype == Datetime: check_cast = [ - F.col(c).drop_nulls().dt.time().eq_missing(time(0, 0, 0)), + F.col(c).dt.time().eq(time(0, 0, 0)), F.col(c).cast(Date), ] type_checks.append(check_cast) if type_checks: - apply_downcast = df.select([d[0] for d in type_checks]).row(0) - - # do a similar check for datetime columns that have only 00:00:00 times. + apply_cast = df.select( + [d[0].all(ignore_nulls=True) for d in type_checks], + ).row(0) if downcast := [ - cast for apply, (_, cast) in zip(apply_downcast, type_checks) if apply + cast for apply, (_, cast) in zip(apply_cast, type_checks) if apply ]: df = df.with_columns(*downcast) @@ -795,7 +814,7 @@ def _read_spreadsheet_calamine( def _read_spreadsheet_pyxlsb( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -850,7 +869,7 @@ def _read_spreadsheet_pyxlsb( def _read_spreadsheet_xlsx2csv( parser: Any, sheet_name: str | None, - read_csv_options: dict[str, Any] | None, + read_options: dict[str, Any], schema_overrides: SchemaDict | None, *, raise_if_empty: bool, @@ -861,14 +880,14 @@ def _read_spreadsheet_xlsx2csv( outfile=csv_buffer, sheetname=sheet_name, ) - if read_csv_options is None: - read_csv_options = {} - read_csv_options.setdefault("truncate_ragged_lines", True) + if read_options is None: + read_options = {} + read_options.setdefault("truncate_ragged_lines", True) return _csv_buffer_to_frame( csv_buffer, separator=",", - read_csv_options=read_csv_options, + read_options=read_options, schema_overrides=schema_overrides, raise_if_empty=raise_if_empty, ) diff --git a/py-polars/tests/unit/io/test_spreadsheet.py b/py-polars/tests/unit/io/test_spreadsheet.py index 91ee0d634145..09ddffa06b75 100644 --- a/py-polars/tests/unit/io/test_spreadsheet.py +++ b/py-polars/tests/unit/io/test_spreadsheet.py @@ -410,7 +410,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N df2 = pl.read_excel( path_xlsx, sheet_name="test4", - read_csv_options={"dtypes": {"cardinality": pl.UInt16}}, + read_options={"dtypes": {"cardinality": pl.UInt16}}, ).drop_nulls() assert df2.schema["cardinality"] == pl.UInt16 assert df2.schema["rows_by_key"] == pl.Float64 @@ -420,7 +420,7 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N path_xlsx, sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, - read_csv_options={ + read_options={ "dtypes": { "rows_by_key": pl.Float32, "iter_groups": pl.Float32, @@ -453,12 +453,12 @@ def test_schema_overrides(path_xlsx: Path, path_xlsb: Path, path_ods: Path) -> N ) with pytest.raises(ParameterCollisionError): - # cannot specify 'cardinality' in both schema_overrides and read_csv_options + # cannot specify 'cardinality' in both schema_overrides and read_options pl.read_excel( path_xlsx, sheet_name="test4", schema_overrides={"cardinality": pl.UInt16}, - read_csv_options={"dtypes": {"cardinality": pl.Int32}}, + read_options={"dtypes": {"cardinality": pl.Int32}}, ) # read multiple sheets in conjunction with 'schema_overrides' @@ -625,29 +625,40 @@ def test_excel_round_trip(write_params: dict[str, Any]) -> None: "val": [100.5, 55.0, -99.5], } ) - header_opts = ( - {} - if write_params.get("include_header", True) - else {"has_header": False, "new_columns": ["dtm", "str", "val"]} - ) - fmt_strptime = "%Y-%m-%d" - if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": - fmt_strptime = "%d-%m-%Y" - # write to an xlsx with polars, using various parameters... - xls = BytesIO() - _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) + engine: ExcelSpreadsheetEngine + for engine in ("calamine", "xlsx2csv"): # type: ignore[assignment] + # TODO: remove the skip when calamine supported on windows + if sys.platform == "win32" and engine == "calamine": + continue + + table_params = ( + {} + if write_params.get("include_header", True) + else ( + {"has_header": False, "new_columns": ["dtm", "str", "val"]} + if engine == "xlsx2csv" + else {"header_row": None, "column_names": ["dtm", "str", "val"]} + ) + ) + fmt_strptime = "%Y-%m-%d" + if write_params.get("dtype_formats", {}).get(pl.Date) == "dd-mm-yyyy": + fmt_strptime = "%d-%m-%Y" - # ...and read it back again: - xldf = pl.read_excel( - xls, - sheet_name="data", - read_csv_options=header_opts, - )[:3] - xldf = xldf.select(xldf.columns[:3]).with_columns( - pl.col("dtm").str.strptime(pl.Date, fmt_strptime) - ) - assert_frame_equal(df, xldf) + # write to an xlsx with polars, using various parameters... + xls = BytesIO() + _wb = df.write_excel(workbook=xls, worksheet="data", **write_params) + + # ...and read it back again: + xldf = pl.read_excel( + xls, + sheet_name="data", + engine=engine, + read_options=table_params, + )[:3].select(df.columns[:3]) + if engine == "xlsx2csv": + xldf = xldf.with_columns(pl.col("dtm").str.strptime(pl.Date, fmt_strptime)) + assert_frame_equal(df, xldf) @pytest.mark.parametrize( @@ -887,11 +898,44 @@ def test_excel_hidden_columns( assert_frame_equal(df, read_df) -def test_invalid_engine_options() -> None: - # read_csv_options only applicable with 'xlsx2csv' engine - with pytest.raises(ValueError, match="cannot specify `read_csv_options`"): - pl.read_excel( - "", - engine="openpyxl", - read_csv_options={"sep": "\t"}, - ) +@pytest.mark.parametrize( + "engine", + [ + "xlsx2csv", + "openpyxl", + pytest.param( + "calamine", + marks=pytest.mark.skipif( + sys.platform == "win32", + reason="fastexcel not yet available on Windows", + ), + ), + ], +) +def test_excel_type_inference_with_nulls(engine: ExcelSpreadsheetEngine) -> None: + df = pl.DataFrame( + { + "a": [1, 2, None], + "b": [1.0, None, 3.5], + "c": ["x", None, "z"], + "d": [True, False, None], + "e": [date(2023, 1, 1), None, date(2023, 1, 4)], + "f": [ + datetime(2023, 1, 1), + datetime(2000, 10, 10, 10, 10), + None, + ], + } + ) + xls = BytesIO() + df.write_excel(xls) + + read_df = pl.read_excel( + xls, + engine=engine, + schema_overrides={ + "e": pl.Date, + "f": pl.Datetime("us"), + }, + ) + assert_frame_equal(df, read_df) diff --git a/py-polars/tests/unit/sql/test_temporal.py b/py-polars/tests/unit/sql/test_temporal.py index da8cdeff784b..77bf04b44fa5 100644 --- a/py-polars/tests/unit/sql/test_temporal.py +++ b/py-polars/tests/unit/sql/test_temporal.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date, datetime, time -from typing import Any +from typing import Any, Literal import pytest @@ -32,6 +32,30 @@ def test_date() -> None: assert_frame_equal(result, expected) +@pytest.mark.parametrize("time_unit", ["ms", "us", "ns"]) +def test_datetime_to_time(time_unit: Literal["ns", "us", "ms"]) -> None: + df = pl.DataFrame( + { + "dtm": [ + datetime(2099, 12, 31, 23, 59, 59), + datetime(1999, 12, 31, 12, 30, 30), + datetime(1969, 12, 31, 1, 1, 1), + datetime(1899, 12, 31, 0, 0, 0), + ], + }, + schema={"dtm": pl.Datetime(time_unit)}, + ) + with pl.SQLContext(df=df, eager_execution=True) as ctx: + result = ctx.execute("SELECT dtm::time as tm from df")["tm"].to_list() + + assert result == [ + time(23, 59, 59), + time(12, 30, 30), + time(1, 1, 1), + time(0, 0, 0), + ] + + @pytest.mark.parametrize( ("part", "dtype", "expected"), [