diff --git a/pgvectorscale/src/access_method/build.rs b/pgvectorscale/src/access_method/build.rs index aa93e28..8899b8e 100644 --- a/pgvectorscale/src/access_method/build.rs +++ b/pgvectorscale/src/access_method/build.rs @@ -62,6 +62,14 @@ impl<'a, 'b> BuildState<'a, 'b> { } } +/// Maximum number of dimensions supported by pgvector's vector type. Also +/// the maximum number of dimensions that can be indexed with diskann. +pub const MAX_DIMENSION: u32 = 16000; + +/// Maximum number of dimensions that can be indexed with diskann without +/// using the SBQ storage type. +pub const MAX_DIMENSION_NO_SBQ: u32 = 2000; + #[pg_guard] pub extern "C" fn ambuild( heaprel: pg_sys::Relation, @@ -73,7 +81,7 @@ pub extern "C" fn ambuild( let opt = TSVIndexOptions::from_relation(&index_relation); notice!( - "Starting index build. num_neighbors={} search_list_size={}, max_alpha={}, storage_layout={:?}", + "Starting index build with num_neighbors={}, search_list_size={}, max_alpha={}, storage_layout={:?}.", opt.get_num_neighbors(), opt.search_list_size, opt.max_alpha, @@ -98,10 +106,22 @@ pub extern "C" fn ambuild( let meta_page = unsafe { MetaPage::create(&index_relation, dimensions as _, distance_type, opt) }; - assert!( - meta_page.get_num_dimensions_to_index() > 0 - && meta_page.get_num_dimensions_to_index() <= 2000 - ); + if meta_page.get_num_dimensions_to_index() == 0 { + error!("No dimensions to index"); + } + + if meta_page.get_num_dimensions_to_index() > MAX_DIMENSION { + error!("Too many dimensions to index (max is {})", MAX_DIMENSION); + } + + if meta_page.get_num_dimensions_to_index() > MAX_DIMENSION_NO_SBQ + && meta_page.get_storage_type() == StorageType::Plain + { + error!( + "Too many dimensions to index with plain storage (max is {}). Use storage_layout=memory_optimized instead.", + MAX_DIMENSION_NO_SBQ + ); + } let ntuples = do_heap_scan(index_info, &heap_relation, &index_relation, meta_page); @@ -878,7 +898,7 @@ pub mod tests { ); select setseed(0.5); - -- generate 300 vectors + -- generate {expected_cnt} vectors INSERT INTO {table_name} (id, embedding) SELECT * @@ -1036,7 +1056,7 @@ pub mod tests { ); select setseed(0.5); - -- generate 300 vectors + -- generate {expected_cnt} vectors INSERT INTO test_data (id, embedding) SELECT * @@ -1086,7 +1106,7 @@ pub mod tests { CREATE INDEX idx_diskann_bq ON test_data USING diskann (embedding) WITH ({index_options}); select setseed(0.5); - -- generate 300 vectors + -- generate {expected_cnt} vectors INSERT INTO test_data (id, embedding) SELECT * @@ -1114,4 +1134,51 @@ pub mod tests { verify_index_accuracy(expected_cnt, dimensions)?; Ok(()) } + + #[pg_test] + pub unsafe fn test_high_dimension_index() -> spi::Result<()> { + let index_options = "num_neighbors=10, search_list_size=10"; + let expected_cnt = 1000; + + for dimensions in [4000, 8000, 12000, 16000] { + Spi::run(&format!( + "CREATE TABLE test_data ( + id int, + embedding vector ({dimensions}) + ); + + CREATE INDEX idx_diskann_bq ON test_data USING diskann (embedding) WITH ({index_options}); + + select setseed(0.5); + -- generate {expected_cnt} vectors + INSERT INTO test_data (id, embedding) + SELECT + * + FROM ( + SELECT + i % {expected_cnt}, + ('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding + FROM + generate_series(1, {dimensions} * {expected_cnt}) i + GROUP BY + i % {expected_cnt}) g; + + SET enable_seqscan = 0; + -- perform index scans on the vectors + SELECT + * + FROM + test_data + ORDER BY + embedding <=> ( + SELECT + ('[' || array_to_string(array_agg(random()), ',', '0') || ']')::vector AS embedding + FROM generate_series(1, {dimensions}));"))?; + + verify_index_accuracy(expected_cnt, dimensions)?; + + Spi::run("DROP TABLE test_data CASCADE;")?; + } + Ok(()) + } } diff --git a/pgvectorscale/src/access_method/sbq.rs b/pgvectorscale/src/access_method/sbq.rs index 1e5b963..df6873b 100644 --- a/pgvectorscale/src/access_method/sbq.rs +++ b/pgvectorscale/src/access_method/sbq.rs @@ -20,8 +20,11 @@ use pgrx::{ use rkyv::{vec::ArchivedVec, Archive, Deserialize, Serialize}; use crate::util::{ - page::PageType, table_slot::TableSlot, tape::Tape, ArchivedItemPointer, HeapPointer, - IndexPointer, ItemPointer, ReadableBuffer, + chain::{ChainItemReader, ChainTapeWriter}, + page::{PageType, ReadablePage}, + table_slot::TableSlot, + tape::Tape, + ArchivedItemPointer, HeapPointer, IndexPointer, ItemPointer, ReadableBuffer, }; use super::{meta_page::MetaPage, neighbor_with_distance::NeighborWithDistance}; @@ -33,33 +36,28 @@ const BITS_STORE_TYPE_SIZE: usize = 64; #[derive(Archive, Deserialize, Serialize, Readable, Writeable)] #[archive(check_bytes)] #[repr(C)] -pub struct SbqMeans { +pub struct SbqMeansV1 { count: u64, means: Vec, m2: Vec, } -impl SbqMeans { +impl SbqMeansV1 { pub unsafe fn load( index: &PgRelation, - meta_page: &super::meta_page::MetaPage, + mut quantizer: SbqQuantizer, + qip: ItemPointer, stats: &mut S, ) -> SbqQuantizer { - let mut quantizer = SbqQuantizer::new(meta_page); - if quantizer.use_mean { - if meta_page.get_quantizer_metadata_pointer().is_none() { - pgrx::error!("No SBQ pointer found in meta page"); - } - let quantizer_item_pointer = meta_page.get_quantizer_metadata_pointer().unwrap(); - let bq = SbqMeans::read(index, quantizer_item_pointer, stats); - let archived = bq.get_archived_node(); - - quantizer.load( - archived.count, - archived.means.to_vec(), - archived.m2.to_vec(), - ); - } + assert!(quantizer.use_mean); + let bq = SbqMeansV1::read(index, qip, stats); + let archived = bq.get_archived_node(); + + quantizer.load( + archived.count, + archived.means.to_vec(), + archived.m2.to_vec(), + ); quantizer } @@ -69,7 +67,7 @@ impl SbqMeans { stats: &mut S, ) -> ItemPointer { let mut tape = Tape::new(index, PageType::SbqMeans); - let node = SbqMeans { + let node = SbqMeansV1 { count: quantizer.count, means: quantizer.mean.to_vec(), m2: quantizer.m2.to_vec(), @@ -80,6 +78,66 @@ impl SbqMeans { } } +#[derive(Archive, Deserialize, Serialize)] +#[archive(check_bytes)] +#[repr(C)] +pub struct SbqMeans { + count: u64, + means: Vec, + m2: Vec, +} + +impl SbqMeans { + pub unsafe fn load( + index: &PgRelation, + meta_page: &super::meta_page::MetaPage, + stats: &mut S, + ) -> SbqQuantizer { + let mut quantizer = SbqQuantizer::new(meta_page); + if !quantizer.use_mean { + return quantizer; + } + let qip = meta_page + .get_quantizer_metadata_pointer() + .unwrap_or_else(|| pgrx::error!("No SBQ pointer found in meta page")); + + let page = ReadablePage::read(index, qip.block_number); + let page_type = page.get_type(); + match page_type { + PageType::SbqMeansV1 => SbqMeansV1::load(index, quantizer, qip, stats), + PageType::SbqMeans => { + let mut tape_reader = ChainItemReader::new(index, PageType::SbqMeans, stats); + let mut buf: Vec = Vec::new(); + for item in tape_reader.read(qip) { + buf.extend_from_slice(item.get_data_slice()); + } + + let means = rkyv::from_bytes::(buf.as_slice()).unwrap(); + quantizer.load(means.count, means.means, means.m2); + quantizer + } + _ => { + pgrx::error!("Invalid page type {} for SbqMeans", page_type as u8); + } + } + } + + pub unsafe fn store( + index: &PgRelation, + quantizer: &SbqQuantizer, + stats: &mut S, + ) -> ItemPointer { + let bq = SbqMeans { + count: quantizer.count, + means: quantizer.mean.clone(), + m2: quantizer.m2.clone(), + }; + let mut tape = ChainTapeWriter::new(index, PageType::SbqMeans, stats); + let buf = rkyv::to_bytes::<_, 1024>(&bq).unwrap(); + tape.write(&buf) + } +} + #[derive(Clone)] pub struct SbqQuantizer { pub use_mean: bool, diff --git a/pgvectorscale/src/util/chain.rs b/pgvectorscale/src/util/chain.rs new file mode 100644 index 0000000..9791d79 --- /dev/null +++ b/pgvectorscale/src/util/chain.rs @@ -0,0 +1,280 @@ +//! This module defines the `ChainTape` data structure, which is used to store large data items that +//! are too big to fit in a single page. See `Tape` for a simpler version that assumes each data +//! item fits in a single page. +//! +//! All page entries begin with a header that contains an item pointer to the next chunk in the chain, +//! if applicable. The last chunk in the chain has an invalid item pointer. +//! +//! The implementation supports an append-only sequence of writes via `ChainTapeWriter` and reads +//! via `ChainTapeReader`. The writer returns an `ItemPointer` that can be used to read the data +//! back. Reads are done via an iterator that returns `ReadableBuffer` objects for the segments +//! of the data. + +use pgrx::{ + pg_sys::{BlockNumber, InvalidBlockNumber}, + PgRelation, +}; +use rkyv::{Archive, Deserialize, Serialize}; + +use crate::access_method::stats::{StatsNodeRead, StatsNodeWrite}; + +use super::{ + page::{PageType, ReadablePage, WritablePage}, + ItemPointer, ReadableBuffer, +}; + +#[derive(Clone, PartialEq, Archive, Deserialize, Serialize)] +#[archive(check_bytes)] +struct ChainItemHeader { + next: ItemPointer, +} + +const CHAIN_ITEM_HEADER_SIZE: usize = std::mem::size_of::(); + +// Empirically-measured slop factor for how much `pg_sys::PageGetFreeSpace` can +// overestimate the free space in a page in our usage patterns. +const PG_SLOP_SIZE: usize = 4; + +pub struct ChainTapeWriter<'a, S: StatsNodeWrite> { + page_type: PageType, + index: &'a PgRelation, + current: BlockNumber, + stats: &'a mut S, +} + +impl<'a, S: StatsNodeWrite> ChainTapeWriter<'a, S> { + /// Create a ChainTape that starts writing on a new page. + pub fn new(index: &'a PgRelation, page_type: PageType, stats: &'a mut S) -> Self { + assert!(page_type.is_chained()); + let page = WritablePage::new(index, page_type); + let block_number = page.get_block_number(); + page.commit(); + Self { + page_type, + index, + current: block_number, + stats, + } + } + + /// Write chained data to the tape, returning an `ItemPointer` to the start of the data. + pub fn write(&mut self, mut data: &[u8]) -> super::ItemPointer { + let mut current_page = WritablePage::modify(self.index, self.current); + + // If there isn't enough space for the header plus some data, start a new page. + if current_page.get_free_space() < CHAIN_ITEM_HEADER_SIZE + PG_SLOP_SIZE + 1 { + current_page = WritablePage::new(self.index, self.page_type); + self.current = current_page.get_block_number(); + } + + // ItemPointer to the first item in the chain. + let mut result: Option = None; + + // Write the data in chunks, creating new pages as needed. + while CHAIN_ITEM_HEADER_SIZE + data.len() + PG_SLOP_SIZE > current_page.get_free_space() { + let next_page = WritablePage::new(self.index, self.page_type); + let header = ChainItemHeader { + next: ItemPointer::new(next_page.get_block_number(), 1), + }; + let header_bytes = rkyv::to_bytes::<_, 256>(&header).unwrap(); + let data_size = current_page.get_free_space() - PG_SLOP_SIZE - CHAIN_ITEM_HEADER_SIZE; + let chunk = &data[..data_size]; + let combined = [header_bytes.as_slice(), chunk].concat(); + let offset_number = current_page.add_item(combined.as_ref()); + result.get_or_insert_with(|| { + ItemPointer::new(current_page.get_block_number(), offset_number) + }); + current_page.commit(); + self.stats.record_write(); + current_page = next_page; + data = &data[data_size..]; + } + + // Write the last chunk of data. + let header = ChainItemHeader { + next: ItemPointer::new_invalid(), + }; + let header_bytes = rkyv::to_bytes::<_, 256>(&header).unwrap(); + let combined = [header_bytes.as_slice(), data].concat(); + let offset_number = current_page.add_item(combined.as_ref()); + let result = result + .unwrap_or_else(|| ItemPointer::new(current_page.get_block_number(), offset_number)); + self.current = current_page.get_block_number(); + current_page.commit(); + self.stats.record_write(); + + result + } +} + +pub struct ChainItemReader<'a, S: StatsNodeRead> { + page_type: PageType, + index: &'a PgRelation, + stats: &'a mut S, +} + +impl<'a, S: StatsNodeRead> ChainItemReader<'a, S> { + pub fn new(index: &'a PgRelation, page_type: PageType, stats: &'a mut S) -> Self { + assert!(page_type.is_chained()); + Self { + page_type, + index, + stats, + } + } + + pub fn read(&'a mut self, ip: ItemPointer) -> ChainItemIterator<'a, S> { + ChainItemIterator { + index: self.index, + ip, + page_type: self.page_type, + stats: self.stats, + } + } +} + +pub struct ChainItemIterator<'a, S: StatsNodeRead> { + index: &'a PgRelation, + ip: ItemPointer, + page_type: PageType, + stats: &'a mut S, +} + +impl<'a, S: StatsNodeRead> Iterator for ChainItemIterator<'a, S> { + type Item = ReadableBuffer<'a>; + + fn next(&mut self) -> Option { + if self.ip.block_number == InvalidBlockNumber { + return None; + } + + unsafe { + let page = ReadablePage::read(self.index, self.ip.block_number); + self.stats.record_read(); + assert!(page.get_type() == self.page_type); + let mut item = page.get_item_unchecked(self.ip.offset); + let slice = item.get_data_slice(); + assert!(slice.len() > CHAIN_ITEM_HEADER_SIZE); + let header_slice = &slice[..CHAIN_ITEM_HEADER_SIZE]; + + let header = rkyv::check_archived_root::(header_slice).unwrap(); + self.ip = ItemPointer::new(header.next.block_number, header.next.offset); + + item.advance(CHAIN_ITEM_HEADER_SIZE); + + Some(item) + } + } +} + +#[cfg(any(test, feature = "pg_test"))] +#[pgrx::pg_schema] +mod tests { + use pgrx::{ + pg_sys::{self, BLCKSZ}, + pg_test, Spi, + }; + + use crate::access_method::stats::InsertStats; + + use super::*; + + fn make_test_relation() -> PgRelation { + Spi::run( + "CREATE TABLE test(encoding vector(3)); + CREATE INDEX idxtest + ON test + USING diskann(encoding) + WITH (num_neighbors=30);", + ) + .unwrap(); + + let index_oid = Spi::get_one::("SELECT 'idxtest'::regclass::oid") + .unwrap() + .expect("oid was null"); + unsafe { PgRelation::from_pg(pg_sys::RelationIdGetRelation(index_oid)) } + } + + #[pg_test] + #[allow(clippy::needless_range_loop)] + fn test_chain_tape() { + let mut rstats = InsertStats::default(); + let mut wstats = InsertStats::default(); + + let index = make_test_relation(); + { + // ChainTape can be used for small items too + let mut tape = ChainTapeWriter::new(&index, PageType::SbqMeans, &mut wstats); + for i in 0..100 { + let data = format!("hello world {i}"); + let ip = tape.write(data.as_bytes()); + let mut reader = ChainItemReader::new(&index, PageType::SbqMeans, &mut rstats); + + let mut iter = reader.read(ip); + let item = iter.next().unwrap(); + assert_eq!(item.get_data_slice(), data.as_bytes()); + assert!(iter.next().is_none()); + } + } + + for data_size in BLCKSZ - 100..BLCKSZ + 100 { + // Exhaustively test around the neighborhood of a page size + let mut bigdata = vec![0u8; data_size as usize]; + for i in 0..bigdata.len() { + bigdata[i] = (i % 256) as u8; + } + + let mut tape = ChainTapeWriter::new(&index, PageType::SbqMeans, &mut wstats); + for _ in 0..10 { + let ip = tape.write(&bigdata); + let mut count = 0; + let mut reader = ChainItemReader::new(&index, PageType::SbqMeans, &mut rstats); + for item in reader.read(ip) { + assert_eq!(item.get_data_slice(), &bigdata[count..count + item.len]); + count += item.len; + } + assert_eq!(count, bigdata.len()); + } + } + + for data_size in (2 * BLCKSZ - 100)..(2 * BLCKSZ + 100) { + // Exhaustively test around the neighborhood of a 2-page size + let mut bigdata = vec![0u8; data_size as usize]; + for i in 0..bigdata.len() { + bigdata[i] = (i % 256) as u8; + } + + let mut tape = ChainTapeWriter::new(&index, PageType::SbqMeans, &mut wstats); + for _ in 0..10 { + let ip = tape.write(&bigdata); + let mut count = 0; + let mut reader = ChainItemReader::new(&index, PageType::SbqMeans, &mut rstats); + for item in reader.read(ip) { + assert_eq!(item.get_data_slice(), &bigdata[count..count + item.len]); + count += item.len; + } + assert_eq!(count, bigdata.len()); + } + } + + for data_size in (3 * BLCKSZ - 100)..(3 * BLCKSZ + 100) { + // Exhaustively test around the neighborhood of a 3-page size + let mut bigdata = vec![0u8; data_size as usize]; + for i in 0..bigdata.len() { + bigdata[i] = (i % 256) as u8; + } + + let mut tape = ChainTapeWriter::new(&index, PageType::SbqMeans, &mut wstats); + for _ in 0..10 { + let ip = tape.write(&bigdata); + let mut count = 0; + let mut reader = ChainItemReader::new(&index, PageType::SbqMeans, &mut rstats); + for item in reader.read(ip) { + assert_eq!(item.get_data_slice(), &bigdata[count..count + item.len]); + count += item.len; + } + assert_eq!(count, bigdata.len()); + } + } + } +} diff --git a/pgvectorscale/src/util/mod.rs b/pgvectorscale/src/util/mod.rs index db6d765..8c0ecea 100644 --- a/pgvectorscale/src/util/mod.rs +++ b/pgvectorscale/src/util/mod.rs @@ -1,4 +1,5 @@ pub mod buffer; +pub mod chain; pub mod page; pub mod ports; pub mod table_slot; @@ -54,6 +55,16 @@ impl<'a> ReadableBuffer<'a> { pub fn get_owned_page(self) -> ReadablePage<'a> { self._page } + + pub fn len(&self) -> usize { + self.len + } + + pub fn advance(&mut self, len: usize) { + assert!(self.len >= len); + self.ptr = unsafe { self.ptr.add(len) }; + self.len -= len; + } } pub struct WritableBuffer<'a> { @@ -83,6 +94,13 @@ impl ItemPointer { } } + pub fn new_invalid() -> Self { + Self { + block_number: pgrx::pg_sys::InvalidBlockNumber, + offset: pgrx::pg_sys::InvalidOffsetNumber, + } + } + pub fn is_valid(&self) -> bool { self.block_number != pgrx::pg_sys::InvalidBlockNumber && self.offset != pgrx::pg_sys::InvalidOffsetNumber diff --git a/pgvectorscale/src/util/page.rs b/pgvectorscale/src/util/page.rs index 4dea201..120b309 100644 --- a/pgvectorscale/src/util/page.rs +++ b/pgvectorscale/src/util/page.rs @@ -30,25 +30,35 @@ pub enum PageType { Node = 1, PqQuantizerDef = 2, PqQuantizerVector = 3, - SbqMeans = 4, + SbqMeansV1 = 4, SbqNode = 5, Meta = 6, + SbqMeans = 7, } impl PageType { - fn from_u8(value: u8) -> Self { + pub fn from_u8(value: u8) -> Self { match value { 0 => PageType::MetaV1, 1 => PageType::Node, 2 => PageType::PqQuantizerDef, 3 => PageType::PqQuantizerVector, - 4 => PageType::SbqMeans, + 4 => PageType::SbqMeansV1, 5 => PageType::SbqNode, 6 => PageType::Meta, + 7 => PageType::SbqMeans, _ => panic!("Unknown PageType number {}", value), } } + + /// `ChainTape` supports chaining of pages that might contain large data. + /// This is not supported for all page types. Note that `Tape` requires + /// that the page type not be chained. + pub fn is_chained(self) -> bool { + matches!(self, PageType::SbqMeans) + } } + /// This is the Tsv-specific data that goes on every "diskann-owned" page /// It is placed at the end of a page in the "special" area diff --git a/pgvectorscale/src/util/tape.rs b/pgvectorscale/src/util/tape.rs index c9135ba..7592f1b 100644 --- a/pgvectorscale/src/util/tape.rs +++ b/pgvectorscale/src/util/tape.rs @@ -2,8 +2,8 @@ use super::page::{PageType, ReadablePage, WritablePage}; use pgrx::{ - pg_sys::{BlockNumber, BLCKSZ}, - *, + pg_sys::{BlockNumber, ForkNumber, RelationGetNumberOfBlocksInFork, BLCKSZ}, + PgRelation, }; pub struct Tape<'a> { @@ -13,8 +13,9 @@ pub struct Tape<'a> { } impl<'a> Tape<'a> { - /// Create a Tape that starts writing on a new page. + /// Create a `Tape` that starts writing on a new page. pub unsafe fn new(index: &'a PgRelation, page_type: PageType) -> Self { + assert!(!page_type.is_chained()); let page = WritablePage::new(index, page_type); let block_number = page.get_block_number(); page.commit(); @@ -27,10 +28,8 @@ impl<'a> Tape<'a> { /// Create a Tape that resumes writing on the newest page of the given type, if possible. pub unsafe fn resume(index: &'a PgRelation, page_type: PageType) -> Self { - let nblocks = pg_sys::RelationGetNumberOfBlocksInFork( - index.as_ptr(), - pg_sys::ForkNumber::MAIN_FORKNUM, - ); + assert!(!page_type.is_chained()); + let nblocks = RelationGetNumberOfBlocksInFork(index.as_ptr(), ForkNumber::MAIN_FORKNUM); let mut current_block = None; for block in (0..nblocks).rev() { if ReadablePage::read(index, block).get_type() == page_type { @@ -51,14 +50,12 @@ impl<'a> Tape<'a> { pub unsafe fn write(&mut self, data: &[u8]) -> super::ItemPointer { let size = data.len(); assert!(size < BLCKSZ as usize); + assert!(!self.page_type.is_chained()); let mut current_page = WritablePage::modify(self.index, self.current); - //don't split data over pages. Depending on packing, - //we may have to implement that in the future. + // Don't split data over pages. (See chain.rs for that.) if current_page.get_free_space() < size { - //TODO update forward pointer; - current_page = WritablePage::new(self.index, self.page_type); self.current = current_page.get_block_number(); if current_page.get_free_space() < size { @@ -69,6 +66,7 @@ impl<'a> Tape<'a> { let item_pointer = super::ItemPointer::with_page(¤t_page, offset_number); current_page.commit(); + item_pointer } @@ -78,6 +76,8 @@ impl<'a> Tape<'a> { #[cfg(any(test, feature = "pg_test"))] #[pgrx::pg_schema] mod tests { + use pgrx::{pg_sys, pg_test, Spi}; + use super::*; fn make_test_relation() -> PgRelation { @@ -121,19 +121,6 @@ mod tests { node_page }; - { - let mut tape = Tape::resume(&indexrel, PageType::SbqMeans); - let ip = tape.write(&[99]); - assert_eq!( - ip.block_number, tape.current, - "Tape block number should match IP" - ); - assert_ne!( - tape.current, node_page, - "Data can only be written to page of its type" - ); - } - { let mut tape = Tape::resume(&indexrel, PageType::PqQuantizerVector); let ip = tape.write(&[99]);