diff --git a/.github/workflows/reuse.linux.yml b/.github/workflows/reuse.linux.yml index 8825d641718ec..809a0cbd8d382 100644 --- a/.github/workflows/reuse.linux.yml +++ b/.github/workflows/reuse.linux.yml @@ -185,6 +185,33 @@ jobs: with: name: test-sqllogic-standalone-${{ matrix.dirs }} + sqllogic_standalone_udf_server: + name: sqllogic_standalone_${{ matrix.dirs }} + runs-on: [self-hosted, X64, Linux, 4c8g, "${{ inputs.runner_provider }}"] + needs: [build, check] + strategy: + matrix: + dirs: + - "udf_server" + steps: + - uses: actions/checkout@v3 + - name: Start UDF Server + run: | + pip install pyarrow + python3 tests/udf-server/udf_test.py & + sleep 2 + - uses: ./.github/actions/test_sqllogic_standalone_linux + timeout-minutes: 15 + with: + dirs: ${{ matrix.dirs }} + handlers: mysql,http,clickhouse + storage-format: all + - name: Upload failure + if: failure() || cancelled() + uses: ./.github/actions/artifact_failure + with: + name: test-sqllogic-standalone-${{ matrix.dirs }} + sqllogic_standalone_with_native: name: sqllogic_standalone_${{ matrix.dirs }}_${{ matrix.format }} runs-on: [self-hosted, X64, Linux, 4c8g, "${{ inputs.runner_provider }}"] diff --git a/Cargo.lock b/Cargo.lock index 7d689be70a313..a7b83b977e8a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1942,14 +1942,18 @@ name = "common-expression" version = "0.1.0" dependencies = [ "arrow-array", + "arrow-flight", "arrow-ord", "arrow-schema", + "arrow-select 46.0.0 (registry+https://github.com/rust-lang/crates.io-index)", + "async-backtrace", "base64 0.21.0", "chrono", "chrono-tz", "comfy-table 6.1.4", "common-arrow", "common-ast", + "common-base", "common-datavalues", "common-exception", "common-hashtable", @@ -1977,6 +1981,7 @@ dependencies = [ "serde", "serde_json", "terminal_size", + "tonic 0.9.2", "typetag", "unicode-segmentation", ] @@ -2150,6 +2155,7 @@ dependencies = [ "chrono", "common-base", "common-exception", + "common-expression", "common-functions", "common-meta-api", "common-meta-app", @@ -3101,6 +3107,7 @@ dependencies = [ "cidr", "common-base", "common-exception", + "common-expression", "common-grpc", "common-management", "common-meta-api", diff --git a/scripts/ci/deploy/config/databend-query-node-1.toml b/scripts/ci/deploy/config/databend-query-node-1.toml index 280bafa6137de..4a9447abe5518 100644 --- a/scripts/ci/deploy/config/databend-query-node-1.toml +++ b/scripts/ci/deploy/config/databend-query-node-1.toml @@ -38,6 +38,9 @@ table_engine_memory_enabled = true default_storage_format = 'parquet' default_compression = 'zstd' +enable_udf_server = true +udf_server_allow_list = ['http://0.0.0.0:8815'] + [[query.users]] name = "root" auth_type = "no_password" diff --git a/scripts/ci/deploy/config/databend-query-node-native.toml b/scripts/ci/deploy/config/databend-query-node-native.toml index a8e000873abef..6b0bfdb9c7b4c 100644 --- a/scripts/ci/deploy/config/databend-query-node-native.toml +++ b/scripts/ci/deploy/config/databend-query-node-native.toml @@ -32,6 +32,9 @@ cluster_id = "test_cluster" table_engine_memory_enabled = true +enable_udf_server = true +udf_server_allow_list = ['http://0.0.0.0:8815'] + [[query.users]] name = "root" auth_type = "no_password" diff --git a/src/common/exception/src/exception_code.rs b/src/common/exception/src/exception_code.rs index 2502db8576530..1ea07fd106873 100644 --- a/src/common/exception/src/exception_code.rs +++ b/src/common/exception/src/exception_code.rs @@ -271,6 +271,10 @@ build_exceptions! { IllegalUDFFormat(2601), UnknownUDF(2602), UdfAlreadyExists(2603), + UDFServerConnectError(2604), + UDFSchemaMismatch(2605), + UnsupportedDataType(2606), + UDFDataError(2607), // Database error codes. UnknownDatabaseEngine(2701), diff --git a/src/meta/app/src/principal/mod.rs b/src/meta/app/src/principal/mod.rs index 80277cea23361..22ceaa43c2739 100644 --- a/src/meta/app/src/principal/mod.rs +++ b/src/meta/app/src/principal/mod.rs @@ -38,6 +38,9 @@ pub use user_auth::AuthInfo; pub use user_auth::AuthType; pub use user_auth::PasswordHashMethod; pub use user_defined_file_format::UserDefinedFileFormat; +pub use user_defined_function::LambdaUDF; +pub use user_defined_function::UDFDefinition; +pub use user_defined_function::UDFServer; pub use user_defined_function::UserDefinedFunction; pub use user_grant::GrantEntry; pub use user_grant::GrantObject; diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index e1a7a3be57147..02feed27f09e7 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -12,44 +12,114 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::convert::TryFrom; +use std::fmt::Display; +use std::fmt::Formatter; -use common_exception::ErrorCode; -use common_exception::Result; -use serde::Deserialize; -use serde::Serialize; +use common_expression::types::DataType; -#[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq, Default)] -#[serde(default)] -pub struct UserDefinedFunction { - pub name: String, +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct LambdaUDF { pub parameters: Vec, + pub definition: String, +} +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UDFServer { + pub address: String, + pub handler: String, + pub language: String, + pub arg_types: Vec, + pub return_type: DataType, +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum UDFDefinition { + LambdaUDF(LambdaUDF), + UDFServer(UDFServer), +} + +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct UserDefinedFunction { + pub name: String, pub description: String, - pub definition: String, + pub definition: UDFDefinition, } impl UserDefinedFunction { - pub fn new(name: &str, parameters: Vec, definition: &str, description: &str) -> Self { + pub fn create_lambda_udf( + name: &str, + parameters: Vec, + definition: &str, + description: &str, + ) -> Self { Self { name: name.to_string(), - parameters, - definition: definition.to_string(), description: description.to_string(), + definition: UDFDefinition::LambdaUDF(LambdaUDF { + parameters, + definition: definition.to_string(), + }), } } -} -impl TryFrom> for UserDefinedFunction { - type Error = ErrorCode; + pub fn create_udf_server( + name: &str, + address: &str, + handler: &str, + language: &str, + arg_types: Vec, + return_type: DataType, + description: &str, + ) -> Self { + Self { + name: name.to_string(), + description: description.to_string(), + definition: UDFDefinition::UDFServer(UDFServer { + address: address.to_string(), + handler: handler.to_string(), + language: language.to_string(), + arg_types, + return_type, + }), + } + } +} - fn try_from(value: Vec) -> Result { - match serde_json::from_slice(&value) { - Ok(udf) => Ok(udf), - Err(serialize_error) => Err(ErrorCode::IllegalUDFFormat(format!( - "Cannot deserialize user defined function from bytes. cause {}", - serialize_error - ))), +impl Display for UDFDefinition { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, " (")?; + match self { + UDFDefinition::LambdaUDF(LambdaUDF { + parameters, + definition, + }) => { + for (i, item) in parameters.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{item}")?; + } + write!(f, ") -> {definition}")?; + } + UDFDefinition::UDFServer(UDFServer { + address, + arg_types, + return_type, + handler, + language, + }) => { + for (i, item) in arg_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{item}")?; + } + write!( + f, + ") RETURNS {return_type} LANGUAGE {language} HANDLER = {handler} ADDRESS = {address}" + )?; + } } + Ok(()) } } diff --git a/src/meta/app/tests/it/main.rs b/src/meta/app/tests/it/main.rs index 5736b8f9efdbe..9539255f03b13 100644 --- a/src/meta/app/tests/it/main.rs +++ b/src/meta/app/tests/it/main.rs @@ -13,7 +13,6 @@ // limitations under the License. mod file_format; -mod user_defined_function; mod user_grant; mod user_info; mod user_privilege; diff --git a/src/meta/app/tests/it/user_defined_function.rs b/src/meta/app/tests/it/user_defined_function.rs deleted file mode 100644 index e8ad4838983bf..0000000000000 --- a/src/meta/app/tests/it/user_defined_function.rs +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2021 Datafuse Labs. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use common_exception::exception::Result; -use common_meta_app::principal::UserDefinedFunction; - -#[test] -fn test_udf() -> Result<()> { - let udf = UserDefinedFunction::new( - "is_not_null", - vec!["p".to_string()], - "not(is_null(p))", - "this is a description", - ); - let ser = serde_json::to_string(&udf)?; - - let de = UserDefinedFunction::try_from(ser.into_bytes())?; - assert_eq!(udf, de); - - Ok(()) -} diff --git a/src/meta/proto-conv/src/lib.rs b/src/meta/proto-conv/src/lib.rs index 7823e280dc867..bcddcc8ef3b55 100644 --- a/src/meta/proto-conv/src/lib.rs +++ b/src/meta/proto-conv/src/lib.rs @@ -76,6 +76,7 @@ mod schema_from_to_protobuf_impl; mod share_from_to_protobuf_impl; mod stage_from_to_protobuf_impl; mod table_from_to_protobuf_impl; +mod udf_from_to_protobuf_impl; mod user_from_to_protobuf_impl; mod util; mod virtual_column_from_to_protobuf_impl; diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs new file mode 100644 index 0000000000000..de8b24d7ee3ba --- /dev/null +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -0,0 +1,153 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_expression::infer_schema_type; +use common_expression::types::DataType; +use common_expression::TableDataType; +use common_meta_app::principal as mt; +use common_protos::pb; + +use crate::reader_check_msg; +use crate::FromToProto; +use crate::Incompatible; +use crate::MIN_READER_VER; +use crate::VER; + +impl FromToProto for mt::LambdaUDF { + type PB = pb::LambdaUdf; + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + fn from_pb(p: pb::LambdaUdf) -> Result { + reader_check_msg(p.ver, p.min_reader_ver)?; + + Ok(mt::LambdaUDF { + parameters: p.parameters, + definition: p.definition, + }) + } + + fn to_pb(&self) -> Result { + Ok(pb::LambdaUdf { + ver: VER, + min_reader_ver: MIN_READER_VER, + parameters: self.parameters.clone(), + definition: self.definition.clone(), + }) + } +} + +impl FromToProto for mt::UDFServer { + type PB = pb::UdfServer; + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + fn from_pb(p: pb::UdfServer) -> Result { + reader_check_msg(p.ver, p.min_reader_ver)?; + + let mut arg_types = Vec::with_capacity(p.arg_types.len()); + for arg_type in p.arg_types { + let arg_type = DataType::from(&TableDataType::from_pb(arg_type)?); + arg_types.push(arg_type); + } + let return_type = DataType::from(&TableDataType::from_pb(p.return_type.ok_or_else( + || Incompatible { + reason: "UdfServer.return_type can not be None".to_string(), + }, + )?)?); + + Ok(mt::UDFServer { + address: p.address, + arg_types, + return_type, + handler: p.handler, + language: p.language, + }) + } + + fn to_pb(&self) -> Result { + let mut arg_types = Vec::with_capacity(self.arg_types.len()); + for arg_type in self.arg_types.iter() { + let arg_type = infer_schema_type(arg_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + arg_types.push(arg_type); + } + let return_type = infer_schema_type(&self.return_type) + .map_err(|e| Incompatible { + reason: format!("Convert DataType to TableDataType failed: {}", e.message()), + })? + .to_pb()?; + + Ok(pb::UdfServer { + ver: VER, + min_reader_ver: MIN_READER_VER, + address: self.address.clone(), + handler: self.handler.clone(), + language: self.language.clone(), + arg_types, + return_type: Some(return_type), + }) + } +} + +impl FromToProto for mt::UserDefinedFunction { + type PB = pb::UserDefinedFunction; + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + fn from_pb(p: pb::UserDefinedFunction) -> Result { + reader_check_msg(p.ver, p.min_reader_ver)?; + let udf_def = match p.definition { + Some(pb::user_defined_function::Definition::LambdaUdf(lambda_udf)) => { + mt::UDFDefinition::LambdaUDF(mt::LambdaUDF::from_pb(lambda_udf)?) + } + Some(pb::user_defined_function::Definition::UdfServer(udf_server)) => { + mt::UDFDefinition::UDFServer(mt::UDFServer::from_pb(udf_server)?) + } + None => { + return Err(Incompatible { + reason: "UserDefinedFunction.definition cannot be None".to_string(), + }); + } + }; + + Ok(mt::UserDefinedFunction { + name: p.name, + description: p.description, + definition: udf_def, + }) + } + + fn to_pb(&self) -> Result { + let udf_def = match &self.definition { + mt::UDFDefinition::LambdaUDF(lambda_udf) => { + pb::user_defined_function::Definition::LambdaUdf(lambda_udf.to_pb()?) + } + mt::UDFDefinition::UDFServer(udf_server) => { + pb::user_defined_function::Definition::UdfServer(udf_server.to_pb()?) + } + }; + + Ok(pb::UserDefinedFunction { + ver: VER, + min_reader_ver: MIN_READER_VER, + name: self.name.clone(), + description: self.description.clone(), + definition: Some(udf_def), + }) + } +} diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index 7aa92ba627c63..7137254c79e66 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -86,7 +86,8 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (54, "2023-08-17: Add: index.proto/IndexMeta::sync_creation", ), (55, "2023-07-31: Add: TableMeta and DatabaseMeta add Ownership", ), (56, "2023-08-31: Add: Least Visible Time", ), - (57, "2023-09-05: Add: catalog.proto add hdfs config", ) + (57, "2023-09-05: Add: catalog.proto add hdfs config", ), + (58, "2023-09-06: Add: udf.proto/UserDefinedFunction", ), // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index cd2cf0420af90..d1789392d45eb 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -62,3 +62,4 @@ mod v054_index_meta; mod v055_table_meta; mod v056_least_visible_time; mod v057_hdfs_storage; +mod v058_udf; diff --git a/src/meta/proto-conv/tests/it/v058_udf.rs b/src/meta/proto-conv/tests/it/v058_udf.rs new file mode 100644 index 0000000000000..6e82e8a2db3f7 --- /dev/null +++ b/src/meta/proto-conv/tests/it/v058_udf.rs @@ -0,0 +1,62 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_expression::types::DataType; +use common_expression::types::NumberDataType; +use common_meta_app::principal::UDFDefinition; +use common_meta_app::principal::UDFServer; +use common_meta_app::principal::UserDefinedFunction; + +use crate::common; + +// These bytes are built when a new version in introduced, +// and are kept for backward compatibility test. +// +// ************************************************************* +// * These messages should never be updated, * +// * only be added when a new version is added, * +// * or be removed when an old version is no longer supported. * +// ************************************************************* +// +// The message bytes are built from the output of `test_build_pb_buf()` +#[test] +fn test_decode_v57_udf() -> anyhow::Result<()> { + let bytes: Vec = vec![ + 10, 8, 112, 108, 117, 115, 95, 105, 110, 116, 18, 21, 84, 104, 105, 115, 32, 105, 115, 32, + 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 34, 107, 10, 21, 104, 116, + 116, 112, 58, 47, 47, 108, 111, 99, 97, 108, 104, 111, 115, 116, 58, 56, 56, 56, 56, 18, + 11, 112, 108, 117, 115, 95, 105, 110, 116, 95, 112, 121, 26, 6, 112, 121, 116, 104, 111, + 110, 34, 17, 154, 2, 8, 58, 0, 160, 6, 58, 168, 6, 24, 160, 6, 58, 168, 6, 24, 34, 17, 154, + 2, 8, 58, 0, 160, 6, 58, 168, 6, 24, 160, 6, 58, 168, 6, 24, 42, 17, 154, 2, 8, 66, 0, 160, + 6, 58, 168, 6, 24, 160, 6, 58, 168, 6, 24, 160, 6, 58, 168, 6, 24, 160, 6, 58, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "plus_int".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::UDFServer(UDFServer { + address: "http://localhost:8888".to_string(), + handler: "plus_int_py".to_string(), + language: "python".to_string(), + arg_types: vec![ + DataType::Number(NumberDataType::Int32), + DataType::Number(NumberDataType::Int32), + ], + return_type: DataType::Number(NumberDataType::Int64), + }), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 58, want()) +} diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto new file mode 100644 index 0000000000000..109067c085793 --- /dev/null +++ b/src/meta/protos/proto/udf.proto @@ -0,0 +1,50 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package databend_proto; + +import "datatype.proto"; + +message LambdaUDF { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + repeated string parameters = 1; + string definition = 2; +} + +message UDFServer { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + string address = 1; + string handler = 2; + string language = 3; + repeated DataType arg_types = 4; + DataType return_type = 5; +} + +message UserDefinedFunction { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + string name = 1; + string description = 2; + oneof definition { + LambdaUDF lambda_udf = 3; + UDFServer udf_server = 4; + } +} \ No newline at end of file diff --git a/src/query/ast/src/ast/format/ast_format.rs b/src/query/ast/src/ast/format/ast_format.rs index 51c214d0931ec..18f5acbaa099c 100644 --- a/src/query/ast/src/ast/format/ast_format.rs +++ b/src/query/ast/src/ast/format/ast_format.rs @@ -1907,39 +1907,81 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { self.children.push(node); } - fn visit_create_udf( - &mut self, - _if_not_exists: bool, - udf_name: &'ast Identifier, - parameters: &'ast [Identifier], - definition: &'ast Expr, - description: &'ast Option, - ) { + fn visit_create_udf(&mut self, stmt: &'ast CreateUDFStmt) { let mut children = Vec::new(); - let udf_name_format_ctx = AstFormatContext::new(format!("UdfNameIdentifier {}", udf_name)); + let udf_name_format_ctx = + AstFormatContext::new(format!("UdfNameIdentifier {}", stmt.udf_name)); children.push(FormatTreeNode::new(udf_name_format_ctx)); - if !parameters.is_empty() { - let mut parameters_children = Vec::with_capacity(parameters.len()); - for parameter in parameters.iter() { - self.visit_identifier(parameter); - parameters_children.push(self.children.pop().unwrap()); - } - let parameters_name = "UdfParameters".to_string(); - let parameters_format_ctx = - AstFormatContext::with_children(parameters_name, parameters_children.len()); - children.push(FormatTreeNode::with_children( - parameters_format_ctx, - parameters_children, - )); - } - self.visit_expr(definition); - let definition_child = self.children.pop().unwrap(); - let definition_name = "UdfDefinition".to_string(); - let definition_format_ctx = AstFormatContext::with_children(definition_name, 1); - children.push(FormatTreeNode::with_children(definition_format_ctx, vec![ - definition_child, - ])); - if let Some(description) = description { + + match &stmt.definition { + UDFDefinition::LambdaUDF { + parameters, + definition, + } => { + if !parameters.is_empty() { + let mut parameters_children = Vec::with_capacity(parameters.len()); + for parameter in parameters.iter() { + self.visit_identifier(parameter); + parameters_children.push(self.children.pop().unwrap()); + } + let parameters_name = "UdfParameters".to_string(); + let parameters_format_ctx = + AstFormatContext::with_children(parameters_name, parameters_children.len()); + children.push(FormatTreeNode::with_children( + parameters_format_ctx, + parameters_children, + )); + } + self.visit_expr(definition); + let definition_child = self.children.pop().unwrap(); + let definition_name = "UdfDefinition".to_string(); + let definition_format_ctx = AstFormatContext::with_children(definition_name, 1); + children.push(FormatTreeNode::with_children(definition_format_ctx, vec![ + definition_child, + ])); + } + UDFDefinition::UDFServer { + arg_types, + return_type, + address, + handler, + language, + } => { + if !arg_types.is_empty() { + let mut arg_types_children = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types.iter() { + let type_format_ctx = AstFormatContext::new(format!("DataType {arg_type}")); + arg_types_children.push(FormatTreeNode::new(type_format_ctx)); + } + let arg_format_ctx = AstFormatContext::with_children( + "UdfArgTypes".to_string(), + arg_types_children.len(), + ); + children.push(FormatTreeNode::with_children( + arg_format_ctx, + arg_types_children, + )); + } + + let return_type_format_ctx = + AstFormatContext::new(format!("UdfReturnType {return_type}")); + children.push(FormatTreeNode::new(return_type_format_ctx)); + + let handler_format_ctx = + AstFormatContext::new(format!("UdfServerHandler {handler}")); + children.push(FormatTreeNode::new(handler_format_ctx)); + + let language_format_ctx = + AstFormatContext::new(format!("UdfServerLanguage {language}")); + children.push(FormatTreeNode::new(language_format_ctx)); + + let address_format_ctx = + AstFormatContext::new(format!("UdfServerAddress {address}")); + children.push(FormatTreeNode::new(address_format_ctx)); + } + } + + if let Some(description) = &stmt.description { let description_name = format!("UdfDescription {}", description); let description_format_ctx = AstFormatContext::new(description_name); children.push(FormatTreeNode::new(description_format_ctx)); @@ -1961,38 +2003,81 @@ impl<'ast> Visitor<'ast> for AstFormatVisitor { self.children.push(node); } - fn visit_alter_udf( - &mut self, - udf_name: &'ast Identifier, - parameters: &'ast [Identifier], - definition: &'ast Expr, - description: &'ast Option, - ) { + fn visit_alter_udf(&mut self, stmt: &'ast AlterUDFStmt) { let mut children = Vec::new(); - let udf_name_format_ctx = AstFormatContext::new(format!("UdfNameIdentifier {}", udf_name)); + let udf_name_format_ctx = + AstFormatContext::new(format!("UdfNameIdentifier {}", stmt.udf_name)); children.push(FormatTreeNode::new(udf_name_format_ctx)); - if !parameters.is_empty() { - let mut parameters_children = Vec::with_capacity(parameters.len()); - for parameter in parameters.iter() { - self.visit_identifier(parameter); - parameters_children.push(self.children.pop().unwrap()); - } - let parameters_name = "UdfParameters".to_string(); - let parameters_format_ctx = - AstFormatContext::with_children(parameters_name, parameters_children.len()); - children.push(FormatTreeNode::with_children( - parameters_format_ctx, - parameters_children, - )); - } - self.visit_expr(definition); - let definition_child = self.children.pop().unwrap(); - let definition_name = "UdfDefinition".to_string(); - let definition_format_ctx = AstFormatContext::with_children(definition_name, 1); - children.push(FormatTreeNode::with_children(definition_format_ctx, vec![ - definition_child, - ])); - if let Some(description) = description { + + match &stmt.definition { + UDFDefinition::LambdaUDF { + parameters, + definition, + } => { + if !parameters.is_empty() { + let mut parameters_children = Vec::with_capacity(parameters.len()); + for parameter in parameters.iter() { + self.visit_identifier(parameter); + parameters_children.push(self.children.pop().unwrap()); + } + let parameters_name = "UdfParameters".to_string(); + let parameters_format_ctx = + AstFormatContext::with_children(parameters_name, parameters_children.len()); + children.push(FormatTreeNode::with_children( + parameters_format_ctx, + parameters_children, + )); + } + self.visit_expr(definition); + let definition_child = self.children.pop().unwrap(); + let definition_name = "UdfDefinition".to_string(); + let definition_format_ctx = AstFormatContext::with_children(definition_name, 1); + children.push(FormatTreeNode::with_children(definition_format_ctx, vec![ + definition_child, + ])); + } + UDFDefinition::UDFServer { + arg_types, + return_type, + address, + handler, + language, + } => { + if !arg_types.is_empty() { + let mut arg_types_children = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types.iter() { + let type_format_ctx = AstFormatContext::new(format!("DataType {arg_type}")); + arg_types_children.push(FormatTreeNode::new(type_format_ctx)); + } + let arg_format_ctx = AstFormatContext::with_children( + "UdfArgTypes".to_string(), + arg_types_children.len(), + ); + children.push(FormatTreeNode::with_children( + arg_format_ctx, + arg_types_children, + )); + } + + let return_type_format_ctx = + AstFormatContext::new(format!("UdfReturnType {return_type}")); + children.push(FormatTreeNode::new(return_type_format_ctx)); + + let handler_format_ctx = + AstFormatContext::new(format!("UdfServerHandler {handler}")); + children.push(FormatTreeNode::new(handler_format_ctx)); + + let language_format_ctx = + AstFormatContext::new(format!("UdfServerLanguage {language}")); + children.push(FormatTreeNode::new(language_format_ctx)); + + let address_format_ctx = + AstFormatContext::new(format!("UdfServerAddress {address}")); + children.push(FormatTreeNode::new(address_format_ctx)); + } + } + + if let Some(description) = &stmt.description { let description_name = format!("UdfDescription {}", description); let description_format_ctx = AstFormatContext::new(description_name); children.push(FormatTreeNode::new(description_format_ctx)); diff --git a/src/query/ast/src/ast/statements/mod.rs b/src/query/ast/src/ast/statements/mod.rs index 3e8505c194f27..a6f67cd81d4ea 100644 --- a/src/query/ast/src/ast/statements/mod.rs +++ b/src/query/ast/src/ast/statements/mod.rs @@ -32,6 +32,7 @@ mod show; mod stage; mod statement; mod table; +mod udf; mod unset; mod update; mod user; @@ -58,6 +59,7 @@ pub use show::*; pub use stage::*; pub use statement::*; pub use table::*; +pub use udf::*; pub use unset::*; pub use update::*; pub use user::*; diff --git a/src/query/ast/src/ast/statements/statement.rs b/src/query/ast/src/ast/statements/statement.rs index ecd33a9d2bed6..5e4f7ba6de3bc 100644 --- a/src/query/ast/src/ast/statements/statement.rs +++ b/src/query/ast/src/ast/statements/statement.rs @@ -21,7 +21,6 @@ use common_meta_app::principal::UserIdentity; use super::merge_into::MergeIntoStmt; use super::*; -use crate::ast::write_comma_separated_list; use crate::ast::Expr; use crate::ast::Identifier; use crate::ast::Query; @@ -164,23 +163,12 @@ pub enum Statement { Revoke(RevokeStmt), // UDF - CreateUDF { - if_not_exists: bool, - udf_name: Identifier, - parameters: Vec, - definition: Box, - description: Option, - }, + CreateUDF(CreateUDFStmt), DropUDF { if_exists: bool, udf_name: Identifier, }, - AlterUDF { - udf_name: Identifier, - parameters: Vec, - definition: Box, - description: Option, - }, + AlterUDF(AlterUDFStmt), // Stages CreateStage(CreateStageStmt), @@ -450,24 +438,7 @@ impl Display for Statement { } } Statement::Revoke(stmt) => write!(f, "{stmt}")?, - Statement::CreateUDF { - if_not_exists, - udf_name, - parameters, - definition, - description, - } => { - write!(f, "CREATE FUNCTION")?; - if *if_not_exists { - write!(f, " IF NOT EXISTS")?; - } - write!(f, " {udf_name} AS (")?; - write_comma_separated_list(f, parameters)?; - write!(f, ") -> {definition}")?; - if let Some(description) = description { - write!(f, " DESC = '{description}'")?; - } - } + Statement::CreateUDF(stmt) => write!(f, "{stmt}")?, Statement::DropUDF { if_exists, udf_name, @@ -478,19 +449,7 @@ impl Display for Statement { } write!(f, " {udf_name}")?; } - Statement::AlterUDF { - udf_name, - parameters, - definition, - description, - } => { - write!(f, "ALTER FUNCTION {udf_name} AS (")?; - write_comma_separated_list(f, parameters)?; - write!(f, ") -> {definition}")?; - if let Some(description) = description { - write!(f, " DESC = '{description}'")?; - } - } + Statement::AlterUDF(stmt) => write!(f, "{stmt}")?, Statement::ListStage { location, pattern } => { write!(f, "LIST @{location}")?; if !pattern.is_empty() { diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs new file mode 100644 index 0000000000000..e634fd541b5b4 --- /dev/null +++ b/src/query/ast/src/ast/statements/udf.rs @@ -0,0 +1,106 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Display; +use std::fmt::Formatter; + +use crate::ast::write_comma_separated_list; +use crate::ast::Expr; +use crate::ast::Identifier; +use crate::ast::TypeName; + +#[derive(Debug, Clone, PartialEq)] +pub enum UDFDefinition { + LambdaUDF { + parameters: Vec, + definition: Box, + }, + UDFServer { + arg_types: Vec, + return_type: TypeName, + address: String, + handler: String, + language: String, + }, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct CreateUDFStmt { + pub if_not_exists: bool, + pub udf_name: Identifier, + pub description: Option, + pub definition: UDFDefinition, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct AlterUDFStmt { + pub udf_name: Identifier, + pub description: Option, + pub definition: UDFDefinition, +} + +impl Display for UDFDefinition { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + match self { + UDFDefinition::LambdaUDF { + parameters, + definition, + } => { + write!(f, "AS (")?; + write_comma_separated_list(f, parameters)?; + write!(f, ") -> {definition}")?; + } + UDFDefinition::UDFServer { + arg_types, + return_type, + address, + handler, + language, + } => { + write!(f, "(")?; + write_comma_separated_list(f, arg_types)?; + write!( + f, + ") RETURNS {return_type} LANGUAGE {language} HANDLER = {handler} ADDRESS = {address}" + )?; + } + } + Ok(()) + } +} + +impl Display for CreateUDFStmt { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "CREATE FUNCTION")?; + if self.if_not_exists { + write!(f, " IF NOT EXISTS")?; + } + write!(f, " {} {}", self.udf_name, self.definition)?; + if let Some(description) = &self.description { + write!(f, " DESC = '{description}'")?; + } + Ok(()) + } +} + +impl Display for AlterUDFStmt { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "ALTER FUNCTION")?; + write!(f, " {} {}", self.udf_name, self.definition)?; + if let Some(description) = &self.description { + write!(f, " DESC = '{description}'")?; + } + Ok(()) + } +} diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index 4f7b7f6a7b132..caf0f2a28a2c3 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -973,31 +973,16 @@ pub fn statement(i: Input) -> IResult { let create_udf = map( rule! { CREATE ~ FUNCTION ~ ( IF ~ NOT ~ EXISTS )? - ~ #ident - ~ AS ~ "(" ~ #comma_separated_list0(ident) ~ ")" - ~ "->" ~ #expr + ~ #ident ~ #udf_definition ~ ( DESC ~ ^"=" ~ ^#literal_string )? }, - |( - _, - _, - opt_if_not_exists, - udf_name, - _, - _, - parameters, - _, - _, - definition, - opt_description, - )| { - Statement::CreateUDF { + |(_, _, opt_if_not_exists, udf_name, definition, opt_description)| { + Statement::CreateUDF(CreateUDFStmt { if_not_exists: opt_if_not_exists.is_some(), udf_name, - parameters, - definition: Box::new(definition), description: opt_description.map(|(_, _, description)| description), - } + definition, + }) }, ); let drop_udf = map( @@ -1012,18 +997,15 @@ pub fn statement(i: Input) -> IResult { let alter_udf = map( rule! { ALTER ~ FUNCTION - ~ #ident - ~ AS ~ "(" ~ #comma_separated_list0(ident) ~ ")" - ~ "->" ~ #expr + ~ #ident ~ #udf_definition ~ ( DESC ~ ^"=" ~ ^#literal_string )? }, - |(_, _, udf_name, _, _, parameters, _, _, definition, opt_description)| { - Statement::AlterUDF { + |(_, _, udf_name, definition, opt_description)| { + Statement::AlterUDF(AlterUDFStmt { udf_name, - parameters, - definition: Box::new(definition), description: opt_description.map(|(_, _, description)| description), - } + definition, + }) }, ); @@ -1543,7 +1525,7 @@ pub fn statement(i: Input) -> IResult { | #show_roles : "`SHOW ROLES`" | #create_role : "`CREATE ROLE [IF NOT EXISTS] `" | #drop_role : "`DROP ROLE [IF EXISTS] `" - | #create_udf : "`CREATE FUNCTION [IF NOT EXISTS] (, ...) -> [DESC = ]`" + | #create_udf : "`CREATE FUNCTION [IF NOT EXISTS] {AS (, ...) -> | (, ...) RETURNS LANGUAGE HANDLER= ADDRESS=} [DESC = ]`" | #drop_udf : "`DROP FUNCTION [IF EXISTS] `" | #alter_udf : "`ALTER FUNCTION (, ...) -> [DESC = ]`" ), @@ -2688,6 +2670,59 @@ pub fn update_expr(i: Input) -> IResult { })(i) } +pub fn udf_arg_type(i: Input) -> IResult { + let nullable = alt(( + value(true, rule! { NULL }), + value(false, rule! { NOT ~ ^NULL }), + )); + map( + rule! { + #type_name ~ #nullable? + }, + |(type_name, nullable)| match nullable { + Some(false) => type_name, + _ => type_name.wrap_nullable(), + }, + )(i) +} + +pub fn udf_definition(i: Input) -> IResult { + let lambda_udf = map( + rule! { + AS ~ "(" ~ #comma_separated_list0(ident) ~ ")" + ~ "->" ~ #expr + }, + |(_, _, parameters, _, _, definition)| UDFDefinition::LambdaUDF { + parameters, + definition: Box::new(definition), + }, + ); + + let udf_server = map( + rule! { + "(" ~ #comma_separated_list0(udf_arg_type) ~ ")" + ~ RETURNS ~ #udf_arg_type + ~ LANGUAGE ~ #ident + ~ HANDLER ~ ^"=" ~ ^#literal_string + ~ ADDRESS ~ ^"=" ~ ^#literal_string + }, + |(_, arg_types, _, _, return_type, _, language, _, _, handler, _, _, address)| { + UDFDefinition::UDFServer { + arg_types, + return_type, + address, + handler, + language: language.to_string(), + } + }, + ); + + rule!( + #udf_server: "(, ...) RETURNS LANGUAGE HANDLER= ADDRESS=" + | #lambda_udf: "AS (, ...) -> " + )(i) +} + pub fn merge_update_expr(i: Input) -> IResult { map( rule! { ( #dot_separated_idents_1_to_3 ~ "=" ~ ^#expr ) }, diff --git a/src/query/ast/src/parser/token.rs b/src/query/ast/src/parser/token.rs index 557cec954bcc2..306ecaf5f3f9a 100644 --- a/src/query/ast/src/parser/token.rs +++ b/src/query/ast/src/parser/token.rs @@ -1002,8 +1002,14 @@ pub enum TokenKind { ROLLUP, #[token("INDEXES", ignore(ascii_case))] INDEXES, + #[token("ADDRESS", ignore(ascii_case))] + ADDRESS, #[token("OWNERSHIP", ignore(ascii_case))] OWNERSHIP, + #[token("HANDLER", ignore(ascii_case))] + HANDLER, + #[token("LANGUAGE", ignore(ascii_case))] + LANGUAGE, } // Reference: https://www.postgresql.org/docs/current/sql-keywords-appendix.html diff --git a/src/query/ast/src/visitors/visitor.rs b/src/query/ast/src/visitors/visitor.rs index 57988a6d6161f..f5ae7b587ed3c 100644 --- a/src/query/ast/src/visitors/visitor.rs +++ b/src/query/ast/src/visitors/visitor.rs @@ -515,26 +515,11 @@ pub trait Visitor<'ast>: Sized { fn visit_revoke(&mut self, _revoke: &'ast RevokeStmt) {} - fn visit_create_udf( - &mut self, - _if_not_exists: bool, - _udf_name: &'ast Identifier, - _parameters: &'ast [Identifier], - _definition: &'ast Expr, - _description: &'ast Option, - ) { - } + fn visit_create_udf(&mut self, _stmt: &'ast CreateUDFStmt) {} fn visit_drop_udf(&mut self, _if_exists: bool, _udf_name: &'ast Identifier) {} - fn visit_alter_udf( - &mut self, - _udf_name: &'ast Identifier, - _parameters: &'ast [Identifier], - _definition: &'ast Expr, - _description: &'ast Option, - ) { - } + fn visit_alter_udf(&mut self, _stmt: &'ast AlterUDFStmt) {} fn visit_create_stage(&mut self, _stmt: &'ast CreateStageStmt) {} diff --git a/src/query/ast/src/visitors/visitor_mut.rs b/src/query/ast/src/visitors/visitor_mut.rs index a8808a4aa0868..fb9ec7ec12cde 100644 --- a/src/query/ast/src/visitors/visitor_mut.rs +++ b/src/query/ast/src/visitors/visitor_mut.rs @@ -530,26 +530,11 @@ pub trait VisitorMut: Sized { fn visit_revoke(&mut self, _revoke: &mut RevokeStmt) {} - fn visit_create_udf( - &mut self, - _if_not_exists: bool, - _udf_name: &mut Identifier, - _parameters: &mut [Identifier], - _definition: &mut Expr, - _description: &mut Option, - ) { - } + fn visit_create_udf(&mut self, _stmt: &mut CreateUDFStmt) {} fn visit_drop_udf(&mut self, _if_exists: bool, _udf_name: &mut Identifier) {} - fn visit_alter_udf( - &mut self, - _udf_name: &mut Identifier, - _parameters: &mut [Identifier], - _definition: &mut Expr, - _description: &mut Option, - ) { - } + fn visit_alter_udf(&mut self, _stmt: &mut AlterUDFStmt) {} fn visit_create_stage(&mut self, _stmt: &mut CreateStageStmt) {} diff --git a/src/query/ast/src/visitors/walk.rs b/src/query/ast/src/visitors/walk.rs index b8c4a8bd5ec3b..383950b98a3ed 100644 --- a/src/query/ast/src/visitors/walk.rs +++ b/src/query/ast/src/visitors/walk.rs @@ -419,29 +419,12 @@ pub fn walk_statement<'a, V: Visitor<'a>>(visitor: &mut V, statement: &'a Statem Statement::Grant(stmt) => visitor.visit_grant(stmt), Statement::ShowGrants { principal } => visitor.visit_show_grant(principal), Statement::Revoke(stmt) => visitor.visit_revoke(stmt), - Statement::CreateUDF { - if_not_exists, - udf_name, - parameters, - definition, - description, - } => visitor.visit_create_udf( - *if_not_exists, - udf_name, - parameters, - definition, - description, - ), + Statement::CreateUDF(stmt) => visitor.visit_create_udf(stmt), Statement::DropUDF { if_exists, udf_name, } => visitor.visit_drop_udf(*if_exists, udf_name), - Statement::AlterUDF { - udf_name, - parameters, - definition, - description, - } => visitor.visit_alter_udf(udf_name, parameters, definition, description), + Statement::AlterUDF(stmt) => visitor.visit_alter_udf(stmt), Statement::ListStage { location, pattern } => visitor.visit_list_stage(location, pattern), Statement::ShowStages => visitor.visit_show_stages(), Statement::DropStage { diff --git a/src/query/ast/src/visitors/walk_mut.rs b/src/query/ast/src/visitors/walk_mut.rs index 9d93580a6f040..d29873105d6ae 100644 --- a/src/query/ast/src/visitors/walk_mut.rs +++ b/src/query/ast/src/visitors/walk_mut.rs @@ -394,29 +394,12 @@ pub fn walk_statement_mut(visitor: &mut V, statement: &mut Statem Statement::Grant(stmt) => visitor.visit_grant(stmt), Statement::ShowGrants { principal } => visitor.visit_show_grant(principal), Statement::Revoke(stmt) => visitor.visit_revoke(stmt), - Statement::CreateUDF { - if_not_exists, - udf_name, - parameters, - definition, - description, - } => visitor.visit_create_udf( - *if_not_exists, - udf_name, - parameters, - definition, - description, - ), + Statement::CreateUDF(stmt) => visitor.visit_create_udf(stmt), Statement::DropUDF { if_exists, udf_name, } => visitor.visit_drop_udf(*if_exists, udf_name), - Statement::AlterUDF { - udf_name, - parameters, - definition, - description, - } => visitor.visit_alter_udf(udf_name, parameters, definition, description), + Statement::AlterUDF(stmt) => visitor.visit_alter_udf(stmt), Statement::ListStage { location, pattern } => visitor.visit_list_stage(location, pattern), Statement::ShowStages => visitor.visit_show_stages(), Statement::DropStage { diff --git a/src/query/config/src/config.rs b/src/query/config/src/config.rs index 8fce562275d64..459f6191a3b74 100644 --- a/src/query/config/src/config.rs +++ b/src/query/config/src/config.rs @@ -1529,6 +1529,13 @@ pub struct QueryConfig { /// https://platform.openai.com/docs/guides/chat #[clap(long, default_value = "gpt-3.5-turbo")] pub openai_api_completion_model: String, + + #[clap(long, default_value = "false")] + pub enable_udf_server: bool, + + /// A list of allowed udf server addresses. + #[clap(long)] + pub udf_server_allow_list: Vec, } impl Default for QueryConfig { @@ -1602,6 +1609,8 @@ impl TryInto for QueryConfig { openai_api_completion_model: self.openai_api_completion_model, openai_api_embedding_model: self.openai_api_embedding_model, openai_api_version: self.openai_api_version, + enable_udf_server: self.enable_udf_server, + udf_server_allow_list: self.udf_server_allow_list, }) } } @@ -1686,6 +1695,8 @@ impl From for QueryConfig { openai_api_version: inner.openai_api_version, openai_api_completion_model: inner.openai_api_completion_model, openai_api_embedding_model: inner.openai_api_embedding_model, + enable_udf_server: inner.enable_udf_server, + udf_server_allow_list: inner.udf_server_allow_list, } } } diff --git a/src/query/config/src/inner.rs b/src/query/config/src/inner.rs index 341407c48eae3..ac4b7cd9a7e87 100644 --- a/src/query/config/src/inner.rs +++ b/src/query/config/src/inner.rs @@ -208,6 +208,9 @@ pub struct QueryConfig { pub openai_api_embedding_base_url: String, pub openai_api_embedding_model: String, pub openai_api_completion_model: String, + + pub enable_udf_server: bool, + pub udf_server_allow_list: Vec, } impl Default for QueryConfig { @@ -271,6 +274,8 @@ impl Default for QueryConfig { openai_api_version: "".to_string(), openai_api_completion_model: "gpt-3.5-turbo".to_string(), openai_api_embedding_model: "text-embedding-ada-002".to_string(), + enable_udf_server: false, + udf_server_allow_list: Vec::new(), } } } diff --git a/src/query/expression/Cargo.toml b/src/query/expression/Cargo.toml index 1c95b5ba29b7e..383b233073374 100755 --- a/src/query/expression/Cargo.toml +++ b/src/query/expression/Cargo.toml @@ -12,6 +12,7 @@ test = false [dependencies] # In alphabetical order # Workspace dependencies common-arrow = { path = "../../common/arrow" } +common-base = { path = "../../common/base" } common-datavalues = { path = "../datavalues" } common-exception = { path = "../../common/exception" } common-hashtable = { path = "../../common/hashtable" } @@ -21,7 +22,10 @@ common-io = { path = "../../common/io" } # Crates.io dependencies arrow-array = "46.0.0" +arrow-flight = "46.0.0" arrow-schema = "46.0.0" +arrow-select = "46.0.0" +async-backtrace = { workspace = true } base64 = "0.21.0" chrono = { workspace = true } chrono-tz = { workspace = true } @@ -47,7 +51,7 @@ rust_decimal = "1.26" serde = { workspace = true } serde_json = { workspace = true } terminal_size = "0.2.6" - +tonic = { workspace = true } typetag = "0.2.3" unicode-segmentation = "1.10.1" diff --git a/src/query/expression/src/convert_arrow_rs/record_batch.rs b/src/query/expression/src/convert_arrow_rs/record_batch.rs index eb005a759d884..888c2b1cee312 100644 --- a/src/query/expression/src/convert_arrow_rs/record_batch.rs +++ b/src/query/expression/src/convert_arrow_rs/record_batch.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use arrow_array::RecordBatch; use arrow_schema::ArrowError; +use crate::convert_arrow_rs::schema::data_schema_to_arrow_schema; use crate::Column; use crate::DataBlock; use crate::DataSchema; @@ -32,6 +33,20 @@ impl DataBlock { RecordBatch::try_new(schema, arrays) } + /// Convert DataBlock to RecordBatch, and keep the schema not change. + pub fn to_record_batch_keep_schema( + self, + data_schema: &DataSchema, + ) -> Result { + let mut arrays = Vec::with_capacity(self.columns().len()); + for entry in self.convert_to_full().columns() { + let column = entry.value.to_owned().into_column().unwrap(); + arrays.push(column.into_arrow_rs()?) + } + let schema = Arc::new(data_schema_to_arrow_schema(data_schema)); + RecordBatch::try_new(schema, arrays) + } + pub fn from_record_batch(batch: &RecordBatch) -> Result<(Self, DataSchema), ArrowError> { let schema: DataSchema = DataSchema::try_from(&(*batch.schema()))?; if batch.num_columns() == 0 { diff --git a/src/query/expression/src/convert_arrow_rs/schema/from_data_schema.rs b/src/query/expression/src/convert_arrow_rs/schema/from_data_schema.rs index 184e7b7f331ed..5cd15bd74ea18 100644 --- a/src/query/expression/src/convert_arrow_rs/schema/from_data_schema.rs +++ b/src/query/expression/src/convert_arrow_rs/schema/from_data_schema.rs @@ -51,3 +51,19 @@ impl From<&DataType> for ArrowDataType { (&infer_schema_type(ty).expect("Generic type can not convert to arrow")).into() } } + +/// This function is similar to the above, but does not change the nullability of any type. +pub fn data_schema_to_arrow_schema(data_schema: &DataSchema) -> ArrowSchema { + let fields = data_schema + .fields + .iter() + .map(|f| { + let ty = f.data_type().into(); + ArrowField::new(f.name(), ty, f.is_nullable_or_null()) + }) + .collect::>(); + ArrowSchema { + fields: Fields::from(fields), + metadata: Default::default(), + } +} diff --git a/src/query/expression/src/convert_arrow_rs/schema/mod.rs b/src/query/expression/src/convert_arrow_rs/schema/mod.rs index b7b1e44695b3d..0ad2414c15e60 100644 --- a/src/query/expression/src/convert_arrow_rs/schema/mod.rs +++ b/src/query/expression/src/convert_arrow_rs/schema/mod.rs @@ -17,4 +17,5 @@ mod from_table_schema; mod to_data_schema; mod to_table_schema; +pub use from_data_schema::data_schema_to_arrow_schema; use from_table_schema::set_nullable; diff --git a/src/query/expression/src/evaluator.rs b/src/query/expression/src/evaluator.rs index 322263fa724b2..265922bf829a1 100644 --- a/src/query/expression/src/evaluator.rs +++ b/src/query/expression/src/evaluator.rs @@ -18,6 +18,7 @@ use std::ops::Not; use common_arrow::arrow::bitmap; use common_arrow::arrow::bitmap::Bitmap; use common_arrow::arrow::bitmap::MutableBitmap; +use common_base::runtime::GlobalQueryRuntime; use common_exception::ErrorCode; use common_exception::Result; use common_exception::Span; @@ -38,13 +39,18 @@ use crate::types::nullable::NullableDomain; use crate::types::BooleanType; use crate::types::DataType; use crate::types::NullableType; +use crate::udf_client::UDFFlightClient; use crate::utils::arrow::constant_bitmap; +use crate::utils::variant_transform::contains_variant; +use crate::utils::variant_transform::transform_variant; use crate::values::Column; use crate::values::ColumnBuilder; use crate::values::Scalar; use crate::values::Value; use crate::BlockEntry; use crate::ColumnIndex; +use crate::DataField; +use crate::DataSchema; use crate::FunctionContext; use crate::FunctionDomain; use crate::FunctionEval; @@ -159,6 +165,13 @@ impl<'a> Evaluator<'a> { ctx.render_error(*span, id.params(), &args, &function.signature.name)?; Ok(result) } + Expr::UDFServerCall { + func_name, + server_addr, + return_type, + args, + .. + } => self.run_udf_server_call(func_name, server_addr, return_type, args, validity), }; #[cfg(debug_assertions)] @@ -194,6 +207,100 @@ impl<'a> Evaluator<'a> { result } + fn run_udf_server_call( + &self, + func_name: &str, + server_addr: &str, + return_type: &DataType, + args: &[Expr], + validity: Option, + ) -> Result> { + let inputs = args + .iter() + .map(|expr| self.partial_run(expr, validity.clone())) + .collect::>>()?; + assert!( + inputs + .iter() + .filter_map(|val| match val { + Value::Column(col) => Some(col.len()), + Value::Scalar(_) => None, + }) + .all_equal() + ); + + // construct input record_batch + let num_rows = self.input_columns.num_rows(); + let fields = args + .iter() + .enumerate() + .map(|(idx, arg)| DataField::new(&format!("arg{}", idx + 1), arg.data_type().clone())) + .collect_vec(); + let data_schema = DataSchema::new(fields); + + let block_entries = inputs + .into_iter() + .zip(args.iter()) + .map(|(col, arg)| { + let arg_type = arg.data_type().clone(); + let block = if contains_variant(&arg_type) { + BlockEntry::new(arg_type, transform_variant(&col, true)?) + } else { + BlockEntry::new(arg_type, col) + }; + Ok(block) + }) + .collect::>>()?; + + let input_batch = DataBlock::new(block_entries, num_rows) + .to_record_batch_keep_schema(&data_schema) + .map_err(|err| ErrorCode::from_string(format!("{err}")))?; + + let func_name = func_name.to_string(); + let server_addr = server_addr.to_string(); + let result_batch = GlobalQueryRuntime::instance() + .runtime() + .block_on(async move { + let mut client = UDFFlightClient::connect(&server_addr).await?; + client.do_exchange(&func_name, input_batch).await + })?; + + let (result_block, result_schema) = + DataBlock::from_record_batch(&result_batch).map_err(|err| { + ErrorCode::UDFDataError(format!( + "Cannot convert arrow record batch to data block: {err}" + )) + })?; + + let result_fields = result_schema.fields(); + if result_fields.is_empty() || result_block.is_empty() { + return Err(ErrorCode::EmptyDataFromServer( + "Get empty data from UDF Server", + )); + } + + if result_fields[0].data_type() != return_type { + return Err(ErrorCode::UDFSchemaMismatch(format!( + "UDF server return incorrect type, expected: {}, but got: {}", + return_type, + result_fields[0].data_type() + ))); + } + if result_block.num_rows() != num_rows { + return Err(ErrorCode::UDFDataError(format!( + "UDF server should return {} rows, but it returned {} rows", + num_rows, + result_block.num_rows() + ))); + } + + if contains_variant(return_type) { + transform_variant(&result_block.get_by_offset(0).value, false) + } else { + Ok(result_block.get_by_offset(0).value.clone()) + } + } + fn run_cast( &self, span: Span, @@ -1228,6 +1335,7 @@ impl<'a, Index: ColumnIndex> ConstantFolder<'a, Index> { (func_expr, func_domain) } + Expr::UDFServerCall { .. } => (expr.clone(), None), }; debug_assert_eq!(expr.data_type(), new_expr.data_type()); diff --git a/src/query/expression/src/expression.rs b/src/query/expression/src/expression.rs index 3177c03870581..a9fd71538fd92 100644 --- a/src/query/expression/src/expression.rs +++ b/src/query/expression/src/expression.rs @@ -63,6 +63,14 @@ pub enum RawExpr { params: Vec, args: Vec>, }, + UDFServerCall { + span: Span, + func_name: String, + server_addr: String, + arg_types: Vec, + return_type: DataType, + args: Vec>, + }, } /// A type-checked and ready to be evaluated expression, having all overloads chosen for function calls. @@ -102,6 +110,14 @@ pub enum Expr { args: Vec>, return_type: DataType, }, + UDFServerCall { + #[educe(Hash(ignore), PartialEq(ignore), Eq(ignore))] + span: Span, + func_name: String, + server_addr: String, + return_type: DataType, + args: Vec>, + }, } /// Serializable expression used to share executable expression between nodes. @@ -136,6 +152,13 @@ pub enum RemoteExpr { args: Vec>, return_type: DataType, }, + UDFServerCall { + span: Span, + func_name: String, + server_addr: String, + return_type: DataType, + args: Vec>, + }, } impl RawExpr { @@ -148,6 +171,7 @@ impl RawExpr { RawExpr::Cast { expr, .. } => walk(expr, buf), RawExpr::FunctionCall { args, .. } => args.iter().for_each(|expr| walk(expr, buf)), RawExpr::Constant { .. } => (), + RawExpr::UDFServerCall { args, .. } => args.iter().for_each(|expr| walk(expr, buf)), } } @@ -198,6 +222,21 @@ impl RawExpr { params: params.clone(), args: args.iter().map(|expr| expr.project_column_ref(f)).collect(), }, + RawExpr::UDFServerCall { + span, + func_name, + server_addr, + arg_types, + return_type, + args, + } => RawExpr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + arg_types: arg_types.clone(), + return_type: return_type.clone(), + args: args.iter().map(|expr| expr.project_column_ref(f)).collect(), + }, } } } @@ -209,6 +248,7 @@ impl Expr { Expr::ColumnRef { span, .. } => *span, Expr::Cast { span, .. } => *span, Expr::FunctionCall { span, .. } => *span, + Expr::UDFServerCall { span, .. } => *span, } } @@ -218,6 +258,7 @@ impl Expr { Expr::ColumnRef { data_type, .. } => data_type, Expr::Cast { dest_type, .. } => dest_type, Expr::FunctionCall { return_type, .. } => return_type, + Expr::UDFServerCall { return_type, .. } => return_type, } } @@ -230,6 +271,7 @@ impl Expr { Expr::Cast { expr, .. } => walk(expr, buf), Expr::FunctionCall { args, .. } => args.iter().for_each(|expr| walk(expr, buf)), Expr::Constant { .. } => (), + Expr::UDFServerCall { args, .. } => args.iter().for_each(|expr| walk(expr, buf)), } } @@ -289,6 +331,19 @@ impl Expr { args: args.iter().map(|expr| expr.project_column_ref(f)).collect(), return_type: return_type.clone(), }, + Expr::UDFServerCall { + span, + func_name, + server_addr, + return_type, + args, + } => Expr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + return_type: return_type.clone(), + args: args.iter().map(|expr| expr.project_column_ref(f)).collect(), + }, } } @@ -339,6 +394,19 @@ impl Expr { args: args.iter().map(Expr::as_remote_expr).collect(), return_type: return_type.clone(), }, + Expr::UDFServerCall { + span, + func_name, + server_addr, + return_type, + args, + } => RemoteExpr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + return_type: return_type.clone(), + args: args.iter().map(Expr::as_remote_expr).collect(), + }, } } @@ -354,6 +422,7 @@ impl Expr { .non_deterministic && args.iter().all(|arg| arg.is_deterministic(registry)) } + Expr::UDFServerCall { .. } => false, } } } @@ -409,6 +478,19 @@ impl RemoteExpr { return_type: return_type.clone(), } } + RemoteExpr::UDFServerCall { + span, + func_name, + server_addr, + return_type, + args, + } => Expr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + return_type: return_type.clone(), + args: args.iter().map(|arg| arg.as_expr(fn_registry)).collect(), + }, } } } diff --git a/src/query/expression/src/type_check.rs b/src/query/expression/src/type_check.rs index 33be2f2188efd..412c0284fa4f6 100755 --- a/src/query/expression/src/type_check.rs +++ b/src/query/expression/src/type_check.rs @@ -126,6 +126,60 @@ pub fn check( check_function(*span, name, params, &args_expr, fn_registry) } + RawExpr::UDFServerCall { + span, + func_name, + server_addr, + arg_types, + return_type, + args, + } => { + let args: Vec<_> = args + .iter() + .map(|arg| check(arg, fn_registry)) + .try_collect()?; + if arg_types.len() != args.len() { + return Err(ErrorCode::SyntaxException(format!( + "Require {} parameters, but got: {}", + arg_types.len(), + args.len() + )) + .set_span(*span)); + } + + let checked_args = args + .iter() + .zip(arg_types) + .map(|(arg, dest_type)| { + let src_type = arg.data_type(); + if !can_auto_cast_to(src_type, dest_type, &fn_registry.default_cast_rules) { + return Err(ErrorCode::InvalidArgument(format!( + "Cannot auto cast datatype {} to {}", + src_type, dest_type, + )) + .set_span(arg.span())); + } + let is_try = fn_registry.is_auto_try_cast_rule(src_type, dest_type); + check_cast(arg.span(), is_try, arg.clone(), dest_type, fn_registry) + }) + .collect::>>()?; + + debug_assert_eq!( + &checked_args + .iter() + .map(|arg| arg.data_type().clone()) + .collect::>(), + arg_types + ); + + Ok(Expr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + return_type: return_type.clone(), + args: checked_args, + }) + } } } diff --git a/src/query/expression/src/utils/display.rs b/src/query/expression/src/utils/display.rs index 051dca7410b5a..cd869a1e23db5 100755 --- a/src/query/expression/src/utils/display.rs +++ b/src/query/expression/src/utils/display.rs @@ -437,6 +437,18 @@ impl Display for RawExpr { } write!(f, ")") } + RawExpr::UDFServerCall { + func_name, args, .. + } => { + write!(f, "{}(", func_name)?; + for (i, arg) in args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{arg}")?; + } + write!(f, ")") + } } } } @@ -640,6 +652,18 @@ impl Display for Expr { } write!(f, ")") } + Expr::UDFServerCall { + func_name, args, .. + } => { + write!(f, "{}(", func_name)?; + for (i, arg) in args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{arg}")?; + } + write!(f, ")") + } } } } @@ -772,6 +796,21 @@ impl Expr { s } }, + Expr::UDFServerCall { + func_name, args, .. + } => { + let mut s = String::new(); + s += func_name; + s += "("; + for (i, arg) in args.iter().enumerate() { + if i > 0 { + s += ", "; + } + s += &arg.sql_display(); + } + s += ")"; + s + } } } diff --git a/src/query/expression/src/utils/mod.rs b/src/query/expression/src/utils/mod.rs index c3c59329226b7..86866a8a3f613 100644 --- a/src/query/expression/src/utils/mod.rs +++ b/src/query/expression/src/utils/mod.rs @@ -21,6 +21,8 @@ pub mod date_helper; pub mod display; pub mod filter_helper; pub mod serialize; +pub mod udf_client; +pub mod variant_transform; use common_arrow::arrow::bitmap::Bitmap; use common_exception::Result; diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs new file mode 100644 index 0000000000000..923a6fdd5922d --- /dev/null +++ b/src/query/expression/src/utils/udf_client.rs @@ -0,0 +1,147 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::time::Duration; + +use arrow_array::RecordBatch; +use arrow_flight::decode::FlightRecordBatchStream; +use arrow_flight::encode::FlightDataEncoderBuilder; +use arrow_flight::flight_service_client::FlightServiceClient; +use arrow_flight::FlightDescriptor; +use arrow_select::concat::concat_batches; +use common_exception::ErrorCode; +use common_exception::Result; +use futures::stream; +use futures::StreamExt; +use futures::TryStreamExt; +use tonic::transport::channel::Channel; +use tonic::transport::Endpoint; +use tonic::Request; + +use crate::types::DataType; +use crate::DataSchema; + +const UDF_REQUEST_TIMEOUT_SEC: u64 = 180; // 180 seconds + +#[derive(Debug, Clone)] +pub struct UDFFlightClient { + inner: FlightServiceClient, +} + +impl UDFFlightClient { + #[async_backtrace::framed] + pub async fn connect(addr: &str) -> Result { + let endpoint = Endpoint::from_shared(addr.to_string()) + .map_err(|err| { + ErrorCode::UDFServerConnectError(format!("Invalid UDF Server address: {err}")) + })? + .connect_timeout(Duration::from_secs(UDF_REQUEST_TIMEOUT_SEC)); + let inner = FlightServiceClient::connect(endpoint) + .await + .map_err(|err| { + ErrorCode::UDFServerConnectError(format!( + "Cannot connect to UDF Server {addr}: {err}" + )) + })?; + Ok(UDFFlightClient { inner }) + } + + fn make_request(&self, t: T) -> Request { + let mut request = Request::new(t); + request.set_timeout(Duration::from_secs(UDF_REQUEST_TIMEOUT_SEC)); + request + } + + #[async_backtrace::framed] + pub async fn check_schema( + &mut self, + func_name: &str, + arg_types: &[DataType], + return_type: &DataType, + ) -> Result<()> { + let descriptor = FlightDescriptor::new_path(vec![func_name.to_string()]); + let request = self.make_request(descriptor); + let flight_info = self.inner.get_flight_info(request).await?.into_inner(); + let schema = flight_info + .try_decode_schema() + .and_then(|schema| DataSchema::try_from(&schema)) + .map_err(|err| ErrorCode::UDFDataError(format!("Decode UDF schema error: {err}")))?; + + let fields_num = schema.fields().len(); + if fields_num == 0 { + return Err(ErrorCode::UDFSchemaMismatch( + "UDF Server should return at least one column", + )); + } + + let (input_fields, output_fields) = schema.fields().split_at(fields_num - 1); + let expect_arg_types = input_fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); + let expect_return_type = output_fields + .iter() + .map(|f| f.data_type().clone()) + .collect::>(); + if expect_arg_types != arg_types { + return Err(ErrorCode::UDFSchemaMismatch(format!( + "UDF arg types mismatch, actual arg types: ({:?})", + expect_arg_types + .iter() + .map(ToString::to_string) + .collect::>() + .join(", ") + ))); + } + + if &expect_return_type[0] != return_type { + return Err(ErrorCode::UDFSchemaMismatch(format!( + "UDF return type mismatch, actual return type: {}", + expect_return_type[0] + ))); + } + + Ok(()) + } + + #[async_backtrace::framed] + pub async fn do_exchange( + &mut self, + func_name: &str, + input_batch: RecordBatch, + ) -> Result { + let descriptor = FlightDescriptor::new_path(vec![func_name.to_string()]); + let flight_data_stream = FlightDataEncoderBuilder::new() + .with_flight_descriptor(Some(descriptor)) + .build(stream::iter(vec![Ok(input_batch)])) + .map(|data| data.unwrap()); + let request = self.make_request(flight_data_stream); + let flight_data_stream = self.inner.do_exchange(request).await?.into_inner(); + let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( + flight_data_stream.map_err(|err| err.into()), + ) + .map_err(|err| ErrorCode::UDFDataError(format!("Decode record batch error: {err}"))); + + let batches: Vec = record_batch_stream.try_collect().await?; + if batches.is_empty() { + return Err(ErrorCode::EmptyDataFromServer( + "Get empty data from UDF Server", + )); + } + + let schema = batches[0].schema(); + concat_batches(&schema, batches.iter()) + .map_err(|err| ErrorCode::UDFDataError(err.to_string())) + } +} diff --git a/src/query/expression/src/utils/variant_transform.rs b/src/query/expression/src/utils/variant_transform.rs new file mode 100644 index 0000000000000..9a4967abd5ddd --- /dev/null +++ b/src/query/expression/src/utils/variant_transform.rs @@ -0,0 +1,101 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_exception::ErrorCode; +use common_exception::Result; +use jsonb::parse_value; +use jsonb::to_string; + +use crate::types::AnyType; +use crate::types::DataType; +use crate::values::Column; +use crate::values::Scalar; +use crate::values::Value; +use crate::ColumnBuilder; +use crate::ScalarRef; + +pub fn contains_variant(data_type: &DataType) -> bool { + match data_type { + DataType::Variant => true, + DataType::Null + | DataType::EmptyArray + | DataType::EmptyMap + | DataType::Boolean + | DataType::String + | DataType::Number(_) + | DataType::Decimal(_) + | DataType::Timestamp + | DataType::Date + | DataType::Bitmap + | DataType::Generic(_) => false, + DataType::Nullable(ty) => contains_variant(ty.as_ref()), + DataType::Array(ty) => contains_variant(ty.as_ref()), + DataType::Map(ty) => contains_variant(ty.as_ref()), + DataType::Tuple(types) => types.iter().any(contains_variant), + } +} + +/// This function decodes variant data into string or parses the string into variant data. +/// When `decode` is true, decoding the variant data into string so that UDF Server can handle the variant data. +/// Otherwise parsing the string into variant data. +pub fn transform_variant(value: &Value, decode: bool) -> Result> { + let value = match value { + Value::Scalar(scalar) => Value::Scalar(transform_scalar(scalar.as_ref(), decode)?), + Value::Column(col) => Value::Column(transform_column(col, decode)?), + }; + Ok(value) +} + +fn transform_column(col: &Column, decode: bool) -> Result { + let mut builder = ColumnBuilder::with_capacity(&col.data_type(), col.len()); + for scalar in col.iter() { + builder.push(transform_scalar(scalar, decode)?.as_ref()); + } + Ok(builder.build()) +} + +fn transform_scalar(scalar: ScalarRef<'_>, decode: bool) -> Result { + let scalar = match scalar { + ScalarRef::Null + | ScalarRef::EmptyArray + | ScalarRef::EmptyMap + | ScalarRef::Number(_) + | ScalarRef::Decimal(_) + | ScalarRef::Timestamp(_) + | ScalarRef::Date(_) + | ScalarRef::Boolean(_) + | ScalarRef::String(_) + | ScalarRef::Bitmap(_) => scalar.to_owned(), + ScalarRef::Array(col) => Scalar::Array(transform_column(&col, decode)?), + ScalarRef::Map(col) => Scalar::Map(transform_column(&col, decode)?), + ScalarRef::Tuple(scalars) => { + let scalars = scalars + .into_iter() + .map(|scalar| transform_scalar(scalar, decode)) + .collect::>>()?; + Scalar::Tuple(scalars) + } + ScalarRef::Variant(data) => { + if decode { + Scalar::Variant(to_string(data).into_bytes()) + } else { + let value = parse_value(data).map_err(|err| { + ErrorCode::UDFDataError(format!("parse json value error: {err}")) + })?; + Scalar::Variant(value.to_vec()) + } + } + }; + Ok(scalar) +} diff --git a/src/query/expression/src/values.rs b/src/query/expression/src/values.rs index 73785c13ff843..57420c733f25b 100755 --- a/src/query/expression/src/values.rs +++ b/src/query/expression/src/values.rs @@ -1337,7 +1337,12 @@ impl Column { let offsets = arrow_col.offsets().clone().into_inner(); let offsets = unsafe { std::mem::transmute::, Buffer>(offsets) }; - Column::String(StringColumn::new(arrow_col.values().clone(), offsets)) + if data_type.is_variant() { + // Variant column from udf server is converted to LargeBinary, we restore it back here. + Column::Variant(StringColumn::new(arrow_col.values().clone(), offsets)) + } else { + Column::String(StringColumn::new(arrow_col.values().clone(), offsets)) + } } // TODO: deprecate it and use LargeBinary instead ArrowDataType::Binary => { diff --git a/src/query/management/Cargo.toml b/src/query/management/Cargo.toml index 6fa1a10f53420..46e906a4ff2a0 100644 --- a/src/query/management/Cargo.toml +++ b/src/query/management/Cargo.toml @@ -30,6 +30,7 @@ async-trait = "0.1.57" serde_json = { workspace = true } [dev-dependencies] +common-expression = { path = "../../query/expression" } common-meta-embedded = { path = "../../meta/embedded" } common-storage = { path = "../../common/storage" } mockall = "0.11.2" diff --git a/src/query/management/src/udf/udf_mgr.rs b/src/query/management/src/udf/udf_mgr.rs index 62b57c4ce7994..85c00161e3d16 100644 --- a/src/query/management/src/udf/udf_mgr.rs +++ b/src/query/management/src/udf/udf_mgr.rs @@ -21,13 +21,14 @@ use common_functions::is_builtin_function; use common_meta_app::principal::UserDefinedFunction; use common_meta_kvapi::kvapi; use common_meta_kvapi::kvapi::UpsertKVReq; -use common_meta_types::IntoSeqV; use common_meta_types::MatchSeq; use common_meta_types::MatchSeqExt; use common_meta_types::MetaError; use common_meta_types::Operation; use common_meta_types::SeqV; +use crate::serde::deserialize_struct; +use crate::serde::serialize_struct; use crate::udf::UdfApi; static UDF_API_KEY_PREFIX: &str = "__fd_udfs"; @@ -64,7 +65,7 @@ impl UdfApi for UdfMgr { } let seq = MatchSeq::Exact(0); - let val = Operation::Update(serde_json::to_vec(&info)?); + let val = Operation::Update(serialize_struct(&info, ErrorCode::IllegalUDFFormat, || "")?); let key = format!("{}/{}", self.udf_prefix, escape_for_key(&info.name)?); let upsert_info = self .kv_api @@ -89,7 +90,7 @@ impl UdfApi for UdfMgr { // Check if UDF is defined let _ = self.get_udf(info.name.as_str(), seq).await?; - let val = Operation::Update(serde_json::to_vec(&info)?); + let val = Operation::Update(serialize_struct(&info, ErrorCode::IllegalUDFFormat, || "")?); let key = format!("{}/{}", self.udf_prefix, escape_for_key(&info.name)?); let upsert_info = self .kv_api @@ -115,7 +116,11 @@ impl UdfApi for UdfMgr { res.ok_or_else(|| ErrorCode::UnknownUDF(format!("Unknown Function {}", udf_name)))?; match seq.match_seq(&seq_value) { - Ok(_) => Ok(seq_value.into_seqv()?), + Ok(_) => Ok(SeqV::with_meta( + seq_value.seq, + seq_value.meta.clone(), + deserialize_struct(&seq_value.data, ErrorCode::IllegalUDFFormat, || "")?, + )), Err(_) => Err(ErrorCode::UnknownUDF(format!( "Unknown Function {}", udf_name @@ -129,7 +134,7 @@ impl UdfApi for UdfMgr { let mut udfs = Vec::with_capacity(values.len()); for (_, value) in values { - let udf = serde_json::from_slice::(&value.data)?; + let udf = deserialize_struct(&value.data, ErrorCode::IllegalUDFFormat, || "")?; udfs.push(udf); } Ok(udfs) diff --git a/src/query/management/tests/it/udf.rs b/src/query/management/tests/it/udf.rs index 016bbf99f4d70..e90c32d42727f 100644 --- a/src/query/management/tests/it/udf.rs +++ b/src/query/management/tests/it/udf.rs @@ -15,7 +15,10 @@ use std::sync::Arc; use common_base::base::tokio; +use common_exception::ErrorCode; use common_exception::Result; +use common_expression::types::DataType; +use common_expression::types::NumberDataType; use common_management::*; use common_meta_app::principal::UserDefinedFunction; use common_meta_embedded::MetaEmbedded; @@ -27,9 +30,12 @@ use common_meta_types::SeqV; async fn test_add_udf() -> Result<()> { let (kv_api, udf_api) = new_udf_api().await?; - let udf = create_test_udf(); + // lambda udf + let udf = create_test_lambda_udf(); udf_api.add_udf(udf.clone()).await?; - let value = kv_api.get_kv("__fd_udfs/admin/isnotempty").await?; + let value = kv_api + .get_kv(format!("__fd_udfs/admin/{}", udf.name).as_str()) + .await?; match value { Some(SeqV { @@ -37,7 +43,30 @@ async fn test_add_udf() -> Result<()> { meta: _, data: value, }) => { - assert_eq!(value, serde_json::to_vec(&udf)?); + assert_eq!( + value, + serialize_struct(&udf, ErrorCode::IllegalUDFFormat, || "")? + ); + } + catch => panic!("GetKVActionReply{:?}", catch), + } + // udf server + let udf = create_test_udf_server(); + udf_api.add_udf(udf.clone()).await?; + let value = kv_api + .get_kv(format!("__fd_udfs/admin/{}", udf.name).as_str()) + .await?; + + match value { + Some(SeqV { + seq: 2, + meta: _, + data: value, + }) => { + assert_eq!( + value, + serialize_struct(&udf, ErrorCode::IllegalUDFFormat, || "")? + ); } catch => panic!("GetKVActionReply{:?}", catch), } @@ -49,9 +78,17 @@ async fn test_add_udf() -> Result<()> { async fn test_already_exists_add_udf() -> Result<()> { let (_, udf_api) = new_udf_api().await?; - let udf = create_test_udf(); + // lambda udf + let udf = create_test_lambda_udf(); udf_api.add_udf(udf.clone()).await?; + match udf_api.add_udf(udf.clone()).await { + Ok(_) => panic!("Already exists add udf must be return Err."), + Err(cause) => assert_eq!(cause.code(), 2603), + } + // udf server + let udf = create_test_udf_server(); + udf_api.add_udf(udf.clone()).await?; match udf_api.add_udf(udf.clone()).await { Ok(_) => panic!("Already exists add udf must be return Err."), Err(cause) => assert_eq!(cause.code(), 2603), @@ -67,11 +104,14 @@ async fn test_successfully_get_udfs() -> Result<()> { let udfs = udf_api.get_udfs().await?; assert_eq!(udfs, vec![]); - let udf = create_test_udf(); - udf_api.add_udf(udf.clone()).await?; + let lambda_udf = create_test_lambda_udf(); + let udf_server = create_test_udf_server(); + + udf_api.add_udf(lambda_udf.clone()).await?; + udf_api.add_udf(udf_server.clone()).await?; let udfs = udf_api.get_udfs().await?; - assert_eq!(udfs[0], udf); + assert_eq!(udfs, vec![lambda_udf, udf_server]); Ok(()) } @@ -79,13 +119,17 @@ async fn test_successfully_get_udfs() -> Result<()> { async fn test_successfully_drop_udf() -> Result<()> { let (_, udf_api) = new_udf_api().await?; - let udf = create_test_udf(); - udf_api.add_udf(udf.clone()).await?; + let lambda_udf = create_test_lambda_udf(); + let udf_server = create_test_udf_server(); + + udf_api.add_udf(lambda_udf.clone()).await?; + udf_api.add_udf(udf_server.clone()).await?; let udfs = udf_api.get_udfs().await?; - assert_eq!(udfs, vec![udf.clone()]); + assert_eq!(udfs, vec![lambda_udf.clone(), udf_server.clone()]); - udf_api.drop_udf(&udf.name, MatchSeq::GE(1)).await?; + udf_api.drop_udf(&lambda_udf.name, MatchSeq::GE(1)).await?; + udf_api.drop_udf(&udf_server.name, MatchSeq::GE(1)).await?; let udfs = udf_api.get_udfs().await?; assert_eq!(udfs, vec![]); @@ -104,8 +148,8 @@ async fn test_unknown_udf_drop_udf() -> Result<()> { Ok(()) } -fn create_test_udf() -> UserDefinedFunction { - UserDefinedFunction::new( +fn create_test_lambda_udf() -> UserDefinedFunction { + UserDefinedFunction::create_lambda_udf( "isnotempty", vec!["p".to_string()], "not(is_null(p))", @@ -113,6 +157,18 @@ fn create_test_udf() -> UserDefinedFunction { ) } +fn create_test_udf_server() -> UserDefinedFunction { + UserDefinedFunction::create_udf_server( + "strlen", + "http://localhost:8888", + "strlen_py", + "python", + vec![DataType::String], + DataType::Number(NumberDataType::Int64), + "This is a description", + ) +} + async fn new_udf_api() -> Result<(Arc, UdfMgr)> { let test_api = Arc::new(MetaEmbedded::new_temp().await?); let mgr = UdfMgr::create(test_api.clone(), "admin")?; diff --git a/src/query/service/src/interpreters/interpreter_delete.rs b/src/query/service/src/interpreters/interpreter_delete.rs index ebdf993d1ea55..91aaabf70a1ec 100644 --- a/src/query/service/src/interpreters/interpreter_delete.rs +++ b/src/query/service/src/interpreters/interpreter_delete.rs @@ -430,6 +430,11 @@ pub fn replace_subquery( replace_subquery(filters, arg)?; } } + ScalarExpr::UDFServerCall(udf) => { + for arg in &mut udf.arguments { + replace_subquery(filters, arg)?; + } + } ScalarExpr::SubqueryExpr { .. } => { let filter = filters.pop_back().unwrap(); *selection = filter; diff --git a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt index d8f48dfdeedc2..8eb6abc04a180 100644 --- a/src/query/service/tests/it/storages/testdata/configs_table_basic.txt +++ b/src/query/service/tests/it/storages/testdata/configs_table_basic.txt @@ -60,6 +60,7 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'default_compression' | 'auto' | '' | | 'query' | 'default_storage_format' | 'auto' | '' | | 'query' | 'disable_system_table_load' | 'false' | '' | +| 'query' | 'enable_udf_server' | 'false' | '' | | 'query' | 'flight_api_address' | '127.0.0.1:9090' | '' | | 'query' | 'flight_sql_handler_host' | '127.0.0.1' | '' | | 'query' | 'flight_sql_handler_port' | '8900' | '' | @@ -105,6 +106,7 @@ DB.Table: 'system'.'configs', Table: configs-table_id:1, ver:0, Engine: SystemCo | 'query' | 'share_endpoint_auth_token_file' | '' | '' | | 'query' | 'table_engine_memory_enabled' | 'true' | '' | | 'query' | 'tenant_id' | 'test' | '' | +| 'query' | 'udf_server_allow_list' | '' | '' | | 'query' | 'users' | '{"name":"root","auth_type":"no_password","auth_string":null}' | '' | | 'query' | 'wait_timeout_mills' | '5000' | '' | | 'storage' | 'allow_insecure' | 'false' | '' | diff --git a/src/query/sql/src/evaluator/cse.rs b/src/query/sql/src/evaluator/cse.rs index 1cf770d9a2731..fcf1589cb1d27 100644 --- a/src/query/sql/src/evaluator/cse.rs +++ b/src/query/sql/src/evaluator/cse.rs @@ -142,6 +142,14 @@ fn count_expressions(expr: &Expr, counter: &mut HashMap) { } // ignore constant and column ref Expr::Constant { .. } | Expr::ColumnRef { .. } => {} + Expr::UDFServerCall { args, .. } => { + let entry = counter.entry(expr.clone()).or_insert(0); + *entry += 1; + + for arg in args { + count_expressions(arg, counter); + } + } } } @@ -167,5 +175,10 @@ fn perform_cse_replacement(expr: &mut Expr, cse_replacements: &HashMap {} + Expr::UDFServerCall { args, .. } => { + for arg in args.iter_mut() { + perform_cse_replacement(arg, cse_replacements); + } + } } } diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index e64babb1985a3..28a868a7c5908 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -49,6 +49,7 @@ use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; +use crate::plans::UDFServerCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; @@ -211,7 +212,22 @@ impl<'a> AggregateRewriter<'a> { } .into()) } - + ScalarExpr::UDFServerCall(udf) => { + let new_args = udf + .arguments + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into()) + } ScalarExpr::LambdaFunction(lambda_func) => { let new_args = lambda_func .args diff --git a/src/query/sql/src/planner/binder/binder.rs b/src/query/sql/src/planner/binder/binder.rs index 310fd1593bcb7..34f62c7d4f94f 100644 --- a/src/query/sql/src/planner/binder/binder.rs +++ b/src/query/sql/src/planner/binder/binder.rs @@ -34,7 +34,6 @@ use common_expression::ConstantFolder; use common_expression::Expr; use common_functions::BUILTIN_FUNCTIONS; use common_meta_app::principal::StageFileFormatType; -use common_meta_app::principal::UserDefinedFunction; use indexmap::IndexMap; use log::warn; @@ -43,11 +42,8 @@ use crate::binder::ColumnBindingBuilder; use crate::binder::CteInfo; use crate::normalize_identifier; use crate::optimizer::SExpr; -use crate::planner::udf_validator::UDFValidator; -use crate::plans::AlterUDFPlan; use crate::plans::CreateFileFormatPlan; use crate::plans::CreateRolePlan; -use crate::plans::CreateUDFPlan; use crate::plans::DropFileFormatPlan; use crate::plans::DropRolePlan; use crate::plans::DropStagePlan; @@ -440,54 +436,8 @@ impl<'a> Binder { Statement::ShowFileFormats => Plan::ShowFileFormats(Box::new(ShowFileFormatsPlan {})), // UDFs - Statement::CreateUDF { - if_not_exists, - udf_name, - parameters, - definition, - description, - } => { - let mut validator = UDFValidator { - name: udf_name.to_string(), - parameters: parameters.iter().map(|v| v.to_string()).collect(), - ..Default::default() - }; - validator.verify_definition_expr(definition)?; - let udf = UserDefinedFunction { - name: validator.name, - parameters: validator.parameters, - definition: definition.to_string(), - description: description.clone().unwrap_or_default(), - }; - - Plan::CreateUDF(Box::new(CreateUDFPlan { - if_not_exists: *if_not_exists, - udf, - })) - } - Statement::AlterUDF { - udf_name, - parameters, - definition, - description, - } => { - let mut validator = UDFValidator { - name: udf_name.to_string(), - parameters: parameters.iter().map(|v| v.to_string()).collect(), - ..Default::default() - }; - validator.verify_definition_expr(definition)?; - let udf = UserDefinedFunction { - name: validator.name, - parameters: validator.parameters, - definition: definition.to_string(), - description: description.clone().unwrap_or_default(), - }; - - Plan::AlterUDF(Box::new(AlterUDFPlan { - udf, - })) - } + Statement::CreateUDF(stmt) => self.bind_create_udf(stmt).await?, + Statement::AlterUDF(stmt) => self.bind_alter_udf(stmt).await?, Statement::DropUDF { if_exists, udf_name, diff --git a/src/query/sql/src/planner/binder/delete.rs b/src/query/sql/src/planner/binder/delete.rs index 0e2e985a0fa02..d7cd0e9185809 100644 --- a/src/query/sql/src/planner/binder/delete.rs +++ b/src/query/sql/src/planner/binder/delete.rs @@ -200,6 +200,12 @@ impl Binder { .await?; subquery_desc.push(desc); } + ScalarExpr::UDFServerCall(scalar) => { + for arg in scalar.arguments.iter() { + self.subquery_desc(arg, table_expr.clone(), subquery_desc) + .await?; + } + } ScalarExpr::BoundColumnRef(_) | ScalarExpr::ConstantExpr(_) | ScalarExpr::WindowFunction(_) diff --git a/src/query/sql/src/planner/binder/lambda.rs b/src/query/sql/src/planner/binder/lambda.rs index fb32ff363551f..318b0b11f3653 100644 --- a/src/query/sql/src/planner/binder/lambda.rs +++ b/src/query/sql/src/planner/binder/lambda.rs @@ -29,6 +29,7 @@ use crate::plans::Lambda; use crate::plans::LambdaFunc; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; +use crate::plans::UDFServerCall; use crate::plans::WindowFunc; use crate::plans::WindowOrderBy; use crate::BindContext; @@ -86,6 +87,23 @@ impl<'a> LambdaRewriter<'a> { } .into()), + ScalarExpr::UDFServerCall(udf) => { + let new_args = udf + .arguments + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into()) + } + // TODO(leiysky): should we recursively process subquery here? ScalarExpr::SubqueryExpr(_) => Ok(scalar.clone()), diff --git a/src/query/sql/src/planner/binder/mod.rs b/src/query/sql/src/planner/binder/mod.rs index 8ed93cd02ed19..aeaaeb96c9339 100644 --- a/src/query/sql/src/planner/binder/mod.rs +++ b/src/query/sql/src/planner/binder/mod.rs @@ -47,6 +47,7 @@ mod sort; mod stage; mod table; mod table_args; +mod udf; mod update; mod values; mod window; diff --git a/src/query/sql/src/planner/binder/scalar_common.rs b/src/query/sql/src/planner/binder/scalar_common.rs index 488ad58637722..32ccd9e98cf32 100644 --- a/src/query/sql/src/planner/binder/scalar_common.rs +++ b/src/query/sql/src/planner/binder/scalar_common.rs @@ -161,6 +161,7 @@ pub fn contain_subquery(scalar: &ScalarExpr) -> bool { } ScalarExpr::FunctionCall(func) => func.arguments.iter().any(contain_subquery), ScalarExpr::CastExpr(CastExpr { argument, .. }) => contain_subquery(argument), + ScalarExpr::UDFServerCall(udf) => udf.arguments.iter().any(contain_subquery), _ => false, } } @@ -212,6 +213,10 @@ pub fn prune_by_children(scalar: &ScalarExpr, columns: &HashSet) -> .all(|arg| prune_by_children(arg, columns)), ScalarExpr::CastExpr(expr) => prune_by_children(expr.argument.as_ref(), columns), ScalarExpr::SubqueryExpr(_) => false, + ScalarExpr::UDFServerCall(udf) => udf + .arguments + .iter() + .all(|arg| prune_by_children(arg, columns)), } } diff --git a/src/query/sql/src/planner/binder/scalar_visitor.rs b/src/query/sql/src/planner/binder/scalar_visitor.rs index 87a94af703724..470497ac19b9a 100644 --- a/src/query/sql/src/planner/binder/scalar_visitor.rs +++ b/src/query/sql/src/planner/binder/scalar_visitor.rs @@ -98,6 +98,11 @@ pub trait ScalarVisitor: Sized { stack.push(RecursionProcessing::Call(&cast.argument)) } ScalarExpr::SubqueryExpr(_) => {} + ScalarExpr::UDFServerCall(udf) => { + for arg in udf.arguments.iter() { + stack.push(RecursionProcessing::Call(arg)); + } + } } visitor diff --git a/src/query/sql/src/planner/binder/sort.rs b/src/query/sql/src/planner/binder/sort.rs index 2fcf342356066..b0d5bd350dece 100644 --- a/src/query/sql/src/planner/binder/sort.rs +++ b/src/query/sql/src/planner/binder/sort.rs @@ -40,6 +40,7 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::Sort; use crate::plans::SortItem; +use crate::plans::UDFServerCall; use crate::BindContext; use crate::IndexType; use crate::WindowChecker; @@ -401,6 +402,24 @@ impl Binder { target_type: target_type.clone(), })) } + ScalarExpr::UDFServerCall(udf) => { + let new_args = udf + .arguments + .iter() + .map(|arg| { + self.rewrite_scalar_with_replacement(bind_context, arg, replacement_fn) + }) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into()) + } _ => Ok(original_scalar.clone()), }, } diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs new file mode 100644 index 0000000000000..0db050f47a8fa --- /dev/null +++ b/src/query/sql/src/planner/binder/udf.rs @@ -0,0 +1,134 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use common_ast::ast::AlterUDFStmt; +use common_ast::ast::CreateUDFStmt; +use common_ast::ast::Identifier; +use common_ast::ast::UDFDefinition; +use common_config::GlobalConfig; +use common_exception::ErrorCode; +use common_exception::Result; +use common_expression::types::DataType; +use common_expression::udf_client::UDFFlightClient; +use common_meta_app::principal::LambdaUDF; +use common_meta_app::principal::UDFDefinition as PlanUDFDefinition; +use common_meta_app::principal::UDFServer; +use common_meta_app::principal::UserDefinedFunction; + +use crate::planner::resolve_type_name; +use crate::planner::udf_validator::UDFValidator; +use crate::plans::AlterUDFPlan; +use crate::plans::CreateUDFPlan; +use crate::plans::Plan; +use crate::Binder; + +impl Binder { + pub(in crate::planner::binder) async fn bind_udf_definition( + &mut self, + udf_name: &Identifier, + udf_description: &Option, + udf_definition: &UDFDefinition, + ) -> Result { + match udf_definition { + UDFDefinition::LambdaUDF { + parameters, + definition, + } => { + let mut validator = UDFValidator { + name: udf_name.to_string(), + parameters: parameters.iter().map(|v| v.to_string()).collect(), + ..Default::default() + }; + validator.verify_definition_expr(definition)?; + Ok(UserDefinedFunction { + name: validator.name, + description: udf_description.clone().unwrap_or_default(), + definition: PlanUDFDefinition::LambdaUDF(LambdaUDF { + parameters: validator.parameters, + definition: definition.to_string(), + }), + }) + } + UDFDefinition::UDFServer { + arg_types, + return_type, + address, + handler, + language, + } => { + if !GlobalConfig::instance().query.enable_udf_server { + return Err(ErrorCode::Unimplemented( + "UDF server is not allowed, you can enable it by setting 'enable_udf_server = true' in query node config", + )); + } + + let udf_server_allow_list = &GlobalConfig::instance().query.udf_server_allow_list; + if udf_server_allow_list + .iter() + .all(|addr| addr.trim_end_matches('/') != address.trim_end_matches('/')) + { + return Err(ErrorCode::InvalidArgument(format!( + "Unallowed UDF server address, '{address}' is not in udf_server_allow_list" + ))); + } + + let mut arg_datatypes = Vec::with_capacity(arg_types.len()); + for arg_type in arg_types { + arg_datatypes.push(DataType::from(&resolve_type_name(arg_type, true)?)); + } + let return_type = DataType::from(&resolve_type_name(return_type, true)?); + + let mut client = UDFFlightClient::connect(address).await?; + client + .check_schema(handler, &arg_datatypes, &return_type) + .await?; + + Ok(UserDefinedFunction { + name: udf_name.to_string(), + description: udf_description.clone().unwrap_or_default(), + definition: PlanUDFDefinition::UDFServer(UDFServer { + address: address.clone(), + arg_types: arg_datatypes, + return_type, + handler: handler.clone(), + language: language.clone(), + }), + }) + } + } + } + + pub(in crate::planner::binder) async fn bind_create_udf( + &mut self, + stmt: &CreateUDFStmt, + ) -> Result { + let udf = self + .bind_udf_definition(&stmt.udf_name, &stmt.description, &stmt.definition) + .await?; + Ok(Plan::CreateUDF(Box::new(CreateUDFPlan { + if_not_exists: stmt.if_not_exists, + udf, + }))) + } + + pub(in crate::planner::binder) async fn bind_alter_udf( + &mut self, + stmt: &AlterUDFStmt, + ) -> Result { + let udf = self + .bind_udf_definition(&stmt.udf_name, &stmt.description, &stmt.definition) + .await?; + Ok(Plan::AlterUDF(Box::new(AlterUDFPlan { udf }))) + } +} diff --git a/src/query/sql/src/planner/binder/window.rs b/src/query/sql/src/planner/binder/window.rs index ebdefdafcccb8..c287fc130c974 100644 --- a/src/query/sql/src/planner/binder/window.rs +++ b/src/query/sql/src/planner/binder/window.rs @@ -33,6 +33,7 @@ use crate::plans::LambdaFunc; use crate::plans::NthValueFunction; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; +use crate::plans::UDFServerCall; use crate::plans::Window; use crate::plans::WindowFunc; use crate::plans::WindowFuncFrame; @@ -306,7 +307,22 @@ impl<'a> WindowRewriter<'a> { self.in_window = false; Ok(scalar) } - + ScalarExpr::UDFServerCall(udf) => { + let new_args = udf + .arguments + .iter() + .map(|arg| self.visit(arg)) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into()) + } ScalarExpr::LambdaFunction(lambda_func) => { let new_args = lambda_func .args diff --git a/src/query/sql/src/planner/format/display_rel_operator.rs b/src/query/sql/src/planner/format/display_rel_operator.rs index 66a377b744feb..0434ad18be379 100644 --- a/src/query/sql/src/planner/format/display_rel_operator.rs +++ b/src/query/sql/src/planner/format/display_rel_operator.rs @@ -124,6 +124,17 @@ pub fn format_scalar(scalar: &ScalarExpr) -> String { ) } ScalarExpr::SubqueryExpr(_) => "SUBQUERY".to_string(), + ScalarExpr::UDFServerCall(udf) => { + format!( + "{}({})", + &udf.func_name, + udf.arguments + .iter() + .map(|arg| { format_scalar(arg) }) + .collect::>() + .join(", ") + ) + } } } diff --git a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs index 442d284c5f8be..d877c1ec4b871 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs @@ -48,6 +48,7 @@ use crate::plans::ScalarItem; use crate::plans::Scan; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDFServerCall; use crate::BaseTableColumn; use crate::ColumnEntry; use crate::DerivedColumn; @@ -837,6 +838,21 @@ impl SubqueryRewriter { target_type: cast_expr.target_type.clone(), })) } + ScalarExpr::UDFServerCall(udf) => { + let arguments = udf + .arguments + .iter() + .map(|arg| self.flatten_scalar(arg, correlated_columns)) + .collect::>>()?; + Ok(ScalarExpr::UDFServerCall(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments, + })) + } _ => Err(ErrorCode::Internal( "Invalid scalar for flattening subquery", )), diff --git a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs index bcf0c2b0012f1..2048eeeec6cdd 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs @@ -47,6 +47,7 @@ use crate::plans::ScalarExpr; use crate::plans::ScalarItem; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDFServerCall; use crate::plans::WindowFuncType; use crate::IndexType; use crate::MetadataRef; @@ -338,6 +339,27 @@ impl SubqueryRewriter { Ok((scalar, s_expr)) } + ScalarExpr::UDFServerCall(udf) => { + let mut args = vec![]; + let mut s_expr = s_expr.clone(); + for arg in udf.arguments.iter() { + let res = self.try_rewrite_subquery(arg, &s_expr, false)?; + s_expr = res.1; + args.push(res.0); + } + + let expr: ScalarExpr = UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: args, + } + .into(); + + Ok((expr, s_expr)) + } } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs index 310f4ae347c4f..71e87b8c20cd5 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/agg_index/query_rewrite.rs @@ -38,6 +38,7 @@ use crate::plans::EvalScalar; use crate::plans::FunctionCall; use crate::plans::RelOperator; use crate::plans::ScalarItem; +use crate::plans::UDFServerCall; use crate::ColumnEntry; use crate::ColumnSet; use crate::IndexType; @@ -360,6 +361,11 @@ fn rewrite_scalar_index( ScalarExpr::CastExpr(cast) => { rewrite_scalar_index(table_index, columns, &mut cast.argument); } + ScalarExpr::UDFServerCall(udf) => { + udf.arguments + .iter_mut() + .for_each(|arg| rewrite_scalar_index(table_index, columns, arg)); + } _ => { /* do nothing */ } } } @@ -707,6 +713,15 @@ impl RewriteInfomartion<'_> { .join(", ") ) } + ScalarExpr::UDFServerCall(udf) => format!( + "{}({})", + &udf.func_name, + udf.arguments + .iter() + .map(|arg| { self.format_scalar(arg) }) + .collect::>() + .join(", ") + ), _ => unreachable!(), // Window function and subquery will not appear in index. } } @@ -994,6 +1009,24 @@ fn rewrite_query_item( .into(), ) } + ScalarExpr::UDFServerCall(udf) => { + let mut new_args = Vec::with_capacity(udf.arguments.len()); + for arg in udf.arguments.iter() { + let new_arg = rewrite_by_selection(query_info, arg, index_selection)?; + new_args.push(new_arg); + } + Some( + UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into(), + ) + } ScalarExpr::AggregateFunction(_) => None, /* Aggregate function must appear in index selection. */ _ => unreachable!(), // Window function and subquery will not appear in index. } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs index 1cc581958e366..33cbe3d93dec9 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/filter_join/derive_filter.rs @@ -175,5 +175,10 @@ fn replace_column(scalar: &mut ScalarExpr, col_to_scalar: &HashMap<&IndexType, & replace_column(&mut expr.argument, col_to_scalar); } ScalarExpr::ConstantExpr(_) | ScalarExpr::SubqueryExpr(_) => {} + ScalarExpr::UDFServerCall(expr) => { + for arg in expr.arguments.iter_mut() { + replace_column(arg, col_to_scalar) + } + } } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs index bad89e0860d6b..928b4285a820a 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_eval_scalar.rs @@ -35,6 +35,7 @@ use crate::plans::PatternPlan; use crate::plans::RelOp; use crate::plans::ScalarExpr; use crate::plans::ScalarItem; +use crate::plans::UDFServerCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; @@ -225,6 +226,22 @@ impl RulePushDownFilterEvalScalar { target_type: cast.target_type.clone(), })) } + ScalarExpr::UDFServerCall(udf) => { + let arguments = udf + .arguments + .iter() + .map(|arg| Self::replace_predicate(arg, items)) + .collect::>>()?; + + Ok(ScalarExpr::UDFServerCall(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments, + })) + } _ => Ok(predicate.clone()), } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs index c19fcc1b8fc94..5812589c11d89 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_scan.rs @@ -32,6 +32,7 @@ use crate::plans::NthValueFunction; use crate::plans::PatternPlan; use crate::plans::RelOp; use crate::plans::Scan; +use crate::plans::UDFServerCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; use crate::plans::WindowOrderBy; @@ -251,6 +252,22 @@ impl RulePushDownFilterScan { target_type: cast.target_type.clone(), })) } + ScalarExpr::UDFServerCall(udf) => { + let arguments = udf + .arguments + .iter() + .map(|arg| Self::replace_view_column(arg, table_entries, column_entries)) + .collect::>>()?; + + Ok(ScalarExpr::UDFServerCall(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments, + })) + } _ => Ok(predicate.clone()), } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs index 94a61fe77a431..f9f4806b1e0fa 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_filter_union.rs @@ -34,6 +34,7 @@ use crate::plans::NthValueFunction; use crate::plans::PatternPlan; use crate::plans::RelOp; use crate::plans::ScalarExpr; +use crate::plans::UDFServerCall; use crate::plans::UnionAll; use crate::plans::WindowFunc; use crate::plans::WindowFuncType; @@ -271,5 +272,21 @@ fn replace_column_binding( ScalarExpr::SubqueryExpr(_) => Err(ErrorCode::Unimplemented( "replace_column_binding: don't support subquery", )), + ScalarExpr::UDFServerCall(udf) => { + let arguments = udf + .arguments + .into_iter() + .map(|arg| replace_column_binding(index_pairs, arg)) + .collect::>>()?; + + Ok(ScalarExpr::UDFServerCall(UDFServerCall { + span: udf.span, + func_name: udf.func_name, + server_addr: udf.server_addr, + arg_types: udf.arg_types, + return_type: udf.return_type, + arguments, + })) + } } } diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs index fa01f9f2293a4..accbed029a802 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_prewhere.rs @@ -89,6 +89,11 @@ impl RulePushDownPrewhere { Self::collect_columns_impl(table_index, schema, cast.argument.as_ref(), columns)?; } ScalarExpr::ConstantExpr(_) => {} + ScalarExpr::UDFServerCall(udf) => { + for arg in udf.arguments.iter() { + Self::collect_columns_impl(table_index, schema, arg, columns)?; + } + } _ => { // SubqueryExpr and AggregateFunction will not appear in Filter-LogicalGet return Err(ErrorCode::Unimplemented(format!( diff --git a/src/query/sql/src/planner/optimizer/s_expr.rs b/src/query/sql/src/planner/optimizer/s_expr.rs index 9b55f416a09f2..ca726cea325bf 100644 --- a/src/query/sql/src/planner/optimizer/s_expr.rs +++ b/src/query/sql/src/planner/optimizer/s_expr.rs @@ -314,5 +314,6 @@ fn find_subquery_in_expr(expr: &ScalarExpr) -> bool { ScalarExpr::FunctionCall(expr) => expr.arguments.iter().any(find_subquery_in_expr), ScalarExpr::CastExpr(expr) => find_subquery_in_expr(&expr.argument), ScalarExpr::SubqueryExpr(_) => true, + ScalarExpr::UDFServerCall(expr) => expr.arguments.iter().any(find_subquery_in_expr), } } diff --git a/src/query/sql/src/planner/plans/scalar_expr.rs b/src/query/sql/src/planner/plans/scalar_expr.rs index 3dbfb7c7bddee..1a7d450c22f26 100644 --- a/src/query/sql/src/planner/plans/scalar_expr.rs +++ b/src/query/sql/src/planner/plans/scalar_expr.rs @@ -43,6 +43,7 @@ pub enum ScalarExpr { FunctionCall(FunctionCall), CastExpr(CastExpr), SubqueryExpr(SubqueryExpr), + UDFServerCall(UDFServerCall), } impl ScalarExpr { @@ -87,6 +88,13 @@ impl ScalarExpr { } ScalarExpr::CastExpr(scalar) => scalar.argument.used_columns(), ScalarExpr::SubqueryExpr(scalar) => scalar.outer_columns.clone(), + ScalarExpr::UDFServerCall(scalar) => { + let mut result = ColumnSet::new(); + for scalar in &scalar.arguments { + result = result.union(&scalar.used_columns()).cloned().collect(); + } + result + } } } @@ -128,6 +136,13 @@ impl ScalarExpr { "SubqueryExpr/WindowFunction doesn't support used_tables method".to_string(), )) } + ScalarExpr::UDFServerCall(scalar) => { + let mut result = vec![]; + for scalar in &scalar.arguments { + result.append(&mut scalar.used_tables(metadata.clone())?); + } + Ok(result) + } } } @@ -147,6 +162,7 @@ impl ScalarExpr { }), ScalarExpr::CastExpr(expr) => expr.span.or(expr.argument.span()), ScalarExpr::SubqueryExpr(expr) => expr.span, + ScalarExpr::UDFServerCall(expr) => expr.span, _ => None, } } @@ -156,7 +172,8 @@ impl ScalarExpr { ScalarExpr::BoundColumnRef(_) | ScalarExpr::ConstantExpr(_) => true, ScalarExpr::WindowFunction(_) | ScalarExpr::AggregateFunction(_) - | ScalarExpr::SubqueryExpr(_) => false, + | ScalarExpr::SubqueryExpr(_) + | ScalarExpr::UDFServerCall(_) => false, ScalarExpr::FunctionCall(func) => { func.arguments.iter().all(|arg| arg.valid_for_clustering()) } @@ -316,6 +333,25 @@ impl TryFrom for SubqueryExpr { } } +impl From for ScalarExpr { + fn from(v: UDFServerCall) -> Self { + Self::UDFServerCall(v) + } +} + +impl TryFrom for UDFServerCall { + type Error = ErrorCode; + fn try_from(value: ScalarExpr) -> Result { + if let ScalarExpr::UDFServerCall(value) = value { + Ok(value) + } else { + Err(ErrorCode::Internal( + "Cannot downcast Scalar to UDFServerCall", + )) + } + } +} + #[derive(Clone, Debug, Educe)] #[educe(PartialEq, Eq, Hash)] pub struct BoundColumnRef { @@ -536,3 +572,15 @@ impl SubqueryExpr { fn hash_column_set(columns: &ColumnSet, state: &mut H) { columns.iter().for_each(|c| c.hash(state)); } + +#[derive(Clone, Debug, Educe)] +#[educe(PartialEq, Eq, Hash)] +pub struct UDFServerCall { + #[educe(Hash(ignore), PartialEq(ignore), Eq(ignore))] + pub span: Span, + pub func_name: String, + pub server_addr: String, + pub arg_types: Vec, + pub return_type: Box, + pub arguments: Vec, +} diff --git a/src/query/sql/src/planner/semantic/grouping_check.rs b/src/query/sql/src/planner/semantic/grouping_check.rs index f633955fa42c1..d457b24468faf 100644 --- a/src/query/sql/src/planner/semantic/grouping_check.rs +++ b/src/query/sql/src/planner/semantic/grouping_check.rs @@ -23,6 +23,7 @@ use crate::plans::CastExpr; use crate::plans::FunctionCall; use crate::plans::LambdaFunc; use crate::plans::ScalarExpr; +use crate::plans::UDFServerCall; use crate::BindContext; /// Check validity of scalar expression in a grouping context. @@ -192,6 +193,22 @@ impl<'a> GroupingChecker<'a> { } Err(ErrorCode::Internal("Invalid aggregate function")) } + ScalarExpr::UDFServerCall(udf) => { + let args = udf + .arguments + .iter() + .map(|arg| self.resolve(arg, span)) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: args, + } + .into()) + } } } } diff --git a/src/query/sql/src/planner/semantic/lowering.rs b/src/query/sql/src/planner/semantic/lowering.rs index 92a482bebd421..58b257a3bc74b 100644 --- a/src/query/sql/src/planner/semantic/lowering.rs +++ b/src/query/sql/src/planner/semantic/lowering.rs @@ -124,6 +124,27 @@ fn resolve_column_type( }) } RawExpr::Constant { .. } => Ok(raw_expr.clone()), + RawExpr::UDFServerCall { + span, + func_name, + server_addr, + arg_types, + return_type, + args, + } => { + let args = args + .iter() + .map(|arg| resolve_column_type(arg, context)) + .collect::>>()?; + Ok(RawExpr::UDFServerCall { + span: *span, + func_name: func_name.clone(), + server_addr: server_addr.clone(), + arg_types: arg_types.clone(), + return_type: return_type.clone(), + args, + }) + } } } @@ -245,6 +266,14 @@ impl ScalarExpr { data_type: subquery.data_type(), display_name: "DUMMY".to_string(), }, + ScalarExpr::UDFServerCall(udf) => RawExpr::UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: (*udf.return_type).clone(), + args: udf.arguments.iter().map(ScalarExpr::as_raw_expr).collect(), + }, } } diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 00ac7050aeb52..0e443846d3220 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -38,6 +38,7 @@ use common_ast::parser::tokenize_sql; use common_ast::Dialect; use common_catalog::catalog::CatalogManager; use common_catalog::table_context::TableContext; +use common_config::GlobalConfig; use common_exception::ErrorCode; use common_exception::Result; use common_exception::Span; @@ -66,8 +67,12 @@ use common_functions::GENERAL_LAMBDA_FUNCTIONS; use common_functions::GENERAL_WINDOW_FUNCTIONS; use common_license::license::Feature::VirtualColumn; use common_license::license_manager::get_license_manager; +use common_meta_app::principal::LambdaUDF; +use common_meta_app::principal::UDFDefinition; +use common_meta_app::principal::UDFServer; use common_users::UserApiProvider; use indexmap::IndexMap; +use itertools::Itertools; use simsearch::SimSearch; use super::name_resolution::NameResolutionContext; @@ -94,6 +99,7 @@ use crate::plans::NtileFunction; use crate::plans::ScalarExpr; use crate::plans::SubqueryExpr; use crate::plans::SubqueryType; +use crate::plans::UDFServerCall; use crate::plans::WindowFunc; use crate::plans::WindowFuncFrame; use crate::plans::WindowFuncFrameBound; @@ -2563,7 +2569,7 @@ impl<'a> TypeChecker<'a> { async fn resolve_udf( &mut self, span: Span, - func_name: &str, + udf_name: &str, arguments: &[Expr], ) -> Result>> { if self.forbid_udf { @@ -2571,7 +2577,7 @@ impl<'a> TypeChecker<'a> { } let udf = UserApiProvider::instance() - .get_udf(self.ctx.get_tenant().as_str(), func_name) + .get_udf(self.ctx.get_tenant().as_str(), udf_name) .await; let udf = if let Ok(udf) = udf { @@ -2580,7 +2586,83 @@ impl<'a> TypeChecker<'a> { return Ok(None); }; - let parameters = udf.parameters; + match udf.definition { + UDFDefinition::LambdaUDF(udf_def) => Ok(Some( + self.resolve_lambda_udf(span, arguments, udf_def).await?, + )), + UDFDefinition::UDFServer(udf_def) => Ok(Some( + self.resolve_udf_server(span, arguments, udf_def).await?, + )), + } + } + + #[async_recursion::async_recursion] + #[async_backtrace::framed] + async fn resolve_udf_server( + &mut self, + span: Span, + arguments: &[Expr], + udf_definition: UDFServer, + ) -> Result> { + if !GlobalConfig::instance().query.enable_udf_server { + return Err(ErrorCode::Unimplemented( + "UDF server is not allowed, you can enable it by setting 'enable_udf_server = true' in query node config", + )); + } + + let udf_server_allow_list = &GlobalConfig::instance().query.udf_server_allow_list; + let address = &udf_definition.address; + if udf_server_allow_list + .iter() + .all(|addr| addr.trim_end_matches('/') != address.trim_end_matches('/')) + { + return Err(ErrorCode::InvalidArgument(format!( + "Unallowed UDF server address, '{address}' is not in udf_server_allow_list" + ))); + } + + let mut args = Vec::with_capacity(arguments.len()); + for argument in arguments { + let box (arg, _) = self.resolve(argument).await?; + args.push(arg); + } + + let raw_expr_args = args.iter().map(|arg| arg.as_raw_expr()).collect_vec(); + let raw_expr = RawExpr::UDFServerCall { + span, + func_name: udf_definition.handler.clone(), + server_addr: udf_definition.address.clone(), + arg_types: udf_definition.arg_types.clone(), + return_type: udf_definition.return_type.clone(), + args: raw_expr_args, + }; + + type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?; + + self.ctx.set_cacheable(false); + Ok(Box::new(( + UDFServerCall { + span, + func_name: udf_definition.handler, + server_addr: udf_definition.address, + arg_types: udf_definition.arg_types, + return_type: Box::new(udf_definition.return_type.clone()), + arguments: args, + } + .into(), + udf_definition.return_type.clone(), + ))) + } + + #[async_recursion::async_recursion] + #[async_backtrace::framed] + async fn resolve_lambda_udf( + &mut self, + span: Span, + arguments: &[Expr], + udf_definition: LambdaUDF, + ) -> Result> { + let parameters = udf_definition.parameters; if parameters.len() != arguments.len() { return Err(ErrorCode::SyntaxException(format!( "Require {} parameters, but got: {}", @@ -2591,7 +2673,7 @@ impl<'a> TypeChecker<'a> { } let settings = self.ctx.get_settings(); let sql_dialect = settings.get_sql_dialect()?; - let sql_tokens = tokenize_sql(udf.definition.as_str())?; + let sql_tokens = tokenize_sql(udf_definition.definition.as_str())?; let expr = parse_expr(&sql_tokens, sql_dialect)?; let mut args_map = HashMap::new(); arguments.iter().enumerate().for_each(|(idx, argument)| { @@ -2610,7 +2692,7 @@ impl<'a> TypeChecker<'a> { }) .map_err(|e| e.set_span(span))?; - Ok(Some(self.resolve(&udf_expr).await?)) + self.resolve(&udf_expr).await } #[async_recursion::async_recursion] diff --git a/src/query/sql/src/planner/semantic/window_check.rs b/src/query/sql/src/planner/semantic/window_check.rs index 8379fbf474d91..42018055c2a78 100644 --- a/src/query/sql/src/planner/semantic/window_check.rs +++ b/src/query/sql/src/planner/semantic/window_check.rs @@ -19,6 +19,7 @@ use crate::binder::ColumnBindingBuilder; use crate::plans::BoundColumnRef; use crate::plans::CastExpr; use crate::plans::FunctionCall; +use crate::plans::UDFServerCall; use crate::BindContext; use crate::ScalarExpr; use crate::Visibility; @@ -97,6 +98,22 @@ impl<'a> WindowChecker<'a> { } ScalarExpr::AggregateFunction(_) => unreachable!(), + ScalarExpr::UDFServerCall(udf) => { + let new_args = udf + .arguments + .iter() + .map(|arg| self.resolve(arg)) + .collect::>>()?; + Ok(UDFServerCall { + span: udf.span, + func_name: udf.func_name.clone(), + server_addr: udf.server_addr.clone(), + arg_types: udf.arg_types.clone(), + return_type: udf.return_type.clone(), + arguments: new_args, + } + .into()) + } } } } diff --git a/src/query/storages/system/src/functions_table.rs b/src/query/storages/system/src/functions_table.rs index c0ded48ecc78d..d69b531cfa6e3 100644 --- a/src/query/storages/system/src/functions_table.rs +++ b/src/query/storages/system/src/functions_table.rs @@ -80,13 +80,13 @@ impl AsyncSystemTable for FunctionsTable { let definitions = (0..names.len()) .map(|i| { if i < builtin_func_len { - "" + "".to_string() } else { udfs.get(i - builtin_func_len) - .map_or("", |udf| udf.definition.as_str()) + .map_or("".to_string(), |udf| udf.definition.to_string()) } }) - .collect::>(); + .collect::>(); let categories = (0..names.len()) .map(|i| if i < builtin_func_len { "" } else { "UDF" }) @@ -106,24 +106,28 @@ impl AsyncSystemTable for FunctionsTable { let syntaxes = (0..names.len()) .map(|i| { if i < builtin_func_len { - "" + "".to_string() } else { udfs.get(i - builtin_func_len) - .map_or("", |udf| udf.definition.as_str()) + .map_or("".to_string(), |udf| udf.definition.to_string()) } }) - .collect::>(); + .collect::>(); let examples = (0..names.len()).map(|_| "").collect::>(); - Ok(DataBlock::new_from_columns(vec![ StringType::from_data(names), BooleanType::from_data(is_builtin), BooleanType::from_data(is_aggregate), - StringType::from_data(definitions), + StringType::from_data( + definitions + .iter() + .map(String::as_str) + .collect::>(), + ), StringType::from_data(categories), StringType::from_data(descriptions), - StringType::from_data(syntaxes), + StringType::from_data(syntaxes.iter().map(String::as_str).collect::>()), StringType::from_data(examples), ])) } diff --git a/src/query/storages/system/src/util.rs b/src/query/storages/system/src/util.rs index 2a4bba36e01d7..203aef5a05cb4 100644 --- a/src/query/storages/system/src/util.rs +++ b/src/query/storages/system/src/util.rs @@ -17,7 +17,7 @@ use common_expression::Scalar; pub fn find_eq_filter(expr: &Expr, visitor: &mut impl FnMut(&str, &Scalar)) { match expr { - Expr::Constant { .. } | Expr::ColumnRef { .. } => {} + Expr::Constant { .. } | Expr::ColumnRef { .. } | Expr::UDFServerCall { .. } => {} Expr::Cast { expr, .. } => find_eq_filter(expr, visitor), Expr::FunctionCall { function, args, .. } => { if function.signature.name == "eq" { diff --git a/src/query/users/Cargo.toml b/src/query/users/Cargo.toml index 1792ce1b24063..93654756ed5b9 100644 --- a/src/query/users/Cargo.toml +++ b/src/query/users/Cargo.toml @@ -41,5 +41,6 @@ serde = { workspace = true } serde_json = "1" [dev-dependencies] +common-expression = { path = "../expression" } pretty_assertions = "1.3.0" wiremock = "0.5.14" diff --git a/src/query/users/tests/it/user_udf.rs b/src/query/users/tests/it/user_udf.rs index e21b5d487d2b2..70fc93670d1eb 100644 --- a/src/query/users/tests/it/user_udf.rs +++ b/src/query/users/tests/it/user_udf.rs @@ -14,13 +14,14 @@ use common_base::base::tokio; use common_exception::Result; +use common_expression::types::DataType; use common_grpc::RpcClientConf; use common_meta_app::principal::UserDefinedFunction; use common_users::UserApiProvider; use pretty_assertions::assert_eq; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] -async fn test_user_udf() -> Result<()> { +async fn test_user_lambda_udf() -> Result<()> { let conf = RpcClientConf::default(); let user_mgr = UserApiProvider::try_create_simple(conf).await?; @@ -31,43 +32,120 @@ async fn test_user_udf() -> Result<()> { let if_not_exists = false; // add isempty. + let isempty_udf = UserDefinedFunction::create_lambda_udf( + isempty, + vec!["p".to_string()], + "is_null(p)", + description, + ); + user_mgr + .add_udf(tenant, isempty_udf.clone(), if_not_exists) + .await?; + + // add isnotempty. + let isnotempty_udf = UserDefinedFunction::create_lambda_udf( + isnotempty, + vec!["p".to_string()], + "not(isempty(p))", + description, + ); + user_mgr + .add_udf(tenant, isnotempty_udf.clone(), if_not_exists) + .await?; + + // get all. { - let udf = - UserDefinedFunction::new(isempty, vec!["p".to_string()], "is_null(p)", description); - user_mgr.add_udf(tenant, udf, if_not_exists).await?; + let udfs = user_mgr.get_udfs(tenant).await?; + assert_eq!(udfs, vec![isempty_udf.clone(), isnotempty_udf.clone()]); } - // add isnotempty. + // get. + { + let udf = user_mgr.get_udf(tenant, isempty).await?; + assert_eq!(udf, isempty_udf.clone()); + } + + // drop. { - let udf = UserDefinedFunction::new( - isnotempty, - vec!["p".to_string()], - "not(isempty(p))", - description, - ); - user_mgr.add_udf(tenant, udf, if_not_exists).await?; + user_mgr.drop_udf(tenant, isnotempty, false).await?; + let udfs = user_mgr.get_udfs(tenant).await?; + assert_eq!(udfs, vec![isempty_udf]); } + // repeat drop same one not with if exist. + { + let res = user_mgr.drop_udf(tenant, isnotempty, false).await; + assert!(res.is_err()); + } + + // repeat drop same one with if exist. + { + let res = user_mgr.drop_udf(tenant, isnotempty, true).await; + assert!(res.is_ok()); + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_user_udf_server() -> Result<()> { + let conf = RpcClientConf::default(); + let user_mgr = UserApiProvider::try_create_simple(conf).await?; + + let tenant = "test"; + let address = "http://127.0.0.1:8888"; + let arg_types = vec![DataType::String]; + let return_type = DataType::Boolean; + let description = "this is a description"; + let isempty = "isempty"; + let isnotempty = "isnotempty"; + let if_not_exists = false; + + // add isempty. + let isempty_udf = UserDefinedFunction::create_udf_server( + isempty, + address, + isempty, + "python", + arg_types.clone(), + return_type.clone(), + description, + ); + user_mgr + .add_udf(tenant, isempty_udf.clone(), if_not_exists) + .await?; + + // add isnotempty. + let isnotempty_udf = UserDefinedFunction::create_udf_server( + isnotempty, + address, + isnotempty, + "python", + arg_types.clone(), + return_type.clone(), + description, + ); + user_mgr + .add_udf(tenant, isnotempty_udf.clone(), if_not_exists) + .await?; + // get all. { let udfs = user_mgr.get_udfs(tenant).await?; - assert_eq!(2, udfs.len()); - assert_eq!(isempty, udfs[0].name); - assert_eq!(isnotempty, udfs[1].name); + assert_eq!(udfs, vec![isempty_udf.clone(), isnotempty_udf.clone()]); } // get. { let udf = user_mgr.get_udf(tenant, isempty).await?; - assert_eq!(isempty, udf.name); + assert_eq!(udf, isempty_udf.clone()); } // drop. { user_mgr.drop_udf(tenant, isnotempty, false).await?; let udfs = user_mgr.get_udfs(tenant).await?; - assert_eq!(1, udfs.len()); - assert_eq!(isempty, udfs[0].name); + assert_eq!(udfs, vec![isempty_udf]); } // repeat drop same one not with if exist. diff --git a/tests/fuse-compat/test-fuse-compat.sh b/tests/fuse-compat/test-fuse-compat.sh index be87be566c812..a7055c302b12f 100755 --- a/tests/fuse-compat/test-fuse-compat.sh +++ b/tests/fuse-compat/test-fuse-compat.sh @@ -17,7 +17,7 @@ query_config_path="scripts/ci/deploy/config/databend-query-node-1.toml" usage() { echo " === Assert that latest query being compatible with an old version query on fuse-table format" echo " === Expect ./bins/current contains current version binaries" - echo " === Usage: $0 " + echo " === Usage: $0 " } source "${SCRIPT_PATH}/util.sh" @@ -46,10 +46,13 @@ echo " === current metasrv ver: $(./bins/current/databend-meta --single --cmd ve echo " === current query ver: $(./bins/current/databend-query --cmd ver | tr '\n' ' ')" echo " === old query ver: $old_query_ver" -download_binary "$old_query_ver" mkdir -p ./target/${BUILD_PROFILE}/ +download_query_config "$old_query_ver" old_config +download_binary "$old_query_ver" + +old_config_path="old_config/$query_config_path" run_test $old_query_ver $old_config_path $logictest_path if [ -n "$stateless_test_path" ]; diff --git a/tests/fuse-compat/util.sh b/tests/fuse-compat/util.sh index 229f8000a43f2..96e39f0d02e61 100755 --- a/tests/fuse-compat/util.sh +++ b/tests/fuse-compat/util.sh @@ -1,10 +1,59 @@ #!/bin/bash +query_config_path="scripts/ci/deploy/config/databend-query-node-1.toml" +query_test_path="tests/sqllogictests" +bend_repo_url="https://github.com/datafuselabs/databend" + + binary_url() { local ver="$1" echo "https://github.com/datafuselabs/databend/releases/download/v${ver}-nightly/databend-v${ver}-nightly-x86_64-unknown-linux-gnu.tar.gz" } +# Clone only specified dir or file in the specified commit +git_partial_clone() { + local repo_url="$1" + local branch="$2" + local worktree_path="$3" + local local_path="$4" + + echo " === Clone $repo_url@$branch:$worktree_path" + echo " === To $local_path/$worktree_path" + + rm -rf "$local_path" || echo "no $local_path" + + git clone \ + -b "$branch" \ + --depth 1 \ + --quiet \ + --filter=blob:none \ + --sparse \ + "$repo_url" \ + "$local_path" + + cd "$local_path" + git sparse-checkout set "$worktree_path" + + echo " === Done clone from $repo_url@$branch:$worktree_path" + + ls "$worktree_path" + + cd - +} + +# Download config.toml for a specific version of query. +download_query_config() { + local ver="$1" + local local_dir="$2" + + config_dir="$(dirname $query_config_path)" + echo " === Download query config.toml from $ver:$config_dir" + + git_partial_clone "$bend_repo_url" "v$ver-nightly" "$config_dir" "$local_dir" +} + + + # download a specific version of databend, untar it to folder `./bins/$ver` # `ver` is semver without prefix `v` or `-nightly` download_binary() { @@ -66,7 +115,8 @@ run_test() { python3 -m pip list local query_old_ver="$1" - local logictest_path="tests/fuse-compat/compat-logictest/$2" + local old_config_path="$2" + local logictest_path="tests/fuse-compat/compat-logictest/$3" echo " === Test with query-$query_old_ver and current query" @@ -76,6 +126,7 @@ run_test() { local metasrv_new="./bins/current/databend-meta" local sqllogictests="./bins/current/databend-sqllogictests" + echo " === metasrv version:" # TODO remove --single "$metasrv_new" --single --cmd ver || echo " === no version yet" @@ -106,12 +157,10 @@ run_test() { echo ' === Start old databend-query...' - config_path="scripts/ci/deploy/config/databend-query-node-1.toml" - # TODO clean up data? echo " === bring up $query_old" - nohup "$query_old" -c "$config_path" --log-level DEBUG --meta-endpoints "0.0.0.0:9191" >query-old.log & + nohup "$query_old" -c "$old_config_path" --log-level DEBUG --meta-endpoints "0.0.0.0:9191" >query-old.log & python3 scripts/ci/wait_tcp.py --timeout 5 --port 3307 echo " === Run test: fuse_compat_write with old query" @@ -130,6 +179,7 @@ run_test() { echo " === Start new databend-query..." + config_path="scripts/ci/deploy/config/databend-query-node-1.toml" echo "new databend config path: $config_path" nohup "$query_new" -c "$config_path" --log-level DEBUG --meta-endpoints "0.0.0.0:9191" >query-current.log & diff --git a/tests/sqllogictests/suites/base/05_ddl/05_0010_ddl_create_udf b/tests/sqllogictests/suites/base/05_ddl/05_0010_ddl_create_udf index 127a07756b8b0..fbde1dbc25b28 100644 --- a/tests/sqllogictests/suites/base/05_ddl/05_0010_ddl_create_udf +++ b/tests/sqllogictests/suites/base/05_ddl/05_0010_ddl_create_udf @@ -27,4 +27,3 @@ DROP FUNCTION isnotempty statement ok DROP FUNCTION isnotempty_with_desc - diff --git a/tests/sqllogictests/suites/base/05_ddl/05_0011_ddl_drop_udf b/tests/sqllogictests/suites/base/05_ddl/05_0011_ddl_drop_udf index 669301952717c..270f4ea95541c 100644 --- a/tests/sqllogictests/suites/base/05_ddl/05_0011_ddl_drop_udf +++ b/tests/sqllogictests/suites/base/05_ddl/05_0011_ddl_drop_udf @@ -12,4 +12,3 @@ DROP FUNCTION IF EXISTS isnotempty statement error 2602 DROP FUNCTION isnotempty - diff --git a/tests/sqllogictests/suites/base/05_ddl/05_0013_ddl_alter_udf b/tests/sqllogictests/suites/base/05_ddl/05_0013_ddl_alter_udf index 014f6a06bdfa7..5aff03a7ee064 100644 --- a/tests/sqllogictests/suites/base/05_ddl/05_0013_ddl_alter_udf +++ b/tests/sqllogictests/suites/base/05_ddl/05_0013_ddl_alter_udf @@ -15,4 +15,3 @@ ALTER FUNCTION is_not_null AS (d) -> not(is_null(d)) statement ok DROP FUNCTION test_alter_udf - diff --git a/tests/sqllogictests/suites/udf_server/udf_server_test b/tests/sqllogictests/suites/udf_server/udf_server_test new file mode 100644 index 0000000000000..afff4a93be547 --- /dev/null +++ b/tests/sqllogictests/suites/udf_server/udf_server_test @@ -0,0 +1,358 @@ +# Please start the UDF Server first before running this test: +# python3 tests/udf-server/udf_test.py +# + + +statement ok +CREATE FUNCTION add_signed (TINYINT, SMALLINT, INT, BIGINT) RETURNS BIGINT LANGUAGE python HANDLER = 'add_signed' ADDRESS = 'http://0.0.0.0:8815' + +statement ok +CREATE FUNCTION add_unsigned (TINYINT UNSIGNED, SMALLINT UNSIGNED, INT UNSIGNED, BIGINT UNSIGNED) RETURNS BIGINT UNSIGNED LANGUAGE python HANDLER = 'add_unsigned' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION add_float (FLOAT, DOUBLE) RETURNS DOUBLE LANGUAGE python HANDLER = 'add_float' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION bool_select (BOOLEAN, BIGINT, BIGINT) RETURNS BIGINT LANGUAGE python HANDLER = 'bool_select' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION gcd (INT, INT) RETURNS INT LANGUAGE python HANDLER = 'gcd' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION split_and_join (VARCHAR, VARCHAR, VARCHAR) RETURNS VARCHAR LANGUAGE python HANDLER = 'split_and_join' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION decimal_div (DECIMAL(36, 18), DECIMAL(36, 18)) RETURNS DECIMAL(72, 28) LANGUAGE python HANDLER = 'decimal_div' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION hex_to_dec (VARCHAR) RETURNS DECIMAL(36, 18) LANGUAGE python HANDLER = 'hex_to_dec' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION add_days_py (DATE, INT) RETURNS DATE LANGUAGE python HANDLER = 'add_days_py' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION add_hours_py (TIMESTAMP, INT) RETURNS TIMESTAMP LANGUAGE python HANDLER = 'add_hours_py' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION array_access (ARRAY(VARCHAR), INT) RETURNS VARCHAR LANGUAGE python HANDLER = 'array_access' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION array_index_of (ARRAY(BIGINT NULL), BIGINT) RETURNS INT NOT NULL LANGUAGE python HANDLER = 'array_index_of' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION map_access (MAP(VARCHAR, VARCHAR), VARCHAR) RETURNS VARCHAR LANGUAGE python HANDLER = 'map_access' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION json_access (VARIANT, VARCHAR) RETURNS VARIANT LANGUAGE python HANDLER = 'json_access' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION json_concat (ARRAY(VARIANT)) RETURNS VARIANT LANGUAGE python HANDLER = 'json_concat' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION tuple_access (TUPLE(ARRAY(VARIANT NULL), INT, VARCHAR), INT, INT) RETURNS TUPLE(VARIANT NULL, VARIANT NULL) LANGUAGE python HANDLER = 'tuple_access' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION return_all (BOOLEAN, TINYINT, SMALLINT, INT, BIGINT, TINYINT UNSIGNED, SMALLINT UNSIGNED, INT UNSIGNED, BIGINT UNSIGNED, FLOAT, DOUBLE, DATE, TIMESTAMP, VARCHAR, VARIANT) RETURNS TUPLE(BOOLEAN NULL, TINYINT NULL, SMALLINT NULL, INT NULL, BIGINT NULL, TINYINT UNSIGNED NULL, SMALLINT UNSIGNED NULL, INT UNSIGNED NULL, BIGINT UNSIGNED NULL, FLOAT NULL, DOUBLE NULL, DATE NULL, TIMESTAMP NULL, VARCHAR NULL, VARIANT NULL) LANGUAGE python HANDLER = 'return_all' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION return_all_arrays (ARRAY(BOOLEAN), ARRAY(TINYINT), ARRAY(SMALLINT), ARRAY(INT), ARRAY(BIGINT), ARRAY(TINYINT UNSIGNED), ARRAY(SMALLINT UNSIGNED), ARRAY(INT UNSIGNED), ARRAY(BIGINT UNSIGNED), ARRAY(FLOAT), ARRAY(DOUBLE), ARRAY(DATE), ARRAY(TIMESTAMP), ARRAY(VARCHAR), ARRAY(VARIANT)) RETURNS TUPLE(ARRAY(BOOLEAN), ARRAY(TINYINT), ARRAY(SMALLINT), ARRAY(INT), ARRAY(BIGINT), ARRAY(TINYINT UNSIGNED), ARRAY(SMALLINT UNSIGNED), ARRAY(INT UNSIGNED), ARRAY(BIGINT UNSIGNED), ARRAY(FLOAT), ARRAY(DOUBLE), ARRAY(DATE), ARRAY(TIMESTAMP), ARRAY(VARCHAR), ARRAY(VARIANT)) LANGUAGE python HANDLER = 'return_all_arrays' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION return_all_non_nullable (BOOLEAN NOT NULL, TINYINT NOT NULL, SMALLINT NOT NULL, INT NOT NULL, BIGINT NOT NULL, TINYINT UNSIGNED NOT NULL, SMALLINT UNSIGNED NOT NULL, INT UNSIGNED NOT NULL, BIGINT UNSIGNED NOT NULL, FLOAT NOT NULL, DOUBLE NOT NULL, DATE NOT NULL, TIMESTAMP NOT NULL, VARCHAR NOT NULL, VARIANT NOT NULL) RETURNS TUPLE(BOOLEAN, TINYINT, SMALLINT, INT, BIGINT, TINYINT UNSIGNED, SMALLINT UNSIGNED, INT UNSIGNED, BIGINT UNSIGNED, FLOAT, DOUBLE, DATE, TIMESTAMP, VARCHAR, VARIANT) LANGUAGE python HANDLER = 'return_all_non_nullable' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION wait (INT) RETURNS INT LANGUAGE python HANDLER = 'wait' ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE FUNCTION wait_concurrent (INT) RETURNS INT LANGUAGE python HANDLER = 'wait_concurrent' ADDRESS = 'http://0.0.0.0:8815'; + +## scalar values test +query II +select add_signed(-1, 2, -3, 4), add_signed(-1, -3, 4, -5); +---- +2 -5 + +query II +select add_unsigned(1, 2, 3, 4), add_unsigned(2, 3, 4, 5); +---- +10 14 + +query F +select add_float(1.5, 2.5); +---- +4.0 + +query II +select bool_select(true, 1, 2), bool_select(false, 1, 2); +---- +1 2 + +query I +select gcd(unnest([4, 5, NULL, 8, 12, NULL]), unnest([12, 2, 3, NULL, 18, NULL])); +---- +4 +1 +NULL +NULL +6 +NULL + +query TT +select split_and_join('1; 3; 5; 7; 9', '; ', ':'); +---- +1:3:5:7:9 + +query I +select hex_to_dec('0000000000da7134f0e'); +---- +58637635342.000000000000000000 + +query F +select decimal_div(1, 7); +---- +0.1428571428571428571428571429 + +query TT +select to_date(18875), add_days_py(to_date(18875), 2); +---- +2021-09-05 2021-09-07 + +query TT +select to_datetime(1630833797), add_hours_py(to_datetime(1630833797), 2); +---- +2021-09-05 09:23:17.000000 2021-09-05 11:23:17.000000 + +query TTT +select array_access(['hello','world','rust'], 0), array_access(['hello','world','rust'], 1), array_access(['hello','world','rust'], 4); +---- +NULL hello NULL + +query IIII +select array_index_of(NULL, 1), array_index_of([3, 5, 7], 5), array_index_of([4, 6], 3), array_index_of([2, 3, NULL], NULL); +---- +0 2 0 3 + +query TT +select map_access({'ip': '192.168.1.1', 'url': 'example.com/home'}, 'ip'), map_access({'ip': '192.168.1.2', 'url': 'example.com/about'}, 'ip'); +---- +192.168.1.1 192.168.1.2 + +query T +select json_access(parse_json('{"customer_id": 123, "order_id": 1001, "items": [{"name": "Shoes", "price": 59.99}, {"name": "T-shirt", "price": 19.99}]}'), 'items'); +---- +[{"name":"Shoes","price":59.99},{"name":"T-shirt","price":19.99}] + +query T +select json_concat([parse_json('{"age": 30, "isPremium": "false", "lastActive": "2023-03-15"}'), parse_json('{"age": 25, "isPremium": "true", "lastActive": "2023-04-10"}')]); +---- +[{"age":30,"isPremium":"false","lastActive":"2023-03-15"},{"age":25,"isPremium":"true","lastActive":"2023-04-10"}] + +query T +select tuple_access(([NULL, parse_json('{"color":"red", "fontSize":16, "theme":"dark"}')], 2, 'foo'), 0, 1); +---- +(NULL,'[null,{"color":"red","fontSize":16,"theme":"dark"}]') + +query T +select return_all(true, NULL, NULL, 3, 4, NULL, 6, 7, 8, NULL, 10.2, NULL, to_datetime(1630833797), 'foo', NULL); +---- +(1,NULL,NULL,3,4,NULL,6,7,8,NULL,10.2,NULL,'2021-09-05 09:23:17.000000','foo',NULL) + +query T +select return_all_non_nullable(true, -1, 2, 3, 4, 5, 6, 7, 8, 9.1, 10.2, to_date(18866), to_datetime(1630833797), 'foo', parse_json('{"foo": 30, "bar": "false"}')); +---- +(1,-1,2,3,4,5,6,7,8,9.1,10.2,'2021-08-27','2021-09-05 09:23:17.000000','foo','{"bar":"false","foo":30}') + +query T +select return_all_arrays([true], [-1, -2], [2,64,67], [3,1234], [4,2341], [5,10], [6,1231], [7,1234], [8,63435], [9.1,231.123], [10.2,6547.789], [to_date(18866)], [to_datetime(1630833797)], ['foo'], [parse_json('{"foo": 30, "bar": "false"}')]); +---- +([1],[-1,-2],[2,64,67],[3,1234],[4,2341],[5,10],[6,1231],[7,1234],[8,63435],[9.1,231.123],[10.2,6547.789],['2021-08-27'],['2021-09-05 09:23:17.000000'],['foo'],['{"bar":"false","foo":30}']) + +## table test + +statement ok +create table decimal(value decimal(36, 18)); + +statement ok +insert into decimal values(0.152587668674722117), (0.017820781941443176); + +query F +select decimal_div(value, 3.3) from decimal; +---- +0.0462386874771885203030303030 +0.0054002369519524775757575758 + +statement ok +DROP TABLE decimal; + +statement ok +CREATE TABLE test_dt (date DATE, ts TIMESTAMP); + +statement ok +INSERT INTO test_dt VALUES ('2022-04-07', '2022-04-07 01:01:01.123456'), ('2022-04-08', '2022-04-08 01:01:01'); + +query TT +select add_days_py(date, 2), add_hours_py(ts, 2) from test_dt; +---- +2022-04-09 2022-04-07 03:01:01.123456 +2022-04-10 2022-04-08 03:01:01.000000 + +statement ok +DROP TABLE test_dt; + +statement ok +CREATE TABLE array_table(col1 ARRAY(VARCHAR), col2 ARRAY(INT64) NULL, col3 INT); + +statement ok +INSERT INTO array_table VALUES (['hello world', 'foo', 'bar'], [1, 2, 3, 4], 1), (['databend', 'sql', 'olap'], [5, 6, 1, 3], 2), (['aaaa', 'bbbb', 'cccc'], NULL, 3); + +query +select array_access(col1, col3), array_access(col1, 2), array_access(col1, 0) from array_table; +---- +hello world foo NULL +sql sql NULL +cccc bbbb NULL + +query +select array_index_of(col2, col3), array_index_of(col2, 2), array_index_of(col2, NULL) from array_table; +---- +1 2 0 +0 0 0 +0 0 0 + +statement ok +DROP TABLE array_table; + +statement ok +CREATE TABLE web_traffic_data(id INT64, traffic_info MAP(STRING, STRING)); + +statement ok +INSERT INTO web_traffic_data VALUES(1, {'ip': '192.168.1.1', 'url': 'example.com/home'}), + (2, {'ip': '192.168.1.2', 'url': 'example.com/about'}), + (3, {'ip': '192.168.1.1', 'url': 'example.com/contact'}); + +query +SELECT map_access(traffic_info, 'ip') as ip_address, COUNT(*) as visits FROM web_traffic_data GROUP BY map_access(traffic_info, 'ip') ORDER BY map_access(traffic_info, 'ip'); +---- +192.168.1.1 2 +192.168.1.2 1 + +statement ok +DROP TABLE web_traffic_data; + +statement ok +CREATE TABLE customer_orders(id INT64, order_data VARIANT); + +statement ok +INSERT INTO customer_orders VALUES(1, parse_json('{"customer_id": 123, "order_id": 1001, "items": [{"name": "Shoes", "price": 59.99}, {"name": "T-shirt", "price": 19.99}]}')), + (2, parse_json('{"customer_id": 456, "order_id": 1002, "items": [{"name": "Backpack", "price": 79.99}, {"name": "Socks", "price": 4.99}]}')), + (3, parse_json('{"customer_id": 123, "order_id": 1003, "items": [{"name": "Shoes", "price": 59.99}, {"name": "Socks", "price": 4.99}]}')); + +query +select json_access(order_data, 'customer_id')::INT64, json_access(order_data, 'order_id'), json_access(order_data, 'items') from customer_orders; +---- +123 1001 [{"name":"Shoes","price":59.99},{"name":"T-shirt","price":19.99}] +456 1002 [{"name":"Backpack","price":79.99},{"name":"Socks","price":4.99}] +123 1003 [{"name":"Shoes","price":59.99},{"name":"Socks","price":4.99}] + +statement ok +DROP TABLE customer_orders; + +statement ok +create table test_wait(col int); + +statement ok +insert into test_wait select * from numbers(10); + +query I +select wait(col) from test_wait; +---- +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + +query I +select wait_concurrent(col) from test_wait; +---- +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + +statement ok +DROP TABLE test_wait; + +statement ok +DROP FUNCTION add_signed; + +statement ok +DROP FUNCTION add_unsigned; + +statement ok +DROP FUNCTION add_float; + +statement ok +DROP FUNCTION bool_select; + +statement ok +DROP FUNCTION gcd; + +statement ok +DROP FUNCTION split_and_join; + +statement ok +DROP FUNCTION decimal_div; + +statement ok +DROP FUNCTION hex_to_dec; + +statement ok +DROP FUNCTION add_days_py; + +statement ok +DROP FUNCTION add_hours_py; + +statement ok +DROP FUNCTION array_access; + +statement ok +DROP FUNCTION array_index_of; + +statement ok +DROP FUNCTION map_access; + +statement ok +DROP FUNCTION json_access; + +statement ok +DROP FUNCTION json_concat; + +statement ok +DROP FUNCTION tuple_access; + +statement ok +DROP FUNCTION return_all; + +statement ok +DROP FUNCTION return_all_arrays; + +statement ok +DROP FUNCTION return_all_non_nullable; + +statement ok +DROP FUNCTION wait; + +statement ok +DROP FUNCTION wait_concurrent; diff --git a/tests/udf-server/README.md b/tests/udf-server/README.md new file mode 100644 index 0000000000000..a882444ca84c5 --- /dev/null +++ b/tests/udf-server/README.md @@ -0,0 +1,152 @@ +## Databend UDF Server Tests + +```sh +pip install pyarrow +# start UDF server +python3 udf_test.py +``` + +```sh +./target/debug/databend-sqllogictests --run_dir udf_server +``` + +## Databend Python UDF Server API +This library provides a Python API for creating user-defined functions (UDF) server in Databend. + +### Introduction +Databend supports user-defined functions implemented as external functions. With the Databend Python UDF API, users can define custom UDFs using Python and start a Python process as a UDF server. Then users can call the customized UDFs in Databend. Databend will remotely access the UDF server to execute the defined functions. + +### Usage + +#### 1. Define your functions in a Python file +```python +from udf import * + +# Define a function +@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") +def split_and_join(s: str, split_s: str, join_s: str) -> str: + return join_s.join(s.split(split_s)) + +# Define a function that accpets nullable values, and set skip_null to True to enable it returns NULL if any argument is NULL. +@udf( + input_types=["INT", "INT"], + result_type="INT", + skip_null=True, +) +def gcd(x: int, y: int) -> int: + while y != 0: + (x, y) = (y, x % y) + return x + +# Define a function that accpets nullable values, and set skip_null to False to enable it handles NULL values inside the function. +@udf( + input_types=["ARRAY(INT64 NULL)", "INT64"], + result_type="INT NOT NULL", + skip_null=False, +) +def array_index_of(array: List[int], item: int): + if array is None: + return 0 + + try: + return array.index(item) + 1 + except ValueError: + return 0 + +# Define a function which is IO bound, and set io_threads to enable it can be executed concurrently. +@udf(input_types=["INT"], result_type="INT", io_threads=32) +def wait_concurrent(x): + # assume an IO operation cost 2s + time.sleep(2) + return x + +if __name__ == '__main__': + # create a UDF server listening at '0.0.0.0:8815' + server = UdfServer("0.0.0.0:8815") + # add defined functions + server.add_function(split_and_join) + server.add_function(gcd) + server.add_function(array_index_of) + server.add_function(wait_concurrent) + # start the UDF server + server.serve() +``` + +`@udf` is an annotation for creating a UDF. It supports following parameters: + +- input_types: A list of strings or Arrow data types that specifies the input data types. +- result_type: A string or an Arrow data type that specifies the return value type. +- name: An optional string specifying the function name. If not provided, the original name will be used. +- io_threads: Number of I/O threads used per data chunk for I/O bound functions. +- skip_null: A boolean value specifying whether to skip NULL value. If it is set to True, NULL values will not be passed to the function, and the corresponding return value is set to NULL. Default to False. + +#### 2. Start the UDF Server +Then we can Start the UDF Server by running: +```sh +python3 udf_server.py +``` + +#### 3. Update Databend query node config +Now, udf server is disabled by default in databend. You can enable it by setting 'enable_udf_server = true' in query node config. + +In addition, for security reasons, only the address specified in the config can be accessed by databend. The list of allowed udf server addresses are specified through the `udf_server_allowlist` variable in the query node config. + +Here is an example config: +``` +[query] +... +enable_udf_server = true +udf_server_allow_list = [ "http://0.0.0.0:8815", "http://example.com" ] +``` + +#### 4. Add the functions to Databend +We can use the `CREATE FUNCTION` command to add the functions you defined to Databend: +``` +CREATE FUNCTION [IF NOT EXISTS] (, ...) RETURNS LANGUAGE HANDLER= ADDRESS= +``` +The `udf_name` is the name of UDF you declared in Databend. The `handler` is the function name you defined in the python UDF server. + +For example: +```sql +CREATE FUNCTION split_and_join (VARCHAR, VARCHAR, VARCHAR) RETURNS VARCHAR LANGUAGE python HANDLER = 'split_and_join' ADDRESS = 'http://0.0.0.0:8815'; +``` + +NOTE: The udf_server_address you specify must appear in `udf_server_allow_list` explained in the previous step. + +> In step 2, when you starting the UDF server, the corresponding sql statement of each function will be printed out. You can use them directly. + +#### 5. Use the functions in Databend +``` +mysql> select split_and_join('3,5,7', ',', ':'); ++-----------------------------------+ +| split_and_join('3,5,7', ',', ':') | ++-----------------------------------+ +| 3:5:7 | ++-----------------------------------+ +``` + +### Data Types +The data types supported by the Python UDF API and their corresponding python types are as follows : + +| SQL Type | Python Type | +| ------------------- | ----------------- | +| BOOLEAN | bool | +| TINYINT (UNSIGNED) | int | +| SMALLINT (UNSIGNED) | int | +| INT (UNSIGNED) | int | +| BIGINT (UNSIGNED) | int | +| FLOAT | float | +| DOUBLE | float | +| DECIMAL | decimal.Decimal | +| DATE | datetime.date | +| TIMESTAMP | datetime.datetime | +| VARCHAR | str | +| VARIANT | any | +| MAP(K,V) | dict | +| ARRAY(T) | list[T] | +| TUPLE(T...) | tuple(T...) | + +The NULL in sql is represented by None in Python. + +### Acknowledgement +Databend Python UDF Server API is inspired by [RisingWave Python API](https://pypi.org/project/risingwave/). \ No newline at end of file diff --git a/tests/udf-server/udf.py b/tests/udf-server/udf.py new file mode 100644 index 0000000000000..9729305b6a5ff --- /dev/null +++ b/tests/udf-server/udf.py @@ -0,0 +1,495 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import * +import concurrent +import inspect +import json +import pyarrow as pa +import pyarrow.flight +import traceback + +# comes from Databend +MAX_DECIMAL128_PRECISION = 38 +MAX_DECIMAL256_PRECISION = 76 +EXTENSION_KEY = "Extension" +ARROW_EXT_TYPE_VARIANT = "Variant" + +TIMESTAMP_UINT = "us" + + +class UserDefinedFunction: + """ + Base interface for user-defined function. + """ + + _name: str + _input_schema: pa.Schema + _result_schema: pa.Schema + + def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: + """ + Apply the function on a batch of inputs. + """ + return iter([]) + + +class ScalarFunction(UserDefinedFunction): + """ + Base interface for user-defined scalar function. A user-defined scalar functions maps zero, one, + or multiple scalar values to a new scalar value. + """ + + _func: Callable + _io_threads: Optional[int] + _executor: Optional[ThreadPoolExecutor] + _skip_null: bool + + def __init__( + self, func, input_types, result_type, name=None, io_threads=None, skip_null=None + ): + self._func = func + self._input_schema = pa.schema( + field.with_name(arg_name) + for arg_name, field in zip( + inspect.getfullargspec(func)[0], + [_to_arrow_field(t) for t in _to_list(input_types)], + ) + ) + self._result_schema = pa.schema( + [_to_arrow_field(result_type).with_name("output")] + ) + self._name = name or ( + func.__name__ if hasattr(func, "__name__") else func.__class__.__name__ + ) + self._io_threads = io_threads + self._executor = ( + ThreadPoolExecutor(max_workers=self._io_threads) + if self._io_threads is not None + else None + ) + + if skip_null and not self._result_schema.field(0).nullable: + raise ValueError( + f"Return type of function {self._name} must be nullable when skip_null is True" + ) + + self._skip_null = skip_null or False + super().__init__() + + def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]: + inputs = [[v.as_py() for v in array] for array in batch] + inputs = [ + _process_func(pa.list_(type), False)(array) + for array, type in zip(inputs, self._input_schema.types) + ] + if self._executor is not None: + # concurrently evaluate the function for each row + if self._skip_null: + null_func = lambda *v: None + tasks = [] + for row in range(batch.num_rows): + args = [col[row] for col in inputs] + func = null_func if None in args else self._func + tasks.append(self._executor.submit(func, *args)) + else: + tasks = [ + self._executor.submit(self._func, *[col[row] for col in inputs]) + for row in range(batch.num_rows) + ] + column = [future.result() for future in tasks] + else: + # evaluate the function for each row + if self._skip_null: + column = [] + for row in range(batch.num_rows): + args = [col[row] for col in inputs] + column.append(None if None in args else self._func(*args)) + else: + column = [ + self._func(*[col[row] for col in inputs]) + for row in range(batch.num_rows) + ] + + column = _process_func(pa.list_(self._result_schema.types[0]), True)(column) + + array = pa.array(column, type=self._result_schema.types[0]) + yield pa.RecordBatch.from_arrays([array], schema=self._result_schema) + + def __call__(self, *args): + return self._func(*args) + + +def udf( + input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType]], + result_type: Union[str, pa.DataType], + name: Optional[str] = None, + io_threads: Optional[int] = None, + skip_null: Optional[bool] = False, +) -> Callable: + """ + Annotation for creating a user-defined scalar function. + + Parameters: + - input_types: A list of strings or Arrow data types that specifies the input data types. + - result_type: A string or an Arrow data type that specifies the return value type. + - name: An optional string specifying the function name. If not provided, the original name will be used. + - io_threads: Number of I/O threads used per data chunk for I/O bound functions. + - skip_null: A boolean value specifying whether to skip NULL value. If it is set to True, NULL values + will not be passed to the function, and the corresponding return value is set to NULL. Default to False. + + Example: + ``` + @udf(input_types=['INT', 'INT'], result_type='INT') + def gcd(x, y): + while y != 0: + (x, y) = (y, x % y) + return x + ``` + + I/O bound Example: + ``` + @udf(input_types=['INT'], result_type='INT', io_threads=64) + def external_api(x): + response = requests.get(my_endpoint + '?param=' + x) + return response["data"] + ``` + """ + + if io_threads is not None and io_threads > 1: + return lambda f: ScalarFunction( + f, + input_types, + result_type, + name, + io_threads=io_threads, + skip_null=skip_null, + ) + else: + return lambda f: ScalarFunction( + f, input_types, result_type, name, skip_null=skip_null + ) + + +class UDFServer(pa.flight.FlightServerBase): + """ + A server that provides user-defined functions to clients. + + Example: + ``` + server = UdfServer(location="0.0.0.0:8815") + server.add_function(my_udf) + server.serve() + ``` + """ + + _location: str + _functions: Dict[str, UserDefinedFunction] + + def __init__(self, location="0.0.0.0:8815", **kwargs): + super(UDFServer, self).__init__("grpc://" + location, **kwargs) + self._location = location + self._functions = {} + + def get_flight_info(self, context, descriptor): + """Return the result schema of a function.""" + func_name = descriptor.path[0].decode("utf-8") + if func_name not in self._functions: + raise ValueError(f"Function {func_name} does not exists") + udf = self._functions[func_name] + # return the concatenation of input and output schema + full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema)) + return pa.flight.FlightInfo( + schema=full_schema, + descriptor=descriptor, + endpoints=[], + total_records=len(full_schema), + total_bytes=0, + ) + + def do_exchange(self, context, descriptor, reader, writer): + """Call a function from the client.""" + func_name = descriptor.path[0].decode("utf-8") + if func_name not in self._functions: + raise ValueError(f"Function {func_name} does not exists") + udf = self._functions[func_name] + writer.begin(udf._result_schema) + try: + for batch in reader: + for output_batch in udf.eval_batch(batch.data): + writer.write_batch(output_batch) + except Exception as e: + print(traceback.print_exc()) + raise e + + def add_function(self, udf: UserDefinedFunction): + """Add a function to the server.""" + name = udf._name + if name in self._functions: + raise ValueError("Function already exists: " + name) + self._functions[name] = udf + input_types = ", ".join( + _arrow_field_to_string(field) for field in udf._input_schema + ) + output_type = _arrow_field_to_string(udf._result_schema[0]) + sql = f"CREATE FUNCTION {name} ({input_types}) RETURNS {output_type} LANGUAGE python HANDLER = '{name}' ADDRESS = 'http://{self._location}';" + print(f"added function: {name}, corresponding SQL:\n{sql}\n") + + def serve(self): + """Start the server.""" + print(f"listening on {self._location}") + super(UDFServer, self).serve() + + +def _process_func(type: pa.DataType, output: bool) -> Callable: + """ + Return a function to process input or output value. + + For input type: + - String=pa.string(): bytes -> str + - Tuple=pa.struct(): dict -> tuple + - Json=pa.large_binary(): bytes -> Any + - Map=pa.map_(): list[tuple(k,v)] -> dict + + For output type: + - Json=pa.large_binary(): Any -> str + - Map=pa.map_(): dict -> list[tuple(k,v)] + """ + if pa.types.is_list(type): + func = _process_func(type.value_type, output) + return ( + lambda array: [(func(v) if v is not None else None) for v in array] + if array is not None + else None + ) + if pa.types.is_struct(type): + funcs = [_process_func(field.type, output) for field in type] + if output: + return ( + lambda tup: tuple( + (func(v) if v is not None else None) for v, func in zip(tup, funcs) + ) + if tup is not None + else None + ) + else: + # the input value of struct type is a dict + # we convert it into tuple here + return ( + lambda map: tuple( + (func(v) if v is not None else None) + for v, func in zip(map.values(), funcs) + ) + if map is not None + else None + ) + if pa.types.is_map(type): + funcs = [ + _process_func(type.key_type, output), + _process_func(type.item_type, output), + ] + if output: + # dict -> list[tuple[k,v]] + return ( + lambda map: [ + tuple(func(v) for v, func in zip(item, funcs)) + for item in map.items() + ] + if map is not None + else None + ) + else: + # list[tuple[k,v]] -> dict + return ( + lambda array: dict( + tuple(func(v) for v, func in zip(item, funcs)) for item in array + ) + if array is not None + else None + ) + + if pa.types.is_string(type) and not output: + # string type is converted to LargeBinary in Databend, + # we cast it back to string here + return lambda v: v.decode("utf-8") if v is not None else None + if pa.types.is_large_binary(type): + if output: + return lambda v: json.dumps(v) if v is not None else None + else: + return lambda v: json.loads(v) if v is not None else None + return lambda v: v + + +def _to_list(x): + if isinstance(x, list): + return x + else: + return [x] + + +def _to_arrow_field(t: Union[str, pa.DataType]) -> pa.Field: + """ + Convert a string or pyarrow.DataType to pyarrow.Field. + """ + if isinstance(t, str): + return _type_str_to_arrow_field(t) + else: + return pa.field("", t, False) + + +def _type_str_to_arrow_field(type_str: str) -> pa.Field: + """ + Convert a SQL data type to `pyarrow.Field`. + """ + type_str = type_str.strip().upper() + nullable = True + if type_str.endswith("NULL"): + type_str = type_str[:-4].strip() + if type_str.endswith("NOT"): + type_str = type_str[:-3].strip() + nullable = False + + return _type_str_to_arrow_field_inner(type_str).with_nullable(nullable) + + +def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field: + type_str = type_str.strip().upper() + if type_str in ("BOOLEAN", "BOOL"): + return pa.field("", pa.bool_(), False) + elif type_str in ("TINYINT", "INT8"): + return pa.field("", pa.int8(), False) + elif type_str in ("SMALLINT", "INT16"): + return pa.field("", pa.int16(), False) + elif type_str in ("INT", "INTEGER", "INT32"): + return pa.field("", pa.int32(), False) + elif type_str in ("BIGINT", "INT64"): + return pa.field("", pa.int64(), False) + elif type_str in ("TINYINT UNSIGNED", "UINT8"): + return pa.field("", pa.uint8(), False) + elif type_str in ("SMALLINT UNSIGNED", "UINT16"): + return pa.field("", pa.uint16(), False) + elif type_str in ("INT UNSIGNED", "INTEGER UNSIGNED", "UINT32"): + return pa.field("", pa.uint32(), False) + elif type_str in ("BIGINT UNSIGNED", "UINT64"): + return pa.field("", pa.uint64(), False) + elif type_str in ("FLOAT", "FLOAT32"): + return pa.field("", pa.float32(), False) + elif type_str in ("FLOAT64", "DOUBLE"): + return pa.field("", pa.float64(), False) + elif type_str == "DATE": + return pa.field("", pa.date32(), False) + elif type_str in ("DATETIME", "TIMESTAMP"): + return pa.field("", pa.timestamp(TIMESTAMP_UINT), False) + elif type_str in ("STRING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"): + return pa.field("", pa.string(), False) + elif type_str in ("VARIANT", "JSON"): + # In Databend, JSON type is identified by the "EXTENSION" key in the metadata. + return pa.field( + "", + pa.large_binary(), + nullable=False, + metadata={EXTENSION_KEY: ARROW_EXT_TYPE_VARIANT}, + ) + elif type_str.startswith("NULLABLE"): + type_str = type_str[8:].strip("()").strip() + return _type_str_to_arrow_field_inner(type_str).with_nullable(True) + elif type_str.endswith("NULL"): + type_str = type_str[:-4].strip() + return _type_str_to_arrow_field_inner(type_str).with_nullable(True) + elif type_str.startswith("DECIMAL"): + # DECIMAL(precision, scale) + str_list = type_str[7:].strip("()").split(",") + precision = int(str_list[0].strip()) + scale = int(str_list[1].strip()) + if precision < 1 or precision > MAX_DECIMAL256_PRECISION: + raise ValueError( + f"Decimal precision must be between 1 and {MAX_DECIMAL256_PRECISION}" + ) + elif scale > precision: + raise ValueError( + f"Decimal scale must be between 0 and precision {precision}" + ) + + if precision < MAX_DECIMAL128_PRECISION: + return pa.field("", pa.decimal128(precision, scale), False) + else: + return pa.field("", pa.decimal256(precision, scale), False) + elif type_str.startswith("ARRAY"): + # ARRAY(INT) + type_str = type_str[5:].strip("()").strip() + return pa.field("", pa.list_(_type_str_to_arrow_field_inner(type_str)), False) + elif type_str.startswith("MAP"): + # MAP(STRING, INT) + str_list = type_str[3:].strip("()").split(",") + key_field = _type_str_to_arrow_field_inner(str_list[0].strip()) + val_field = _type_str_to_arrow_field_inner(str_list[1].strip()) + return pa.field("", pa.map_(key_field, val_field), False) + elif type_str.startswith("TUPLE"): + # TUPLE(STRING, INT, INT) + str_list = type_str[5:].strip("()").split(",") + fields = [] + for type_str in str_list: + type_str = type_str.strip() + fields.append(_type_str_to_arrow_field_inner(type_str)) + return pa.field("", pa.struct(fields), False) + else: + raise ValueError(f"Unsupported type: {type_str}") + + +def _arrow_field_to_string(field: pa.Field) -> str: + """ + Convert a `pyarrow.Field` to a SQL data type string. + """ + type_str = _data_type_to_string(field.type) + return f"{type_str} NOT NULL" if not field.nullable else type_str + + +def _inner_field_to_string(field: pa.Field) -> str: + ## inner field default is NOT NULL in databend + type_str = _data_type_to_string(field.type) + return f"{type_str} NULL" if field.nullable else type_str + + +def _data_type_to_string(t: pa.DataType) -> str: + """ + Convert a `pyarrow.DataType` to a SQL data type string. + """ + if pa.types.is_boolean(t): + return "BOOLEAN" + elif pa.types.is_int8(t): + return "TINYINT" + elif pa.types.is_int16(t): + return "SMALLINT" + elif pa.types.is_int32(t): + return "INT" + elif pa.types.is_int64(t): + return "BIGINT" + elif pa.types.is_uint8(t): + return "TINYINT UNSIGNED" + elif pa.types.is_uint16(t): + return "SMALLINT UNSIGNED" + elif pa.types.is_uint32(t): + return "INT UNSIGNED" + elif pa.types.is_uint64(t): + return "BIGINT UNSIGNED" + elif pa.types.is_float32(t): + return "FLOAT" + elif pa.types.is_float64(t): + return "DOUBLE" + elif pa.types.is_decimal(t): + return f"DECIMAL({t.precision}, {t.scale})" + elif pa.types.is_date32(t): + return "DATE" + elif pa.types.is_timestamp(t): + return "TIMESTAMP" + elif pa.types.is_string(t): + return "VARCHAR" + elif pa.types.is_large_binary(t): + return "VARIANT" + elif pa.types.is_list(t): + return f"ARRAY({_inner_field_to_string(t.value_field)})" + elif pa.types.is_map(t): + return f"MAP({_inner_field_to_string(t.key_field)}, {_inner_field_to_string(t.item_field)})" + elif pa.types.is_struct(t): + args_str = ", ".join(_inner_field_to_string(field) for field in t) + return f"TUPLE({args_str})" + else: + raise ValueError(f"Unsupported type: {t}") diff --git a/tests/udf-server/udf_test.py b/tests/udf-server/udf_test.py new file mode 100644 index 0000000000000..c31f72471b979 --- /dev/null +++ b/tests/udf-server/udf_test.py @@ -0,0 +1,282 @@ +import datetime +from udf import * +from decimal import Decimal +import time + + +@udf(input_types=["TINYINT", "SMALLINT", "INT", "BIGINT"], result_type="BIGINT") +def add_signed(a, b, c, d): + return a + b + c + d + + +@udf(input_types=["UINT8", "UINT16", "UINT32", "UINT64"], result_type="UINT64") +def add_unsigned(a, b, c, d): + return a + b + c + d + + +@udf(input_types=["FLOAT", "DOUBLE"], result_type="DOUBLE") +def add_float(a, b): + return a + b + + +@udf(input_types=["BOOLEAN", "BIGINT", "BIGINT"], result_type="BIGINT") +def bool_select(condition, a, b): + return a if condition else b + + +@udf( + name="gcd", + input_types=["INT", "INT"], + result_type="INT", + skip_null=True, +) +def gcd(x: int, y: int) -> int: + while y != 0: + (x, y) = (y, x % y) + return x + + +@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR") +def split_and_join(s: str, split_s: str, join_s: str) -> str: + return join_s.join(s.split(split_s)) + + +@udf(input_types="VARCHAR", result_type="DECIMAL(36, 18)") +def hex_to_dec(hex: str) -> Decimal: + hex = hex.strip() + + dec = Decimal(0) + while hex: + chunk = hex[:16] + chunk_value = int(hex[:16], 16) + dec = dec * (1 << (4 * len(chunk))) + chunk_value + hex = hex[len(chunk) :] + return dec + + +@udf(input_types=["DECIMAL(36, 18)", "DECIMAL(36, 18)"], result_type="DECIMAL(72, 28)") +def decimal_div(v1: Decimal, v2: Decimal) -> Decimal: + result = v1 / v2 + return result.quantize(Decimal("0." + "0" * 28)) + + +@udf(input_types=["DATE", "INT"], result_type="DATE") +def add_days_py(dt: datetime.date, days: int): + return dt + datetime.timedelta(days=days) + + +@udf(input_types=["TIMESTAMP", "INT"], result_type="TIMESTAMP") +def add_hours_py(dt: datetime.datetime, hours: int): + return dt + datetime.timedelta(hours=hours) + + +@udf(input_types=["ARRAY(VARCHAR)", "INT"], result_type="VARCHAR") +def array_access(array: List[str], idx: int) -> Optional[str]: + if idx == 0 or idx > len(array): + return None + return array[idx - 1] + + +@udf( + input_types=["ARRAY(INT64 NULL)", "INT64"], + result_type="INT NOT NULL", + skip_null=False, +) +def array_index_of(array: List[int], item: int): + if array is None: + return 0 + + try: + return array.index(item) + 1 + except ValueError: + return 0 + + +@udf(input_types=["MAP(VARCHAR,VARCHAR)", "VARCHAR"], result_type="VARCHAR") +def map_access(map: Dict[str, str], key: str) -> str: + return map[key] if key in map else None + + +@udf(input_types=["VARIANT", "VARCHAR"], result_type="VARIANT") +def json_access(json: Any, key: str) -> Any: + return json[key] + + +@udf(input_types=["ARRAY(VARIANT)"], result_type="VARIANT") +def json_concat(list: List[Any]) -> Any: + return list + + +@udf( + input_types=["TUPLE(ARRAY(VARIANT NULL), INT, VARCHAR)", "INT", "INT"], + result_type="TUPLE(VARIANT NULL, VARIANT NULL)", +) +def tuple_access( + tup: Tuple[List[Any], int, str], idx1: int, idx2: int +) -> Tuple[Any, Any]: + v1 = None if idx1 == 0 or idx1 > len(tup) else tup[idx1 - 1] + v2 = None if idx2 == 0 or idx2 > len(tup) else tup[idx2 - 1] + return v1, v2 + + +ALL_SCALAR_TYPES = "BOOLEAN,TINYINT,SMALLINT,INT,BIGINT,UINT8,UINT16,UINT32,UINT64,FLOAT,DOUBLE,DATE,TIMESTAMP,VARCHAR,VARIANT".split( + "," +) + + +@udf( + input_types=ALL_SCALAR_TYPES, + result_type=f"TUPLE({','.join(f'{t} NULL' for t in ALL_SCALAR_TYPES)})", +) +def return_all( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf( + input_types=[f"ARRAY({t})" for t in ALL_SCALAR_TYPES], + result_type=f"TUPLE({','.join(f'ARRAY({t})' for t in ALL_SCALAR_TYPES)})", +) +def return_all_arrays( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf( + input_types=[f"{t} NOT NULL" for t in ALL_SCALAR_TYPES], + result_type=f"TUPLE({','.join(f'{t}' for t in ALL_SCALAR_TYPES)})", +) +def return_all_non_nullable( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, +): + return ( + bool, + i8, + i16, + i32, + i64, + u8, + u16, + u32, + u64, + f32, + f64, + date, + timestamp, + varchar, + json, + ) + + +@udf(input_types=["INT"], result_type="INT") +def wait(x): + time.sleep(0.1) + return x + + +@udf(input_types=["INT"], result_type="INT", io_threads=32) +def wait_concurrent(x): + time.sleep(0.1) + return x + + +if __name__ == "__main__": + udf_server = UDFServer("0.0.0.0:8815") + udf_server.add_function(add_signed) + udf_server.add_function(add_unsigned) + udf_server.add_function(add_float) + udf_server.add_function(bool_select) + udf_server.add_function(gcd) + udf_server.add_function(split_and_join) + udf_server.add_function(decimal_div) + udf_server.add_function(hex_to_dec) + udf_server.add_function(add_days_py) + udf_server.add_function(add_hours_py) + udf_server.add_function(array_access) + udf_server.add_function(array_index_of) + udf_server.add_function(map_access) + udf_server.add_function(json_access) + udf_server.add_function(json_concat) + udf_server.add_function(tuple_access) + udf_server.add_function(return_all) + udf_server.add_function(return_all_arrays) + udf_server.add_function(return_all_non_nullable) + udf_server.add_function(wait) + udf_server.add_function(wait_concurrent) + udf_server.serve()