From 2089054d8d6c964e286c119a310754111df98fe0 Mon Sep 17 00:00:00 2001 From: cutecutecat Date: Mon, 15 Apr 2024 10:11:01 +0800 Subject: [PATCH] feat: new text embedding for sparse vector Signed-off-by: cutecutecat --- rust-toolchain.toml | 2 +- src/datatype/text_svecf32.rs | 26 +++--- src/utils/parse.rs | 103 +++++++++++++++++++++++ tests/sqllogictest/sparse.slt | 8 +- tests/sqllogictest/svector_subscript.slt | 50 +++++------ 5 files changed, 149 insertions(+), 40 deletions(-) diff --git a/rust-toolchain.toml b/rust-toolchain.toml index a9ee38b8d..1d84399d0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-03-27" +channel = "nightly-2024-04-12" profile = "default" targets = [ "aarch64-apple-darwin", diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs index f2a0d6bb4..8a8f11999 100644 --- a/src/datatype/text_svecf32.rs +++ b/src/datatype/text_svecf32.rs @@ -10,13 +10,13 @@ use std::ffi::{CStr, CString}; #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { - use crate::utils::parse::parse_vector; + use crate::utils::parse::parse_pgvector_svector; let reserve = Typmod::parse_from_i32(typmod) .unwrap() .dims() .map(|x| x.get()) .unwrap_or(0); - let v = parse_vector(input.to_bytes(), reserve as usize, |s| { + let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| { s.parse::().ok() }); match v { @@ -40,16 +40,22 @@ fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output { #[pgrx::pg_extern(immutable, strict, parallel_safe)] fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { + let dims = vector.for_borrow().dims(); let mut buffer = String::new(); - buffer.push('['); + buffer.push('{'); let vec = vector.for_borrow().to_vec(); - let mut iter = vec.iter(); - if let Some(x) = iter.next() { - buffer.push_str(format!("{}", x).as_str()); - } - for x in iter { - buffer.push_str(format!(", {}", x).as_str()); + let mut need_splitter = true; + for (i, x) in vec.iter().enumerate() { + if *x != F32::zero() { + match need_splitter { + true => { + buffer.push_str(format!("{}:{}", i + 1, x).as_str()); + need_splitter = false; + } + false => buffer.push_str(format!(", {}:{}", i + 1, x).as_str()), + } + } } - buffer.push(']'); + buffer.push_str(format!("}}/{}", dims).as_str()); CString::new(buffer).unwrap() } diff --git a/src/utils/parse.rs b/src/utils/parse.rs index e5ef93316..a3dec2c78 100644 --- a/src/utils/parse.rs +++ b/src/utils/parse.rs @@ -1,3 +1,4 @@ +use num_traits::Zero; use thiserror::Error; #[derive(Debug, Error)] @@ -83,3 +84,105 @@ where } Ok(vector) } + +#[inline(always)] +pub fn parse_pgvector_svector( + input: &[u8], + reserve: usize, + f: F, +) -> Result, ParseVectorError> +where + F: Fn(&str) -> Option, +{ + use arrayvec::ArrayVec; + if input.is_empty() { + return Err(ParseVectorError::EmptyString {}); + } + let left = 'a: { + for position in 0..input.len() - 1 { + match input[position] { + b'{' => break 'a position, + b' ' => continue, + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + return Err(ParseVectorError::BadParentheses { character: '{' }); + }; + let mut token: ArrayVec = ArrayVec::new(); + let mut capacity = reserve; + let right = 'a: { + for position in (1..input.len()).rev() { + match input[position] { + b'0'..=b'9' => { + if token.try_push(input[position]).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + } + b'/' => { + token.reverse(); + let s = unsafe { std::str::from_utf8_unchecked(&token[..]) }; + capacity = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })?; + } + b'}' => { + token.clear(); + break 'a position; + } + b' ' => continue, + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + return Err(ParseVectorError::BadParentheses { character: '}' }); + }; + let mut vector = vec![T::zero(); capacity]; + let mut index: usize = 0; + for position in left + 1..right { + let c = input[position]; + match c { + b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => { + if token.is_empty() { + token.push(b'$'); + } + if token.try_push(c).is_err() { + return Err(ParseVectorError::TooLongNumber { position }); + } + } + b',' => { + if !token.is_empty() { + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + vector[index] = num; + token.clear(); + } else { + return Err(ParseVectorError::TooShortNumber { position }); + } + } + b':' => { + if !token.is_empty() { + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + index = s + .parse::() + .map_err(|_| ParseVectorError::BadParsing { position })? + - 1; + token.clear(); + } else { + return Err(ParseVectorError::TooShortNumber { position }); + } + } + b' ' => (), + _ => return Err(ParseVectorError::BadCharacter { position }), + } + } + if !token.is_empty() { + let position = right; + // Safety: all bytes in `token` are ascii characters + let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) }; + let num = f(s).ok_or(ParseVectorError::BadParsing { position })?; + vector[index] = num; + token.clear(); + } + Ok(vector) +} diff --git a/tests/sqllogictest/sparse.slt b/tests/sqllogictest/sparse.slt index 018b37423..ee59c9d0c 100644 --- a/tests/sqllogictest/sparse.slt +++ b/tests/sqllogictest/sparse.slt @@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops) WITH (options = "[indexing.hnsw]"); query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 query I -SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2; ---- 10 @@ -40,7 +40,7 @@ DROP TABLE t; query I SELECT to_svector(5, '{1,2}', '{1,2}'); ---- -[0, 1, 2, 0, 0] +{2:1, 3:2}/5 statement error Lengths of index and value are not matched. SELECT to_svector(5, '{1,2,3}', '{1,2}'); diff --git a/tests/sqllogictest/svector_subscript.slt b/tests/sqllogictest/svector_subscript.slt index ad683b75a..23e07d39f 100644 --- a/tests/sqllogictest/svector_subscript.slt +++ b/tests/sqllogictest/svector_subscript.slt @@ -2,87 +2,87 @@ statement ok SET search_path TO pg_temp, vectors; query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:6]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6]; ---- -[3, 4, 5] +{1:3, 2:4, 3:5}/3 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:4]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4]; ---- -[0, 1, 2, 3] +{\2:1, 3:2, 4:3}/4 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:]; ---- -[5, 6, 7] +{1:5, 2:6, 3:7}/3 query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:8]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8]; ---- -[1, 2, 3, 4, 5, 6, 7] +{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7 statement error type svector does only support one subscript -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:3][1:1]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1]; statement error type svector does only support slice fetch -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3]; query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:4]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[9:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:0]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:-1]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:8]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:]; ---- NULL query I -SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:NULL]; +SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL]; ---- NULL query I -SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[3:7]; +SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7]; ---- -[0, 4, 0, 0] +{2:4}/4 query I -SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[5:7]; +SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7]; ---- -[0, 0] +{}/2 query I -SELECT ('[0, 0, 0, 0, 0, 0, 0, 0]'::svector)[5:7]; +SELECT ('{}/2/8'::svector)[5:7]; ---- -[0, 0] \ No newline at end of file +{}/2 \ No newline at end of file