Skip to content

Commit

Permalink
refactor: use sparse struct to parse
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat committed Jun 3, 2024
1 parent 2d6c196 commit 1b376ab
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 34 deletions.
27 changes: 5 additions & 22 deletions src/datatype/text_svecf32.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,22 @@
use super::memory_svecf32::SVecf32Output;
use crate::datatype::memory_svecf32::SVecf32Input;
use crate::datatype::typmod::Typmod;
use crate::error::*;
use base::scalar::*;
use base::vector::*;
use num_traits::Zero;
use pgrx::pg_sys::Oid;
use std::ffi::{CStr, CString};

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
use crate::utils::parse::parse_pgvector_svector;
let reserve = Typmod::parse_from_i32(typmod)
.unwrap()
.dims()
.map(|x| x.get())
.unwrap_or(0);
let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| {
s.parse::<F32>().ok()
});
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
match v {
Err(e) => {
bad_literal(&e.to_string());
}
Ok(vector) => {
check_value_dims_1048575(vector.len());
let mut indexes = Vec::<u32>::new();
let mut values = Vec::<F32>::new();
for (i, &x) in vector.iter().enumerate() {
if !x.is_zero() {
indexes.push(i as u32);
values.push(x);
}
}
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values))
Ok((indexes, values, dims)) => {
check_value_dims_1048575(dims);
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/sql/finalize.sql
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,7 @@ CREATE AGGREGATE avg(svector) (
STYPE = svector_accumulate_state,
COMBINEFUNC = _vectors_svector_combine,
FINALFUNC = _vectors_svector_final,
INITCOND = '(0, [0])',
INITCOND = '(0, {}/1)',
PARALLEL = SAFE
);

Expand Down
22 changes: 12 additions & 10 deletions src/utils/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ where
#[inline(always)]
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
input: &[u8],
reserve: usize,
f: F,
) -> Result<Vec<T>, ParseVectorError>
) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError>
where
F: Fn(&str) -> Option<T>,
{
use arrayvec::ArrayVec;
if input.is_empty() {
return Err(ParseVectorError::EmptyString {});
}
let mut dims: usize = 0;
let left = 'a: {
for position in 0..input.len() - 1 {
match input[position] {
Expand All @@ -109,7 +109,6 @@ where
return Err(ParseVectorError::BadParentheses { character: '{' });
};
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
let mut capacity = reserve;
let right = 'a: {
for position in (1..input.len()).rev() {
match input[position] {
Expand All @@ -121,7 +120,7 @@ where
b'/' => {
token.reverse();
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
capacity = s
dims = s
.parse::<usize>()
.map_err(|_| ParseVectorError::BadParsing { position })?;
}
Expand All @@ -135,8 +134,9 @@ where
}
return Err(ParseVectorError::BadParentheses { character: '}' });
};
let mut vector = vec![T::zero(); capacity];
let mut index: usize = 0;
let mut indexes = Vec::<u32>::new();
let mut values = Vec::<T>::new();
let mut index: u32 = 0;
for position in left + 1..right {
let c = input[position];
match c {
Expand All @@ -153,7 +153,8 @@ where
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
vector[index] = num;
indexes.push(index);
values.push(num);
token.clear();
} else {
return Err(ParseVectorError::TooShortNumber { position });
Expand All @@ -164,7 +165,7 @@ where
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
index = s
.parse::<usize>()
.parse::<u32>()
.map_err(|_| ParseVectorError::BadParsing { position })?;
token.clear();
} else {
Expand All @@ -180,8 +181,9 @@ where
// Safety: all bytes in `token` are ascii characters
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
vector[index] = num;
indexes.push(index);
values.push(num);
token.clear();
}
Ok(vector)
Ok((indexes, values, dims))
}
2 changes: 1 addition & 1 deletion tests/sqllogictest/sparse.slt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ DROP TABLE t;
query I
SELECT to_svector(5, '{1,2}', '{1,2}');
----
{2:1, 3:2}/5
{1:1, 2:2}/5

query I
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');
Expand Down

0 comments on commit 1b376ab

Please sign in to comment.