Skip to content

Commit

Permalink
feat(rust, python): add strict parameter to decoding expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 20, 2023
1 parent 8eced68 commit 5cf1ff6
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 33 deletions.
45 changes: 32 additions & 13 deletions polars/polars-core/src/chunked_array/binary/encoding.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
use std::borrow::Cow;

use base64::engine::general_purpose;
use base64::Engine as _;
use hex;

use crate::prelude::*;

impl BinaryChunked {
pub fn hex_decode(&self) -> PolarsResult<BinaryChunked> {
self.try_apply(|s| {
let bytes =
hex::decode(s).map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
Ok(bytes.into())
})
pub fn hex_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
if strict {
self.try_apply(|s| {
let bytes = hex::decode(s).map_err(|_e| {
PolarsError::ComputeError(
"Invalid 'hex' encoding found. Try setting 'strict' to false to ignore."
.into(),
)
})?;
Ok(bytes.into())
})
} else {
Ok(self.apply_on_opt(|opt_s| opt_s.and_then(|s| hex::decode(s).ok().map(Cow::Owned))))
}
}

pub fn hex_encode(&self) -> Series {
Expand All @@ -19,13 +29,22 @@ impl BinaryChunked {
.unwrap()
}

pub fn base64_decode(&self) -> PolarsResult<BinaryChunked> {
self.try_apply(|s| {
let bytes = general_purpose::STANDARD
.decode(s)
.map_err(|e| PolarsError::ComputeError(e.to_string().into()))?;
Ok(bytes.into())
})
pub fn base64_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
if strict {
self.try_apply(|s| {
let bytes = general_purpose::STANDARD.decode(s).map_err(|_e| {
PolarsError::ComputeError(
"Invalid 'base64' encoding found. Try setting 'strict' to false to ignore."
.into(),
)
})?;
Ok(bytes.into())
})
} else {
Ok(self.apply_on_opt(|opt_s| {
opt_s.and_then(|s| general_purpose::STANDARD.decode(s).ok().map(Cow::Owned))
}))
}
}

pub fn base64_encode(&self) -> Series {
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/chunked_array/strings/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ impl Utf8Chunked {
}

#[cfg(feature = "binary_encoding")]
pub fn hex_decode(&self) -> PolarsResult<BinaryChunked> {
pub fn hex_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
self.cast_unchecked(&DataType::Binary)?
.binary()?
.hex_decode()
.hex_decode(strict)
}

#[must_use]
Expand All @@ -28,10 +28,10 @@ impl Utf8Chunked {
}

#[cfg(feature = "binary_encoding")]
pub fn base64_decode(&self) -> PolarsResult<BinaryChunked> {
pub fn base64_decode(&self, strict: bool) -> PolarsResult<BinaryChunked> {
self.cast_unchecked(&DataType::Binary)?
.binary()?
.base64_decode()
.base64_decode(strict)
}

#[must_use]
Expand Down
9 changes: 6 additions & 3 deletions py-polars/polars/internals/expr/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,23 @@ def starts_with(self, sub: bytes) -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.binary_starts_with(sub))

def decode(self, encoding: TransferEncoding) -> pli.Expr:
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding : {'hex', 'base64'}
The encoding to use.
strict
Raise an error if the underlying value cannot be decoded,
otherwise mask out with a null value.
"""
if encoding == "hex":
return pli.wrap_expr(self._pyexpr.binary_hex_decode())
return pli.wrap_expr(self._pyexpr.binary_hex_decode(strict))
elif encoding == "base64":
return pli.wrap_expr(self._pyexpr.binary_base64_decode())
return pli.wrap_expr(self._pyexpr.binary_base64_decode(strict))
else:
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
Expand Down
9 changes: 6 additions & 3 deletions py-polars/polars/internals/expr/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,20 +665,23 @@ def json_path_match(self, json_path: str) -> pli.Expr:
"""
return pli.wrap_expr(self._pyexpr.str_json_path_match(json_path))

def decode(self, encoding: TransferEncoding) -> pli.Expr:
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Expr:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding : {'hex', 'base64'}
The encoding to use.
strict
Raise an error if the underlying value cannot be decoded,
otherwise mask out with a null value.
"""
if encoding == "hex":
return pli.wrap_expr(self._pyexpr.str_hex_decode())
return pli.wrap_expr(self._pyexpr.str_hex_decode(strict))
elif encoding == "base64":
return pli.wrap_expr(self._pyexpr.str_base64_decode())
return pli.wrap_expr(self._pyexpr.str_base64_decode(strict))
else:
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/series/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,17 @@ def starts_with(self, sub: bytes) -> pli.Series:
"""

def decode(self, encoding: TransferEncoding) -> pli.Series:
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding : {'hex', 'base64'}
The encoding to use.
strict
Raise an error if the underlying value cannot be decoded,
otherwise mask out with a null value.
"""

Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/internals/series/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,14 +266,17 @@ def starts_with(self, sub: str) -> pli.Series:
"""

def decode(self, encoding: TransferEncoding) -> pli.Series:
def decode(self, encoding: TransferEncoding, *, strict: bool = True) -> pli.Series:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding : {'hex', 'base64'}
The encoding to use.
strict
Raise an error if the underlying value cannot be decoded,
otherwise mask out with a null value.
"""

Expand Down
16 changes: 8 additions & 8 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,11 +739,11 @@ impl PyExpr {
.with_fmt("str.hex_encode")
.into()
}
pub fn str_hex_decode(&self) -> PyExpr {
pub fn str_hex_decode(&self, strict: bool) -> PyExpr {
self.clone()
.inner
.map(
move |s| s.utf8()?.hex_decode().map(|s| s.into_series()),
move |s| s.utf8()?.hex_decode(strict).map(|s| s.into_series()),
GetOutput::same_type(),
)
.with_fmt("str.hex_decode")
Expand All @@ -760,11 +760,11 @@ impl PyExpr {
.into()
}

pub fn str_base64_decode(&self) -> PyExpr {
pub fn str_base64_decode(&self, strict: bool) -> PyExpr {
self.clone()
.inner
.map(
move |s| s.utf8()?.base64_decode().map(|s| s.into_series()),
move |s| s.utf8()?.base64_decode(strict).map(|s| s.into_series()),
GetOutput::same_type(),
)
.with_fmt("str.base64_decode")
Expand All @@ -781,11 +781,11 @@ impl PyExpr {
.with_fmt("binary.hex_encode")
.into()
}
pub fn binary_hex_decode(&self) -> PyExpr {
pub fn binary_hex_decode(&self, strict: bool) -> PyExpr {
self.clone()
.inner
.map(
move |s| s.binary()?.hex_decode().map(|s| s.into_series()),
move |s| s.binary()?.hex_decode(strict).map(|s| s.into_series()),
GetOutput::same_type(),
)
.with_fmt("binary.hex_decode")
Expand All @@ -802,11 +802,11 @@ impl PyExpr {
.into()
}

pub fn binary_base64_decode(&self) -> PyExpr {
pub fn binary_base64_decode(&self, strict: bool) -> PyExpr {
self.clone()
.inner
.map(
move |s| s.binary()?.base64_decode().map(|s| s.into_series()),
move |s| s.binary()?.base64_decode(strict).map(|s| s.into_series()),
GetOutput::same_type(),
)
.with_fmt("binary.base64_decode")
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_utf8.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import polars as pl
import pytest


def test_min_max_agg_on_str() -> None:
Expand All @@ -22,3 +23,10 @@ def test_length_vs_nchars() -> None:
]
)
assert df.rows() == [("café", 5, 4), ("東京", 6, 2)]


def test_decode_strict() -> None:
df = pl.DataFrame({"strings": ["0IbQvTc3", "0J%2FQldCf0JA%3D", "0J%2FRgNC%2B0YHRgtC%2B"]})
assert df.select(pl.col("strings").str.decode("base64", strict=False)).to_dict(False) == {'strings': [b'\xd0\x86\xd0\xbd77', None, None]}
with pytest.raises(pl.ComputeError):
df.select(pl.col("strings").str.decode("base64", strict=True))

0 comments on commit 5cf1ff6

Please sign in to comment.