Skip to content

Commit

Permalink
feat: new text embedding for sparse vector
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat committed Apr 16, 2024
1 parent 06137e1 commit 2089054
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 40 deletions.
2 changes: 1 addition & 1 deletion rust-toolchain.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[toolchain]
channel = "nightly-2024-03-27"
channel = "nightly-2024-04-12"
profile = "default"
targets = [
"aarch64-apple-darwin",
Expand Down
26 changes: 16 additions & 10 deletions src/datatype/text_svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use std::ffi::{CStr, CString};

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
use crate::utils::parse::parse_vector;
use crate::utils::parse::parse_pgvector_svector;
let reserve = Typmod::parse_from_i32(typmod)
.unwrap()
.dims()
.map(|x| x.get())
.unwrap_or(0);
let v = parse_vector(input.to_bytes(), reserve as usize, |s| {
let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| {
s.parse::<F32>().ok()
});
match v {
Expand All @@ -40,16 +40,22 @@ fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString {
let dims = vector.for_borrow().dims();
let mut buffer = String::new();
buffer.push('[');
buffer.push('{');
let vec = vector.for_borrow().to_vec();
let mut iter = vec.iter();
if let Some(x) = iter.next() {
buffer.push_str(format!("{}", x).as_str());
}
for x in iter {
buffer.push_str(format!(", {}", x).as_str());
let mut need_splitter = true;
for (i, x) in vec.iter().enumerate() {
if *x != F32::zero() {
match need_splitter {
true => {
buffer.push_str(format!("{}:{}", i + 1, x).as_str());
need_splitter = false;
}
false => buffer.push_str(format!(", {}:{}", i + 1, x).as_str()),
}
}
}
buffer.push(']');
buffer.push_str(format!("}}/{}", dims).as_str());
CString::new(buffer).unwrap()
}
103 changes: 103 additions & 0 deletions src/utils/parse.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use num_traits::Zero;
use thiserror::Error;

#[derive(Debug, Error)]
Expand Down Expand Up @@ -83,3 +84,105 @@ where
}
Ok(vector)
}

#[inline(always)]
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
input: &[u8],
reserve: usize,
f: F,
) -> Result<Vec<T>, ParseVectorError>
where
F: Fn(&str) -> Option<T>,
{
use arrayvec::ArrayVec;
if input.is_empty() {
return Err(ParseVectorError::EmptyString {});
}
let left = 'a: {
for position in 0..input.len() - 1 {
match input[position] {
b'{' => break 'a position,
b' ' => continue,
_ => return Err(ParseVectorError::BadCharacter { position }),
}
}
return Err(ParseVectorError::BadParentheses { character: '{' });
};
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
let mut capacity = reserve;
let right = 'a: {
for position in (1..input.len()).rev() {
match input[position] {
b'0'..=b'9' => {
if token.try_push(input[position]).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
}
b'/' => {
token.reverse();
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
capacity = s
.parse::<usize>()
.map_err(|_| ParseVectorError::BadParsing { position })?;
}
b'}' => {
token.clear();
break 'a position;
}
b' ' => continue,
_ => return Err(ParseVectorError::BadCharacter { position }),
}
}
return Err(ParseVectorError::BadParentheses { character: '}' });
};
let mut vector = vec![T::zero(); capacity];
let mut index: usize = 0;
for position in left + 1..right {
let c = input[position];
match c {
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-' => {
if token.is_empty() {
token.push(b'$');
}
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
}
b',' => {
if !token.is_empty() {
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
vector[index] = num;
token.clear();
} else {
return Err(ParseVectorError::TooShortNumber { position });
}
}
b':' => {
if !token.is_empty() {
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
index = s
.parse::<usize>()
.map_err(|_| ParseVectorError::BadParsing { position })?
- 1;
token.clear();
} else {
return Err(ParseVectorError::TooShortNumber { position });
}
}
b' ' => (),
_ => return Err(ParseVectorError::BadCharacter { position }),
}
}
if !token.is_empty() {
let position = right;
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
vector[index] = num;
token.clear();
}
Ok(vector)
}
8 changes: 4 additions & 4 deletions tests/sqllogictest/sparse.slt
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ CREATE INDEX ON t USING vectors (val svector_cos_ops)
WITH (options = "[indexing.hnsw]");

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

query I
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2;
SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '{1:3,2:1}/6'::svector limit 10) t2;
----
10

Expand All @@ -40,7 +40,7 @@ DROP TABLE t;
query I
SELECT to_svector(5, '{1,2}', '{1,2}');
----
[0, 1, 2, 0, 0]
{2:1, 3:2}/5

statement error Lengths of index and value are not matched.
SELECT to_svector(5, '{1,2,3}', '{1,2}');
Expand Down
50 changes: 25 additions & 25 deletions tests/sqllogictest/svector_subscript.slt
Original file line number Diff line number Diff line change
Expand Up @@ -2,87 +2,87 @@ statement ok
SET search_path TO pg_temp, vectors;

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:6];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6];
----
[3, 4, 5]
{1:3, 2:4, 3:5}/3

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:4];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4];
----
[0, 1, 2, 3]
{\2:1, 3:2, 4:3}/4

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:];
----
[5, 6, 7]
{1:5, 2:6, 3:7}/3

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:8];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8];
----
[1, 2, 3, 4, 5, 6, 7]
{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7

statement error type svector does only support one subscript
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:3][1:1];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1];

statement error type svector does only support slice fetch
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3];

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:4];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[9:];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:0];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:-1];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:NULL];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:8];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:NULL];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:];
----
NULL

query I
SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:NULL];
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL];
----
NULL

query I
SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[3:7];
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7];
----
[0, 4, 0, 0]
{2:4}/4

query I
SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[5:7];
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7];
----
[0, 0]
{}/2

query I
SELECT ('[0, 0, 0, 0, 0, 0, 0, 0]'::svector)[5:7];
SELECT ('{}/2/8'::svector)[5:7];
----
[0, 0]
{}/2

0 comments on commit 2089054

Please sign in to comment.