From 1ac5ba5709d5089cfc73e7337d8b4cf3046421c1 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 22 Oct 2024 22:04:11 +0800 Subject: [PATCH 01/10] move utf8 from dsl to functions --- daft/daft/__init__.pyi | 60 +- daft/expressions/expressions.py | 56 +- src/daft-dsl/src/functions/mod.rs | 8 +- src/daft-dsl/src/functions/utf8/capitalize.rs | 41 - src/daft-dsl/src/functions/utf8/endswith.rs | 45 -- src/daft-dsl/src/functions/utf8/extract.rs | 51 -- .../src/functions/utf8/extract_all.rs | 51 -- src/daft-dsl/src/functions/utf8/length.rs | 41 - .../src/functions/utf8/length_bytes.rs | 41 - src/daft-dsl/src/functions/utf8/lower.rs | 41 - src/daft-dsl/src/functions/utf8/lstrip.rs | 41 - src/daft-dsl/src/functions/utf8/mod.rs | 356 --------- src/daft-dsl/src/functions/utf8/normalize.rs | 47 -- src/daft-dsl/src/functions/utf8/reverse.rs | 41 - src/daft-dsl/src/functions/utf8/rstrip.rs | 41 - src/daft-dsl/src/functions/utf8/split.rs | 51 -- src/daft-dsl/src/functions/utf8/to_date.rs | 47 -- .../src/functions/utf8/to_datetime.rs | 66 -- src/daft-dsl/src/functions/utf8/upper.rs | 41 - src/daft-dsl/src/python.rs | 160 ---- src/daft-functions/src/lib.rs | 2 + src/daft-functions/src/utf8/capitalize.rs | 68 ++ .../src}/utf8/contains.rs | 47 +- .../src/utf8/endswith.rs} | 47 +- src/daft-functions/src/utf8/extract.rs | 74 ++ src/daft-functions/src/utf8/extract_all.rs | 74 ++ .../src}/utf8/find.rs | 47 +- .../src}/utf8/ilike.rs | 47 +- .../src}/utf8/left.rs | 47 +- src/daft-functions/src/utf8/length.rs | 68 ++ src/daft-functions/src/utf8/length_bytes.rs | 68 ++ .../src}/utf8/like.rs | 47 +- src/daft-functions/src/utf8/lower.rs | 68 ++ .../src}/utf8/lpad.rs | 47 +- src/daft-functions/src/utf8/lstrip.rs | 68 ++ .../src}/utf8/match_.rs | 47 +- src/daft-functions/src/utf8/mod.rs | 111 +++ src/daft-functions/src/utf8/normalize.rs | 87 +++ .../src}/utf8/repeat.rs | 47 +- .../src}/utf8/replace.rs | 67 +- src/daft-functions/src/utf8/reverse.rs | 68 ++ .../src}/utf8/right.rs | 47 +- .../src}/utf8/rpad.rs | 47 +- src/daft-functions/src/utf8/rstrip.rs | 68 ++ src/daft-functions/src/utf8/split.rs | 74 ++ src/daft-functions/src/utf8/startswith.rs | 72 ++ .../src}/utf8/substr.rs | 47 +- src/daft-functions/src/utf8/to_date.rs | 77 ++ src/daft-functions/src/utf8/to_datetime.rs | 90 +++ src/daft-functions/src/utf8/upper.rs | 68 ++ src/daft-logical-plan/src/display.rs | 7 +- src/daft-sql/src/modules/utf8.rs | 727 +++++++++++------- src/daft-sql/src/planner.rs | 9 +- 53 files changed, 2229 insertions(+), 1676 deletions(-) delete mode 100644 src/daft-dsl/src/functions/utf8/capitalize.rs delete mode 100644 src/daft-dsl/src/functions/utf8/endswith.rs delete mode 100644 src/daft-dsl/src/functions/utf8/extract.rs delete mode 100644 src/daft-dsl/src/functions/utf8/extract_all.rs delete mode 100644 src/daft-dsl/src/functions/utf8/length.rs delete mode 100644 src/daft-dsl/src/functions/utf8/length_bytes.rs delete mode 100644 src/daft-dsl/src/functions/utf8/lower.rs delete mode 100644 src/daft-dsl/src/functions/utf8/lstrip.rs delete mode 100644 src/daft-dsl/src/functions/utf8/mod.rs delete mode 100644 src/daft-dsl/src/functions/utf8/normalize.rs delete mode 100644 src/daft-dsl/src/functions/utf8/reverse.rs delete mode 100644 src/daft-dsl/src/functions/utf8/rstrip.rs delete mode 100644 src/daft-dsl/src/functions/utf8/split.rs delete mode 100644 src/daft-dsl/src/functions/utf8/to_date.rs delete mode 100644 src/daft-dsl/src/functions/utf8/to_datetime.rs delete mode 100644 src/daft-dsl/src/functions/utf8/upper.rs create mode 100644 src/daft-functions/src/utf8/capitalize.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/contains.rs (52%) rename src/{daft-dsl/src/functions/utf8/startswith.rs => daft-functions/src/utf8/endswith.rs} (52%) create mode 100644 src/daft-functions/src/utf8/extract.rs create mode 100644 src/daft-functions/src/utf8/extract_all.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/find.rs (53%) rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/ilike.rs (53%) rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/left.rs (53%) create mode 100644 src/daft-functions/src/utf8/length.rs create mode 100644 src/daft-functions/src/utf8/length_bytes.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/like.rs (53%) create mode 100644 src/daft-functions/src/utf8/lower.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/lpad.rs (52%) create mode 100644 src/daft-functions/src/utf8/lstrip.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/match_.rs (52%) create mode 100644 src/daft-functions/src/utf8/mod.rs create mode 100644 src/daft-functions/src/utf8/normalize.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/repeat.rs (53%) rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/replace.rs (50%) create mode 100644 src/daft-functions/src/utf8/reverse.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/right.rs (53%) rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/rpad.rs (52%) create mode 100644 src/daft-functions/src/utf8/rstrip.rs create mode 100644 src/daft-functions/src/utf8/split.rs create mode 100644 src/daft-functions/src/utf8/startswith.rs rename src/{daft-dsl/src/functions => daft-functions/src}/utf8/substr.rs (53%) create mode 100644 src/daft-functions/src/utf8/to_date.rs create mode 100644 src/daft-functions/src/utf8/to_datetime.rs create mode 100644 src/daft-functions/src/utf8/upper.rs diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 3598a4042d..7c650ac32d 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1095,34 +1095,6 @@ class PyExpr: def __repr__(self) -> str: ... def __hash__(self) -> int: ... def __reduce__(self) -> tuple: ... - def utf8_endswith(self, pattern: PyExpr) -> PyExpr: ... - def utf8_startswith(self, pattern: PyExpr) -> PyExpr: ... - def utf8_contains(self, pattern: PyExpr) -> PyExpr: ... - def utf8_match(self, pattern: PyExpr) -> PyExpr: ... - def utf8_split(self, pattern: PyExpr, regex: bool) -> PyExpr: ... - def utf8_extract(self, pattern: PyExpr, index: int) -> PyExpr: ... - def utf8_extract_all(self, pattern: PyExpr, index: int) -> PyExpr: ... - def utf8_replace(self, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... - def utf8_length(self) -> PyExpr: ... - def utf8_length_bytes(self) -> PyExpr: ... - def utf8_lower(self) -> PyExpr: ... - def utf8_upper(self) -> PyExpr: ... - def utf8_lstrip(self) -> PyExpr: ... - def utf8_rstrip(self) -> PyExpr: ... - def utf8_reverse(self) -> PyExpr: ... - def utf8_capitalize(self) -> PyExpr: ... - def utf8_left(self, nchars: PyExpr) -> PyExpr: ... - def utf8_right(self, nchars: PyExpr) -> PyExpr: ... - def utf8_find(self, substr: PyExpr) -> PyExpr: ... - def utf8_rpad(self, length: PyExpr, pad: PyExpr) -> PyExpr: ... - def utf8_lpad(self, length: PyExpr, pad: PyExpr) -> PyExpr: ... - def utf8_repeat(self, n: PyExpr) -> PyExpr: ... - def utf8_like(self, pattern: PyExpr) -> PyExpr: ... - def utf8_ilike(self, pattern: PyExpr) -> PyExpr: ... - def utf8_substr(self, start: PyExpr, length: PyExpr) -> PyExpr: ... - def utf8_to_date(self, format: str) -> PyExpr: ... - def utf8_to_datetime(self, format: str, timezone: str | None = None) -> PyExpr: ... - def utf8_normalize(self, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def map_get(self, key: PyExpr) -> PyExpr: ... def partitioning_days(self) -> PyExpr: ... @@ -1320,6 +1292,38 @@ def list_max(expr: PyExpr) -> PyExpr: ... def list_slice(expr: PyExpr, start: PyExpr, end: PyExpr | None = None) -> PyExpr: ... def list_chunk(expr: PyExpr, size: int) -> PyExpr: ... +# --- +# expr.utf8 namespace +# --- +def utf8_endswith(pattern: PyExpr) -> PyExpr: ... +def utf8_startswith(pattern: PyExpr) -> PyExpr: ... +def utf8_contains(pattern: PyExpr) -> PyExpr: ... +def utf8_match(pattern: PyExpr) -> PyExpr: ... +def utf8_split(pattern: PyExpr, regex: bool) -> PyExpr: ... +def utf8_extract(pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_extract_all(pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_replace(pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... +def utf8_length() -> PyExpr: ... +def utf8_length_bytes() -> PyExpr: ... +def utf8_lower() -> PyExpr: ... +def utf8_upper() -> PyExpr: ... +def utf8_lstrip() -> PyExpr: ... +def utf8_rstrip() -> PyExpr: ... +def utf8_reverse() -> PyExpr: ... +def utf8_capitalize() -> PyExpr: ... +def utf8_left(nchars: PyExpr) -> PyExpr: ... +def utf8_right(nchars: PyExpr) -> PyExpr: ... +def utf8_find(substr: PyExpr) -> PyExpr: ... +def utf8_rpad(length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_lpad(length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_repeat(n: PyExpr) -> PyExpr: ... +def utf8_like(pattern: PyExpr) -> PyExpr: ... +def utf8_ilike(pattern: PyExpr) -> PyExpr: ... +def utf8_substr(start: PyExpr, length: PyExpr) -> PyExpr: ... +def utf8_to_date(format: str) -> PyExpr: ... +def utf8_to_datetime(format: str, timezone: str | None = None) -> PyExpr: ... +def utf8_normalize(remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... + class PyCatalog: @staticmethod def new() -> PyCatalog: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 44bbc302e8..df75ec5002 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1887,7 +1887,7 @@ def contains(self, substr: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value contains the provided pattern """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(self._expr.utf8_contains(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_contains(substr_expr._expr)) def match(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given regular expression pattern in a string column @@ -1917,7 +1917,7 @@ def match(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_match(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_match(pattern_expr._expr)) def endswith(self, suffix: str | Expression) -> Expression: """Checks whether each string ends with the given pattern in a string column @@ -1947,7 +1947,7 @@ def endswith(self, suffix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value ends with the provided pattern """ suffix_expr = Expression._to_expression(suffix) - return Expression._from_pyexpr(self._expr.utf8_endswith(suffix_expr._expr)) + return Expression._from_pyexpr(native.utf8_endswith(suffix_expr._expr)) def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column @@ -1977,7 +1977,7 @@ def startswith(self, prefix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value starts with the provided pattern """ prefix_expr = Expression._to_expression(prefix) - return Expression._from_pyexpr(self._expr.utf8_startswith(prefix_expr._expr)) + return Expression._from_pyexpr(native.utf8_startswith(prefix_expr._expr)) def split(self, pattern: str | Expression, regex: bool = False) -> Expression: r"""Splits each string on the given literal or regex pattern, into a list of strings. @@ -2028,7 +2028,7 @@ def split(self, pattern: str | Expression, regex: bool = False) -> Expression: Expression: A List[Utf8] expression containing the string splits for each string in the column. """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_split(pattern_expr._expr, regex)) + return Expression._from_pyexpr(native.utf8_split(pattern_expr._expr, regex)) def concat(self, other: str | Expression) -> Expression: """Concatenates two string expressions together @@ -2119,7 +2119,7 @@ def extract(self, pattern: str | Expression, index: int = 0) -> Expression: `extract_all` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_extract(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract(pattern_expr._expr, index)) def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: r"""Extracts the specified match group from all regex matches in each string in a string column. @@ -2175,7 +2175,7 @@ def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: `extract` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_extract_all(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract_all(pattern_expr._expr, index)) def replace( self, @@ -2232,7 +2232,7 @@ def replace( """ pattern_expr = Expression._to_expression(pattern) replacement_expr = Expression._to_expression(replacement) - return Expression._from_pyexpr(self._expr.utf8_replace(pattern_expr._expr, replacement_expr._expr, regex)) + return Expression._from_pyexpr(native.utf8_replace(pattern_expr._expr, replacement_expr._expr, regex)) def length(self) -> Expression: """Retrieves the length for a UTF-8 string column @@ -2259,7 +2259,7 @@ def length(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(self._expr.utf8_length()) + return Expression._from_pyexpr(native.utf8_length()) def length_bytes(self) -> Expression: """Retrieves the length for a UTF-8 string column in bytes. @@ -2286,7 +2286,7 @@ def length_bytes(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(self._expr.utf8_length_bytes()) + return Expression._from_pyexpr(native.utf8_length_bytes()) def lower(self) -> Expression: """Convert UTF-8 string to all lowercase @@ -2313,7 +2313,7 @@ def lower(self) -> Expression: Returns: Expression: a String expression which is `self` lowercased """ - return Expression._from_pyexpr(self._expr.utf8_lower()) + return Expression._from_pyexpr(native.utf8_lower()) def upper(self) -> Expression: """Convert UTF-8 string to all upper @@ -2340,7 +2340,7 @@ def upper(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased """ - return Expression._from_pyexpr(self._expr.utf8_upper()) + return Expression._from_pyexpr(native.utf8_upper()) def lstrip(self) -> Expression: """Strip whitespace from the left side of a UTF-8 string @@ -2367,7 +2367,7 @@ def lstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with leading whitespace stripped """ - return Expression._from_pyexpr(self._expr.utf8_lstrip()) + return Expression._from_pyexpr(native.utf8_lstrip()) def rstrip(self) -> Expression: """Strip whitespace from the right side of a UTF-8 string @@ -2394,7 +2394,7 @@ def rstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with trailing whitespace stripped """ - return Expression._from_pyexpr(self._expr.utf8_rstrip()) + return Expression._from_pyexpr(native.utf8_rstrip()) def reverse(self) -> Expression: """Reverse a UTF-8 string @@ -2421,7 +2421,7 @@ def reverse(self) -> Expression: Returns: Expression: a String expression which is `self` reversed """ - return Expression._from_pyexpr(self._expr.utf8_reverse()) + return Expression._from_pyexpr(native.utf8_reverse()) def capitalize(self) -> Expression: """Capitalize a UTF-8 string @@ -2448,7 +2448,7 @@ def capitalize(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased with the first character and lowercased the rest """ - return Expression._from_pyexpr(self._expr.utf8_capitalize()) + return Expression._from_pyexpr(native.utf8_capitalize()) def left(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) left-most characters of each string @@ -2476,7 +2476,7 @@ def left(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` left-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(self._expr.utf8_left(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_left(nchars_expr._expr)) def right(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) right-most characters of each string @@ -2504,7 +2504,7 @@ def right(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` right-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(self._expr.utf8_right(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_right(nchars_expr._expr)) def find(self, substr: str | Expression) -> Expression: """Returns the index of the first occurrence of the substring in each string @@ -2536,7 +2536,7 @@ def find(self, substr: str | Expression) -> Expression: Expression: an Int64 expression with the index of the first occurrence of the substring in each string """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(self._expr.utf8_find(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_find(substr_expr._expr)) def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Right-pads each string by truncating or padding with the character @@ -2569,7 +2569,7 @@ def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(self._expr.utf8_rpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_rpad(length_expr._expr, pad_expr._expr)) def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Left-pads each string by truncating on the right or padding with the character @@ -2602,7 +2602,7 @@ def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(self._expr.utf8_lpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_lpad(length_expr._expr, pad_expr._expr)) def repeat(self, n: int | Expression) -> Expression: """Repeats each string n times @@ -2630,7 +2630,7 @@ def repeat(self, n: int | Expression) -> Expression: Expression: a String expression which is `self` repeated `n` times """ n_expr = Expression._to_expression(n) - return Expression._from_pyexpr(self._expr.utf8_repeat(n_expr._expr)) + return Expression._from_pyexpr(native.utf8_repeat(n_expr._expr)) def like(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case sensitive @@ -2661,7 +2661,7 @@ def like(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_like(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_like(pattern_expr._expr)) def ilike(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case insensitive @@ -2692,7 +2692,7 @@ def ilike(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(self._expr.utf8_ilike(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_ilike(pattern_expr._expr)) def substr(self, start: int | Expression, length: int | Expression | None = None) -> Expression: """Extract a substring from a string, starting at a specified index and extending for a given length. @@ -2724,7 +2724,7 @@ def substr(self, start: int | Expression, length: int | Expression | None = None """ start_expr = Expression._to_expression(start) length_expr = Expression._to_expression(length) - return Expression._from_pyexpr(self._expr.utf8_substr(start_expr._expr, length_expr._expr)) + return Expression._from_pyexpr(native.utf8_substr(start_expr._expr, length_expr._expr)) def to_date(self, format: str) -> Expression: """Converts a string to a date using the specified format @@ -2755,7 +2755,7 @@ def to_date(self, format: str) -> Expression: Returns: Expression: a Date expression which is parsed by given format """ - return Expression._from_pyexpr(self._expr.utf8_to_date(format)) + return Expression._from_pyexpr(native.utf8_to_date(format)) def to_datetime(self, format: str, timezone: str | None = None) -> Expression: """Converts a string to a datetime using the specified format and timezone @@ -2805,7 +2805,7 @@ def to_datetime(self, format: str, timezone: str | None = None) -> Expression: Returns: Expression: a DateTime expression which is parsed by given format and timezone """ - return Expression._from_pyexpr(self._expr.utf8_to_datetime(format, timezone)) + return Expression._from_pyexpr(native.utf8_to_datetime(format, timezone)) def normalize( self, @@ -2849,7 +2849,7 @@ def normalize( Returns: Expression: a String expression which is normalized. """ - return Expression._from_pyexpr(self._expr.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) + return Expression._from_pyexpr(native.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) def tokenize_encode( self, diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 48a7751197..33962c0cc9 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -5,7 +5,6 @@ pub mod python; pub mod scalar; pub mod sketch; pub mod struct_; -pub mod utf8; use std::{ fmt::{Display, Formatter, Result, Write}, @@ -18,15 +17,11 @@ use python::PythonUDF; pub use scalar::*; use serde::{Deserialize, Serialize}; -use self::{ - map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr, - utf8::Utf8Expr, -}; +use self::{map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr}; use crate::{Expr, ExprRef, Operator}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { - Utf8(Utf8Expr), Map(MapExpr), Sketch(SketchExpr), Struct(StructExpr), @@ -49,7 +44,6 @@ impl FunctionExpr { #[inline] fn get_evaluator(&self) -> &dyn FunctionEvaluator { match self { - Self::Utf8(expr) => expr.get_evaluator(), Self::Map(expr) => expr.get_evaluator(), Self::Sketch(expr) => expr.get_evaluator(), Self::Struct(expr) => expr.get_evaluator(), diff --git a/src/daft-dsl/src/functions/utf8/capitalize.rs b/src/daft-dsl/src/functions/utf8/capitalize.rs deleted file mode 100644 index caa3c25359..0000000000 --- a/src/daft-dsl/src/functions/utf8/capitalize.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct CapitalizeEvaluator {} - -impl FunctionEvaluator for CapitalizeEvaluator { - fn fn_name(&self) -> &'static str { - "capitalize" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to capitalize to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_capitalize(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/endswith.rs b/src/daft-dsl/src/functions/utf8/endswith.rs deleted file mode 100644 index 5785f92257..0000000000 --- a/src/daft-dsl/src/functions/utf8/endswith.rs +++ /dev/null @@ -1,45 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct EndswithEvaluator {} - -impl FunctionEvaluator for EndswithEvaluator { - fn fn_name(&self) -> &'static str { - "endswith" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::Boolean)) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to endswith to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => data.utf8_endswith(pattern), - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/extract.rs b/src/daft-dsl/src/functions/utf8/extract.rs deleted file mode 100644 index abe9d4df16..0000000000 --- a/src/daft-dsl/src/functions/utf8/extract.rs +++ /dev/null @@ -1,51 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ExtractEvaluator {} - -impl FunctionEvaluator for ExtractEvaluator { - fn fn_name(&self) -> &'static str { - "extract" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::Utf8)) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to extract to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => { - let index = match expr { - FunctionExpr::Utf8(Utf8Expr::Extract(index)) => index, - _ => panic!("Expected Utf8 Extract Expr, got {expr}"), - }; - data.utf8_extract(pattern, *index) - } - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/extract_all.rs b/src/daft-dsl/src/functions/utf8/extract_all.rs deleted file mode 100644 index e2395e8c19..0000000000 --- a/src/daft-dsl/src/functions/utf8/extract_all.rs +++ /dev/null @@ -1,51 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ExtractAllEvaluator {} - -impl FunctionEvaluator for ExtractAllEvaluator { - fn fn_name(&self) -> &'static str { - "extractall" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to extractAll to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => { - let index = match expr { - FunctionExpr::Utf8(Utf8Expr::ExtractAll(index)) => index, - _ => panic!("Expected Utf8 ExtractAll Expr, got {expr}"), - }; - data.utf8_extract_all(pattern, *index) - } - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/length.rs b/src/daft-dsl/src/functions/utf8/length.rs deleted file mode 100644 index 9f4729ac76..0000000000 --- a/src/daft-dsl/src/functions/utf8/length.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LengthEvaluator {} - -impl FunctionEvaluator for LengthEvaluator { - fn fn_name(&self) -> &'static str { - "length" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), - _ => Err(DaftError::TypeError(format!( - "Expects input to length to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_length(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/length_bytes.rs b/src/daft-dsl/src/functions/utf8/length_bytes.rs deleted file mode 100644 index cdf0af383a..0000000000 --- a/src/daft-dsl/src/functions/utf8/length_bytes.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LengthBytesEvaluator {} - -impl FunctionEvaluator for LengthBytesEvaluator { - fn fn_name(&self) -> &'static str { - "length_bytes" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), - _ => Err(DaftError::TypeError(format!( - "Expects input to length_bytes to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_length_bytes(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/lower.rs b/src/daft-dsl/src/functions/utf8/lower.rs deleted file mode 100644 index f3fd7a8c47..0000000000 --- a/src/daft-dsl/src/functions/utf8/lower.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LowerEvaluator {} - -impl FunctionEvaluator for LowerEvaluator { - fn fn_name(&self) -> &'static str { - "lower" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to lower to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_lower(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/lstrip.rs b/src/daft-dsl/src/functions/utf8/lstrip.rs deleted file mode 100644 index 534aa1cd37..0000000000 --- a/src/daft-dsl/src/functions/utf8/lstrip.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct LstripEvaluator {} - -impl FunctionEvaluator for LstripEvaluator { - fn fn_name(&self) -> &'static str { - "lstrip" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to lstrip to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_lstrip(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/mod.rs b/src/daft-dsl/src/functions/utf8/mod.rs deleted file mode 100644 index 7a795250ff..0000000000 --- a/src/daft-dsl/src/functions/utf8/mod.rs +++ /dev/null @@ -1,356 +0,0 @@ -mod capitalize; -mod contains; -mod endswith; -mod extract; -mod extract_all; -mod find; -mod ilike; -mod left; -mod length; -mod length_bytes; -mod like; -mod lower; -mod lpad; -mod lstrip; -mod match_; -mod normalize; -mod repeat; -mod replace; -mod reverse; -mod right; -mod rpad; -mod rstrip; -mod split; -mod startswith; -mod substr; -mod to_date; -mod to_datetime; -mod upper; - -use capitalize::CapitalizeEvaluator; -use contains::ContainsEvaluator; -use daft_core::array::ops::Utf8NormalizeOptions; -use endswith::EndswithEvaluator; -use extract::ExtractEvaluator; -use extract_all::ExtractAllEvaluator; -use find::FindEvaluator; -use ilike::IlikeEvaluator; -use left::LeftEvaluator; -use length::LengthEvaluator; -use length_bytes::LengthBytesEvaluator; -use like::LikeEvaluator; -use lower::LowerEvaluator; -use lpad::LpadEvaluator; -use lstrip::LstripEvaluator; -use normalize::NormalizeEvaluator; -use repeat::RepeatEvaluator; -use replace::ReplaceEvaluator; -use reverse::ReverseEvaluator; -use right::RightEvaluator; -use rpad::RpadEvaluator; -use rstrip::RstripEvaluator; -use serde::{Deserialize, Serialize}; -use split::SplitEvaluator; -use startswith::StartswithEvaluator; -use substr::SubstrEvaluator; -use to_date::ToDateEvaluator; -use to_datetime::ToDatetimeEvaluator; -use upper::UpperEvaluator; - -use super::FunctionEvaluator; -use crate::{functions::utf8::match_::MatchEvaluator, Expr, ExprRef}; - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] -pub enum Utf8Expr { - EndsWith, - StartsWith, - Contains, - Split(bool), - Match, - Extract(usize), - ExtractAll(usize), - Replace(bool), - Length, - LengthBytes, - Lower, - Upper, - Lstrip, - Rstrip, - Reverse, - Capitalize, - Left, - Right, - Find, - Rpad, - Lpad, - Repeat, - Like, - Ilike, - Substr, - ToDate(String), - ToDatetime(String, Option), - Normalize(Utf8NormalizeOptions), -} - -impl Utf8Expr { - #[inline] - pub fn get_evaluator(&self) -> &dyn FunctionEvaluator { - match self { - Self::EndsWith => &EndswithEvaluator {}, - Self::StartsWith => &StartswithEvaluator {}, - Self::Contains => &ContainsEvaluator {}, - Self::Split(_) => &SplitEvaluator {}, - Self::Match => &MatchEvaluator {}, - Self::Extract(_) => &ExtractEvaluator {}, - Self::ExtractAll(_) => &ExtractAllEvaluator {}, - Self::Replace(_) => &ReplaceEvaluator {}, - Self::Length => &LengthEvaluator {}, - Self::LengthBytes => &LengthBytesEvaluator {}, - Self::Lower => &LowerEvaluator {}, - Self::Upper => &UpperEvaluator {}, - Self::Lstrip => &LstripEvaluator {}, - Self::Rstrip => &RstripEvaluator {}, - Self::Reverse => &ReverseEvaluator {}, - Self::Capitalize => &CapitalizeEvaluator {}, - Self::Left => &LeftEvaluator {}, - Self::Right => &RightEvaluator {}, - Self::Find => &FindEvaluator {}, - Self::Rpad => &RpadEvaluator {}, - Self::Lpad => &LpadEvaluator {}, - Self::Repeat => &RepeatEvaluator {}, - Self::Like => &LikeEvaluator {}, - Self::Ilike => &IlikeEvaluator {}, - Self::Substr => &SubstrEvaluator {}, - Self::ToDate(_) => &ToDateEvaluator {}, - Self::ToDatetime(_, _) => &ToDatetimeEvaluator {}, - Self::Normalize(_) => &NormalizeEvaluator {}, - } - } -} - -pub fn endswith(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::EndsWith), - inputs: vec![data, pattern], - } - .into() -} - -pub fn startswith(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::StartsWith), - inputs: vec![data, pattern], - } - .into() -} - -pub fn contains(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Contains), - inputs: vec![data, pattern], - } - .into() -} - -pub fn match_(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Match), - inputs: vec![data, pattern], - } - .into() -} - -pub fn split(data: ExprRef, pattern: ExprRef, regex: bool) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Split(regex)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn extract(data: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Extract(index)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn extract_all(data: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ExtractAll(index)), - inputs: vec![data, pattern], - } - .into() -} - -pub fn replace(data: ExprRef, pattern: ExprRef, replacement: ExprRef, regex: bool) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Replace(regex)), - inputs: vec![data, pattern, replacement], - } - .into() -} - -pub fn length(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Length), - inputs: vec![data], - } - .into() -} - -pub fn length_bytes(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::LengthBytes), - inputs: vec![data], - } - .into() -} - -pub fn lower(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lower), - inputs: vec![data], - } - .into() -} - -pub fn upper(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Upper), - inputs: vec![data], - } - .into() -} - -pub fn lstrip(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lstrip), - inputs: vec![data], - } - .into() -} - -pub fn rstrip(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Rstrip), - inputs: vec![data], - } - .into() -} - -pub fn reverse(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Reverse), - inputs: vec![data], - } - .into() -} - -pub fn capitalize(data: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Capitalize), - inputs: vec![data], - } - .into() -} - -pub fn left(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Left), - inputs: vec![data, count], - } - .into() -} - -pub fn right(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Right), - inputs: vec![data, count], - } - .into() -} - -pub fn find(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Find), - inputs: vec![data, pattern], - } - .into() -} - -pub fn rpad(data: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Rpad), - inputs: vec![data, length, pad], - } - .into() -} - -pub fn lpad(data: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Lpad), - inputs: vec![data, length, pad], - } - .into() -} - -pub fn repeat(data: ExprRef, count: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Repeat), - inputs: vec![data, count], - } - .into() -} - -pub fn like(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Like), - inputs: vec![data, pattern], - } - .into() -} - -pub fn ilike(data: ExprRef, pattern: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Ilike), - inputs: vec![data, pattern], - } - .into() -} - -pub fn substr(data: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Substr), - inputs: vec![data, start, length], - } - .into() -} - -pub fn to_date(data: ExprRef, format: &str) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ToDate(format.to_string())), - inputs: vec![data], - } - .into() -} - -pub fn to_datetime(data: ExprRef, format: &str, timezone: Option<&str>) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::ToDatetime( - format.to_string(), - timezone.map(|s| s.to_string()), - )), - inputs: vec![data], - } - .into() -} - -pub fn normalize(data: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef { - Expr::Function { - func: super::FunctionExpr::Utf8(Utf8Expr::Normalize(opts)), - inputs: vec![data], - } - .into() -} diff --git a/src/daft-dsl/src/functions/utf8/normalize.rs b/src/daft-dsl/src/functions/utf8/normalize.rs deleted file mode 100644 index b693e2c017..0000000000 --- a/src/daft-dsl/src/functions/utf8/normalize.rs +++ /dev/null @@ -1,47 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct NormalizeEvaluator {} - -impl FunctionEvaluator for NormalizeEvaluator { - fn fn_name(&self) -> &'static str { - "normalize" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to normalize to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let opts = match expr { - FunctionExpr::Utf8(Utf8Expr::Normalize(opts)) => opts, - _ => panic!("Expected Utf8 Normalize Expr, got {expr}"), - }; - data.utf8_normalize(*opts) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/reverse.rs b/src/daft-dsl/src/functions/utf8/reverse.rs deleted file mode 100644 index cff9363a82..0000000000 --- a/src/daft-dsl/src/functions/utf8/reverse.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ReverseEvaluator {} - -impl FunctionEvaluator for ReverseEvaluator { - fn fn_name(&self) -> &'static str { - "reverse" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to reverse to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_reverse(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/rstrip.rs b/src/daft-dsl/src/functions/utf8/rstrip.rs deleted file mode 100644 index c138d4c86c..0000000000 --- a/src/daft-dsl/src/functions/utf8/rstrip.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct RstripEvaluator {} - -impl FunctionEvaluator for RstripEvaluator { - fn fn_name(&self) -> &'static str { - "rstrip" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to rstrip to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_rstrip(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/split.rs b/src/daft-dsl/src/functions/utf8/split.rs deleted file mode 100644 index 0518786055..0000000000 --- a/src/daft-dsl/src/functions/utf8/split.rs +++ /dev/null @@ -1,51 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct SplitEvaluator {} - -impl FunctionEvaluator for SplitEvaluator { - fn fn_name(&self) -> &'static str { - "split" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { - (Ok(data_field), Ok(pattern_field)) => { - match (&data_field.dtype, &pattern_field.dtype) { - (DataType::Utf8, DataType::Utf8) => { - Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to split to be utf8, but received {data_field} and {pattern_field}", - ))), - } - } - (Err(e), _) | (_, Err(e)) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data, pattern] => { - let regex = match expr { - FunctionExpr::Utf8(Utf8Expr::Split(regex)) => regex, - _ => panic!("Expected Utf8 Split Expr, got {expr}"), - }; - data.utf8_split(pattern, *regex) - } - _ => Err(DaftError::ValueError(format!( - "Expected 2 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/to_date.rs b/src/daft-dsl/src/functions/utf8/to_date.rs deleted file mode 100644 index 58adecbc05..0000000000 --- a/src/daft-dsl/src/functions/utf8/to_date.rs +++ /dev/null @@ -1,47 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ToDateEvaluator {} - -impl FunctionEvaluator for ToDateEvaluator { - fn fn_name(&self) -> &'static str { - "to_date" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Date)), - _ => Err(DaftError::TypeError(format!( - "Expects inputs to to_date to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let format = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDate(format)) => format, - _ => panic!("Expected Utf8 ToDate Expr, got {expr}"), - }; - data.utf8_to_date(format) - } - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/to_datetime.rs b/src/daft-dsl/src/functions/utf8/to_datetime.rs deleted file mode 100644 index 25368c8e64..0000000000 --- a/src/daft-dsl/src/functions/utf8/to_datetime.rs +++ /dev/null @@ -1,66 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::{datatypes::infer_timeunit_from_format_string, prelude::*}; - -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ToDatetimeEvaluator {} - -impl FunctionEvaluator for ToDatetimeEvaluator { - fn fn_name(&self) -> &'static str { - "to_datetime" - } - - fn to_field( - &self, - inputs: &[ExprRef], - schema: &Schema, - expr: &FunctionExpr, - ) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => { - let (format, timezone) = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDatetime(format, timezone)) => { - (format, timezone) - } - _ => panic!("Expected Utf8 ToDatetime Expr, got {expr}"), - }; - let timeunit = infer_timeunit_from_format_string(format); - Ok(Field::new( - data_field.name, - DataType::Timestamp(timeunit, timezone.clone()), - )) - } - _ => Err(DaftError::TypeError(format!( - "Expects inputs to to_datetime to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { - match inputs { - [data] => { - let (format, timezone) = match expr { - FunctionExpr::Utf8(Utf8Expr::ToDatetime(format, timezone)) => { - (format, timezone) - } - _ => panic!("Expected Utf8 ToDatetime Expr, got {expr}"), - }; - data.utf8_to_datetime(format, timezone.as_deref()) - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/utf8/upper.rs b/src/daft-dsl/src/functions/utf8/upper.rs deleted file mode 100644 index a02438b495..0000000000 --- a/src/daft-dsl/src/functions/utf8/upper.rs +++ /dev/null @@ -1,41 +0,0 @@ -use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; - -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct UpperEvaluator {} - -impl FunctionEvaluator for UpperEvaluator { - fn fn_name(&self) -> &'static str { - "upper" - } - - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => match data.to_field(schema) { - Ok(data_field) => match &data_field.dtype { - DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), - _ => Err(DaftError::TypeError(format!( - "Expects input to upper to be utf8, but received {data_field}", - ))), - }, - Err(e) => Err(e), - }, - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { - match inputs { - [data] => data.utf8_upper(), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input args, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 0ba4ad8b92..01d6866b94 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -9,7 +9,6 @@ use common_error::DaftError; use common_py_serde::impl_bincode_py_state_serialization; use common_resource_request::ResourceRequest; use daft_core::{ - array::ops::Utf8NormalizeOptions, datatypes::{IntervalValue, IntervalValueBuilder}, prelude::*, python::{PyDataType, PyField, PySchema, PySeries, PyTimeUnit}, @@ -487,165 +486,6 @@ impl PyExpr { hasher.finish() } - pub fn utf8_endswith(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::endswith; - Ok(endswith(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_startswith(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::startswith; - Ok(startswith(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_contains(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::contains; - Ok(contains(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_match(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::match_; - Ok(match_(self.into(), pattern.expr.clone()).into()) - } - - pub fn utf8_split(&self, pattern: &Self, regex: bool) -> PyResult { - use crate::functions::utf8::split; - Ok(split(self.into(), pattern.expr.clone(), regex).into()) - } - - pub fn utf8_extract(&self, pattern: &Self, index: usize) -> PyResult { - use crate::functions::utf8::extract; - Ok(extract(self.into(), pattern.expr.clone(), index).into()) - } - - pub fn utf8_extract_all(&self, pattern: &Self, index: usize) -> PyResult { - use crate::functions::utf8::extract_all; - Ok(extract_all(self.into(), pattern.expr.clone(), index).into()) - } - - pub fn utf8_replace(&self, pattern: &Self, replacement: &Self, regex: bool) -> PyResult { - use crate::functions::utf8::replace; - Ok(replace( - self.into(), - pattern.expr.clone(), - replacement.expr.clone(), - regex, - ) - .into()) - } - - pub fn utf8_length(&self) -> PyResult { - use crate::functions::utf8::length; - Ok(length(self.into()).into()) - } - - pub fn utf8_length_bytes(&self) -> PyResult { - use crate::functions::utf8::length_bytes; - Ok(length_bytes(self.into()).into()) - } - - pub fn utf8_lower(&self) -> PyResult { - use crate::functions::utf8::lower; - Ok(lower(self.into()).into()) - } - - pub fn utf8_upper(&self) -> PyResult { - use crate::functions::utf8::upper; - Ok(upper(self.into()).into()) - } - - pub fn utf8_lstrip(&self) -> PyResult { - use crate::functions::utf8::lstrip; - Ok(lstrip(self.into()).into()) - } - - pub fn utf8_rstrip(&self) -> PyResult { - use crate::functions::utf8::rstrip; - Ok(rstrip(self.into()).into()) - } - - pub fn utf8_reverse(&self) -> PyResult { - use crate::functions::utf8::reverse; - Ok(reverse(self.into()).into()) - } - - pub fn utf8_capitalize(&self) -> PyResult { - use crate::functions::utf8::capitalize; - Ok(capitalize(self.into()).into()) - } - - pub fn utf8_left(&self, count: &Self) -> PyResult { - use crate::functions::utf8::left; - Ok(left(self.into(), count.into()).into()) - } - - pub fn utf8_right(&self, count: &Self) -> PyResult { - use crate::functions::utf8::right; - Ok(right(self.into(), count.into()).into()) - } - - pub fn utf8_find(&self, substr: &Self) -> PyResult { - use crate::functions::utf8::find; - Ok(find(self.into(), substr.into()).into()) - } - - pub fn utf8_rpad(&self, length: &Self, pad: &Self) -> PyResult { - use crate::functions::utf8::rpad; - Ok(rpad(self.into(), length.into(), pad.into()).into()) - } - - pub fn utf8_lpad(&self, length: &Self, pad: &Self) -> PyResult { - use crate::functions::utf8::lpad; - Ok(lpad(self.into(), length.into(), pad.into()).into()) - } - - pub fn utf8_repeat(&self, n: &Self) -> PyResult { - use crate::functions::utf8::repeat; - Ok(repeat(self.into(), n.into()).into()) - } - - pub fn utf8_like(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::like; - Ok(like(self.into(), pattern.into()).into()) - } - - pub fn utf8_ilike(&self, pattern: &Self) -> PyResult { - use crate::functions::utf8::ilike; - Ok(ilike(self.into(), pattern.into()).into()) - } - - pub fn utf8_substr(&self, start: &Self, length: &Self) -> PyResult { - use crate::functions::utf8::substr; - Ok(substr(self.into(), start.into(), length.into()).into()) - } - - pub fn utf8_to_date(&self, format: &str) -> PyResult { - use crate::functions::utf8::to_date; - Ok(to_date(self.into(), format).into()) - } - - pub fn utf8_to_datetime(&self, format: &str, timezone: Option<&str>) -> PyResult { - use crate::functions::utf8::to_datetime; - Ok(to_datetime(self.into(), format, timezone).into()) - } - - pub fn utf8_normalize( - &self, - remove_punct: bool, - lowercase: bool, - nfd_unicode: bool, - white_space: bool, - ) -> PyResult { - use crate::functions::utf8::normalize; - let opts = Utf8NormalizeOptions { - remove_punct, - lowercase, - nfd_unicode, - white_space, - }; - - Ok(normalize(self.into(), opts).into()) - } - pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(self.into(), name).into()) diff --git a/src/daft-functions/src/lib.rs b/src/daft-functions/src/lib.rs index a8b1c8d0cd..e73a73cda0 100644 --- a/src/daft-functions/src/lib.rs +++ b/src/daft-functions/src/lib.rs @@ -11,6 +11,7 @@ pub mod temporal; pub mod to_struct; pub mod tokenize; pub mod uri; +pub mod utf8; use common_error::DaftError; #[cfg(feature = "python")] @@ -50,6 +51,7 @@ pub fn register_modules(parent: &Bound) -> PyResult<()> { float::register_modules(parent)?; temporal::register_modules(parent)?; list::register_modules(parent)?; + utf8::register_modules(parent)?; Ok(()) } diff --git a/src/daft-functions/src/utf8/capitalize.rs b/src/daft-functions/src/utf8/capitalize.rs new file mode 100644 index 0000000000..f666e1f3cb --- /dev/null +++ b/src/daft-functions/src/utf8/capitalize.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Capitalize {} + +#[typetag::serde] +impl ScalarUDF for Utf8Capitalize { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_capitalize" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to capitalize to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_capitalize(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_capitalize(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Capitalize {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_capitalize")] +pub fn py_utf8_capitalize(expr: PyExpr) -> PyResult { + Ok(utf8_capitalize(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/contains.rs b/src/daft-functions/src/utf8/contains.rs similarity index 52% rename from src/daft-dsl/src/functions/utf8/contains.rs rename to src/daft-functions/src/utf8/contains.rs index 8c63b17be3..9427a45858 100644 --- a/src/daft-dsl/src/functions/utf8/contains.rs +++ b/src/daft-functions/src/utf8/contains.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Contains {} -pub(super) struct ContainsEvaluator {} - -impl FunctionEvaluator for ContainsEvaluator { - fn fn_name(&self) -> &'static str { - "contains" +#[typetag::serde] +impl ScalarUDF for Utf8Contains { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_contains" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for ContainsEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_contains(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for ContainsEvaluator { } } } + +#[must_use] +pub fn utf8_contains(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Contains {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_contains")] +pub fn py_utf8_contains(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_contains(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/startswith.rs b/src/daft-functions/src/utf8/endswith.rs similarity index 52% rename from src/daft-dsl/src/functions/utf8/startswith.rs rename to src/daft-functions/src/utf8/endswith.rs index 01ae5eda7e..6a78b90542 100644 --- a/src/daft-dsl/src/functions/utf8/startswith.rs +++ b/src/daft-functions/src/utf8/endswith.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Endswith {} -pub(super) struct StartswithEvaluator {} - -impl FunctionEvaluator for StartswithEvaluator { - fn fn_name(&self) -> &'static str { - "startswith" +#[typetag::serde] +impl ScalarUDF for Utf8Endswith { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_endswith" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for StartswithEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_startswith(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for StartswithEvaluator { } } } + +#[must_use] +pub fn utf8_endswith(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Endswith {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_endswith")] +pub fn py_utf8_endswith(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_endswith(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/extract.rs b/src/daft-functions/src/utf8/extract.rs new file mode 100644 index 0000000000..679bb2916b --- /dev/null +++ b/src/daft-functions/src/utf8/extract.rs @@ -0,0 +1,74 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Extract { + pub index: usize, +} + +#[typetag::serde] +impl ScalarUDF for Utf8Extract { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_extract" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::Utf8)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to extract to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_extract(pattern, self.index), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_extract(input: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { + ScalarFunction::new(Utf8Extract { index }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_extract")] +pub fn py_utf8_extract(expr: PyExpr, pattern: PyExpr, index: usize) -> PyResult { + Ok(utf8_extract(expr.into(), pattern.into(), index).into()) +} diff --git a/src/daft-functions/src/utf8/extract_all.rs b/src/daft-functions/src/utf8/extract_all.rs new file mode 100644 index 0000000000..a3c318ad0c --- /dev/null +++ b/src/daft-functions/src/utf8/extract_all.rs @@ -0,0 +1,74 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ExtractAll { + pub index: usize, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ExtractAll { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_extract_all" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to extractAll to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_extract_all(pattern, self.index), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_extract_all(input: ExprRef, pattern: ExprRef, index: usize) -> ExprRef { + ScalarFunction::new(Utf8ExtractAll { index }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_extract_all")] +pub fn py_utf8_extract_all(expr: PyExpr, pattern: PyExpr, index: usize) -> PyResult { + Ok(utf8_extract_all(expr.into(), pattern.into(), index).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/find.rs b/src/daft-functions/src/utf8/find.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/find.rs rename to src/daft-functions/src/utf8/find.rs index d184d17c5f..547ce874d0 100644 --- a/src/daft-dsl/src/functions/utf8/find.rs +++ b/src/daft-functions/src/utf8/find.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Find {} -pub(super) struct FindEvaluator {} - -impl FunctionEvaluator for FindEvaluator { - fn fn_name(&self) -> &'static str { - "find" +#[typetag::serde] +impl ScalarUDF for Utf8Find { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_find" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, substr] => match (data.to_field(schema), substr.to_field(schema)) { (Ok(data_field), Ok(substr_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for FindEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, substr] => data.utf8_find(substr), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for FindEvaluator { } } } + +#[must_use] +pub fn utf8_find(input: ExprRef, substr: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Find {}, vec![input, substr]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_find")] +pub fn py_utf8_find(expr: PyExpr, substr: PyExpr) -> PyResult { + Ok(utf8_find(expr.into(), substr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/ilike.rs b/src/daft-functions/src/utf8/ilike.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/ilike.rs rename to src/daft-functions/src/utf8/ilike.rs index 35c0ce1e20..6b94ff2010 100644 --- a/src/daft-dsl/src/functions/utf8/ilike.rs +++ b/src/daft-functions/src/utf8/ilike.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Ilike {} -pub(super) struct IlikeEvaluator {} - -impl FunctionEvaluator for IlikeEvaluator { - fn fn_name(&self) -> &'static str { - "ilike" +#[typetag::serde] +impl ScalarUDF for Utf8Ilike { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_ilike" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for IlikeEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_ilike(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for IlikeEvaluator { } } } + +#[must_use] +pub fn utf8_ilike(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Ilike {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_ilike")] +pub fn py_utf8_ilike(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_ilike(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/left.rs b/src/daft-functions/src/utf8/left.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/left.rs rename to src/daft-functions/src/utf8/left.rs index ffde503901..1b8548fd86 100644 --- a/src/daft-dsl/src/functions/utf8/left.rs +++ b/src/daft-functions/src/utf8/left.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Left {} -pub(super) struct LeftEvaluator {} - -impl FunctionEvaluator for LeftEvaluator { - fn fn_name(&self) -> &'static str { - "left" +#[typetag::serde] +impl ScalarUDF for Utf8Left { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_left" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { (Ok(data_field), Ok(nchars_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for LeftEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, nchars] => data.utf8_left(nchars), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for LeftEvaluator { } } } + +#[must_use] +pub fn utf8_left(input: ExprRef, nchars: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Left {}, vec![input, nchars]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_left")] +pub fn py_utf8_left(expr: PyExpr, nchars: PyExpr) -> PyResult { + Ok(utf8_left(expr.into(), nchars.into()).into()) +} diff --git a/src/daft-functions/src/utf8/length.rs b/src/daft-functions/src/utf8/length.rs new file mode 100644 index 0000000000..4010915e3f --- /dev/null +++ b/src/daft-functions/src/utf8/length.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Length {} + +#[typetag::serde] +impl ScalarUDF for Utf8Length { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_length" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), + _ => Err(DaftError::TypeError(format!( + "Expects input to length to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_length(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_length(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Length {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_length")] +pub fn py_utf8_length(expr: PyExpr) -> PyResult { + Ok(utf8_length(expr.into()).into()) +} diff --git a/src/daft-functions/src/utf8/length_bytes.rs b/src/daft-functions/src/utf8/length_bytes.rs new file mode 100644 index 0000000000..71f0d03a01 --- /dev/null +++ b/src/daft-functions/src/utf8/length_bytes.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8LengthBytes {} + +#[typetag::serde] +impl ScalarUDF for Utf8LengthBytes { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_length_bytes" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::UInt64)), + _ => Err(DaftError::TypeError(format!( + "Expects input to length_bytes to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_length_bytes(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_length_bytes(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8LengthBytes {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_length_bytes")] +pub fn py_utf8_length_bytes(expr: PyExpr) -> PyResult { + Ok(utf8_length_bytes(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/like.rs b/src/daft-functions/src/utf8/like.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/like.rs rename to src/daft-functions/src/utf8/like.rs index a2a2a96def..cdd47275f6 100644 --- a/src/daft-dsl/src/functions/utf8/like.rs +++ b/src/daft-functions/src/utf8/like.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Like {} -pub(super) struct LikeEvaluator {} - -impl FunctionEvaluator for LikeEvaluator { - fn fn_name(&self) -> &'static str { - "like" +#[typetag::serde] +impl ScalarUDF for Utf8Like { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_like" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for LikeEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_like(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for LikeEvaluator { } } } + +#[must_use] +pub fn utf8_like(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Like {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_like")] +pub fn py_utf8_like(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_like(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/lower.rs b/src/daft-functions/src/utf8/lower.rs new file mode 100644 index 0000000000..b98594123b --- /dev/null +++ b/src/daft-functions/src/utf8/lower.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lower {} + +#[typetag::serde] +impl ScalarUDF for Utf8Lower { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_lower" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lower to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_lower(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_lower(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lower {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lower")] +pub fn py_utf8_lower(expr: PyExpr) -> PyResult { + Ok(utf8_lower(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/lpad.rs b/src/daft-functions/src/utf8/lpad.rs similarity index 52% rename from src/daft-dsl/src/functions/utf8/lpad.rs rename to src/daft-functions/src/utf8/lpad.rs index 9880568aed..090aa260fc 100644 --- a/src/daft-dsl/src/functions/utf8/lpad.rs +++ b/src/daft-functions/src/utf8/lpad.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lpad {} -pub(super) struct LpadEvaluator {} - -impl FunctionEvaluator for LpadEvaluator { - fn fn_name(&self) -> &'static str { - "lpad" +#[typetag::serde] +impl ScalarUDF for Utf8Lpad { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_lpad" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, length, pad] => { let data = data.to_field(schema)?; @@ -35,7 +45,7 @@ impl FunctionEvaluator for LpadEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, length, pad] => data.utf8_lpad(length, pad), _ => Err(DaftError::ValueError(format!( @@ -45,3 +55,20 @@ impl FunctionEvaluator for LpadEvaluator { } } } + +#[must_use] +pub fn utf8_lpad(input: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lpad {}, vec![input, length, pad]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lpad")] +pub fn py_utf8_lpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyResult { + Ok(utf8_lpad(expr.into(), length.into(), pad.into()).into()) +} diff --git a/src/daft-functions/src/utf8/lstrip.rs b/src/daft-functions/src/utf8/lstrip.rs new file mode 100644 index 0000000000..f3d25a9179 --- /dev/null +++ b/src/daft-functions/src/utf8/lstrip.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Lstrip {} + +#[typetag::serde] +impl ScalarUDF for Utf8Lstrip { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_lstrip" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to lstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_lstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_lstrip(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Lstrip {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_lstrip")] +pub fn py_utf8_lstrip(expr: PyExpr) -> PyResult { + Ok(utf8_lstrip(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/match_.rs b/src/daft-functions/src/utf8/match_.rs similarity index 52% rename from src/daft-dsl/src/functions/utf8/match_.rs rename to src/daft-functions/src/utf8/match_.rs index 7455aca17c..4791861806 100644 --- a/src/daft-dsl/src/functions/utf8/match_.rs +++ b/src/daft-functions/src/utf8/match_.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Match {} -pub(super) struct MatchEvaluator {} - -impl FunctionEvaluator for MatchEvaluator { - fn fn_name(&self) -> &'static str { - "match" +#[typetag::serde] +impl ScalarUDF for Utf8Match { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_match" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { (Ok(data_field), Ok(pattern_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for MatchEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, pattern] => data.utf8_match(pattern), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for MatchEvaluator { } } } + +#[must_use] +pub fn utf8_match(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Match {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_match")] +pub fn py_utf8_match(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_match(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-functions/src/utf8/mod.rs b/src/daft-functions/src/utf8/mod.rs new file mode 100644 index 0000000000..c6c4105b55 --- /dev/null +++ b/src/daft-functions/src/utf8/mod.rs @@ -0,0 +1,111 @@ +mod capitalize; +mod contains; +mod endswith; +mod extract; +mod extract_all; +mod find; +mod ilike; +mod left; +mod length; +mod length_bytes; +mod like; +mod lower; +mod lpad; +mod lstrip; +mod match_; +mod normalize; +mod repeat; +mod replace; +mod reverse; +mod right; +mod rpad; +mod rstrip; +mod split; +mod startswith; +mod substr; +mod to_date; +mod to_datetime; +mod upper; + +pub use capitalize::{utf8_capitalize as capitalize, Utf8Capitalize}; +pub use contains::{utf8_contains as contains, Utf8Contains}; +pub use endswith::{utf8_endswith as endswith, Utf8Endswith}; +pub use extract::{utf8_extract as extract, Utf8Extract}; +pub use extract_all::{utf8_extract_all as extract_all, Utf8ExtractAll}; +pub use find::{utf8_find as find, Utf8Find}; +pub use ilike::{utf8_ilike as ilike, Utf8Ilike}; +pub use left::{utf8_left as left, Utf8Left}; +pub use length::{utf8_length as length, Utf8Length}; +pub use length_bytes::{utf8_length_bytes as length_bytes, Utf8LengthBytes}; +pub use like::{utf8_like as like, Utf8Like}; +pub use lower::{utf8_lower as lower, Utf8Lower}; +pub use lpad::{utf8_lpad as lpad, Utf8Lpad}; +pub use lstrip::{utf8_lstrip as lstrip, Utf8Lstrip}; +pub use match_::{utf8_match as match_, Utf8Match}; +pub use normalize::{utf8_normalize as normalize, Utf8Normalize}; +#[cfg(feature = "python")] +use pyo3::prelude::*; +pub use repeat::{utf8_repeat as repeat, Utf8Repeat}; +pub use replace::{utf8_replace as replace, Utf8Replace}; +pub use reverse::{utf8_reverse as reverse, Utf8Reverse}; +pub use right::{utf8_right as right, Utf8Right}; +pub use rpad::{utf8_rpad as rpad, Utf8Rpad}; +pub use rstrip::{utf8_rstrip as rstrip, Utf8Rstrip}; +pub use split::{utf8_split as split, Utf8Split}; +pub use startswith::{utf8_startswith as startswith, Utf8Startswith}; +pub use substr::{utf8_substr as substr, Utf8Substr}; +pub use to_date::{utf8_to_date as to_date, Utf8ToDate}; +pub use to_datetime::{utf8_to_datetime as to_datetime, Utf8ToDatetime}; +pub use upper::{utf8_upper as upper, Utf8Upper}; + +#[cfg(feature = "python")] +pub fn register_modules(parent: &Bound) -> PyResult<()> { + parent.add_function(wrap_pyfunction_bound!( + capitalize::py_utf8_capitalize, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(contains::py_utf8_contains, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(endswith::py_utf8_endswith, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(extract::py_utf8_extract, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + extract_all::py_utf8_extract_all, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(find::py_utf8_find, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(like::py_utf8_like, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(ilike::py_utf8_ilike, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(left::py_utf8_left, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(length::py_utf8_length, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + length_bytes::py_utf8_length_bytes, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(lower::py_utf8_lower, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(lpad::py_utf8_lpad, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(lstrip::py_utf8_lstrip, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(match_::py_utf8_match, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + normalize::py_utf8_normalize, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(repeat::py_utf8_repeat, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(replace::py_utf8_replace, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(reverse::py_utf8_reverse, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(right::py_utf8_right, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(rpad::py_utf8_rpad, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(rstrip::py_utf8_rstrip, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(split::py_utf8_split, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + startswith::py_utf8_startswith, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(substr::py_utf8_substr, parent)?)?; + parent.add_function(wrap_pyfunction_bound!(to_date::py_utf8_to_date, parent)?)?; + parent.add_function(wrap_pyfunction_bound!( + to_datetime::py_utf8_to_datetime, + parent + )?)?; + parent.add_function(wrap_pyfunction_bound!(upper::py_utf8_upper, parent)?)?; + + Ok(()) +} diff --git a/src/daft-functions/src/utf8/normalize.rs b/src/daft-functions/src/utf8/normalize.rs new file mode 100644 index 0000000000..744ba78035 --- /dev/null +++ b/src/daft-functions/src/utf8/normalize.rs @@ -0,0 +1,87 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Normalize { + pub opts: Utf8NormalizeOptions, +} + +#[typetag::serde] +impl ScalarUDF for Utf8Normalize { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_normalize" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to normalize to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_normalize(self.opts), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_normalize(input: ExprRef, opts: Utf8NormalizeOptions) -> ExprRef { + ScalarFunction::new(Utf8Normalize { opts }, vec![input]).into() +} + +use daft_core::array::ops::Utf8NormalizeOptions; +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_normalize")] +pub fn py_utf8_normalize( + expr: PyExpr, + remove_punct: bool, + lowercase: bool, + nfd_unicode: bool, + white_space: bool, +) -> PyResult { + Ok(utf8_normalize( + expr.into(), + Utf8NormalizeOptions { + remove_punct, + lowercase, + nfd_unicode, + white_space, + }, + ) + .into()) +} diff --git a/src/daft-dsl/src/functions/utf8/repeat.rs b/src/daft-functions/src/utf8/repeat.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/repeat.rs rename to src/daft-functions/src/utf8/repeat.rs index c321a6920a..a864fd701b 100644 --- a/src/daft-dsl/src/functions/utf8/repeat.rs +++ b/src/daft-functions/src/utf8/repeat.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Repeat {} -pub(super) struct RepeatEvaluator {} - -impl FunctionEvaluator for RepeatEvaluator { - fn fn_name(&self) -> &'static str { - "repeat" +#[typetag::serde] +impl ScalarUDF for Utf8Repeat { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_repeat" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, ntimes] => match (data.to_field(schema), ntimes.to_field(schema)) { (Ok(data_field), Ok(ntimes_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for RepeatEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, ntimes] => data.utf8_repeat(ntimes), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for RepeatEvaluator { } } } + +#[must_use] +pub fn utf8_repeat(input: ExprRef, ntimes: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Repeat {}, vec![input, ntimes]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_repeat")] +pub fn py_utf8_repeat(expr: PyExpr, ntimes: PyExpr) -> PyResult { + Ok(utf8_repeat(expr.into(), ntimes.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/replace.rs b/src/daft-functions/src/utf8/replace.rs similarity index 50% rename from src/daft-dsl/src/functions/utf8/replace.rs rename to src/daft-functions/src/utf8/replace.rs index 022f98ac17..cfbb612611 100644 --- a/src/daft-dsl/src/functions/utf8/replace.rs +++ b/src/daft-functions/src/utf8/replace.rs @@ -1,17 +1,29 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::{super::FunctionEvaluator, Utf8Expr}; -use crate::{functions::FunctionExpr, ExprRef}; - -pub(super) struct ReplaceEvaluator {} +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Replace { + pub regex: bool, +} -impl FunctionEvaluator for ReplaceEvaluator { - fn fn_name(&self) -> &'static str { - "replace" +#[typetag::serde] +impl ScalarUDF for Utf8Replace { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_replace" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, pattern, replacement] => match ( data.to_field(schema), @@ -37,15 +49,9 @@ impl FunctionEvaluator for ReplaceEvaluator { } } - fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, pattern, replacement] => { - let regex = match expr { - FunctionExpr::Utf8(Utf8Expr::Replace(regex)) => regex, - _ => panic!("Expected Utf8 Replace Expr, got {expr}"), - }; - data.utf8_replace(pattern, replacement, *regex) - } + [data, pattern, replacement] => data.utf8_replace(pattern, replacement, self.regex), _ => Err(DaftError::ValueError(format!( "Expected 3 input args, got {}", inputs.len() @@ -53,3 +59,30 @@ impl FunctionEvaluator for ReplaceEvaluator { } } } + +#[must_use] +pub fn utf8_replace( + input: ExprRef, + pattern: ExprRef, + replacement: ExprRef, + regex: bool, +) -> ExprRef { + ScalarFunction::new(Utf8Replace { regex }, vec![input, pattern, replacement]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_replace")] +pub fn py_utf8_replace( + expr: PyExpr, + pattern: PyExpr, + replacement: PyExpr, + regex: bool, +) -> PyResult { + Ok(utf8_replace(expr.into(), pattern.into(), replacement.into(), regex).into()) +} diff --git a/src/daft-functions/src/utf8/reverse.rs b/src/daft-functions/src/utf8/reverse.rs new file mode 100644 index 0000000000..07227f578e --- /dev/null +++ b/src/daft-functions/src/utf8/reverse.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Reverse {} + +#[typetag::serde] +impl ScalarUDF for Utf8Reverse { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_reverse" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to reverse to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_reverse(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_reverse(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Reverse {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_reverse")] +pub fn py_utf8_reverse(expr: PyExpr) -> PyResult { + Ok(utf8_reverse(expr.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/right.rs b/src/daft-functions/src/utf8/right.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/right.rs rename to src/daft-functions/src/utf8/right.rs index 892c0f7341..b32b4080c2 100644 --- a/src/daft-dsl/src/functions/utf8/right.rs +++ b/src/daft-functions/src/utf8/right.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Right {} -pub(super) struct RightEvaluator {} - -impl FunctionEvaluator for RightEvaluator { - fn fn_name(&self) -> &'static str { - "right" +#[typetag::serde] +impl ScalarUDF for Utf8Right { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_right" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, nchars] => match (data.to_field(schema), nchars.to_field(schema)) { (Ok(data_field), Ok(nchars_field)) => { @@ -33,7 +43,7 @@ impl FunctionEvaluator for RightEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, nchars] => data.utf8_right(nchars), _ => Err(DaftError::ValueError(format!( @@ -43,3 +53,20 @@ impl FunctionEvaluator for RightEvaluator { } } } + +#[must_use] +pub fn utf8_right(input: ExprRef, nchars: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Right {}, vec![input, nchars]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_right")] +pub fn py_utf8_right(expr: PyExpr, nchars: PyExpr) -> PyResult { + Ok(utf8_right(expr.into(), nchars.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/rpad.rs b/src/daft-functions/src/utf8/rpad.rs similarity index 52% rename from src/daft-dsl/src/functions/utf8/rpad.rs rename to src/daft-functions/src/utf8/rpad.rs index f7c0769fac..01a9766203 100644 --- a/src/daft-dsl/src/functions/utf8/rpad.rs +++ b/src/daft-functions/src/utf8/rpad.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Rpad {} -pub(super) struct RpadEvaluator {} - -impl FunctionEvaluator for RpadEvaluator { - fn fn_name(&self) -> &'static str { - "rpad" +#[typetag::serde] +impl ScalarUDF for Utf8Rpad { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_rpad" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, length, pad] => { let data = data.to_field(schema)?; @@ -35,7 +45,7 @@ impl FunctionEvaluator for RpadEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, length, pad] => data.utf8_rpad(length, pad), _ => Err(DaftError::ValueError(format!( @@ -45,3 +55,20 @@ impl FunctionEvaluator for RpadEvaluator { } } } + +#[must_use] +pub fn utf8_rpad(input: ExprRef, length: ExprRef, pad: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Rpad {}, vec![input, length, pad]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_rpad")] +pub fn py_utf8_rpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyResult { + Ok(utf8_rpad(expr.into(), length.into(), pad.into()).into()) +} diff --git a/src/daft-functions/src/utf8/rstrip.rs b/src/daft-functions/src/utf8/rstrip.rs new file mode 100644 index 0000000000..df8044ec76 --- /dev/null +++ b/src/daft-functions/src/utf8/rstrip.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Rstrip {} + +#[typetag::serde] +impl ScalarUDF for Utf8Rstrip { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_rstrip" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to rstrip to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_rstrip(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_rstrip(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Rstrip {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_rstrip")] +pub fn py_utf8_rstrip(expr: PyExpr) -> PyResult { + Ok(utf8_rstrip(expr.into()).into()) +} diff --git a/src/daft-functions/src/utf8/split.rs b/src/daft-functions/src/utf8/split.rs new file mode 100644 index 0000000000..f8a1725b66 --- /dev/null +++ b/src/daft-functions/src/utf8/split.rs @@ -0,0 +1,74 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Split { + pub regex: bool, +} + +#[typetag::serde] +impl ScalarUDF for Utf8Split { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_split" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::List(Box::new(DataType::Utf8)))) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to split to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_split(pattern, self.regex), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_split(input: ExprRef, pattern: ExprRef, regex: bool) -> ExprRef { + ScalarFunction::new(Utf8Split { regex }, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_split")] +pub fn py_utf8_split(expr: PyExpr, pattern: PyExpr, regex: bool) -> PyResult { + Ok(utf8_split(expr.into(), pattern.into(), regex).into()) +} diff --git a/src/daft-functions/src/utf8/startswith.rs b/src/daft-functions/src/utf8/startswith.rs new file mode 100644 index 0000000000..98a3ce4556 --- /dev/null +++ b/src/daft-functions/src/utf8/startswith.rs @@ -0,0 +1,72 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Startswith {} + +#[typetag::serde] +impl ScalarUDF for Utf8Startswith { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_startswith" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data, pattern] => match (data.to_field(schema), pattern.to_field(schema)) { + (Ok(data_field), Ok(pattern_field)) => { + match (&data_field.dtype, &pattern_field.dtype) { + (DataType::Utf8, DataType::Utf8) => { + Ok(Field::new(data_field.name, DataType::Boolean)) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to startswith to be utf8, but received {data_field} and {pattern_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data, pattern] => data.utf8_startswith(pattern), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_startswith(input: ExprRef, pattern: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Startswith {}, vec![input, pattern]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_startswith")] +pub fn py_utf8_startswith(expr: PyExpr, pattern: PyExpr) -> PyResult { + Ok(utf8_startswith(expr.into(), pattern.into()).into()) +} diff --git a/src/daft-dsl/src/functions/utf8/substr.rs b/src/daft-functions/src/utf8/substr.rs similarity index 53% rename from src/daft-dsl/src/functions/utf8/substr.rs rename to src/daft-functions/src/utf8/substr.rs index d2ec60256a..a8eac19215 100644 --- a/src/daft-dsl/src/functions/utf8/substr.rs +++ b/src/daft-functions/src/utf8/substr.rs @@ -1,17 +1,27 @@ use common_error::{DaftError, DaftResult}; -use daft_core::prelude::*; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; -use super::super::FunctionEvaluator; -use crate::{functions::FunctionExpr, ExprRef}; +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Substr {} -pub(super) struct SubstrEvaluator {} - -impl FunctionEvaluator for SubstrEvaluator { - fn fn_name(&self) -> &'static str { - "substr" +#[typetag::serde] +impl ScalarUDF for Utf8Substr { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_substr" } - fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { match inputs { [data, start, length] => { let data = data.to_field(schema)?; @@ -37,7 +47,7 @@ impl FunctionEvaluator for SubstrEvaluator { } } - fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { [data, start, length] => data.utf8_substr(start, length), _ => Err(DaftError::ValueError(format!( @@ -47,3 +57,20 @@ impl FunctionEvaluator for SubstrEvaluator { } } } + +#[must_use] +pub fn utf8_substr(input: ExprRef, start: ExprRef, length: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Substr {}, vec![input, start, length]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_substr")] +pub fn py_utf8_substr(expr: PyExpr, start: PyExpr, length: PyExpr) -> PyResult { + Ok(utf8_substr(expr.into(), start.into(), length.into()).into()) +} diff --git a/src/daft-functions/src/utf8/to_date.rs b/src/daft-functions/src/utf8/to_date.rs new file mode 100644 index 0000000000..5b2f176783 --- /dev/null +++ b/src/daft-functions/src/utf8/to_date.rs @@ -0,0 +1,77 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ToDate { + pub format: String, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ToDate { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_to_date" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Date)), + _ => Err(DaftError::TypeError(format!( + "Expects inputs to to_date to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_to_date(&self.format), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_to_date>(input: ExprRef, format: S) -> ExprRef { + ScalarFunction::new( + Utf8ToDate { + format: format.into(), + }, + vec![input], + ) + .into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_to_date")] +pub fn py_utf8_to_date(expr: PyExpr, format: &str) -> PyResult { + Ok(utf8_to_date::<&str>(expr.into(), format).into()) +} diff --git a/src/daft-functions/src/utf8/to_datetime.rs b/src/daft-functions/src/utf8/to_datetime.rs new file mode 100644 index 0000000000..d642799ed7 --- /dev/null +++ b/src/daft-functions/src/utf8/to_datetime.rs @@ -0,0 +1,90 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + datatypes::infer_timeunit_from_format_string, + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8ToDatetime { + pub format: String, + pub timezone: Option, +} + +#[typetag::serde] +impl ScalarUDF for Utf8ToDatetime { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_to_datetime" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => { + let timeunit = infer_timeunit_from_format_string(&self.format); + Ok(Field::new( + data_field.name, + DataType::Timestamp(timeunit, self.timezone.clone()), + )) + } + _ => Err(DaftError::TypeError(format!( + "Expects inputs to to_datetime to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_to_datetime(&self.format, self.timezone.as_deref()), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_to_datetime>( + input: ExprRef, + format: S, + timezone: Option, +) -> ExprRef { + ScalarFunction::new( + Utf8ToDatetime { + format: format.into(), + timezone: timezone.map(|s| s.into()), + }, + vec![input], + ) + .into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; + +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_to_datetime")] +pub fn py_utf8_to_datetime(expr: PyExpr, format: &str, timezone: Option<&str>) -> PyResult { + Ok(utf8_to_datetime::<&str>(expr.into(), format, timezone).into()) +} diff --git a/src/daft-functions/src/utf8/upper.rs b/src/daft-functions/src/utf8/upper.rs new file mode 100644 index 0000000000..cbaf2fff8b --- /dev/null +++ b/src/daft-functions/src/utf8/upper.rs @@ -0,0 +1,68 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{ + prelude::{DataType, Field, Schema}, + series::Series, +}; +use daft_dsl::{ + functions::{ScalarFunction, ScalarUDF}, + ExprRef, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub struct Utf8Upper {} + +#[typetag::serde] +impl ScalarUDF for Utf8Upper { + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn name(&self) -> &'static str { + "utf8_upper" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)), + _ => Err(DaftError::TypeError(format!( + "Expects input to upper to be utf8, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series]) -> DaftResult { + match inputs { + [data] => data.utf8_upper(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} + +#[must_use] +pub fn utf8_upper(input: ExprRef) -> ExprRef { + ScalarFunction::new(Utf8Upper {}, vec![input]).into() +} + +#[cfg(feature = "python")] +use { + daft_dsl::python::PyExpr, + pyo3::{pyfunction, PyResult}, +}; +#[cfg(feature = "python")] +#[pyfunction] +#[pyo3(name = "utf8_upper")] +pub fn py_utf8_upper(expr: PyExpr) -> PyResult { + Ok(utf8_upper(expr.into()).into()) +} diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index 26c7470fa7..6bcdba360b 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -34,11 +34,8 @@ mod test { use common_display::mermaid::{MermaidDisplay, MermaidDisplayOptions, SubgraphOptions}; use common_error::DaftResult; use daft_core::prelude::*; - use daft_dsl::{ - col, - functions::utf8::{endswith, startswith}, - lit, - }; + use daft_dsl::{col, lit}; + use daft_functions::utf8::{endswith, startswith}; use pretty_assertions::assert_eq; use crate::{ diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 084da08962..d2a7a0ef01 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,10 +1,6 @@ use daft_core::array::ops::Utf8NormalizeOptions; use daft_dsl::{ binary_op, - functions::{ - self, - utf8::{normalize, Utf8Expr}, - }, ExprRef, LiteralValue, Operator, }; use daft_functions::{ @@ -14,51 +10,168 @@ use daft_functions::{ use super::SQLModule; use crate::{ - ensure, error::{PlannerError, SQLPlannerResult}, functions::{SQLFunction, SQLFunctionArguments}, - invalid_operation_err, unsupported_sql_err, + invalid_operation_err, }; +macro_rules! utf8_function_one_argument { + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name:expr) => { + pub struct $name; + + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok($func(input)) + } + _ => invalid_operation_err!(concat!( + "invalid arguments for ", + $sql_name, + ". Expected ", + $sql_name, + "(", + stringify!($arg_name), + ")" + )), + } + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name] + } + } + }; +} + +macro_rules! utf8_function_two_arguments { + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr) => { + pub struct $name; + + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input1, input2] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + Ok($func(input1, input2)) + } + _ => invalid_operation_err!(concat!( + "invalid arguments for ", + $sql_name, + ". Expected ", + $sql_name, + "(", + stringify!($arg_name_1), + ", ", + stringify!($arg_name_2), + ")" + )), + } + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name_1, $arg_name_2] + } + } + }; +} + +macro_rules! utf8_function_three_arguments { + ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr, $arg_name_3:expr) => { + pub struct $name; + + impl SQLFunction for $name { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input1, input2, input3] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + let input3 = planner.plan_function_arg(input3)?; + Ok($func(input1, input2, input3)) + } + _ => invalid_operation_err!(concat!( + "invalid arguments for ", + $sql_name, + ". Expected ", + $sql_name, + "(", + stringify!($arg_name_1), + ", ", + stringify!($arg_name_2), + ", ", + stringify!($arg_name_3), + ")" + )), + } + } + + fn docstrings(&self, _alias: &str) -> String { + $doc.to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[$arg_name_1, $arg_name_2, $arg_name_3] + } + } + }; +} + pub struct SQLModuleUtf8; impl SQLModule for SQLModuleUtf8 { fn register(parent: &mut crate::functions::SQLFunctions) { - use Utf8Expr::{ - Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Left, Length, LengthBytes, - Lower, Lpad, Lstrip, Match, Repeat, Replace, Reverse, Right, Rpad, Rstrip, Split, - StartsWith, ToDate, ToDatetime, Upper, - }; - parent.add_fn("ends_with", EndsWith); - parent.add_fn("starts_with", StartsWith); - parent.add_fn("contains", Contains); - parent.add_fn("split", Split(false)); + parent.add_fn("ends_with", SQLUtf8EndsWith); + parent.add_fn("starts_with", SQLUtf8StartsWith); + parent.add_fn("contains", SQLUtf8Contains); + parent.add_fn("split", SQLUtf8Split); // TODO add split variants // parent.add("split", f(Split(false))); - parent.add_fn("regexp_match", Match); - parent.add_fn("regexp_extract", Extract(0)); - parent.add_fn("regexp_extract_all", ExtractAll(0)); - parent.add_fn("regexp_replace", Replace(true)); - parent.add_fn("regexp_split", Split(true)); + parent.add_fn("regexp_match", SQLUtf8RegexpMatch); + parent.add_fn("regexp_extract", SQLUtf8RegexpExtract); + parent.add_fn("regexp_extract_all", SQLUtf8RegexpExtractAll); + parent.add_fn("regexp_replace", SQLUtf8RegexpReplace); + parent.add_fn("regexp_split", SQLUtf8RegexpSplit); // TODO add replace variants // parent.add("replace", f(Replace(false))); - parent.add_fn("length", Length); - parent.add_fn("length_bytes", LengthBytes); - parent.add_fn("lower", Lower); - parent.add_fn("upper", Upper); - parent.add_fn("lstrip", Lstrip); - parent.add_fn("rstrip", Rstrip); - parent.add_fn("reverse", Reverse); - parent.add_fn("capitalize", Capitalize); - parent.add_fn("left", Left); - parent.add_fn("right", Right); - parent.add_fn("find", Find); - parent.add_fn("rpad", Rpad); - parent.add_fn("lpad", Lpad); - parent.add_fn("repeat", Repeat); - - parent.add_fn("to_date", ToDate(String::new())); - parent.add_fn("to_datetime", ToDatetime(String::new(), None)); + parent.add_fn("length", SQLUtf8Length); + parent.add_fn("length_bytes", SQLUtf8LengthBytes); + parent.add_fn("lower", SQLUtf8Lower); + parent.add_fn("upper", SQLUtf8Upper); + parent.add_fn("lstrip", SQLUtf8Lstrip); + parent.add_fn("rstrip", SQLUtf8Rstrip); + parent.add_fn("reverse", SQLUtf8Reverse); + parent.add_fn("capitalize", SQLUtf8Capitalize); + parent.add_fn("left", SQLUtf8Left); + parent.add_fn("right", SQLUtf8Right); + parent.add_fn("find", SQLUtf8Find); + parent.add_fn("rpad", SQLUtf8Rpad); + parent.add_fn("lpad", SQLUtf8Lpad); + parent.add_fn("repeat", SQLUtf8Repeat); + + parent.add_fn("to_date", SQLUtf8ToDate); + parent.add_fn("to_datetime", SQLUtf8ToDatetime); parent.add_fn("count_matches", SQLCountMatches); parent.add_fn("normalize", SQLNormalize); parent.add_fn("tokenize_encode", SQLTokenizeEncode); @@ -67,255 +180,334 @@ impl SQLModule for SQLModuleUtf8 { } } -impl SQLModuleUtf8 {} - -impl SQLFunction for Utf8Expr { +utf8_function_two_arguments!( + SQLUtf8EndsWith, + "ends_with", + daft_functions::utf8::endswith, + "Returns true if the string ends with the specified substring", + "string_input", + "substring" +); + +utf8_function_two_arguments!( + SQLUtf8StartsWith, + "starts_with", + daft_functions::utf8::startswith, + "Returns true if the string starts with the specified substring", + "string_input", + "substring" +); + +utf8_function_two_arguments!( + SQLUtf8Contains, + "contains", + daft_functions::utf8::contains, + "Returns true if the string contains the specified substring", + "string_input", + "substring" +); + +utf8_function_two_arguments!( + SQLUtf8Split, + "split", + |input, pattern| daft_functions::utf8::split(input, pattern, false), + "Splits the string by the specified delimiter and returns an array of substrings", + "string_input", + "delimiter" +); + +utf8_function_two_arguments!( + SQLUtf8RegexpMatch, + "regexp_match", + daft_functions::utf8::match_, + "Returns true if the string matches the specified regular expression pattern", + "string_input", + "pattern" +); + +utf8_function_three_arguments!( + SQLUtf8RegexpReplace, + "regexp_replace", + |input, pattern, replacement| daft_functions::utf8::replace(input, pattern, replacement, true), + "Replaces all occurrences of a substring with a new string", + "string_input", + "pattern", + "replacement" +); + +utf8_function_two_arguments!( + SQLUtf8RegexpSplit, + "regexp_split", + |input, pattern| daft_functions::utf8::split(input, pattern, true), + "Splits the string by the specified delimiter and returns an array of substrings", + "string_input", + "delimiter" +); + +utf8_function_one_argument!( + SQLUtf8Length, + "length", + daft_functions::utf8::length, + "Returns the length of the string", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8LengthBytes, + "length_bytes", + daft_functions::utf8::length_bytes, + "Returns the length of the string in bytes", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Lower, + "lower", + daft_functions::utf8::lower, + "Converts the string to lowercase", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Upper, + "upper", + daft_functions::utf8::upper, + "Converts the string to uppercase", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Lstrip, + "lstrip", + daft_functions::utf8::lstrip, + "Removes leading whitespace from the string", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Rstrip, + "rstrip", + daft_functions::utf8::rstrip, + "Removes trailing whitespace from the string", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Reverse, + "reverse", + daft_functions::utf8::reverse, + "Reverses the order of characters in the string", + "string_input" +); + +utf8_function_one_argument!( + SQLUtf8Capitalize, + "capitalize", + daft_functions::utf8::capitalize, + "Capitalizes the first character of the string", + "string_input" +); + +utf8_function_two_arguments!( + SQLUtf8Left, + "left", + daft_functions::utf8::left, + "Returns the specified number of leftmost characters from the string", + "string_input", + "length" +); + +utf8_function_two_arguments!( + SQLUtf8Right, + "right", + daft_functions::utf8::right, + "Returns the specified number of rightmost characters from the string", + "string_input", + "length" +); + +utf8_function_two_arguments!( + SQLUtf8Find, + "find", + daft_functions::utf8::find, + "Returns the index of the first occurrence of a substring within the string", + "string_input", + "substring" +); + +utf8_function_three_arguments!( + SQLUtf8Rpad, + "rpad", + daft_functions::utf8::rpad, + "Pads the string on the right side with the specified string until it reaches the specified length", + "string_input", "length", "pad" +); + +utf8_function_three_arguments!( + SQLUtf8Lpad, + "lpad", + daft_functions::utf8::lpad, + "Pads the string on the left side with the specified string until it reaches the specified length", + "string_input", "length", "pad" +); + +utf8_function_two_arguments!( + SQLUtf8Repeat, + "repeat", + daft_functions::utf8::repeat, + "Repeats the string the specified number of times", + "string_input", + "count" +); + +pub struct SQLUtf8RegexpExtract; + +impl SQLFunction for SQLUtf8RegexpExtract { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], planner: &crate::planner::SQLPlanner, ) -> SQLPlannerResult { - let inputs = self.args_to_expr_unnamed(inputs, planner)?; - to_expr(self, &inputs) + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(daft_functions::utf8::extract(input, pattern, 0)) + } + [input, pattern, idx] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let idx = planner.plan_function_arg(idx)?.as_literal().and_then(LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) + })? as usize; + Ok(daft_functions::utf8::extract(input, pattern, idx)) + } + _ => invalid_operation_err!("regexp_extract takes exactly two or three arguments"), + } } fn docstrings(&self, _alias: &str) -> String { - match self { - Self::EndsWith => "Returns true if the string ends with the specified substring".to_string(), - Self::StartsWith => "Returns true if the string starts with the specified substring".to_string(), - Self::Contains => "Returns true if the string contains the specified substring".to_string(), - Self::Split(_) => "Splits the string by the specified delimiter and returns an array of substrings".to_string(), - Self::Match => "Returns true if the string matches the specified regular expression pattern".to_string(), - Self::Extract(_) => "Extracts the first substring that matches the specified regular expression pattern".to_string(), - Self::ExtractAll(_) => "Extracts all substrings that match the specified regular expression pattern".to_string(), - Self::Replace(_) => "Replaces all occurrences of a substring with a new string".to_string(), - Self::Like => "Returns true if the string matches the specified SQL LIKE pattern".to_string(), - Self::Ilike => "Returns true if the string matches the specified SQL LIKE pattern (case-insensitive)".to_string(), - Self::Length => "Returns the length of the string".to_string(), - Self::Lower => "Converts the string to lowercase".to_string(), - Self::Upper => "Converts the string to uppercase".to_string(), - Self::Lstrip => "Removes leading whitespace from the string".to_string(), - Self::Rstrip => "Removes trailing whitespace from the string".to_string(), - Self::Reverse => "Reverses the order of characters in the string".to_string(), - Self::Capitalize => "Capitalizes the first character of the string".to_string(), - Self::Left => "Returns the specified number of leftmost characters from the string".to_string(), - Self::Right => "Returns the specified number of rightmost characters from the string".to_string(), - Self::Find => "Returns the index of the first occurrence of a substring within the string".to_string(), - Self::Rpad => "Pads the string on the right side with the specified string until it reaches the specified length".to_string(), - Self::Lpad => "Pads the string on the left side with the specified string until it reaches the specified length".to_string(), - Self::Repeat => "Repeats the string the specified number of times".to_string(), - Self::Substr => "Returns a substring of the string starting at the specified position and length".to_string(), - Self::ToDate(_) => "Parses the string as a date using the specified format.".to_string(), - Self::ToDatetime(_, _) => "Parses the string as a datetime using the specified format.".to_string(), - Self::LengthBytes => "Returns the length of the string in bytes".to_string(), - Self::Normalize(_) => "Normalizes a string for more useful deduplication and data cleaning".to_string(), - } + "Extracts the first substring that matches the specified regular expression pattern" + .to_string() } fn arg_names(&self) -> &'static [&'static str] { - match self { - Self::EndsWith => &["string_input", "substring"], - Self::StartsWith => &["string_input", "substring"], - Self::Contains => &["string_input", "substring"], - Self::Split(_) => &["string_input", "delimiter"], - Self::Match => &["string_input", "pattern"], - Self::Extract(_) => &["string_input", "pattern"], - Self::ExtractAll(_) => &["string_input", "pattern"], - Self::Replace(_) => &["string_input", "pattern", "replacement"], - Self::Like => &["string_input", "pattern"], - Self::Ilike => &["string_input", "pattern"], - Self::Length => &["string_input"], - Self::Lower => &["string_input"], - Self::Upper => &["string_input"], - Self::Lstrip => &["string_input"], - Self::Rstrip => &["string_input"], - Self::Reverse => &["string_input"], - Self::Capitalize => &["string_input"], - Self::Left => &["string_input", "length"], - Self::Right => &["string_input", "length"], - Self::Find => &["string_input", "substring"], - Self::Rpad => &["string_input", "length", "pad"], - Self::Lpad => &["string_input", "length", "pad"], - Self::Repeat => &["string_input", "count"], - Self::Substr => &["string_input", "start", "length"], - Self::ToDate(_) => &["string_input", "format"], - Self::ToDatetime(_, _) => &["string_input", "format"], - Self::LengthBytes => &["string_input"], - Self::Normalize(_) => &[ - "input", - "remove_punct", - "lowercase", - "nfd_unicode", - "white_space", - ], - } + &["string_input", "pattern"] } } -fn to_expr(expr: &Utf8Expr, args: &[ExprRef]) -> SQLPlannerResult { - use functions::utf8::{ - capitalize, contains, endswith, extract, extract_all, find, left, length, length_bytes, - lower, lpad, lstrip, match_, repeat, replace, reverse, right, rpad, rstrip, split, - startswith, to_date, to_datetime, upper, Utf8Expr, - }; - use Utf8Expr::{ - Capitalize, Contains, EndsWith, Extract, ExtractAll, Find, Ilike, Left, Length, - LengthBytes, Like, Lower, Lpad, Lstrip, Match, Normalize, Repeat, Replace, Reverse, Right, - Rpad, Rstrip, Split, StartsWith, Substr, ToDate, ToDatetime, Upper, - }; - match expr { - EndsWith => { - ensure!(args.len() == 2, "endswith takes exactly two arguments"); - Ok(endswith(args[0].clone(), args[1].clone())) - } - StartsWith => { - ensure!(args.len() == 2, "startswith takes exactly two arguments"); - Ok(startswith(args[0].clone(), args[1].clone())) - } - Contains => { - ensure!(args.len() == 2, "contains takes exactly two arguments"); - Ok(contains(args[0].clone(), args[1].clone())) - } - Split(true) => { - ensure!(args.len() == 2, "split takes exactly two arguments"); - Ok(split(args[0].clone(), args[1].clone(), true)) - } - Split(false) => { - ensure!(args.len() == 2, "split takes exactly two arguments"); - Ok(split(args[0].clone(), args[1].clone(), false)) - } - Match => { - ensure!(args.len() == 2, "regexp_match takes exactly two arguments"); - Ok(match_(args[0].clone(), args[1].clone())) - } - Extract(_) => match args { - [input, pattern] => Ok(extract(input.clone(), pattern.clone(), 0)), - [input, pattern, idx] => { - let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { - PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) - })?; +pub struct SQLUtf8RegexpExtractAll; - Ok(extract(input.clone(), pattern.clone(), idx as usize)) - } - _ => { - invalid_operation_err!("regexp_extract takes exactly two or three arguments") +impl SQLFunction for SQLUtf8RegexpExtractAll { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, pattern] => { + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + Ok(daft_functions::utf8::extract_all(input, pattern, 0)) } - }, - ExtractAll(_) => match args { - [input, pattern] => Ok(extract_all(input.clone(), pattern.clone(), 0)), [input, pattern, idx] => { - let idx = idx.as_literal().and_then(daft_dsl::LiteralValue::as_i64).ok_or_else(|| { - PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract, found {idx:?}")) - })?; - - Ok(extract_all(input.clone(), pattern.clone(), idx as usize)) - } - _ => { - invalid_operation_err!("regexp_extract_all takes exactly two or three arguments") + let input = planner.plan_function_arg(input)?; + let pattern = planner.plan_function_arg(pattern)?; + let idx = planner.plan_function_arg(idx)?.as_literal().and_then(LiteralValue::as_i64).ok_or_else(|| { + PlannerError::invalid_operation(format!("Expected a literal integer for the third argument of regexp_extract_all, found {idx:?}")) + })? as usize; + Ok(daft_functions::utf8::extract_all(input, pattern, idx)) } - }, - Replace(_) => { - ensure!(args.len() == 3, "replace takes exactly three arguments"); - Ok(replace( - args[0].clone(), - args[1].clone(), - args[2].clone(), - false, - )) - } - Like => { - unreachable!("like should be handled by the parser") - } - Ilike => { - unreachable!("ilike should be handled by the parser") - } - Length => { - ensure!(args.len() == 1, "length takes exactly one argument"); - Ok(length(args[0].clone())) - } - LengthBytes => { - ensure!(args.len() == 1, "length_bytes takes exactly one argument"); - Ok(length_bytes(args[0].clone())) - } - Lower => { - ensure!(args.len() == 1, "lower takes exactly one argument"); - Ok(lower(args[0].clone())) - } - Upper => { - ensure!(args.len() == 1, "upper takes exactly one argument"); - Ok(upper(args[0].clone())) + _ => invalid_operation_err!("regexp_extract_all takes exactly two or three arguments"), } - Lstrip => { - ensure!(args.len() == 1, "lstrip takes exactly one argument"); - Ok(lstrip(args[0].clone())) - } - Rstrip => { - ensure!(args.len() == 1, "rstrip takes exactly one argument"); - Ok(rstrip(args[0].clone())) - } - Reverse => { - ensure!(args.len() == 1, "reverse takes exactly one argument"); - Ok(reverse(args[0].clone())) - } - Capitalize => { - ensure!(args.len() == 1, "capitalize takes exactly one argument"); - Ok(capitalize(args[0].clone())) - } - Left => { - ensure!(args.len() == 2, "left takes exactly two arguments"); - Ok(left(args[0].clone(), args[1].clone())) - } - Right => { - ensure!(args.len() == 2, "right takes exactly two arguments"); - Ok(right(args[0].clone(), args[1].clone())) - } - Find => { - ensure!(args.len() == 2, "find takes exactly two arguments"); - Ok(find(args[0].clone(), args[1].clone())) - } - Rpad => { - ensure!(args.len() == 3, "rpad takes exactly three arguments"); - Ok(rpad(args[0].clone(), args[1].clone(), args[2].clone())) - } - Lpad => { - ensure!(args.len() == 3, "lpad takes exactly three arguments"); - Ok(lpad(args[0].clone(), args[1].clone(), args[2].clone())) - } - Repeat => { - ensure!(args.len() == 2, "repeat takes exactly two arguments"); - Ok(repeat(args[0].clone(), args[1].clone())) - } - Substr => { - unreachable!("substr should be handled by the parser") - } - ToDate(_) => { - ensure!(args.len() == 2, "to_date takes exactly two arguments"); - let fmt = match args[1].as_ref().as_literal() { - Some(LiteralValue::Utf8(s)) => s, - _ => invalid_operation_err!("to_date format must be a string"), - }; - Ok(to_date(args[0].clone(), fmt)) - } - ToDatetime(..) => { - ensure!( - args.len() >= 2, - "to_datetime takes either two or three arguments" - ); - let fmt = match args[1].as_ref().as_literal() { - Some(LiteralValue::Utf8(s)) => s, - _ => invalid_operation_err!("to_datetime format must be a string"), - }; - let tz = match args.get(2).and_then(|e| e.as_ref().as_literal()) { - Some(LiteralValue::Utf8(s)) => Some(s.as_str()), - _ => invalid_operation_err!("to_datetime timezone must be a string"), - }; - - Ok(to_datetime(args[0].clone(), fmt, tz)) + } + + fn docstrings(&self, _alias: &str) -> String { + "Extracts all substrings that match the specified regular expression pattern".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "pattern"] + } +} + +pub struct SQLUtf8ToDate; + +impl SQLFunction for SQLUtf8ToDate { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, fmt] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_date format must be a string") + })?; + Ok(daft_functions::utf8::to_date(input, fmt)) + } + _ => invalid_operation_err!("to_date takes exactly two arguments"), } - Normalize(_) => { - unsupported_sql_err!("normalize") + } + + fn docstrings(&self, _alias: &str) -> String { + "Parses the string as a date using the specified format.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "format"] + } +} + +pub struct SQLUtf8ToDatetime; + +impl SQLFunction for SQLUtf8ToDatetime { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, fmt] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_datetime format must be a string") + })?; + Ok(daft_functions::utf8::to_datetime(input, fmt, None)) + } + [input, fmt, tz] => { + let input = planner.plan_function_arg(input)?; + let fmt = planner.plan_function_arg(fmt)?; + let fmt = fmt + .as_literal() + .and_then(|lit| lit.as_str()) + .ok_or_else(|| { + PlannerError::invalid_operation("to_datetime format must be a string") + })?; + let tz = planner.plan_function_arg(tz)?; + let tz = tz.as_literal().and_then(|lit| lit.as_str()); + Ok(daft_functions::utf8::to_datetime(input, fmt, tz)) + } + _ => invalid_operation_err!("to_datetime takes either two or three arguments"), } } + + fn docstrings(&self, _alias: &str) -> String { + "Parses the string as a datetime using the specified format.".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &["string_input", "format"] + } } pub struct SQLCountMatches; @@ -404,7 +596,10 @@ impl SQLFunction for SQLNormalize { match inputs { [input] => { let input = planner.plan_function_arg(input)?; - Ok(normalize(input, Utf8NormalizeOptions::default())) + Ok(daft_functions::utf8::normalize( + input, + Utf8NormalizeOptions::default(), + )) } [input, args @ ..] => { let input = planner.plan_function_arg(input)?; @@ -413,7 +608,7 @@ impl SQLFunction for SQLNormalize { &["remove_punct", "lowercase", "nfd_unicode", "white_space"], 0, )?; - Ok(normalize(input, args)) + Ok(daft_functions::utf8::normalize(input, args)) } _ => invalid_operation_err!("Invalid arguments for normalize"), } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index d241fc7fbb..419d961151 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -8,12 +8,15 @@ use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ col, - functions::utf8::{ilike, like, to_date, to_datetime}, has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, Subquery, }; -use daft_functions::numeric::{ceil::ceil, floor::floor}; +use daft_functions::{ + numeric::{ceil::ceil, floor::floor}, + utf8::{ilike, like, to_date, to_datetime}, +}; use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; + use sqlparser::{ ast::{ ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, @@ -1261,7 +1264,7 @@ impl SQLPlanner { let start = self.plan_expr(substring_from)?; let length = self.plan_expr(substring_for)?; - Ok(daft_dsl::functions::utf8::substr(expr, start, length)) + Ok(daft_functions::utf8::substr(expr, start, length)) } SQLExpr::Substring { special: false, .. } => { unsupported_sql_err!("`SUBSTRING(expr [FROM start] [FOR len])` syntax") From 813922d948d3a477492787686bafa0bb70cb4027 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 22 Oct 2024 22:10:07 +0800 Subject: [PATCH 02/10] update --- src/daft-sql/src/modules/utf8.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index d2a7a0ef01..645c611eb0 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -36,7 +36,7 @@ macro_rules! utf8_function_one_argument { ". Expected ", $sql_name, "(", - stringify!($arg_name), + $arg_name, ")" )), } @@ -75,9 +75,9 @@ macro_rules! utf8_function_two_arguments { ". Expected ", $sql_name, "(", - stringify!($arg_name_1), + $arg_name_1, ", ", - stringify!($arg_name_2), + $arg_name_2, ")" )), } @@ -117,11 +117,11 @@ macro_rules! utf8_function_three_arguments { ". Expected ", $sql_name, "(", - stringify!($arg_name_1), + $arg_name_1, ", ", - stringify!($arg_name_2), + $arg_name_2, ", ", - stringify!($arg_name_3), + $arg_name_3, ")" )), } From ba5736824608fe3fd8e4d91038053276de96b271 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 22 Oct 2024 22:37:03 +0800 Subject: [PATCH 03/10] fixes --- daft/daft/__init__.pyi | 58 ++++++++++++++++--------------- daft/expressions/expressions.py | 60 ++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 56 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 7c650ac32d..5fe1307747 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1295,34 +1295,36 @@ def list_chunk(expr: PyExpr, size: int) -> PyExpr: ... # --- # expr.utf8 namespace # --- -def utf8_endswith(pattern: PyExpr) -> PyExpr: ... -def utf8_startswith(pattern: PyExpr) -> PyExpr: ... -def utf8_contains(pattern: PyExpr) -> PyExpr: ... -def utf8_match(pattern: PyExpr) -> PyExpr: ... -def utf8_split(pattern: PyExpr, regex: bool) -> PyExpr: ... -def utf8_extract(pattern: PyExpr, index: int) -> PyExpr: ... -def utf8_extract_all(pattern: PyExpr, index: int) -> PyExpr: ... -def utf8_replace(pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... -def utf8_length() -> PyExpr: ... -def utf8_length_bytes() -> PyExpr: ... -def utf8_lower() -> PyExpr: ... -def utf8_upper() -> PyExpr: ... -def utf8_lstrip() -> PyExpr: ... -def utf8_rstrip() -> PyExpr: ... -def utf8_reverse() -> PyExpr: ... -def utf8_capitalize() -> PyExpr: ... -def utf8_left(nchars: PyExpr) -> PyExpr: ... -def utf8_right(nchars: PyExpr) -> PyExpr: ... -def utf8_find(substr: PyExpr) -> PyExpr: ... -def utf8_rpad(length: PyExpr, pad: PyExpr) -> PyExpr: ... -def utf8_lpad(length: PyExpr, pad: PyExpr) -> PyExpr: ... -def utf8_repeat(n: PyExpr) -> PyExpr: ... -def utf8_like(pattern: PyExpr) -> PyExpr: ... -def utf8_ilike(pattern: PyExpr) -> PyExpr: ... -def utf8_substr(start: PyExpr, length: PyExpr) -> PyExpr: ... -def utf8_to_date(format: str) -> PyExpr: ... -def utf8_to_datetime(format: str, timezone: str | None = None) -> PyExpr: ... -def utf8_normalize(remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool) -> PyExpr: ... +def utf8_endswith(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_startswith(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_contains(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_match(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_split(expr: PyExpr, pattern: PyExpr, regex: bool) -> PyExpr: ... +def utf8_extract(expr: PyExpr, pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_extract_all(expr: PyExpr, pattern: PyExpr, index: int) -> PyExpr: ... +def utf8_replace(expr: PyExpr, pattern: PyExpr, replacement: PyExpr, regex: bool) -> PyExpr: ... +def utf8_length(expr: PyExpr) -> PyExpr: ... +def utf8_length_bytes(expr: PyExpr) -> PyExpr: ... +def utf8_lower(expr: PyExpr) -> PyExpr: ... +def utf8_upper(expr: PyExpr) -> PyExpr: ... +def utf8_lstrip(expr: PyExpr) -> PyExpr: ... +def utf8_rstrip(expr: PyExpr) -> PyExpr: ... +def utf8_reverse(expr: PyExpr) -> PyExpr: ... +def utf8_capitalize(expr: PyExpr) -> PyExpr: ... +def utf8_left(expr: PyExpr, nchars: PyExpr) -> PyExpr: ... +def utf8_right(expr: PyExpr, nchars: PyExpr) -> PyExpr: ... +def utf8_find(expr: PyExpr, substr: PyExpr) -> PyExpr: ... +def utf8_rpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_lpad(expr: PyExpr, length: PyExpr, pad: PyExpr) -> PyExpr: ... +def utf8_repeat(expr: PyExpr, n: PyExpr) -> PyExpr: ... +def utf8_like(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_ilike(expr: PyExpr, pattern: PyExpr) -> PyExpr: ... +def utf8_substr(expr: PyExpr, start: PyExpr, length: PyExpr) -> PyExpr: ... +def utf8_to_date(expr: PyExpr, format: str) -> PyExpr: ... +def utf8_to_datetime(expr: PyExpr, format: str, timezone: str | None = None) -> PyExpr: ... +def utf8_normalize( + expr: PyExpr, remove_punct: bool, lowercase: bool, nfd_unicode: bool, white_space: bool +) -> PyExpr: ... class PyCatalog: @staticmethod diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index df75ec5002..5a4cb1c7c6 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1887,7 +1887,7 @@ def contains(self, substr: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value contains the provided pattern """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(native.utf8_contains(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_contains(self._expr, substr_expr._expr)) def match(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given regular expression pattern in a string column @@ -1917,7 +1917,7 @@ def match(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_match(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_match(self._expr, pattern_expr._expr)) def endswith(self, suffix: str | Expression) -> Expression: """Checks whether each string ends with the given pattern in a string column @@ -1947,7 +1947,7 @@ def endswith(self, suffix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value ends with the provided pattern """ suffix_expr = Expression._to_expression(suffix) - return Expression._from_pyexpr(native.utf8_endswith(suffix_expr._expr)) + return Expression._from_pyexpr(native.utf8_endswith(self._expr, suffix_expr._expr)) def startswith(self, prefix: str | Expression) -> Expression: """Checks whether each string starts with the given pattern in a string column @@ -1977,7 +1977,7 @@ def startswith(self, prefix: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value starts with the provided pattern """ prefix_expr = Expression._to_expression(prefix) - return Expression._from_pyexpr(native.utf8_startswith(prefix_expr._expr)) + return Expression._from_pyexpr(native.utf8_startswith(self._expr, prefix_expr._expr)) def split(self, pattern: str | Expression, regex: bool = False) -> Expression: r"""Splits each string on the given literal or regex pattern, into a list of strings. @@ -2028,7 +2028,7 @@ def split(self, pattern: str | Expression, regex: bool = False) -> Expression: Expression: A List[Utf8] expression containing the string splits for each string in the column. """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_split(pattern_expr._expr, regex)) + return Expression._from_pyexpr(native.utf8_split(self._expr, pattern_expr._expr, regex)) def concat(self, other: str | Expression) -> Expression: """Concatenates two string expressions together @@ -2119,7 +2119,7 @@ def extract(self, pattern: str | Expression, index: int = 0) -> Expression: `extract_all` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_extract(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract(self._expr, pattern_expr._expr, index)) def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: r"""Extracts the specified match group from all regex matches in each string in a string column. @@ -2175,7 +2175,7 @@ def extract_all(self, pattern: str | Expression, index: int = 0) -> Expression: `extract` """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_extract_all(pattern_expr._expr, index)) + return Expression._from_pyexpr(native.utf8_extract_all(self._expr, pattern_expr._expr, index)) def replace( self, @@ -2232,7 +2232,9 @@ def replace( """ pattern_expr = Expression._to_expression(pattern) replacement_expr = Expression._to_expression(replacement) - return Expression._from_pyexpr(native.utf8_replace(pattern_expr._expr, replacement_expr._expr, regex)) + return Expression._from_pyexpr( + native.utf8_replace(self._expr, pattern_expr._expr, replacement_expr._expr, regex) + ) def length(self) -> Expression: """Retrieves the length for a UTF-8 string column @@ -2259,7 +2261,7 @@ def length(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(native.utf8_length()) + return Expression._from_pyexpr(native.utf8_length(self._expr)) def length_bytes(self) -> Expression: """Retrieves the length for a UTF-8 string column in bytes. @@ -2286,7 +2288,7 @@ def length_bytes(self) -> Expression: Returns: Expression: an UInt64 expression with the length of each string """ - return Expression._from_pyexpr(native.utf8_length_bytes()) + return Expression._from_pyexpr(native.utf8_length_bytes(self._expr)) def lower(self) -> Expression: """Convert UTF-8 string to all lowercase @@ -2313,7 +2315,7 @@ def lower(self) -> Expression: Returns: Expression: a String expression which is `self` lowercased """ - return Expression._from_pyexpr(native.utf8_lower()) + return Expression._from_pyexpr(native.utf8_lower(self._expr)) def upper(self) -> Expression: """Convert UTF-8 string to all upper @@ -2340,7 +2342,7 @@ def upper(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased """ - return Expression._from_pyexpr(native.utf8_upper()) + return Expression._from_pyexpr(native.utf8_upper(self._expr)) def lstrip(self) -> Expression: """Strip whitespace from the left side of a UTF-8 string @@ -2367,7 +2369,7 @@ def lstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with leading whitespace stripped """ - return Expression._from_pyexpr(native.utf8_lstrip()) + return Expression._from_pyexpr(native.utf8_lstrip(self._expr)) def rstrip(self) -> Expression: """Strip whitespace from the right side of a UTF-8 string @@ -2394,7 +2396,7 @@ def rstrip(self) -> Expression: Returns: Expression: a String expression which is `self` with trailing whitespace stripped """ - return Expression._from_pyexpr(native.utf8_rstrip()) + return Expression._from_pyexpr(native.utf8_rstrip(self._expr)) def reverse(self) -> Expression: """Reverse a UTF-8 string @@ -2421,7 +2423,7 @@ def reverse(self) -> Expression: Returns: Expression: a String expression which is `self` reversed """ - return Expression._from_pyexpr(native.utf8_reverse()) + return Expression._from_pyexpr(native.utf8_reverse(self._expr)) def capitalize(self) -> Expression: """Capitalize a UTF-8 string @@ -2448,7 +2450,7 @@ def capitalize(self) -> Expression: Returns: Expression: a String expression which is `self` uppercased with the first character and lowercased the rest """ - return Expression._from_pyexpr(native.utf8_capitalize()) + return Expression._from_pyexpr(native.utf8_capitalize(self._expr)) def left(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) left-most characters of each string @@ -2476,7 +2478,7 @@ def left(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` left-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(native.utf8_left(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_left(self._expr, nchars_expr._expr)) def right(self, nchars: int | Expression) -> Expression: """Gets the n (from nchars) right-most characters of each string @@ -2504,7 +2506,7 @@ def right(self, nchars: int | Expression) -> Expression: Expression: a String expression which is the `n` right-most characters of `self` """ nchars_expr = Expression._to_expression(nchars) - return Expression._from_pyexpr(native.utf8_right(nchars_expr._expr)) + return Expression._from_pyexpr(native.utf8_right(self._expr, nchars_expr._expr)) def find(self, substr: str | Expression) -> Expression: """Returns the index of the first occurrence of the substring in each string @@ -2536,7 +2538,7 @@ def find(self, substr: str | Expression) -> Expression: Expression: an Int64 expression with the index of the first occurrence of the substring in each string """ substr_expr = Expression._to_expression(substr) - return Expression._from_pyexpr(native.utf8_find(substr_expr._expr)) + return Expression._from_pyexpr(native.utf8_find(self._expr, substr_expr._expr)) def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Right-pads each string by truncating or padding with the character @@ -2569,7 +2571,7 @@ def rpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(native.utf8_rpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_rpad(self._expr, length_expr._expr, pad_expr._expr)) def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """Left-pads each string by truncating on the right or padding with the character @@ -2602,7 +2604,7 @@ def lpad(self, length: int | Expression, pad: str | Expression) -> Expression: """ length_expr = Expression._to_expression(length) pad_expr = Expression._to_expression(pad) - return Expression._from_pyexpr(native.utf8_lpad(length_expr._expr, pad_expr._expr)) + return Expression._from_pyexpr(native.utf8_lpad(self._expr, length_expr._expr, pad_expr._expr)) def repeat(self, n: int | Expression) -> Expression: """Repeats each string n times @@ -2630,7 +2632,7 @@ def repeat(self, n: int | Expression) -> Expression: Expression: a String expression which is `self` repeated `n` times """ n_expr = Expression._to_expression(n) - return Expression._from_pyexpr(native.utf8_repeat(n_expr._expr)) + return Expression._from_pyexpr(native.utf8_repeat(self._expr, n_expr._expr)) def like(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case sensitive @@ -2661,7 +2663,7 @@ def like(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_like(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_like(self._expr, pattern_expr._expr)) def ilike(self, pattern: str | Expression) -> Expression: """Checks whether each string matches the given SQL LIKE pattern, case insensitive @@ -2692,7 +2694,7 @@ def ilike(self, pattern: str | Expression) -> Expression: Expression: a Boolean expression indicating whether each value matches the provided pattern """ pattern_expr = Expression._to_expression(pattern) - return Expression._from_pyexpr(native.utf8_ilike(pattern_expr._expr)) + return Expression._from_pyexpr(native.utf8_ilike(self._expr, pattern_expr._expr)) def substr(self, start: int | Expression, length: int | Expression | None = None) -> Expression: """Extract a substring from a string, starting at a specified index and extending for a given length. @@ -2724,7 +2726,7 @@ def substr(self, start: int | Expression, length: int | Expression | None = None """ start_expr = Expression._to_expression(start) length_expr = Expression._to_expression(length) - return Expression._from_pyexpr(native.utf8_substr(start_expr._expr, length_expr._expr)) + return Expression._from_pyexpr(native.utf8_substr(self._expr, start_expr._expr, length_expr._expr)) def to_date(self, format: str) -> Expression: """Converts a string to a date using the specified format @@ -2755,7 +2757,7 @@ def to_date(self, format: str) -> Expression: Returns: Expression: a Date expression which is parsed by given format """ - return Expression._from_pyexpr(native.utf8_to_date(format)) + return Expression._from_pyexpr(native.utf8_to_date(self._expr, format)) def to_datetime(self, format: str, timezone: str | None = None) -> Expression: """Converts a string to a datetime using the specified format and timezone @@ -2805,7 +2807,7 @@ def to_datetime(self, format: str, timezone: str | None = None) -> Expression: Returns: Expression: a DateTime expression which is parsed by given format and timezone """ - return Expression._from_pyexpr(native.utf8_to_datetime(format, timezone)) + return Expression._from_pyexpr(native.utf8_to_datetime(self._expr, format, timezone)) def normalize( self, @@ -2849,7 +2851,9 @@ def normalize( Returns: Expression: a String expression which is normalized. """ - return Expression._from_pyexpr(native.utf8_normalize(remove_punct, lowercase, nfd_unicode, white_space)) + return Expression._from_pyexpr( + native.utf8_normalize(self._expr, remove_punct, lowercase, nfd_unicode, white_space) + ) def tokenize_encode( self, From 8bf1a5922d97350bc1dda3cef5760e4c64a68111 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 22 Oct 2024 22:41:29 +0800 Subject: [PATCH 04/10] fixes --- src/daft-functions/src/utf8/endswith.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-functions/src/utf8/endswith.rs b/src/daft-functions/src/utf8/endswith.rs index 6a78b90542..bd6e08c836 100644 --- a/src/daft-functions/src/utf8/endswith.rs +++ b/src/daft-functions/src/utf8/endswith.rs @@ -45,7 +45,7 @@ impl ScalarUDF for Utf8Endswith { fn evaluate(&self, inputs: &[Series]) -> DaftResult { match inputs { - [data, pattern] => data.utf8_startswith(pattern), + [data, pattern] => data.utf8_endswith(pattern), _ => Err(DaftError::ValueError(format!( "Expected 2 input args, got {}", inputs.len() From 342e01230568fbc164c760d6871209fde692ee8e Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 22 Oct 2024 22:43:19 +0800 Subject: [PATCH 05/10] fixes --- src/daft-functions/src/utf8/endswith.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/daft-functions/src/utf8/endswith.rs b/src/daft-functions/src/utf8/endswith.rs index bd6e08c836..1c57b48019 100644 --- a/src/daft-functions/src/utf8/endswith.rs +++ b/src/daft-functions/src/utf8/endswith.rs @@ -30,7 +30,7 @@ impl ScalarUDF for Utf8Endswith { Ok(Field::new(data_field.name, DataType::Boolean)) } _ => Err(DaftError::TypeError(format!( - "Expects inputs to startswith to be utf8, but received {data_field} and {pattern_field}", + "Expects inputs to endswith to be utf8, but received {data_field} and {pattern_field}", ))), } } From b6a5381ef3e3d15fb5cd78974c575e8245b0278d Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Wed, 23 Oct 2024 11:09:51 +0800 Subject: [PATCH 06/10] address comments & fixes --- src/daft-functions/src/utf8/capitalize.rs | 2 +- src/daft-functions/src/utf8/contains.rs | 2 +- src/daft-functions/src/utf8/endswith.rs | 2 +- src/daft-functions/src/utf8/extract.rs | 2 +- src/daft-functions/src/utf8/extract_all.rs | 2 +- src/daft-functions/src/utf8/find.rs | 2 +- src/daft-functions/src/utf8/ilike.rs | 2 +- src/daft-functions/src/utf8/left.rs | 2 +- src/daft-functions/src/utf8/length.rs | 2 +- src/daft-functions/src/utf8/length_bytes.rs | 2 +- src/daft-functions/src/utf8/like.rs | 2 +- src/daft-functions/src/utf8/lower.rs | 2 +- src/daft-functions/src/utf8/lpad.rs | 2 +- src/daft-functions/src/utf8/lstrip.rs | 2 +- src/daft-functions/src/utf8/match_.rs | 2 +- src/daft-functions/src/utf8/normalize.rs | 2 +- src/daft-functions/src/utf8/repeat.rs | 2 +- src/daft-functions/src/utf8/replace.rs | 2 +- src/daft-functions/src/utf8/reverse.rs | 2 +- src/daft-functions/src/utf8/right.rs | 4 +- src/daft-functions/src/utf8/rpad.rs | 2 +- src/daft-functions/src/utf8/rstrip.rs | 2 +- src/daft-functions/src/utf8/split.rs | 2 +- src/daft-functions/src/utf8/startswith.rs | 2 +- src/daft-functions/src/utf8/substr.rs | 2 +- src/daft-functions/src/utf8/to_date.rs | 2 +- src/daft-functions/src/utf8/to_datetime.rs | 2 +- src/daft-functions/src/utf8/upper.rs | 2 +- src/daft-sql/src/modules/utf8.rs | 180 ++++++++++---------- 29 files changed, 123 insertions(+), 115 deletions(-) diff --git a/src/daft-functions/src/utf8/capitalize.rs b/src/daft-functions/src/utf8/capitalize.rs index f666e1f3cb..abf770924c 100644 --- a/src/daft-functions/src/utf8/capitalize.rs +++ b/src/daft-functions/src/utf8/capitalize.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Capitalize { self } fn name(&self) -> &'static str { - "utf8_capitalize" + "capitalize" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/contains.rs b/src/daft-functions/src/utf8/contains.rs index 9427a45858..2ea8708711 100644 --- a/src/daft-functions/src/utf8/contains.rs +++ b/src/daft-functions/src/utf8/contains.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Contains { self } fn name(&self) -> &'static str { - "utf8_contains" + "contains" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/endswith.rs b/src/daft-functions/src/utf8/endswith.rs index 1c57b48019..8f11cb8db8 100644 --- a/src/daft-functions/src/utf8/endswith.rs +++ b/src/daft-functions/src/utf8/endswith.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Endswith { self } fn name(&self) -> &'static str { - "utf8_endswith" + "endswith" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/extract.rs b/src/daft-functions/src/utf8/extract.rs index 679bb2916b..f5a97b1c3e 100644 --- a/src/daft-functions/src/utf8/extract.rs +++ b/src/daft-functions/src/utf8/extract.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8Extract { self } fn name(&self) -> &'static str { - "utf8_extract" + "extract" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/extract_all.rs b/src/daft-functions/src/utf8/extract_all.rs index a3c318ad0c..b40c6ad47c 100644 --- a/src/daft-functions/src/utf8/extract_all.rs +++ b/src/daft-functions/src/utf8/extract_all.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8ExtractAll { self } fn name(&self) -> &'static str { - "utf8_extract_all" + "extractall" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/find.rs b/src/daft-functions/src/utf8/find.rs index 547ce874d0..3ec11bec97 100644 --- a/src/daft-functions/src/utf8/find.rs +++ b/src/daft-functions/src/utf8/find.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Find { self } fn name(&self) -> &'static str { - "utf8_find" + "find" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/ilike.rs b/src/daft-functions/src/utf8/ilike.rs index 6b94ff2010..54c0d4ca8e 100644 --- a/src/daft-functions/src/utf8/ilike.rs +++ b/src/daft-functions/src/utf8/ilike.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Ilike { self } fn name(&self) -> &'static str { - "utf8_ilike" + "ilike" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/left.rs b/src/daft-functions/src/utf8/left.rs index 1b8548fd86..c055ad7ecb 100644 --- a/src/daft-functions/src/utf8/left.rs +++ b/src/daft-functions/src/utf8/left.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Left { self } fn name(&self) -> &'static str { - "utf8_left" + "left" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/length.rs b/src/daft-functions/src/utf8/length.rs index 4010915e3f..8d58d8ae27 100644 --- a/src/daft-functions/src/utf8/length.rs +++ b/src/daft-functions/src/utf8/length.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Length { self } fn name(&self) -> &'static str { - "utf8_length" + "length" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/length_bytes.rs b/src/daft-functions/src/utf8/length_bytes.rs index 71f0d03a01..dbcb841701 100644 --- a/src/daft-functions/src/utf8/length_bytes.rs +++ b/src/daft-functions/src/utf8/length_bytes.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8LengthBytes { self } fn name(&self) -> &'static str { - "utf8_length_bytes" + "length_bytes" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/like.rs b/src/daft-functions/src/utf8/like.rs index cdd47275f6..915a805a9b 100644 --- a/src/daft-functions/src/utf8/like.rs +++ b/src/daft-functions/src/utf8/like.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Like { self } fn name(&self) -> &'static str { - "utf8_like" + "like" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/lower.rs b/src/daft-functions/src/utf8/lower.rs index b98594123b..7935168d9b 100644 --- a/src/daft-functions/src/utf8/lower.rs +++ b/src/daft-functions/src/utf8/lower.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Lower { self } fn name(&self) -> &'static str { - "utf8_lower" + "lower" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/lpad.rs b/src/daft-functions/src/utf8/lpad.rs index 090aa260fc..89808d645a 100644 --- a/src/daft-functions/src/utf8/lpad.rs +++ b/src/daft-functions/src/utf8/lpad.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Lpad { self } fn name(&self) -> &'static str { - "utf8_lpad" + "lpad" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/lstrip.rs b/src/daft-functions/src/utf8/lstrip.rs index f3d25a9179..f7441a8ac2 100644 --- a/src/daft-functions/src/utf8/lstrip.rs +++ b/src/daft-functions/src/utf8/lstrip.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Lstrip { self } fn name(&self) -> &'static str { - "utf8_lstrip" + "lstrip" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/match_.rs b/src/daft-functions/src/utf8/match_.rs index 4791861806..0a9cbc8a8c 100644 --- a/src/daft-functions/src/utf8/match_.rs +++ b/src/daft-functions/src/utf8/match_.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Match { self } fn name(&self) -> &'static str { - "utf8_match" + "match" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/normalize.rs b/src/daft-functions/src/utf8/normalize.rs index 744ba78035..b9455d23d3 100644 --- a/src/daft-functions/src/utf8/normalize.rs +++ b/src/daft-functions/src/utf8/normalize.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8Normalize { self } fn name(&self) -> &'static str { - "utf8_normalize" + "normalize" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/repeat.rs b/src/daft-functions/src/utf8/repeat.rs index a864fd701b..dc74a4bfed 100644 --- a/src/daft-functions/src/utf8/repeat.rs +++ b/src/daft-functions/src/utf8/repeat.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Repeat { self } fn name(&self) -> &'static str { - "utf8_repeat" + "repeat" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/replace.rs b/src/daft-functions/src/utf8/replace.rs index cfbb612611..76134c8136 100644 --- a/src/daft-functions/src/utf8/replace.rs +++ b/src/daft-functions/src/utf8/replace.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8Replace { self } fn name(&self) -> &'static str { - "utf8_replace" + "replace" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/reverse.rs b/src/daft-functions/src/utf8/reverse.rs index 07227f578e..60674fc168 100644 --- a/src/daft-functions/src/utf8/reverse.rs +++ b/src/daft-functions/src/utf8/reverse.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Reverse { self } fn name(&self) -> &'static str { - "utf8_reverse" + "reverse" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/right.rs b/src/daft-functions/src/utf8/right.rs index b32b4080c2..fbac7742b4 100644 --- a/src/daft-functions/src/utf8/right.rs +++ b/src/daft-functions/src/utf8/right.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Right { self } fn name(&self) -> &'static str { - "utf8_right" + "right" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { @@ -30,7 +30,7 @@ impl ScalarUDF for Utf8Right { Ok(Field::new(data_field.name, DataType::Utf8)) } _ => Err(DaftError::TypeError(format!( - "Expects inputs to left to be utf8 and integer, but received {data_field} and {nchars_field}", + "Expects inputs to right to be utf8 and integer, but received {data_field} and {nchars_field}", ))), } } diff --git a/src/daft-functions/src/utf8/rpad.rs b/src/daft-functions/src/utf8/rpad.rs index 01a9766203..2a0864a578 100644 --- a/src/daft-functions/src/utf8/rpad.rs +++ b/src/daft-functions/src/utf8/rpad.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Rpad { self } fn name(&self) -> &'static str { - "utf8_rpad" + "rpad" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/rstrip.rs b/src/daft-functions/src/utf8/rstrip.rs index df8044ec76..b2528a99ac 100644 --- a/src/daft-functions/src/utf8/rstrip.rs +++ b/src/daft-functions/src/utf8/rstrip.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Rstrip { self } fn name(&self) -> &'static str { - "utf8_rstrip" + "rstrip" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/split.rs b/src/daft-functions/src/utf8/split.rs index f8a1725b66..60e8110ce1 100644 --- a/src/daft-functions/src/utf8/split.rs +++ b/src/daft-functions/src/utf8/split.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8Split { self } fn name(&self) -> &'static str { - "utf8_split" + "split" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/startswith.rs b/src/daft-functions/src/utf8/startswith.rs index 98a3ce4556..3a0bb50d2b 100644 --- a/src/daft-functions/src/utf8/startswith.rs +++ b/src/daft-functions/src/utf8/startswith.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Startswith { self } fn name(&self) -> &'static str { - "utf8_startswith" + "startswith" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/substr.rs b/src/daft-functions/src/utf8/substr.rs index a8eac19215..90628e441c 100644 --- a/src/daft-functions/src/utf8/substr.rs +++ b/src/daft-functions/src/utf8/substr.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Substr { self } fn name(&self) -> &'static str { - "utf8_substr" + "substr" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/to_date.rs b/src/daft-functions/src/utf8/to_date.rs index 5b2f176783..911cca84c9 100644 --- a/src/daft-functions/src/utf8/to_date.rs +++ b/src/daft-functions/src/utf8/to_date.rs @@ -20,7 +20,7 @@ impl ScalarUDF for Utf8ToDate { self } fn name(&self) -> &'static str { - "utf8_to_date" + "to_date" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/to_datetime.rs b/src/daft-functions/src/utf8/to_datetime.rs index d642799ed7..862859699f 100644 --- a/src/daft-functions/src/utf8/to_datetime.rs +++ b/src/daft-functions/src/utf8/to_datetime.rs @@ -22,7 +22,7 @@ impl ScalarUDF for Utf8ToDatetime { self } fn name(&self) -> &'static str { - "utf8_to_datetime" + "to_datetime" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-functions/src/utf8/upper.rs b/src/daft-functions/src/utf8/upper.rs index cbaf2fff8b..b6500e825f 100644 --- a/src/daft-functions/src/utf8/upper.rs +++ b/src/daft-functions/src/utf8/upper.rs @@ -18,7 +18,7 @@ impl ScalarUDF for Utf8Upper { self } fn name(&self) -> &'static str { - "utf8_upper" + "upper" } fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult { diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 645c611eb0..464627ffa8 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -15,31 +15,76 @@ use crate::{ invalid_operation_err, }; -macro_rules! utf8_function_one_argument { +fn utf8_unary( + func: impl Fn(ExprRef) -> ExprRef, + sql_name: &str, + arg_name: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input] => { + let input = planner.plan_function_arg(input)?; + Ok(func(input)) + } + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name})", + ), + } +} + +fn utf8_binary( + func: impl Fn(ExprRef, ExprRef) -> ExprRef, + sql_name: &str, + arg_name_1: &str, + arg_name_2: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input1, input2] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + Ok(func(input1, input2)) + } + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name_1}, {arg_name_2})", + ), + } +} + +fn utf8_ternary( + func: impl Fn(ExprRef, ExprRef, ExprRef) -> ExprRef, + sql_name: &str, + arg_name_1: &str, + arg_name_2: &str, + arg_name_3: &str, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, +) -> SQLPlannerResult { + match inputs { + [input1, input2, input3] => { + let input1 = planner.plan_function_arg(input1)?; + let input2 = planner.plan_function_arg(input2)?; + let input3 = planner.plan_function_arg(input3)?; + Ok(func(input1, input2, input3)) + }, + _ => invalid_operation_err!( + "invalid arguments for {sql_name}. Expected {sql_name}({arg_name_1}, {arg_name_2}, {arg_name_3})", + ), + } +} + +macro_rules! utf8_function { ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name:expr) => { pub struct $name; - impl SQLFunction for $name { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], planner: &crate::planner::SQLPlanner, ) -> SQLPlannerResult { - match inputs { - [input] => { - let input = planner.plan_function_arg(input)?; - Ok($func(input)) - } - _ => invalid_operation_err!(concat!( - "invalid arguments for ", - $sql_name, - ". Expected ", - $sql_name, - "(", - $arg_name, - ")" - )), - } + utf8_unary($func, $sql_name, $arg_name, inputs, planner) } fn docstrings(&self, _alias: &str) -> String { @@ -51,36 +96,15 @@ macro_rules! utf8_function_one_argument { } } }; -} - -macro_rules! utf8_function_two_arguments { ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr) => { pub struct $name; - impl SQLFunction for $name { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], planner: &crate::planner::SQLPlanner, ) -> SQLPlannerResult { - match inputs { - [input1, input2] => { - let input1 = planner.plan_function_arg(input1)?; - let input2 = planner.plan_function_arg(input2)?; - Ok($func(input1, input2)) - } - _ => invalid_operation_err!(concat!( - "invalid arguments for ", - $sql_name, - ". Expected ", - $sql_name, - "(", - $arg_name_1, - ", ", - $arg_name_2, - ")" - )), - } + utf8_binary($func, $sql_name, $arg_name_1, $arg_name_2, inputs, planner) } fn docstrings(&self, _alias: &str) -> String { @@ -92,39 +116,23 @@ macro_rules! utf8_function_two_arguments { } } }; -} - -macro_rules! utf8_function_three_arguments { ($name:ident, $sql_name:expr, $func:expr, $doc:expr, $arg_name_1:expr, $arg_name_2:expr, $arg_name_3:expr) => { pub struct $name; - impl SQLFunction for $name { fn to_expr( &self, inputs: &[sqlparser::ast::FunctionArg], planner: &crate::planner::SQLPlanner, ) -> SQLPlannerResult { - match inputs { - [input1, input2, input3] => { - let input1 = planner.plan_function_arg(input1)?; - let input2 = planner.plan_function_arg(input2)?; - let input3 = planner.plan_function_arg(input3)?; - Ok($func(input1, input2, input3)) - } - _ => invalid_operation_err!(concat!( - "invalid arguments for ", - $sql_name, - ". Expected ", - $sql_name, - "(", - $arg_name_1, - ", ", - $arg_name_2, - ", ", - $arg_name_3, - ")" - )), - } + utf8_ternary( + $func, + $sql_name, + $arg_name_1, + $arg_name_2, + $arg_name_3, + inputs, + planner, + ) } fn docstrings(&self, _alias: &str) -> String { @@ -180,7 +188,7 @@ impl SQLModule for SQLModuleUtf8 { } } -utf8_function_two_arguments!( +utf8_function!( SQLUtf8EndsWith, "ends_with", daft_functions::utf8::endswith, @@ -189,7 +197,7 @@ utf8_function_two_arguments!( "substring" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8StartsWith, "starts_with", daft_functions::utf8::startswith, @@ -198,7 +206,7 @@ utf8_function_two_arguments!( "substring" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Contains, "contains", daft_functions::utf8::contains, @@ -207,7 +215,7 @@ utf8_function_two_arguments!( "substring" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Split, "split", |input, pattern| daft_functions::utf8::split(input, pattern, false), @@ -216,7 +224,7 @@ utf8_function_two_arguments!( "delimiter" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8RegexpMatch, "regexp_match", daft_functions::utf8::match_, @@ -225,17 +233,17 @@ utf8_function_two_arguments!( "pattern" ); -utf8_function_three_arguments!( +utf8_function!( SQLUtf8RegexpReplace, "regexp_replace", - |input, pattern, replacement| daft_functions::utf8::replace(input, pattern, replacement, true), + |input, pattern, replacement| daft_functions::utf8::replace(input, pattern, replacement, false), "Replaces all occurrences of a substring with a new string", "string_input", "pattern", "replacement" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8RegexpSplit, "regexp_split", |input, pattern| daft_functions::utf8::split(input, pattern, true), @@ -244,7 +252,7 @@ utf8_function_two_arguments!( "delimiter" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Length, "length", daft_functions::utf8::length, @@ -252,7 +260,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8LengthBytes, "length_bytes", daft_functions::utf8::length_bytes, @@ -260,7 +268,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Lower, "lower", daft_functions::utf8::lower, @@ -268,7 +276,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Upper, "upper", daft_functions::utf8::upper, @@ -276,7 +284,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Lstrip, "lstrip", daft_functions::utf8::lstrip, @@ -284,7 +292,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Rstrip, "rstrip", daft_functions::utf8::rstrip, @@ -292,7 +300,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Reverse, "reverse", daft_functions::utf8::reverse, @@ -300,7 +308,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_one_argument!( +utf8_function!( SQLUtf8Capitalize, "capitalize", daft_functions::utf8::capitalize, @@ -308,7 +316,7 @@ utf8_function_one_argument!( "string_input" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Left, "left", daft_functions::utf8::left, @@ -317,7 +325,7 @@ utf8_function_two_arguments!( "length" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Right, "right", daft_functions::utf8::right, @@ -326,7 +334,7 @@ utf8_function_two_arguments!( "length" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Find, "find", daft_functions::utf8::find, @@ -335,7 +343,7 @@ utf8_function_two_arguments!( "substring" ); -utf8_function_three_arguments!( +utf8_function!( SQLUtf8Rpad, "rpad", daft_functions::utf8::rpad, @@ -343,7 +351,7 @@ utf8_function_three_arguments!( "string_input", "length", "pad" ); -utf8_function_three_arguments!( +utf8_function!( SQLUtf8Lpad, "lpad", daft_functions::utf8::lpad, @@ -351,7 +359,7 @@ utf8_function_three_arguments!( "string_input", "length", "pad" ); -utf8_function_two_arguments!( +utf8_function!( SQLUtf8Repeat, "repeat", daft_functions::utf8::repeat, From cd21b04e010b602475447618a92b813af74bf299 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Tue, 12 Nov 2024 14:15:07 +0800 Subject: [PATCH 07/10] fixes codestyle --- src/daft-sql/src/modules/utf8.rs | 5 +---- src/daft-sql/src/planner.rs | 3 +-- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/daft-sql/src/modules/utf8.rs b/src/daft-sql/src/modules/utf8.rs index 464627ffa8..edf3b3133e 100644 --- a/src/daft-sql/src/modules/utf8.rs +++ b/src/daft-sql/src/modules/utf8.rs @@ -1,8 +1,5 @@ use daft_core::array::ops::Utf8NormalizeOptions; -use daft_dsl::{ - binary_op, - ExprRef, LiteralValue, Operator, -}; +use daft_dsl::{binary_op, ExprRef, LiteralValue, Operator}; use daft_functions::{ count_matches::{utf8_count_matches, CountMatchesFunction}, tokenize::{tokenize_decode, tokenize_encode, TokenizeDecodeFunction, TokenizeEncodeFunction}, diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 419d961151..cbeb751ca0 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -9,14 +9,13 @@ use daft_core::prelude::*; use daft_dsl::{ col, has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, - Subquery, + Subquery }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, utf8::{ilike, like, to_date, to_datetime}, }; use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; - use sqlparser::{ ast::{ ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, From fca75ab24e62ddf83d1bc18d74cf8998e3e8cd22 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Fri, 15 Nov 2024 10:37:57 +0800 Subject: [PATCH 08/10] format --- src/daft-sql/src/planner.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index cbeb751ca0..6f73cedd04 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -7,9 +7,8 @@ use std::{ use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ - col, - has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, - Subquery + col, has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, + Operator, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, From 4c7be016ea4b5510e8bf2646ca85b2360bd8b430 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Fri, 15 Nov 2024 11:04:04 +0800 Subject: [PATCH 09/10] fixes unrelated codestyle --- daft/dataframe/dataframe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index dc691cb4b6..a79443e327 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -1601,7 +1601,7 @@ def limit(self, num: int) -> "DataFrame": │ --- │ │ Int64 │ ╞═══════╡ - │ 1 │ f + │ 1 │ ├╌╌╌╌╌╌╌┤ │ 2 │ ├╌╌╌╌╌╌╌┤ From 030483f4039b757b220987d4b7963cff66f28d57 Mon Sep 17 00:00:00 2001 From: xianyangliu Date: Fri, 15 Nov 2024 11:23:00 +0800 Subject: [PATCH 10/10] Trigger Build