From bac7e7bac964f5e409065d8a31e5ec580e719bd2 Mon Sep 17 00:00:00 2001 From: Zhenchi Date: Mon, 9 Dec 2024 15:19:00 +0800 Subject: [PATCH] refactor: extract implicit conversion helper functions of vector type (#5118) refactor: extract implicit conversion helper functions of vector Signed-off-by: Zhenchi --- src/common/function/src/scalars/vector.rs | 1 + .../function/src/scalars/vector/distance.rs | 132 +-------------- .../function/src/scalars/vector/impl_conv.rs | 156 ++++++++++++++++++ 3 files changed, 165 insertions(+), 124 deletions(-) create mode 100644 src/common/function/src/scalars/vector/impl_conv.rs diff --git a/src/common/function/src/scalars/vector.rs b/src/common/function/src/scalars/vector.rs index 602504ec83ba..7c8cf5550e25 100644 --- a/src/common/function/src/scalars/vector.rs +++ b/src/common/function/src/scalars/vector.rs @@ -14,6 +14,7 @@ mod convert; mod distance; +pub(crate) mod impl_conv; use std::sync::Arc; diff --git a/src/common/function/src/scalars/vector/distance.rs b/src/common/function/src/scalars/vector/distance.rs index 1905a375f3e4..f17eec5b042c 100644 --- a/src/common/function/src/scalars/vector/distance.rs +++ b/src/common/function/src/scalars/vector/distance.rs @@ -18,18 +18,17 @@ mod l2sq; use std::borrow::Cow; use std::fmt::Display; -use std::sync::Arc; use common_query::error::{InvalidFuncArgsSnafu, Result}; use common_query::prelude::Signature; use datatypes::prelude::ConcreteDataType; use datatypes::scalars::ScalarVectorBuilder; -use datatypes::value::ValueRef; -use datatypes::vectors::{Float32VectorBuilder, MutableVector, Vector, VectorRef}; +use datatypes::vectors::{Float32VectorBuilder, MutableVector, VectorRef}; use snafu::ensure; use crate::function::{Function, FunctionContext}; use crate::helper; +use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const}; macro_rules! define_distance_function { ($StructName:ident, $display_name:expr, $similarity_method:path) => { @@ -80,17 +79,17 @@ macro_rules! define_distance_function { return Ok(result.to_vector()); } - let arg0_const = parse_if_constant_string(arg0)?; - let arg1_const = parse_if_constant_string(arg1)?; + let arg0_const = as_veclit_if_const(arg0)?; + let arg1_const = as_veclit_if_const(arg1)?; for i in 0..size { let vec0 = match arg0_const.as_ref() { - Some(a) => Some(Cow::Borrowed(a.as_slice())), - None => as_vector(arg0.get_ref(i))?, + Some(a) => Some(Cow::Borrowed(a.as_ref())), + None => as_veclit(arg0.get_ref(i))?, }; let vec1 = match arg1_const.as_ref() { - Some(b) => Some(Cow::Borrowed(b.as_slice())), - None => as_vector(arg1.get_ref(i))?, + Some(b) => Some(Cow::Borrowed(b.as_ref())), + None => as_veclit(arg1.get_ref(i))?, }; if let (Some(vec0), Some(vec1)) = (vec0, vec1) { @@ -129,98 +128,6 @@ define_distance_function!(CosDistanceFunction, "vec_cos_distance", cos::cos); define_distance_function!(L2SqDistanceFunction, "vec_l2sq_distance", l2sq::l2sq); define_distance_function!(DotProductFunction, "vec_dot_product", dot::dot); -/// Parse a vector value if the value is a constant string. -fn parse_if_constant_string(arg: &Arc) -> Result>> { - if !arg.is_const() { - return Ok(None); - } - if arg.data_type() != ConcreteDataType::string_datatype() { - return Ok(None); - } - arg.get_ref(0) - .as_string() - .unwrap() // Safe: checked if it is a string - .map(parse_f32_vector_from_string) - .transpose() -} - -/// Convert a value to a vector value. -/// Supported data types are binary and string. -fn as_vector(arg: ValueRef<'_>) -> Result>> { - match arg.data_type() { - ConcreteDataType::Binary(_) => arg - .as_binary() - .unwrap() // Safe: checked if it is a binary - .map(binary_as_vector) - .transpose(), - ConcreteDataType::String(_) => arg - .as_string() - .unwrap() // Safe: checked if it is a string - .map(|s| Ok(Cow::Owned(parse_f32_vector_from_string(s)?))) - .transpose(), - ConcreteDataType::Null(_) => Ok(None), - _ => InvalidFuncArgsSnafu { - err_msg: format!("Unsupported data type: {:?}", arg.data_type()), - } - .fail(), - } -} - -/// Convert a u8 slice to a vector value. -fn binary_as_vector(bytes: &[u8]) -> Result> { - if bytes.len() % std::mem::size_of::() != 0 { - return InvalidFuncArgsSnafu { - err_msg: format!("Invalid binary length of vector: {}", bytes.len()), - } - .fail(); - } - - if cfg!(target_endian = "little") { - Ok(unsafe { - let vec = std::slice::from_raw_parts( - bytes.as_ptr() as *const f32, - bytes.len() / std::mem::size_of::(), - ); - Cow::Borrowed(vec) - }) - } else { - let v = bytes - .chunks_exact(std::mem::size_of::()) - .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) - .collect::>(); - Ok(Cow::Owned(v)) - } -} - -/// Parse a string to a vector value. -/// Valid inputs are strings like "[1.0, 2.0, 3.0]". -fn parse_f32_vector_from_string(s: &str) -> Result> { - let trimmed = s.trim(); - if !trimmed.starts_with('[') || !trimmed.ends_with(']') { - return InvalidFuncArgsSnafu { - err_msg: format!( - "Failed to parse {s} to Vector value: not properly enclosed in brackets" - ), - } - .fail(); - } - let content = trimmed[1..trimmed.len() - 1].trim(); - if content.is_empty() { - return Ok(Vec::new()); - } - - content - .split(',') - .map(|s| s.trim().parse::()) - .collect::>() - .map_err(|e| { - InvalidFuncArgsSnafu { - err_msg: format!("Failed to parse {s} to Vector value: {e}"), - } - .build() - }) -} - #[cfg(test)] mod tests { use std::sync::Arc; @@ -456,27 +363,4 @@ mod tests { assert!(result.is_err()); } } - - #[test] - fn test_parse_vector_from_string() { - let result = parse_f32_vector_from_string("[1.0, 2.0, 3.0]").unwrap(); - assert_eq!(result, vec![1.0, 2.0, 3.0]); - - let result = parse_f32_vector_from_string("[]").unwrap(); - assert_eq!(result, Vec::::new()); - - let result = parse_f32_vector_from_string("[1.0, a, 3.0]"); - assert!(result.is_err()); - } - - #[test] - fn test_binary_as_vector() { - let bytes = [0, 0, 128, 63]; - let result = binary_as_vector(&bytes).unwrap(); - assert_eq!(result.as_ref(), &[1.0]); - - let invalid_bytes = [0, 0, 128]; - let result = binary_as_vector(&invalid_bytes); - assert!(result.is_err()); - } } diff --git a/src/common/function/src/scalars/vector/impl_conv.rs b/src/common/function/src/scalars/vector/impl_conv.rs new file mode 100644 index 000000000000..903bfb2a0336 --- /dev/null +++ b/src/common/function/src/scalars/vector/impl_conv.rs @@ -0,0 +1,156 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::borrow::Cow; +use std::sync::Arc; + +use common_query::error::{InvalidFuncArgsSnafu, Result}; +use datatypes::prelude::ConcreteDataType; +use datatypes::value::ValueRef; +use datatypes::vectors::Vector; + +/// Convert a constant string or binary literal to a vector literal. +pub fn as_veclit_if_const(arg: &Arc) -> Result>> { + if !arg.is_const() { + return Ok(None); + } + if arg.data_type() != ConcreteDataType::string_datatype() + && arg.data_type() != ConcreteDataType::binary_datatype() + { + return Ok(None); + } + as_veclit(arg.get_ref(0)) +} + +/// Convert a string or binary literal to a vector literal. +pub fn as_veclit(arg: ValueRef<'_>) -> Result>> { + match arg.data_type() { + ConcreteDataType::Binary(_) => arg + .as_binary() + .unwrap() // Safe: checked if it is a binary + .map(binlit_as_veclit) + .transpose(), + ConcreteDataType::String(_) => arg + .as_string() + .unwrap() // Safe: checked if it is a string + .map(|s| Ok(Cow::Owned(parse_veclit_from_strlit(s)?))) + .transpose(), + ConcreteDataType::Null(_) => Ok(None), + _ => InvalidFuncArgsSnafu { + err_msg: format!("Unsupported data type: {:?}", arg.data_type()), + } + .fail(), + } +} + +/// Convert a u8 slice to a vector literal. +pub fn binlit_as_veclit(bytes: &[u8]) -> Result> { + if bytes.len() % std::mem::size_of::() != 0 { + return InvalidFuncArgsSnafu { + err_msg: format!("Invalid binary length of vector: {}", bytes.len()), + } + .fail(); + } + + if cfg!(target_endian = "little") { + Ok(unsafe { + let vec = std::slice::from_raw_parts( + bytes.as_ptr() as *const f32, + bytes.len() / std::mem::size_of::(), + ); + Cow::Borrowed(vec) + }) + } else { + let v = bytes + .chunks_exact(std::mem::size_of::()) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) + .collect::>(); + Ok(Cow::Owned(v)) + } +} + +/// Parse a string literal to a vector literal. +/// Valid inputs are strings like "[1.0, 2.0, 3.0]". +pub fn parse_veclit_from_strlit(s: &str) -> Result> { + let trimmed = s.trim(); + if !trimmed.starts_with('[') || !trimmed.ends_with(']') { + return InvalidFuncArgsSnafu { + err_msg: format!( + "Failed to parse {s} to Vector value: not properly enclosed in brackets" + ), + } + .fail(); + } + let content = trimmed[1..trimmed.len() - 1].trim(); + if content.is_empty() { + return Ok(Vec::new()); + } + + content + .split(',') + .map(|s| s.trim().parse::()) + .collect::>() + .map_err(|e| { + InvalidFuncArgsSnafu { + err_msg: format!("Failed to parse {s} to Vector value: {e}"), + } + .build() + }) +} + +#[allow(unused)] +/// Convert a vector literal to a binary literal. +pub fn veclit_to_binlit(vec: &[f32]) -> Vec { + if cfg!(target_endian = "little") { + unsafe { + std::slice::from_raw_parts(vec.as_ptr() as *const u8, std::mem::size_of_val(vec)) + .to_vec() + } + } else { + let mut bytes = Vec::with_capacity(std::mem::size_of_val(vec)); + for e in vec { + bytes.extend_from_slice(&e.to_le_bytes()); + } + bytes + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_veclit_from_strlit() { + let result = parse_veclit_from_strlit("[1.0, 2.0, 3.0]").unwrap(); + assert_eq!(result, vec![1.0, 2.0, 3.0]); + + let result = parse_veclit_from_strlit("[]").unwrap(); + assert_eq!(result, Vec::::new()); + + let result = parse_veclit_from_strlit("[1.0, a, 3.0]"); + assert!(result.is_err()); + } + + #[test] + fn test_binlit_as_veclit() { + let vec = &[1.0, 2.0, 3.0]; + let bytes = veclit_to_binlit(vec); + let result = binlit_as_veclit(&bytes).unwrap(); + assert_eq!(result.as_ref(), vec); + + let invalid_bytes = [0, 0, 128]; + let result = binlit_as_veclit(&invalid_bytes); + assert!(result.is_err()); + } +}