From 0a1e2ed0c318742b567da58ecbb3aba641fdeaae Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Fri, 20 Jan 2023 14:36:04 +0100 Subject: [PATCH] feat(rust, python): add strict parameter to decoding expressions --- .../src/chunked_array/binary/encoding.rs | 45 +++++++++++++------ .../src/chunked_array/strings/encoding.rs | 8 ++-- py-polars/polars/internals/expr/binary.py | 9 ++-- py-polars/polars/internals/expr/string.py | 9 ++-- py-polars/polars/internals/series/binary.py | 5 ++- py-polars/polars/internals/series/string.py | 5 ++- py-polars/src/lazy/dsl.rs | 16 +++---- py-polars/tests/unit/test_utf8.py | 13 ++++++ 8 files changed, 77 insertions(+), 33 deletions(-) diff --git a/polars/polars-core/src/chunked_array/binary/encoding.rs b/polars/polars-core/src/chunked_array/binary/encoding.rs index 740a850e7dce..3bae1e6ebdbc 100644 --- a/polars/polars-core/src/chunked_array/binary/encoding.rs +++ b/polars/polars-core/src/chunked_array/binary/encoding.rs @@ -1,3 +1,5 @@ +use std::borrow::Cow; + use base64::engine::general_purpose; use base64::Engine as _; use hex; @@ -5,12 +7,20 @@ use hex; use crate::prelude::*; impl BinaryChunked { - pub fn hex_decode(&self) -> PolarsResult { - self.try_apply(|s| { - let bytes = - hex::decode(s).map_err(|e| PolarsError::ComputeError(e.to_string().into()))?; - Ok(bytes.into()) - }) + pub fn hex_decode(&self, strict: bool) -> PolarsResult { + if strict { + self.try_apply(|s| { + let bytes = hex::decode(s).map_err(|_e| { + PolarsError::ComputeError( + "Invalid 'hex' encoding found. Try setting 'strict' to false to ignore." + .into(), + ) + })?; + Ok(bytes.into()) + }) + } else { + Ok(self.apply_on_opt(|opt_s| opt_s.and_then(|s| hex::decode(s).ok().map(Cow::Owned)))) + } } pub fn hex_encode(&self) -> Series { @@ -19,13 +29,22 @@ impl BinaryChunked { .unwrap() } - pub fn base64_decode(&self) -> PolarsResult { - self.try_apply(|s| { - let bytes = general_purpose::STANDARD - .decode(s) - .map_err(|e| PolarsError::ComputeError(e.to_string().into()))?; - Ok(bytes.into()) - }) + pub fn base64_decode(&self, strict: bool) -> PolarsResult { + if strict { + self.try_apply(|s| { + let bytes = general_purpose::STANDARD.decode(s).map_err(|_e| { + PolarsError::ComputeError( + "Invalid 'base64' encoding found. Try setting 'strict' to false to ignore." + .into(), + ) + })?; + Ok(bytes.into()) + }) + } else { + Ok(self.apply_on_opt(|opt_s| { + opt_s.and_then(|s| general_purpose::STANDARD.decode(s).ok().map(Cow::Owned)) + })) + } } pub fn base64_encode(&self) -> Series { diff --git a/polars/polars-core/src/chunked_array/strings/encoding.rs b/polars/polars-core/src/chunked_array/strings/encoding.rs index 93c1af545369..ca79ffc98418 100644 --- a/polars/polars-core/src/chunked_array/strings/encoding.rs +++ b/polars/polars-core/src/chunked_array/strings/encoding.rs @@ -11,10 +11,10 @@ impl Utf8Chunked { } #[cfg(feature = "binary_encoding")] - pub fn hex_decode(&self) -> PolarsResult { + pub fn hex_decode(&self, strict: bool) -> PolarsResult { self.cast_unchecked(&DataType::Binary)? .binary()? - .hex_decode() + .hex_decode(strict) } #[must_use] @@ -28,10 +28,10 @@ impl Utf8Chunked { } #[cfg(feature = "binary_encoding")] - pub fn base64_decode(&self) -> PolarsResult { + pub fn base64_decode(&self, strict: bool) -> PolarsResult { self.cast_unchecked(&DataType::Binary)? .binary()? - .base64_decode() + .base64_decode(strict) } #[must_use] diff --git a/py-polars/polars/internals/expr/binary.py b/py-polars/polars/internals/expr/binary.py index 4d2c12665fa7..ae63e6d82f0b 100644 --- a/py-polars/polars/internals/expr/binary.py +++ b/py-polars/polars/internals/expr/binary.py @@ -56,7 +56,7 @@ def starts_with(self, sub: bytes) -> pli.Expr: """ return pli.wrap_expr(self._pyexpr.binary_starts_with(sub)) - def decode(self, encoding: TransferEncoding) -> pli.Expr: + def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr: """ Decode a value using the provided encoding. @@ -64,12 +64,15 @@ def decode(self, encoding: TransferEncoding) -> pli.Expr: ---------- encoding : {'hex', 'base64'} The encoding to use. + strict + Raise an error if the underlying value cannot be decoded, + otherwise mask out with a null value. """ if encoding == "hex": - return pli.wrap_expr(self._pyexpr.binary_hex_decode()) + return pli.wrap_expr(self._pyexpr.binary_hex_decode(strict)) elif encoding == "base64": - return pli.wrap_expr(self._pyexpr.binary_base64_decode()) + return pli.wrap_expr(self._pyexpr.binary_base64_decode(strict)) else: raise ValueError( f"encoding must be one of {{'hex', 'base64'}}, got {encoding}" diff --git a/py-polars/polars/internals/expr/string.py b/py-polars/polars/internals/expr/string.py index 42832d9a1492..6a990bd1b90f 100644 --- a/py-polars/polars/internals/expr/string.py +++ b/py-polars/polars/internals/expr/string.py @@ -665,7 +665,7 @@ def json_path_match(self, json_path: str) -> pli.Expr: """ return pli.wrap_expr(self._pyexpr.str_json_path_match(json_path)) - def decode(self, encoding: TransferEncoding) -> pli.Expr: + def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr: """ Decode a value using the provided encoding. @@ -673,12 +673,15 @@ def decode(self, encoding: TransferEncoding) -> pli.Expr: ---------- encoding : {'hex', 'base64'} The encoding to use. + strict + Raise an error if the underlying value cannot be decoded, + otherwise mask out with a null value. """ if encoding == "hex": - return pli.wrap_expr(self._pyexpr.str_hex_decode()) + return pli.wrap_expr(self._pyexpr.str_hex_decode(strict)) elif encoding == "base64": - return pli.wrap_expr(self._pyexpr.str_base64_decode()) + return pli.wrap_expr(self._pyexpr.str_base64_decode(strict)) else: raise ValueError( f"encoding must be one of {{'hex', 'base64'}}, got {encoding}" diff --git a/py-polars/polars/internals/series/binary.py b/py-polars/polars/internals/series/binary.py index 067b1060b3e7..d90b0b03f2e7 100644 --- a/py-polars/polars/internals/series/binary.py +++ b/py-polars/polars/internals/series/binary.py @@ -56,7 +56,7 @@ def starts_with(self, sub: bytes) -> pli.Series: """ - def decode(self, encoding: TransferEncoding) -> pli.Series: + def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series: """ Decode a value using the provided encoding. @@ -64,6 +64,9 @@ def decode(self, encoding: TransferEncoding) -> pli.Series: ---------- encoding : {'hex', 'base64'} The encoding to use. + strict + Raise an error if the underlying value cannot be decoded, + otherwise mask out with a null value. """ diff --git a/py-polars/polars/internals/series/string.py b/py-polars/polars/internals/series/string.py index ca6e5db791f9..2a9d4bb0ae7f 100644 --- a/py-polars/polars/internals/series/string.py +++ b/py-polars/polars/internals/series/string.py @@ -266,7 +266,7 @@ def starts_with(self, sub: str) -> pli.Series: """ - def decode(self, encoding: TransferEncoding) -> pli.Series: + def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series: """ Decode a value using the provided encoding. @@ -274,6 +274,9 @@ def decode(self, encoding: TransferEncoding) -> pli.Series: ---------- encoding : {'hex', 'base64'} The encoding to use. + strict + Raise an error if the underlying value cannot be decoded, + otherwise mask out with a null value. """ diff --git a/py-polars/src/lazy/dsl.rs b/py-polars/src/lazy/dsl.rs index 528bcd283641..350238b84f7b 100644 --- a/py-polars/src/lazy/dsl.rs +++ b/py-polars/src/lazy/dsl.rs @@ -739,11 +739,11 @@ impl PyExpr { .with_fmt("str.hex_encode") .into() } - pub fn str_hex_decode(&self) -> PyExpr { + pub fn str_hex_decode(&self, strict: bool) -> PyExpr { self.clone() .inner .map( - move |s| s.utf8()?.hex_decode().map(|s| s.into_series()), + move |s| s.utf8()?.hex_decode(strict).map(|s| s.into_series()), GetOutput::same_type(), ) .with_fmt("str.hex_decode") @@ -760,11 +760,11 @@ impl PyExpr { .into() } - pub fn str_base64_decode(&self) -> PyExpr { + pub fn str_base64_decode(&self, strict: bool) -> PyExpr { self.clone() .inner .map( - move |s| s.utf8()?.base64_decode().map(|s| s.into_series()), + move |s| s.utf8()?.base64_decode(strict).map(|s| s.into_series()), GetOutput::same_type(), ) .with_fmt("str.base64_decode") @@ -781,11 +781,11 @@ impl PyExpr { .with_fmt("binary.hex_encode") .into() } - pub fn binary_hex_decode(&self) -> PyExpr { + pub fn binary_hex_decode(&self, strict: bool) -> PyExpr { self.clone() .inner .map( - move |s| s.binary()?.hex_decode().map(|s| s.into_series()), + move |s| s.binary()?.hex_decode(strict).map(|s| s.into_series()), GetOutput::same_type(), ) .with_fmt("binary.hex_decode") @@ -802,11 +802,11 @@ impl PyExpr { .into() } - pub fn binary_base64_decode(&self) -> PyExpr { + pub fn binary_base64_decode(&self, strict: bool) -> PyExpr { self.clone() .inner .map( - move |s| s.binary()?.base64_decode().map(|s| s.into_series()), + move |s| s.binary()?.base64_decode(strict).map(|s| s.into_series()), GetOutput::same_type(), ) .with_fmt("binary.base64_decode") diff --git a/py-polars/tests/unit/test_utf8.py b/py-polars/tests/unit/test_utf8.py index 4ca0cafe32d7..3a79e53bdebd 100644 --- a/py-polars/tests/unit/test_utf8.py +++ b/py-polars/tests/unit/test_utf8.py @@ -1,3 +1,5 @@ +import pytest + import polars as pl @@ -22,3 +24,14 @@ def test_length_vs_nchars() -> None: ] ) assert df.rows() == [("café", 5, 4), ("東京", 6, 2)] + + +def test_decode_strict() -> None: + df = pl.DataFrame( + {"strings": ["0IbQvTc3", "0J%2FQldCf0JA%3D", "0J%2FRgNC%2B0YHRgtC%2B"]} + ) + assert df.select(pl.col("strings").str.decode("base64", strict=False)).to_dict( + False + ) == {"strings": [b"\xd0\x86\xd0\xbd77", None, None]} + with pytest.raises(pl.ComputeError): + df.select(pl.col("strings").str.decode("base64", strict=True))