Skip to content

Commit

Permalink
fix: by comments
Browse files Browse the repository at this point in the history
Signed-off-by: cutecutecat <junyuchen@tensorchord.ai>
  • Loading branch information
cutecutecat committed Jun 18, 2024
1 parent 8a8561e commit 94312ef
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 90 deletions.
11 changes: 6 additions & 5 deletions src/datatype/text_svecf32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,21 @@ use std::fmt::Write;

#[pgrx::pg_extern(immutable, strict, parallel_safe)]
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero};
use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero, svector_sorted};
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
match v {
Err(e) => {
bad_literal(&e.to_string());
}
Ok((indexes, values, dims)) => {
let (mut sorted_indexes, mut sorted_values) = svector_sorted(&indexes, &values);
check_value_dims_1048575(dims);
check_index_in_bound(&indexes, dims);
let (non_zero_indexes, non_zero_values) = svector_filter_nonzero(&indexes, &values);
check_index_in_bound(&sorted_indexes, dims);
svector_filter_nonzero(&mut sorted_indexes, &mut sorted_values);
SVecf32Output::new(SVecf32Borrowed::new(
dims as u32,
&non_zero_indexes,
&non_zero_values,
&sorted_indexes,
&sorted_values,
))
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ ADVICE: Check if dimensions of the vector are among 1 and 1_048_575."
}

pub fn check_index_in_bound(indexes: &[u32], dims: usize) -> NonZeroU32 {
let mut last: u32 = 0;
for (i, index) in indexes.iter().enumerate() {
if i > 0 && last == *index {
let mut last: Option<u32> = None;
for index in indexes {
if last == Some(*index) {
error!("Indexes need to be unique, but there are more than one same index {index}")
}
if *index >= dims as u32 {
error!("Index out of bounds: the dim is {dims} but the index is {index}");
}
last = *index;
last = Some(*index);
}
NonZeroU32::new(dims as u32).unwrap()
}
Expand Down
160 changes: 79 additions & 81 deletions src/utils/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,54 @@ where
Ok(vector)
}

#[derive(PartialEq, Debug)]
#[derive(PartialEq, Debug, Clone)]
enum ParseState {
Start,
LeftBracket,
Index,
Colon,
Value,
Splitter,
Comma,
Length,
RightBracket,
Splitter,
Dims,
}

#[inline(always)]
pub fn svector_filter_nonzero<T: Zero + Clone + PartialEq>(
pub fn svector_sorted<T: Zero + Clone + PartialEq>(
indexes: &[u32],
values: &[T],
) -> (Vec<u32>, Vec<T>) {
let non_zero_indexes: Vec<u32> = indexes
.iter()
.enumerate()
.filter(|(i, _)| values.get(*i).unwrap() != &T::zero())
.map(|(_, x)| *x)
.collect();
let non_zero_values: Vec<T> = indexes
.iter()
.enumerate()
.filter(|(i, _)| values.get(*i).unwrap() != &T::zero())
.map(|(i, _)| values.get(i).unwrap().clone())
.collect();
(non_zero_indexes, non_zero_values)
let mut indices = (0..indexes.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &indexes[i]);

let mut sorted_indexes: Vec<u32> = Vec::with_capacity(indexes.len());
let mut sorted_values: Vec<T> = Vec::with_capacity(indexes.len());
for i in indices {
sorted_indexes.push(*indexes.get(i).unwrap());
sorted_values.push(values.get(i).unwrap().clone());
}
(sorted_indexes, sorted_values)
}

#[inline(always)]
pub fn svector_filter_nonzero<T: Zero + Clone + PartialEq>(
indexes: &mut Vec<u32>,
values: &mut Vec<T>,
) {
// Index must be sorted!
let mut i = 0;
let mut j = 0;
while j < values.len() {
if !values[j].is_zero() {
indexes[i] = indexes[j];
values[i] = values[j].clone();
i += 1;
}
j += 1;
}
indexes.truncate(i);
values.truncate(i);
}

#[inline(always)]
Expand All @@ -133,110 +152,82 @@ where
let mut values = Vec::<T>::new();

let mut state = ParseState::Start;
for (position, char) in input.iter().enumerate() {
let c = *char;
match (&state, c) {
(_, b' ') => {}
(ParseState::Start, b'{') => {
state = ParseState::LeftBracket;
}
for (position, c) in input.iter().copied().enumerate() {
state = match (&state, c) {
(_, b' ') => state,
(ParseState::Start, b'{') => ParseState::LeftBracket,
(
ParseState::LeftBracket | ParseState::Index | ParseState::Comma,
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-',
) => {
if token.is_empty() {
token.push(b'$');
}
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
state = ParseState::Index;
ParseState::Index
}
(ParseState::LeftBracket | ParseState::Comma, b'}') => {
state = ParseState::Splitter;
(ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Value
}
(ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket,
(ParseState::Index, b':') => {
if token.is_empty() {
return Err(ParseVectorError::TooShortNumber { position });
}
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let index = s
.parse::<u32>()
.map_err(|_| ParseVectorError::BadParsing { position })?;
indexes.push(index);
token.clear();
state = ParseState::Value;
ParseState::Colon
}
(ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
if token.is_empty() {
token.push(b'$');
}
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Value
}
(ParseState::Value, b',') => {
if token.is_empty() {
return Err(ParseVectorError::TooShortNumber { position });
}
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
values.push(num);
token.clear();
state = ParseState::Comma;
ParseState::Comma
}
(ParseState::Value, b'}') => {
if token.is_empty() {
return Err(ParseVectorError::TooShortNumber { position });
}
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
values.push(num);
token.clear();
state = ParseState::Splitter;
}
(ParseState::Splitter, b'/') => {
state = ParseState::Length;
ParseState::RightBracket
}
(ParseState::Length, b'0'..=b'9') => {
if token.is_empty() {
token.push(b'$');
}
(ParseState::RightBracket, b'/') => ParseState::Splitter,
(ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => {
if token.try_push(c).is_err() {
return Err(ParseVectorError::TooLongNumber { position });
}
ParseState::Dims
}
(_, _) => {
return Err(ParseVectorError::BadCharacter { position });
}
}
}
if state != ParseState::Length {
if state != ParseState::Dims {
return Err(ParseVectorError::BadParsing {
position: input.len(),
});
}
if token.is_empty() {
return Err(ParseVectorError::TooShortNumber {
position: input.len(),
});
}
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
let dims = s
.parse::<usize>()
.map_err(|_| ParseVectorError::BadParsing {
position: input.len(),
})?;

let mut indices = (0..indexes.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &indexes[i]);
let sorted_values: Vec<T> = indices
.iter()
.map(|i| values.get(*i).unwrap().clone())
.collect();
indexes.sort();

Ok((indexes, sorted_values, dims))
Ok((indexes, values, dims))
}

#[cfg(test)]
Expand Down Expand Up @@ -266,8 +257,8 @@ mod tests {
(
"{3:3, 2:2, 1:1, 0:0}/4",
(
vec![0, 1, 2, 3],
vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0)],
vec![3, 2, 1, 0],
vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)],
4,
),
),
Expand All @@ -294,16 +285,13 @@ mod tests {
"{0:1, 1:2, 2:3",
ParseVectorError::BadParsing { position: 14 },
),
(
"{0:1, 1:2}/",
ParseVectorError::TooShortNumber { position: 11 },
),
("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }),
("{0}/5", ParseVectorError::BadCharacter { position: 2 }),
("{0:}/5", ParseVectorError::TooShortNumber { position: 3 }),
("{0:}/5", ParseVectorError::BadCharacter { position: 3 }),
("{:0}/5", ParseVectorError::BadCharacter { position: 1 }),
(
"{0:, 1:2}/5",
ParseVectorError::TooShortNumber { position: 3 },
ParseVectorError::BadCharacter { position: 3 },
),
("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }),
("/2", ParseVectorError::BadCharacter { position: 0 }),
Expand Down Expand Up @@ -347,23 +335,33 @@ mod tests {
),
(
"{2:0, 1:0}/2",
(vec![1, 2], vec![F32(0.0), F32(0.0)], 2),
(vec![2, 1], vec![F32(0.0), F32(0.0)], 2),
(vec![], vec![]),
),
(
"{2:0, 1:0, }/2",
(vec![1, 2], vec![F32(0.0), F32(0.0)], 2),
(vec![2, 1], vec![F32(0.0), F32(0.0)], 2),
(vec![], vec![]),
),
(
"{3:2, 2:1, 1:0, 0:-1}/4",
(
vec![3, 2, 1, 0],
vec![F32(2.0), F32(1.0), F32(0.0), F32(-1.0)],
4,
),
(vec![0, 2, 3], vec![F32(-1.0), F32(1.0), F32(2.0)]),
),
];
for (e, parsed, filtered) in exprs {
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret);
assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e);

let (indexes, values, _) = parsed;
let nonzero = svector_filter_nonzero(&indexes, &values);
assert_eq!(nonzero, filtered, "filtered at expr {:?}", e);
let (mut indexes, mut values) = svector_sorted(&indexes, &values);
svector_filter_nonzero(&mut indexes, &mut values);
assert_eq!((indexes, values), filtered, "filtered at expr {:?}", e);
}
}
}

0 comments on commit 94312ef

Please sign in to comment.