From 221bea8ce71a45c2eb049cab82d519aa03e62861 Mon Sep 17 00:00:00 2001 From: Rui He <118280419+ruihe774@users.noreply.github.com> Date: Fri, 5 Jul 2024 20:41:25 +0800 Subject: [PATCH] fix(python): Fix handling of TextIOWrapper in write_csv (#17328) --- .github/workflows/test-python.yml | 1 + py-polars/polars/dataframe/frame.py | 5 +- py-polars/src/file.rs | 93 ++++++++++++++++------------- py-polars/tests/unit/io/test_csv.py | 26 ++++++++ 4 files changed, 81 insertions(+), 44 deletions(-) diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index 10a4adc54615..6c8659db8fab 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -25,6 +25,7 @@ concurrency: env: RUSTFLAGS: -C debuginfo=0 # Do not produce debug symbols to keep memory usage down RUST_BACKTRACE: 1 + PYTHONUTF8: 1 defaults: run: diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index ceb19ac14a3f..e1ad0acb5d52 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -7,7 +7,7 @@ import random from collections import defaultdict from collections.abc import Sized -from io import BytesIO, StringIO, TextIOWrapper +from io import BytesIO, StringIO from operator import itemgetter from pathlib import Path from typing import ( @@ -24,7 +24,6 @@ NoReturn, Sequence, TypeVar, - cast, get_args, overload, ) @@ -2695,8 +2694,6 @@ def write_csv( should_return_buffer = True elif isinstance(file, (str, os.PathLike)): file = normalize_filepath(file) - elif isinstance(file, TextIOWrapper): - file = cast(TextIOWrapper, file.buffer) self._df.write_csv( file, diff --git a/py-polars/src/file.rs b/py-polars/src/file.rs index 56ff242f06fe..adee1d3125f3 100644 --- a/py-polars/src/file.rs +++ b/py-polars/src/file.rs @@ -51,36 +51,34 @@ impl PyFileLikeObject { Cursor::new(buf) } - /// Same as `PyFileLikeObject::new`, but validates that the underlying + /// Validates that the underlying /// python object has a `read`, `write`, and `seek` methods in respect to parameters. /// Will return a `TypeError` if object does not have `read`, `seek`, and `write` methods. - pub fn with_requirements( - object: PyObject, + pub fn ensure_requirements( + object: &Bound, read: bool, write: bool, seek: bool, - ) -> PyResult { - Python::with_gil(|py| { - if read && object.getattr(py, "read").is_err() { - return Err(PyErr::new::( - "Object does not have a .read() method.", - )); - } + ) -> PyResult<()> { + if read && object.getattr("read").is_err() { + return Err(PyErr::new::( + "Object does not have a .read() method.", + )); + } - if seek && object.getattr(py, "seek").is_err() { - return Err(PyErr::new::( - "Object does not have a .seek() method.", - )); - } + if seek && object.getattr("seek").is_err() { + return Err(PyErr::new::( + "Object does not have a .seek() method.", + )); + } - if write && object.getattr(py, "write").is_err() { - return Err(PyErr::new::( - "Object does not have a .write() method.", - )); - } + if write && object.getattr("write").is_err() { + return Err(PyErr::new::( + "Object does not have a .write() method.", + )); + } - Ok(PyFileLikeObject::new(object)) - }) + Ok(()) } } @@ -196,7 +194,7 @@ fn get_either_file_and_path( write: bool, ) -> PyResult<(EitherRustPythonFile, Option)> { Python::with_gil(|py| { - let py_f = py_f.bind(py); + let py_f = py_f.into_bound(py); if let Ok(s) = py_f.extract::>() { let file_path = std::path::Path::new(&*s); let file_path = resolve_homedir(file_path); @@ -208,6 +206,15 @@ fn get_either_file_and_path( Ok((EitherRustPythonFile::Rust(f), Some(file_path))) } else { let io = py.import_bound("io").unwrap(); + let is_utf8_encoding = |py_f: &Bound| -> PyResult { + let encoding = py_f.getattr("encoding")?; + let encoding = encoding.extract::>()?; + Ok(encoding.eq_ignore_ascii_case("utf-8") || encoding.eq_ignore_ascii_case("utf8")) + }; + let flush_file = |py_f: &Bound| -> PyResult<()> { + py_f.getattr("flush")?.call0()?; + Ok(()) + }; #[cfg(target_family = "unix")] if let Some(fd) = ((py_f.is_exact_instance(&io.getattr("FileIO").unwrap()) || py_f.is_exact_instance(&io.getattr("BufferedReader").unwrap()) @@ -215,22 +222,8 @@ fn get_either_file_and_path( || py_f.is_exact_instance(&io.getattr("BufferedRandom").unwrap()) || py_f.is_exact_instance(&io.getattr("BufferedRWPair").unwrap()) || (py_f.is_exact_instance(&io.getattr("TextIOWrapper").unwrap()) - && py_f - .getattr("encoding") - .ok() - .filter(|encoding| match encoding.extract::>() { - Ok(encoding) => { - encoding.eq_ignore_ascii_case("utf-8") - || encoding.eq_ignore_ascii_case("utf8") - }, - Err(_) => false, - }) - .is_some())) - && (!write - || py_f - .getattr("flush") - .and_then(|flush| flush.call0()) - .is_ok())) + && is_utf8_encoding(&py_f)?)) + && (!write || flush_file(&py_f).is_ok())) .then(|| { py_f.getattr("fileno") .and_then(|fileno| fileno.call0()) @@ -256,7 +249,27 @@ fn get_either_file_and_path( Ensure you pass a path to the file instead of a python file object when possible for best \ performance."); } - let f = PyFileLikeObject::with_requirements(py_f.to_object(py), !write, write, !write)?; + // Unwrap TextIOWrapper + // Allow subclasses to allow things like pytest.capture.CaptureIO + let py_f = if py_f + .is_instance(&io.getattr("TextIOWrapper").unwrap()) + .unwrap_or_default() + { + if !is_utf8_encoding(&py_f)? { + return Err(PyPolarsErr::from( + polars_err!(InvalidOperation: "file encoding is not UTF-8"), + ) + .into()); + } + if write { + flush_file(&py_f)?; + } + py_f.getattr("buffer")? + } else { + py_f + }; + PyFileLikeObject::ensure_requirements(&py_f, !write, write, !write)?; + let f = PyFileLikeObject::new(py_f.to_object(py)); Ok((EitherRustPythonFile::Py(f), None)) } }) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 113d0cedc89c..94dc428a1223 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -2221,3 +2221,29 @@ def test_projection_applied_on_file_with_no_rows_16606(tmp_path: Path) -> None: out = pl.scan_csv(path).select(columns).collect().columns assert out == columns + + +@pytest.mark.write_disk() +def test_write_csv_to_dangling_file_17328( + df_no_lists: pl.DataFrame, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + df_no_lists.write_csv((tmp_path / "dangling.csv").open("w")) + + +def test_write_csv_raise_on_non_utf8_17328( + df_no_lists: pl.DataFrame, tmp_path: Path +) -> None: + tmp_path.mkdir(exist_ok=True) + with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"): + df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) + + +@pytest.mark.write_disk() +def test_write_csv_appending_17328(tmp_path: Path) -> None: + tmp_path.mkdir(exist_ok=True) + with (tmp_path / "append.csv").open("w") as f: + f.write("# test\n") + pl.DataFrame({"col": ["value"]}).write_csv(f) + with (tmp_path / "append.csv").open("r") as f: + assert f.read() == "# test\ncol\nvalue\n"