diff --git a/Cargo.toml b/Cargo.toml index 1eb39ca..b290b63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ repository = "https://github.com/diesel-rs/diesel_full_text_search" edition = "2021" [dependencies] +byteorder = "1.5.0" diesel = { version = "~2.2.0", features = ["postgres_backend"], default-features = false } [features] diff --git a/src/lib.rs b/src/lib.rs index f7443a3..6d5b443 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,8 @@ mod types { - use diesel::sql_types::*; + use std::io::{BufRead, Cursor}; + + use byteorder::{NetworkEndian, ReadBytesExt}; + use diesel::{deserialize::FromSql, pg::Pg, sql_types::*, Queryable}; #[derive(Clone, Copy, SqlType)] #[diesel(postgres_type(oid = 3615, array_oid = 3645))] @@ -18,6 +21,70 @@ mod types { #[derive(SqlType)] #[diesel(postgres_type(name = "regconfig"))] pub struct RegConfig; + + impl FromSql for PgTsVector { + fn from_sql( + bytes: ::RawValue<'_>, + ) -> diesel::deserialize::Result { + let mut cursor = Cursor::new(bytes.as_bytes()); + + // From Postgres `tsvector.c`: + // + // The binary format is as follows: + // + // uint32 number of lexemes + // + // for each lexeme: + // lexeme text in client encoding, null-terminated + // uint16 number of positions + // for each position: + // uint16 WordEntryPos + + // Number of lexemes (uint32) + let num_lexemes = cursor.read_u32::()?; + + let mut entries = Vec::with_capacity(num_lexemes as usize); + + for _ in 0..num_lexemes { + let mut lexeme = Vec::new(); + cursor.read_until(0, &mut lexeme)?; + // Remove null terminator + lexeme.pop(); + let lexeme = String::from_utf8(lexeme)?; + + // Number of positions (uint16) + let num_positions = cursor.read_u16::()?; + + let mut positions = Vec::with_capacity(num_positions as usize); + for _ in 0..num_positions { + positions.push(cursor.read_u16::()?); + } + + entries.push(PgTsVectorEntry { lexeme, positions }); + } + + Ok(PgTsVector { entries }) + } + } + + impl Queryable for PgTsVector { + type Row = Self; + + fn build(row: Self::Row) -> diesel::deserialize::Result { + Ok(row) + } + } + + #[derive(Debug, Clone, PartialEq)] + pub struct PgTsVector { + pub entries: Vec, + } + + #[derive(Debug, Clone, PartialEq)] + pub struct PgTsVectorEntry { + pub lexeme: String, + pub positions: Vec, + } } pub mod configuration { @@ -219,3 +286,110 @@ mod dsl { pub use self::dsl::*; pub use self::functions::*; pub use self::types::*; + +#[cfg(all(test, feature = "with-diesel-postgres"))] +mod tests { + use super::*; + + use diesel::dsl::sql; + use diesel::pg::PgConnection; + use diesel::prelude::*; + + #[test] + fn test_tsvector_from_sql_with_positions() { + let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let mut conn = + PgConnection::establish(&database_url).expect("Error connecting to database"); + + let query = diesel::select(sql::( + "to_tsvector('a fat cat sat on a mat and ate a fat rat')", + )); + let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query"); + + let expected = PgTsVector { + entries: vec![ + PgTsVectorEntry { + lexeme: "ate".to_owned(), + positions: vec![9], + }, + PgTsVectorEntry { + lexeme: "cat".to_owned(), + positions: vec![3], + }, + PgTsVectorEntry { + lexeme: "fat".to_owned(), + positions: vec![2, 11], + }, + PgTsVectorEntry { + lexeme: "mat".to_owned(), + positions: vec![7], + }, + PgTsVectorEntry { + lexeme: "rat".to_owned(), + positions: vec![12], + }, + PgTsVectorEntry { + lexeme: "sat".to_owned(), + positions: vec![4], + }, + ], + }; + + assert_eq!(expected, result); + } + + #[test] + fn test_tsvector_from_sql_without_positions() { + let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let mut conn = + PgConnection::establish(&database_url).expect("Error connecting to database"); + + let query = diesel::select(sql::( + "'a fat cat sat on a mat and ate a fat rat'::tsvector", + )); + let result: PgTsVector = query.get_result(&mut conn).expect("Error executing query"); + + let expected = PgTsVector { + entries: vec![ + PgTsVectorEntry { + lexeme: "a".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "and".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "ate".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "cat".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "fat".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "mat".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "on".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "rat".to_owned(), + positions: vec![], + }, + PgTsVectorEntry { + lexeme: "sat".to_owned(), + positions: vec![], + }, + ], + }; + + assert_eq!(expected, result); + } +}