diff --git a/dbms/src/Common/LRUCache.h b/dbms/src/Common/LRUCache.h index 82c2c8c49e1..63284ae2969 100644 --- a/dbms/src/Common/LRUCache.h +++ b/dbms/src/Common/LRUCache.h @@ -71,6 +71,14 @@ class LRUCache return res; } + /// Returns whether a specific key is in the LRU cache + /// without updating the LRU order. + bool contains(const Key & key) + { + std::lock_guard cache_lock(mutex); + return cells.find(key) != cells.end(); + } + void set(const Key & key, const MappedPtr & mapped) { std::lock_guard cache_lock(mutex); diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h index 412aa67f6dd..d1440061b16 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.h @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include diff --git a/dbms/src/Interpreters/Context.cpp b/dbms/src/Interpreters/Context.cpp index b3465ef0b05..1679e089d2a 100644 --- a/dbms/src/Interpreters/Context.cpp +++ b/dbms/src/Interpreters/Context.cpp @@ -57,7 +57,7 @@ #include #include #include -#include +#include #include #include #include @@ -1405,16 +1405,16 @@ void Context::dropMinMaxIndexCache() const { auto lock = getLock(); if (shared->minmax_index_cache) - shared->minmax_index_cache->reset(); + shared->minmax_index_cache.reset(); } -void Context::setVectorIndexCache(size_t cache_size_in_bytes) +void Context::setVectorIndexCache(size_t cache_entities) { auto lock = getLock(); RUNTIME_CHECK(!shared->vector_index_cache); - shared->vector_index_cache = std::make_shared(cache_size_in_bytes); + shared->vector_index_cache = std::make_shared(cache_entities); } DM::VectorIndexCachePtr Context::getVectorIndexCache() const @@ -1427,7 +1427,7 @@ void Context::dropVectorIndexCache() const { auto lock = getLock(); if (shared->vector_index_cache) - shared->vector_index_cache->reset(); + shared->vector_index_cache.reset(); } bool Context::isDeltaIndexLimited() const diff --git a/dbms/src/Interpreters/Context.h b/dbms/src/Interpreters/Context.h index 42b1ee667a4..686da5d4696 100644 --- a/dbms/src/Interpreters/Context.h +++ b/dbms/src/Interpreters/Context.h @@ -398,7 +398,7 @@ class Context std::shared_ptr getMinMaxIndexCache() const; void dropMinMaxIndexCache() const; - void setVectorIndexCache(size_t cache_size_in_bytes); + void setVectorIndexCache(size_t cache_entities); std::shared_ptr getVectorIndexCache() const; void dropVectorIndexCache() const; diff --git a/dbms/src/Server/Server.cpp b/dbms/src/Server/Server.cpp index 2688c11478c..24e021a2406 100644 --- a/dbms/src/Server/Server.cpp +++ b/dbms/src/Server/Server.cpp @@ -1428,10 +1428,9 @@ int Server::main(const std::vector & /*args*/) if (minmax_index_cache_size) global_context->setMinMaxIndexCache(minmax_index_cache_size); - // 1GiB vector index cache. - size_t vec_index_cache_size = config().getUInt64("vec_index_cache_size", 1ULL * 1024 * 1024 * 1024); - if (vec_index_cache_size) - global_context->setVectorIndexCache(vec_index_cache_size); + size_t vec_index_cache_entities = config().getUInt64("vec_index_cache_entities", 1000); + if (vec_index_cache_entities) + global_context->setVectorIndexCache(vec_index_cache_entities); /// Size of max memory usage of DeltaIndex, used by DeltaMerge engine. /// - In non-disaggregated mode, its default value is 0, means unlimited, and it diff --git a/dbms/src/Storages/DeltaMerge/ColumnStat.h b/dbms/src/Storages/DeltaMerge/ColumnStat.h index 07a251388e7..f23b743ec77 100644 --- a/dbms/src/Storages/DeltaMerge/ColumnStat.h +++ b/dbms/src/Storages/DeltaMerge/ColumnStat.h @@ -41,7 +41,7 @@ struct ColumnStat size_t array_sizes_bytes = 0; size_t array_sizes_mark_bytes = 0; - std::optional vector_index = std::nullopt; + std::optional vector_index = std::nullopt; dtpb::ColumnStat toProto() const { diff --git a/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h b/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h index 639351bc33e..b02408166eb 100644 --- a/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h +++ b/dbms/src/Storages/DeltaMerge/DeltaMergeDefines.h @@ -93,14 +93,14 @@ struct ColumnDefine /// Note: ColumnDefine is used in both Write path and Read path. /// In the read path, vector_index is usually not available. Use AnnQueryInfo for /// read related vector index information. - TiDB::VectorIndexInfoPtr vector_index; + TiDB::VectorIndexDefinitionPtr vector_index; explicit ColumnDefine( ColId id_ = 0, String name_ = "", DataTypePtr type_ = nullptr, Field default_value_ = Field{}, - TiDB::VectorIndexInfoPtr vector_index_ = nullptr) + TiDB::VectorIndexDefinitionPtr vector_index_ = nullptr) : id(id_) , name(std::move(name_)) , type(std::move(type_)) diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp new file mode 100644 index 00000000000..008c3ffeba9 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.cpp @@ -0,0 +1,466 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "Storages/S3/FileCachePerf.h" + +namespace DB::ErrorCodes +{ +extern const int S3_ERROR; +} // namespace DB::ErrorCodes + +namespace DB::DM +{ + +DMFileWithVectorIndexBlockInputStream::DMFileWithVectorIndexBlockInputStream( + const ANNQueryInfoPtr & ann_query_info_, + const DMFilePtr & dmfile_, + Block && header_layout_, + DMFileReader && reader_, + ColumnDefine && vec_cd_, + const FileProviderPtr & file_provider_, + const ReadLimiterPtr & read_limiter_, + const ScanContextPtr & scan_context_, + const VectorIndexCachePtr & vec_index_cache_, + const BitmapFilterView & valid_rows_, + const String & tracing_id) + : log(Logger::get(tracing_id)) + , ann_query_info(ann_query_info_) + , dmfile(dmfile_) + , header_layout(std::move(header_layout_)) + , reader(std::move(reader_)) + , vec_cd(std::move(vec_cd_)) + , file_provider(file_provider_) + , read_limiter(read_limiter_) + , scan_context(scan_context_) + , vec_index_cache(vec_index_cache_) + , valid_rows(valid_rows_) +{ + RUNTIME_CHECK(ann_query_info); + RUNTIME_CHECK(vec_cd.id == ann_query_info->column_id()); + for (const auto & cd : reader.read_columns) + { + RUNTIME_CHECK(header_layout.has(cd.name), cd.name); + RUNTIME_CHECK(cd.id != vec_cd.id); + } + RUNTIME_CHECK(header_layout.has(vec_cd.name)); + RUNTIME_CHECK(header_layout.columns() == reader.read_columns.size() + 1); + + // Fill start_offset_to_pack_id + const auto & pack_stats = dmfile->getPackStats(); + start_offset_to_pack_id.reserve(pack_stats.size()); + UInt32 start_offset = 0; + for (size_t pack_id = 0, pack_id_max = pack_stats.size(); pack_id < pack_id_max; pack_id++) + { + start_offset_to_pack_id[start_offset] = pack_id; + start_offset += pack_stats[pack_id].rows; + } + + // Fill header + header = toEmptyBlock(reader.read_columns); + addColumnToBlock(header, vec_cd.id, vec_cd.name, vec_cd.type, vec_cd.type->createColumn(), vec_cd.default_value); +} + +DMFileWithVectorIndexBlockInputStream::~DMFileWithVectorIndexBlockInputStream() +{ + if (!vec_column_reader) + return; + + scan_context->total_vector_idx_read_vec_time_ms += static_cast(duration_read_from_vec_index_seconds * 1000); + scan_context->total_vector_idx_read_others_time_ms + += static_cast(duration_read_from_other_columns_seconds * 1000); + + LOG_DEBUG( // + log, + "Finished read DMFile with vector index for column dmf_{}/{}(id={}), " + "query_top_k={} load_index+result={:.2f}s read_from_index={:.2f}s read_from_others={:.2f}s", + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + ann_query_info->top_k(), + duration_load_vec_index_and_results_seconds, + duration_read_from_vec_index_seconds, + duration_read_from_other_columns_seconds); +} + + +Block DMFileWithVectorIndexBlockInputStream::read(FilterPtr & res_filter, bool return_filter) +{ + if (return_filter) + return readImpl(res_filter); + + // If return_filter == false, we must filter by ourselves. + + FilterPtr filter = nullptr; + auto res = readImpl(filter); + if (filter != nullptr) + { + for (auto & col : res) + col.column = col.column->filter(*filter, -1); + } + // filter == nullptr means all rows are valid and no need to filter. + + return res; +} + +Block DMFileWithVectorIndexBlockInputStream::readImpl(FilterPtr & res_filter) +{ + load(); + + Block res; + if (!reader.read_columns.empty()) + res = readByFollowingOtherColumns(); + else + res = readByIndexReader(); + + // Filter the output rows. If no rows need to filter, res_filter is nullptr. + filter.resize(res.rows()); + bool all_match = valid_rows_after_search.get(filter, res.startOffset(), res.rows()); + + if unlikely (all_match) + res_filter = nullptr; + else + res_filter = &filter; + return res; +} + +Block DMFileWithVectorIndexBlockInputStream::readByIndexReader() +{ + const auto & pack_stats = dmfile->getPackStats(); + size_t all_packs = pack_stats.size(); + const auto & use_packs = reader.pack_filter.getUsePacksConst(); + + RUNTIME_CHECK(use_packs.size() == all_packs); + + // Skip as many packs as possible according to Pack Filter + while (index_reader_next_pack_id < all_packs) + { + if (use_packs[index_reader_next_pack_id]) + break; + index_reader_next_row_id += pack_stats[index_reader_next_pack_id].rows; + index_reader_next_pack_id++; + } + + if (index_reader_next_pack_id >= all_packs) + // Finished + return {}; + + auto read_pack_id = index_reader_next_pack_id; + auto block_start_row_id = index_reader_next_row_id; + auto read_rows = pack_stats[read_pack_id].rows; + + index_reader_next_pack_id++; + index_reader_next_row_id += read_rows; + + Block block; + block.setStartOffset(block_start_row_id); + + auto vec_column = vec_cd.type->createColumn(); + + Stopwatch w; + vec_column_reader->read(vec_column, read_pack_id, read_rows); + duration_read_from_vec_index_seconds += w.elapsedSeconds(); + + block.insert(ColumnWithTypeAndName{// + std::move(vec_column), + vec_cd.type, + vec_cd.name, + vec_cd.id}); + + return block; +} + +Block DMFileWithVectorIndexBlockInputStream::readByFollowingOtherColumns() +{ + // First read other columns. + Stopwatch w; + auto block_others = reader.read(); + duration_read_from_other_columns_seconds += w.elapsedSeconds(); + + if (!block_others) + return {}; + + // Using vec_cd.type to construct a Column directly instead of using + // the type from dmfile, so that we don't need extra transforms + // (e.g. wrap with a Nullable). vec_column_reader is compatible with + // both Nullable and NotNullable. + auto vec_column = vec_cd.type->createColumn(); + + // Then read from vector index for the same pack. + w.restart(); + + vec_column_reader->read(vec_column, getPackIdFromBlock(block_others), block_others.rows()); + duration_read_from_vec_index_seconds += w.elapsedSeconds(); + + // Re-assemble block using the same layout as header_layout. + Block res = header_layout.cloneEmpty(); + // Note: the start offset counts from the beginning of THIS dmfile. It + // is not a global offset. + res.setStartOffset(block_others.startOffset()); + for (const auto & elem : block_others) + { + RUNTIME_CHECK(res.has(elem.name)); + res.getByName(elem.name).column = std::move(elem.column); + } + RUNTIME_CHECK(res.has(vec_cd.name)); + res.getByName(vec_cd.name).column = std::move(vec_column); + + return res; +} + +void DMFileWithVectorIndexBlockInputStream::load() +{ + if (loaded) + return; + + Stopwatch w; + + loadVectorIndex(); + loadVectorSearchResult(); + + duration_load_vec_index_and_results_seconds = w.elapsedSeconds(); + + loaded = true; +} + +void DMFileWithVectorIndexBlockInputStream::loadVectorIndex() +{ + bool has_s3_download = false; + bool has_load_from_file = false; + + double duration_load_index = 0; // include download from s3 and load from fs + + auto col_id = ann_query_info->column_id(); + + RUNTIME_CHECK(dmfile->useMetaV2()); // v3 + + // Check vector index exists on the column + const auto & column_stat = dmfile->getColumnStat(col_id); + RUNTIME_CHECK(column_stat.index_bytes > 0); + RUNTIME_CHECK(column_stat.vector_index.has_value()); + + // If local file is invalidated, cache is not valid anymore. So we + // need to ensure file exists on local fs first. + const auto file_name_base = DMFile::getFileNameBase(col_id); + const auto index_file_path = dmfile->colIndexPath(file_name_base); + String local_index_file_path; + FileSegmentPtr file_guard = nullptr; + if (auto s3_file_name = S3::S3FilenameView::fromKeyWithPrefix(index_file_path); s3_file_name.isValid()) + { + // Disaggregated mode + auto * file_cache = FileCache::instance(); + RUNTIME_CHECK_MSG(file_cache, "Must enable S3 file cache to use vector index"); + + Stopwatch watch; + + auto perf_begin = PerfContext::file_cache; + + // If download file failed, retry a few times. + for (auto i = 3; i > 0; --i) + { + try + { + file_guard = file_cache->downloadFileForLocalRead( // + s3_file_name, + column_stat.index_bytes); + if (file_guard) + { + local_index_file_path = file_guard->getLocalFileName(); + break; // Successfully downloaded index into local cache + } + + throw Exception( // + ErrorCodes::S3_ERROR, + "Failed to download vector index file {}", + index_file_path); + } + catch (...) + { + if (i <= 1) + throw; + } + } + + if ( // + PerfContext::file_cache.fg_download_from_s3 > perf_begin.fg_download_from_s3 || // + PerfContext::file_cache.fg_wait_download_from_s3 > perf_begin.fg_wait_download_from_s3) + has_s3_download = true; + + duration_load_index += watch.elapsedSeconds(); + } + else + { + // Not disaggregated mode + local_index_file_path = index_file_path; + } + + auto load_from_file = [&]() { + has_load_from_file = true; + return VectorIndexViewer::view(*column_stat.vector_index, local_index_file_path); + }; + + Stopwatch watch; + if (vec_index_cache) + // Note: must use local_index_file_path as the cache key, because cache + // will check whether file is still valid and try to remove memory references + // when file is dropped. + vec_index = vec_index_cache->getOrSet(local_index_file_path, load_from_file); + else + vec_index = load_from_file(); + + duration_load_index += watch.elapsedSeconds(); + RUNTIME_CHECK(vec_index != nullptr); + + scan_context->total_vector_idx_load_time_ms += static_cast(duration_load_index * 1000); + if (has_s3_download) + // it could be possible that s3=true but load_from_file=false, it means we download a file + // and then reuse the memory cache. The majority time comes from s3 download + // so we still count it as s3 download. + scan_context->total_vector_idx_load_from_s3++; + else if (has_load_from_file) + scan_context->total_vector_idx_load_from_disk++; + else + scan_context->total_vector_idx_load_from_cache++; + + LOG_DEBUG( // + log, + "Loaded vector index for column dmf_{}/{}(id={}), index_size={} kind={} cost={:.2f}s {} {}", + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + column_stat.index_bytes, + column_stat.vector_index->index_kind(), + duration_load_index, + has_s3_download ? "(S3)" : "", + has_load_from_file ? "(LoadFile)" : ""); +} + +void DMFileWithVectorIndexBlockInputStream::loadVectorSearchResult() +{ + Stopwatch watch; + + auto perf_begin = PerfContext::vector_search; + + auto results_rowid = vec_index->search(ann_query_info, valid_rows); + + auto discarded_nodes = PerfContext::vector_search.discarded_nodes - perf_begin.discarded_nodes; + auto visited_nodes = PerfContext::vector_search.visited_nodes - perf_begin.visited_nodes; + + double search_duration = watch.elapsedSeconds(); + scan_context->total_vector_idx_search_time_ms += static_cast(search_duration * 1000); + scan_context->total_vector_idx_search_discarded_nodes += discarded_nodes; + scan_context->total_vector_idx_search_visited_nodes += visited_nodes; + + size_t rows_after_mvcc = valid_rows.count(); + size_t rows_after_vector_search = results_rowid.size(); + + // After searching with the BitmapFilter, we create a bitmap + // to exclude rows that are not in the search result, because these rows + // are produced as [] or NULL, which is not a valid vector for future use. + // The bitmap will be used when returning the output to the caller. + { + valid_rows_after_search = BitmapFilter(valid_rows.size(), false); + for (auto rowid : results_rowid) + valid_rows_after_search.set(rowid, 1, true); + valid_rows_after_search.runOptimize(); + } + + vec_column_reader = std::make_shared( // + dmfile, + vec_index, + std::move(results_rowid)); + + // Vector index is very likely to filter out some packs. For example, + // if we query for Top 1, then only 1 pack will be remained. So we + // update the pack filter used by the DMFileReader to avoid reading + // unnecessary data for other columns. + size_t valid_packs_before_search = 0; + size_t valid_packs_after_search = 0; + const auto & pack_stats = dmfile->getPackStats(); + auto & use_packs = reader.pack_filter.getUsePacks(); + + size_t results_it = 0; + const size_t results_it_max = results_rowid.size(); + + UInt32 pack_start = 0; + + for (size_t pack_id = 0, pack_id_max = dmfile->getPacks(); pack_id < pack_id_max; pack_id++) + { + if (use_packs[pack_id]) + valid_packs_before_search++; + + bool pack_has_result = false; + + UInt32 pack_end = pack_start + pack_stats[pack_id].rows; + while (results_it < results_it_max // + && results_rowid[results_it] >= pack_start // + && results_rowid[results_it] < pack_end) + { + pack_has_result = true; + results_it++; + } + + if (!pack_has_result) + use_packs[pack_id] = 0; + + if (use_packs[pack_id]) + valid_packs_after_search++; + + pack_start = pack_end; + } + + RUNTIME_CHECK_MSG(results_it == results_it_max, "All packs has been visited but not all results are consumed"); + + LOG_DEBUG( // + log, + "Finished vector search over column dmf_{}/{}(id={}), cost={:.2f}s " + "top_k_[query/visited/discarded/result]={}/{}/{}/{} " + "rows_[file/after_mvcc/after_search]={}/{}/{} " + "pack_[total/before_search/after_search]={}/{}/{}", + + dmfile->fileId(), + vec_cd.name, + vec_cd.id, + search_duration, + + ann_query_info->top_k(), + visited_nodes, // Visited nodes will be larger than query_top_k when there are MVCC rows + discarded_nodes, // How many nodes are skipped by MVCC + results_rowid.size(), + + dmfile->getRows(), + rows_after_mvcc, + rows_after_vector_search, + + pack_stats.size(), + valid_packs_before_search, + valid_packs_after_search); +} + +UInt32 DMFileWithVectorIndexBlockInputStream::getPackIdFromBlock(const Block & block) +{ + // The start offset of a block is ensured to be aligned with the pack. + // This is how we know which pack the block comes from. + auto start_offset = block.startOffset(); + auto it = start_offset_to_pack_id.find(start_offset); + RUNTIME_CHECK(it != start_offset_to_pack_id.end()); + return it->second; +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h index 20aa03110d7..b992f88ef8a 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWithVectorIndexBlockInputStream.h @@ -14,10 +14,12 @@ #pragma once -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include #include namespace DB::DM @@ -88,72 +90,9 @@ class DMFileWithVectorIndexBlockInputStream : public SkippableBlockInputStream const ScanContextPtr & scan_context_, const VectorIndexCachePtr & vec_index_cache_, const BitmapFilterView & valid_rows_, - const String & tracing_id) - : log(Logger::get(tracing_id)) - , ann_query_info(ann_query_info_) - , dmfile(dmfile_) - , header_layout(std::move(header_layout_)) - , reader(std::move(reader_)) - , vec_cd(std::move(vec_cd_)) - , file_provider(file_provider_) - , read_limiter(read_limiter_) - , scan_context(scan_context_) - , vec_index_cache(vec_index_cache_) - , valid_rows(valid_rows_) - { - RUNTIME_CHECK(ann_query_info); - RUNTIME_CHECK(vec_cd.id == ann_query_info->column_id()); - for (const auto & cd : reader.read_columns) - { - RUNTIME_CHECK(header_layout.has(cd.name), cd.name); - RUNTIME_CHECK(cd.id != vec_cd.id); - } - RUNTIME_CHECK(header_layout.has(vec_cd.name)); - RUNTIME_CHECK(header_layout.columns() == reader.read_columns.size() + 1); - - // Fill start_offset_to_pack_id - const auto & pack_stats = dmfile->getPackStats(); - start_offset_to_pack_id.reserve(pack_stats.size()); - UInt32 start_offset = 0; - for (size_t pack_id = 0, pack_id_max = pack_stats.size(); pack_id < pack_id_max; pack_id++) - { - start_offset_to_pack_id[start_offset] = pack_id; - start_offset += pack_stats[pack_id].rows; - } - - // Fill header - header = toEmptyBlock(reader.read_columns); - addColumnToBlock( - header, - vec_cd.id, - vec_cd.name, - vec_cd.type, - vec_cd.type->createColumn(), - vec_cd.default_value); - } - - ~DMFileWithVectorIndexBlockInputStream() override - { - if (!vec_column_reader) - return; - - scan_context->total_vector_idx_read_vec_time_ms - += static_cast(duration_read_from_vec_index_seconds * 1000); - scan_context->total_vector_idx_read_others_time_ms - += static_cast(duration_read_from_other_columns_seconds * 1000); + const String & tracing_id); - LOG_DEBUG( // - log, - "Finished read DMFile with vector index for column dmf_{}/{}(id={}), " - "query_top_k={} load_index+result={:.2f}s read_from_index={:.2f}s read_from_others={:.2f}s", - dmfile->fileId(), - vec_cd.name, - vec_cd.id, - ann_query_info->top_k(), - duration_load_vec_index_and_results_seconds, - duration_read_from_vec_index_seconds, - duration_read_from_other_columns_seconds); - } + ~DMFileWithVectorIndexBlockInputStream() override; public: Block read() override @@ -165,48 +104,12 @@ class DMFileWithVectorIndexBlockInputStream : public SkippableBlockInputStream // When all rows in block are not filtered out, // `res_filter` will be set to null. // The caller needs to do handle this situation. - Block read(FilterPtr & res_filter, bool return_filter) override - { - if (return_filter) - return readImpl(res_filter); - - // If return_filter == false, we must filter by ourselves. - - FilterPtr filter = nullptr; - auto res = readImpl(filter); - if (filter != nullptr) - { - for (auto & col : res) - col.column = col.column->filter(*filter, -1); - } - // filter == nullptr means all rows are valid and no need to filter. - - return res; - } + Block read(FilterPtr & res_filter, bool return_filter) override; // When all rows in block are not filtered out, // `res_filter` will be set to null. // The caller needs to do handle this situation. - Block readImpl(FilterPtr & res_filter) - { - load(); - - Block res; - if (!reader.read_columns.empty()) - res = readByFollowingOtherColumns(); - else - res = readByIndexReader(); - - // Filter the output rows. If no rows need to filter, res_filter is nullptr. - filter.resize(res.rows()); - bool all_match = valid_rows_after_search.get(filter, res.startOffset(), res.rows()); - - if unlikely (all_match) - res_filter = nullptr; - else - res_filter = &filter; - return res; - } + Block readImpl(FilterPtr & res_filter); bool getSkippedRows(size_t &) override { @@ -237,320 +140,20 @@ class DMFileWithVectorIndexBlockInputStream : public SkippableBlockInputStream // Read data totally from the VectorColumnFromIndexReader. This is used // when there is no other column to read. - Block readByIndexReader() - { - const auto & pack_stats = dmfile->getPackStats(); - size_t all_packs = pack_stats.size(); - const auto & use_packs = reader.pack_filter.getUsePacksConst(); - - RUNTIME_CHECK(use_packs.size() == all_packs); - - // Skip as many packs as possible according to Pack Filter - while (index_reader_next_pack_id < all_packs) - { - if (use_packs[index_reader_next_pack_id]) - break; - index_reader_next_row_id += pack_stats[index_reader_next_pack_id].rows; - index_reader_next_pack_id++; - } - - if (index_reader_next_pack_id >= all_packs) - // Finished - return {}; - - auto read_pack_id = index_reader_next_pack_id; - auto block_start_row_id = index_reader_next_row_id; - auto read_rows = pack_stats[read_pack_id].rows; - - index_reader_next_pack_id++; - index_reader_next_row_id += read_rows; - - Block block; - block.setStartOffset(block_start_row_id); - - auto vec_column = vec_cd.type->createColumn(); - - Stopwatch w; - vec_column_reader->read(vec_column, read_pack_id, read_rows); - duration_read_from_vec_index_seconds += w.elapsedSeconds(); - - block.insert(ColumnWithTypeAndName{// - std::move(vec_column), - vec_cd.type, - vec_cd.name, - vec_cd.id}); - - return block; - } + Block readByIndexReader(); // Read data from other columns first, then read from VectorColumnFromIndexReader. This is used // when there are other columns to read. - Block readByFollowingOtherColumns() - { - // First read other columns. - Stopwatch w; - auto block_others = reader.read(); - duration_read_from_other_columns_seconds += w.elapsedSeconds(); - - if (!block_others) - return {}; - - // Using vec_cd.type to construct a Column directly instead of using - // the type from dmfile, so that we don't need extra transforms - // (e.g. wrap with a Nullable). vec_column_reader is compatible with - // both Nullable and NotNullable. - auto vec_column = vec_cd.type->createColumn(); - - // Then read from vector index for the same pack. - w.restart(); - - vec_column_reader->read(vec_column, getPackIdFromBlock(block_others), block_others.rows()); - duration_read_from_vec_index_seconds += w.elapsedSeconds(); - - // Re-assemble block using the same layout as header_layout. - Block res = header_layout.cloneEmpty(); - // Note: the start offset counts from the beginning of THIS dmfile. It - // is not a global offset. - res.setStartOffset(block_others.startOffset()); - for (const auto & elem : block_others) - { - RUNTIME_CHECK(res.has(elem.name)); - res.getByName(elem.name).column = std::move(elem.column); - } - RUNTIME_CHECK(res.has(vec_cd.name)); - res.getByName(vec_cd.name).column = std::move(vec_column); - - return res; - } + Block readByFollowingOtherColumns(); private: - void load() - { - if (loaded) - return; + void load(); - Stopwatch w; + void loadVectorIndex(); - loadVectorIndex(); - loadVectorSearchResult(); - - duration_load_vec_index_and_results_seconds = w.elapsedSeconds(); - - loaded = true; - } - - void loadVectorIndex() - { - bool is_index_load_from_cache = true; + void loadVectorSearchResult(); - auto col_id = ann_query_info->column_id(); - - RUNTIME_CHECK(dmfile->useMetaV2()); // v3 - - // Check vector index exists on the column - const auto & column_stat = dmfile->getColumnStat(col_id); - RUNTIME_CHECK(column_stat.index_bytes > 0); - - const auto & type = column_stat.type; - RUNTIME_CHECK(VectorIndex::isSupportedType(*type)); - RUNTIME_CHECK(column_stat.vector_index.has_value()); - - const auto file_name_base = DMFile::getFileNameBase(col_id); - auto load_vector_index = [&]() { - is_index_load_from_cache = false; - - auto index_guard = S3::S3RandomAccessFile::setReadFileInfo( - {dmfile->getReadFileSize(col_id, dmfile->colIndexFileName(file_name_base)), scan_context}); - - auto info = dmfile->merged_sub_file_infos.find(dmfile->colIndexFileName(file_name_base)); - if (info == dmfile->merged_sub_file_infos.end()) - { - throw Exception( - fmt::format("Unknown index file {}", dmfile->colIndexPath(file_name_base)), - ErrorCodes::LOGICAL_ERROR); - } - - auto file_path = dmfile->mergedPath(info->second.number); - auto encryp_path = dmfile->encryptionMergedPath(info->second.number); - auto offset = info->second.offset; - auto data_size = info->second.size; - - auto buffer = ReadBufferFromFileProvider( - file_provider, - file_path, - encryp_path, - dmfile->getConfiguration()->getChecksumFrameLength(), - read_limiter); - buffer.seek(offset); - - // TODO: Read from file directly? - String raw_data; - raw_data.resize(data_size); - buffer.read(reinterpret_cast(raw_data.data()), data_size); - - auto buf = createReadBufferFromData( - std::move(raw_data), - dmfile->colDataPath(file_name_base), - dmfile->getConfiguration()->getChecksumFrameLength(), - dmfile->configuration->getChecksumAlgorithm(), - dmfile->configuration->getChecksumFrameLength()); - - auto index_kind = magic_enum::enum_cast(column_stat.vector_index->index_kind()); - RUNTIME_CHECK(index_kind.has_value()); - RUNTIME_CHECK(index_kind.value() != TiDB::VectorIndexKind::INVALID); - - auto index_distance_metric - = magic_enum::enum_cast(column_stat.vector_index->distance_metric()); - RUNTIME_CHECK(index_distance_metric.has_value()); - RUNTIME_CHECK(index_distance_metric.value() != TiDB::DistanceMetric::INVALID); - - auto index = VectorIndex::load(index_kind.value(), index_distance_metric.value(), *buf); - return index; - }; - - Stopwatch watch; - - if (vec_index_cache) - { - // TODO: Is cache key valid on Compute Node for different Write Nodes? - vec_index = vec_index_cache->getOrSet(dmfile->colIndexCacheKey(file_name_base), load_vector_index); - } - else - { - // try load from the cache first - if (vec_index_cache) - vec_index = vec_index_cache->get(dmfile->colIndexCacheKey(file_name_base)); - if (vec_index == nullptr) - vec_index = load_vector_index(); - } - - double duration_load_index = watch.elapsedSeconds(); - RUNTIME_CHECK(vec_index != nullptr); - scan_context->total_vector_idx_load_time_ms += static_cast(duration_load_index * 1000); - if (is_index_load_from_cache) - scan_context->total_vector_idx_load_from_cache++; - else - scan_context->total_vector_idx_load_from_disk++; - - LOG_DEBUG( // - log, - "Loaded vector index for column dmf_{}/{}(id={}), index_size={} kind={} cost={:.2f}s from_cache={}", - dmfile->fileId(), - vec_cd.name, - vec_cd.id, - column_stat.index_bytes, - column_stat.vector_index->index_kind(), - duration_load_index, - is_index_load_from_cache); - } - - void loadVectorSearchResult() - { - Stopwatch watch; - - VectorIndex::SearchStatistics statistics; - auto results_rowid = vec_index->search(ann_query_info, valid_rows, statistics); - - double search_duration = watch.elapsedSeconds(); - scan_context->total_vector_idx_search_time_ms += static_cast(search_duration * 1000); - scan_context->total_vector_idx_search_discarded_nodes += statistics.discarded_nodes; - scan_context->total_vector_idx_search_visited_nodes += statistics.visited_nodes; - - size_t rows_after_mvcc = valid_rows.count(); - size_t rows_after_vector_search = results_rowid.size(); - - // After searching with the BitmapFilter, we create a bitmap - // to exclude rows that are not in the search result, because these rows - // are produced as [] or NULL, which is not a valid vector for future use. - // The bitmap will be used when returning the output to the caller. - { - valid_rows_after_search = BitmapFilter(valid_rows.size(), false); - for (auto rowid : results_rowid) - valid_rows_after_search.set(rowid, 1, true); - valid_rows_after_search.runOptimize(); - } - - vec_column_reader = std::make_shared( // - dmfile, - vec_index, - std::move(results_rowid)); - - // Vector index is very likely to filter out some packs. For example, - // if we query for Top 1, then only 1 pack will be remained. So we - // update the pack filter used by the DMFileReader to avoid reading - // unnecessary data for other columns. - size_t valid_packs_before_search = 0; - size_t valid_packs_after_search = 0; - const auto & pack_stats = dmfile->getPackStats(); - auto & use_packs = reader.pack_filter.getUsePacks(); - - size_t results_it = 0; - const size_t results_it_max = results_rowid.size(); - - UInt32 pack_start = 0; - - for (size_t pack_id = 0, pack_id_max = dmfile->getPacks(); pack_id < pack_id_max; pack_id++) - { - if (use_packs[pack_id]) - valid_packs_before_search++; - - bool pack_has_result = false; - - UInt32 pack_end = pack_start + pack_stats[pack_id].rows; - while (results_it < results_it_max // - && results_rowid[results_it] >= pack_start // - && results_rowid[results_it] < pack_end) - { - pack_has_result = true; - results_it++; - } - - if (!pack_has_result) - use_packs[pack_id] = 0; - - if (use_packs[pack_id]) - valid_packs_after_search++; - - pack_start = pack_end; - } - - RUNTIME_CHECK_MSG(results_it == results_it_max, "All packs has been visited but not all results are consumed"); - - LOG_DEBUG( // - log, - "Finished vector search over column dmf_{}/{}(id={}), cost={:.2f}s " - "top_k_[query/visited/discarded/result]={}/{}/{}/{} " - "rows_[file/after_mvcc/after_search]={}/{}/{} " - "pack_[total/before_search/after_search]={}/{}/{}", - - dmfile->fileId(), - vec_cd.name, - vec_cd.id, - search_duration, - - ann_query_info->top_k(), - statistics.visited_nodes, // Visited nodes will be larger than query_top_k when there are MVCC rows - statistics.discarded_nodes, // How many nodes are skipped by MVCC - results_rowid.size(), - - dmfile->getRows(), - rows_after_mvcc, - rows_after_vector_search, - - pack_stats.size(), - valid_packs_before_search, - valid_packs_after_search); - } - - inline UInt32 getPackIdFromBlock(const Block & block) - { - // The start offset of a block is ensured to be aligned with the pack. - // This is how we know which pack the block comes from. - auto start_offset = block.startOffset(); - auto it = start_offset_to_pack_id.find(start_offset); - RUNTIME_CHECK(it != start_offset_to_pack_id.end()); - return it->second; - } + UInt32 getPackIdFromBlock(const Block & block); private: const LoggerPtr log; @@ -575,7 +178,7 @@ class DMFileWithVectorIndexBlockInputStream : public SkippableBlockInputStream std::unordered_map start_offset_to_pack_id; // Filled from reader in constructor // Set after load(). - VectorIndexPtr vec_index = nullptr; + VectorIndexViewerPtr vec_index = nullptr; // Set after load(). VectorColumnFromIndexReaderPtr vec_column_reader = nullptr; // Set after load(). Used to filter the output rows. diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp index 278e814d88f..82436679897 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.cpp @@ -18,9 +18,9 @@ #include #include #include +#include #ifndef NDEBUG -#include #include #include #endif @@ -60,7 +60,7 @@ DMFileWriter::DMFileWriter( for (auto & cd : write_columns) { if (cd.vector_index) - RUNTIME_CHECK(VectorIndex::isSupportedType(*cd.type)); + RUNTIME_CHECK(VectorIndexBuilder::isSupportedType(*cd.type)); // TODO: currently we only generate index for Integers, Date, DateTime types, and this should be configurable by user. /// for handle column always generate index @@ -115,7 +115,11 @@ DMFileWriter::WriteBufferFromFileBasePtr DMFileWriter::createPackStatsFile() options.max_compress_block_size); } -void DMFileWriter::addStreams(ColId col_id, DataTypePtr type, bool do_index, TiDB::VectorIndexInfoPtr do_vector_index) +void DMFileWriter::addStreams( + ColId col_id, + DataTypePtr type, + bool do_index, + TiDB::VectorIndexDefinitionPtr do_vector_index) { auto callback = [&](const IDataType::SubstreamPath & substream_path) { const auto stream_name = DMFile::getFileNameBase(col_id, substream_path); @@ -358,34 +362,20 @@ void DMFileWriter::finalizeColumn(ColId col_id, DataTypePtr type) if (stream->vector_index && !is_empty_file) { - dmfile->checkMergedFile(merged_file, file_provider, write_limiter); - - auto fname = dmfile->colIndexFileName(stream_name); + // Vector index files are always not written into the merged file + // because we want to allow to be mmaped by the usearch. - auto buffer = createWriteBufferFromFileBaseByWriterBuffer( - merged_file.buffer, - dmfile->configuration->getChecksumAlgorithm(), - dmfile->configuration->getChecksumFrameLength()); - - stream->vector_index->serializeBinary(*buffer); - - col_stat.index_bytes = buffer->getMaterializedBytes(); + const auto index_name = dmfile->colIndexPath(stream_name); + stream->vector_index->save(index_name); + col_stat.index_bytes = Poco::File(index_name).getSize(); // Memorize what kind of vector index it is, so that we can correctly restore it when reading. - col_stat.vector_index = dtpb::ColumnVectorIndexInfo{}; - col_stat.vector_index->set_index_kind(String(magic_enum::enum_name(stream->vector_index->kind))); + col_stat.vector_index.emplace(); + col_stat.vector_index->set_index_kind( + tipb::VectorIndexKind_Name(stream->vector_index->definition->kind)); col_stat.vector_index->set_distance_metric( - String(magic_enum::enum_name(stream->vector_index->distance_metric))); - - MergedSubFileInfo info{ - fname, - merged_file.file_info.number, - merged_file.file_info.size, - col_stat.index_bytes}; - dmfile->merged_sub_file_infos[fname] = info; - - merged_file.file_info.size += col_stat.index_bytes; - buffer->next(); + tipb::VectorDistanceMetric_Name(stream->vector_index->definition->distance_metric)); + col_stat.vector_index->set_dimensions(stream->vector_index->definition->dimension); } // write mark into merged_file_writer diff --git a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h index 960d448d093..540c7838a2e 100644 --- a/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h +++ b/dbms/src/Storages/DeltaMerge/File/DMFileWriter.h @@ -55,7 +55,7 @@ class DMFileWriter FileProviderPtr & file_provider, const WriteLimiterPtr & write_limiter_, bool do_index, - TiDB::VectorIndexInfoPtr do_vector_index) + TiDB::VectorIndexDefinitionPtr do_vector_index) : plain_file(WriteBufferByFileProviderBuilder( dmfile->configuration.has_value(), file_provider, @@ -73,7 +73,7 @@ class DMFileWriter : std::unique_ptr( new CompressedWriteBuffer(*plain_file, compression_settings))) , minmaxes(do_index ? std::make_shared(*type) : nullptr) - , vector_index(do_vector_index ? VectorIndex::create(*do_vector_index) : nullptr) + , vector_index(do_vector_index ? VectorIndexBuilder::create(do_vector_index) : nullptr) { if (!dmfile->useMetaV2()) { @@ -100,7 +100,7 @@ class DMFileWriter WriteBufferPtr compressed_buf; MinMaxIndexPtr minmaxes; - VectorIndexPtr vector_index; + VectorIndexBuilderPtr vector_index; MarksInCompressedFilePtr marks; @@ -162,7 +162,7 @@ class DMFileWriter /// Add streams with specified column id. Since a single column may have more than one Stream, /// for example Nullable column has a NullMap column, we would track them with a mapping /// FileNameBase -> Stream. - void addStreams(ColId col_id, DataTypePtr type, bool do_index, TiDB::VectorIndexInfoPtr do_vector_index); + void addStreams(ColId col_id, DataTypePtr type, bool do_index, TiDB::VectorIndexDefinitionPtr do_vector_index); WriteBufferFromFileBasePtr createMetaFile(); WriteBufferFromFileBasePtr createMetaV2File(); diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp index c263643c250..c445312da0e 100644 --- a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.cpp @@ -14,8 +14,6 @@ #include -#include - namespace DB::DM { @@ -32,7 +30,7 @@ std::vector VectorColumnFromIndexReader::calcPackStartRowID(const DMFile } MutableColumnPtr VectorColumnFromIndexReader::calcResultsByPack( - std::vector && results, + std::vector && results, const DMFile::PackStats & pack_stats, const std::vector & pack_start_rowid) { @@ -110,7 +108,7 @@ void VectorColumnFromIndexReader::read(MutableColumnPtr & column, size_t start_p RUNTIME_CHECK(filled_result_rows == offset_in_pack); // TODO: We could fill multiple rows if rowid is continuous. - VectorIndex::Key rowid = pack_start_rowid[pack_id] + offset_in_pack; + VectorIndexViewer::Key rowid = pack_start_rowid[pack_id] + offset_in_pack; index->get(rowid, value); column->insertData(reinterpret_cast(value.data()), value.size() * sizeof(Float32)); filled_result_rows++; diff --git a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h index 14e807385bd..83b91b4e63e 100644 --- a/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h +++ b/dbms/src/Storages/DeltaMerge/File/VectorColumnFromIndexReader.h @@ -44,7 +44,7 @@ class VectorColumnFromIndexReader const DMFile::PackStats & pack_stats; const std::vector pack_start_rowid; - const VectorIndexPtr index; + const VectorIndexViewerPtr index; /// results_by_pack[i]=[a,b,c...] means pack[i]'s row offset [a,b,c,...] is contained in the result set. /// The rowid of a is pack_start_rowid[i]+a. MutableColumnPtr /* ColumnArray of UInt32 */ results_by_pack; @@ -53,7 +53,7 @@ class VectorColumnFromIndexReader static std::vector calcPackStartRowID(const DMFile::PackStats & pack_stats); static MutableColumnPtr calcResultsByPack( - std::vector && results, + std::vector && results, const DMFile::PackStats & pack_stats, const std::vector & pack_start_rowid); @@ -62,8 +62,8 @@ class VectorColumnFromIndexReader /// including NULLs and delete marks. explicit VectorColumnFromIndexReader( const DMFilePtr & dmfile_, - const VectorIndexPtr & index_, - std::vector && results_) + const VectorIndexViewerPtr & index_, + std::vector && results_) : dmfile(dmfile_) , pack_stats(dmfile_->getPackStats()) , pack_start_rowid(calcPackStartRowID(pack_stats)) diff --git a/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto b/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto index 77256bfff43..b5e3e5d915a 100644 --- a/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto +++ b/dbms/src/Storages/DeltaMerge/File/dtpb/dmfile.proto @@ -49,13 +49,6 @@ message ChecksumConfig { repeated ChecksumDebugInfo debug_info = 5; } -// Note: This message does not contain all fields of VectorIndexInfo, -// because this message is only used for reading the vector index carried with the column. -message ColumnVectorIndexInfo { - optional string index_kind = 1; - optional string distance_metric = 2; -} - message ColumnStat { optional int64 col_id = 1; optional string type_name = 2; @@ -69,9 +62,22 @@ message ColumnStat { optional uint64 array_sizes_bytes = 10; optional uint64 array_sizes_mark_bytes = 11; - optional ColumnVectorIndexInfo vector_index = 101; + reserved 101; // old VectorIndexFileProps which does not have dimensions, we just treat index as not exist. + optional VectorIndexFileProps vector_index = 102; } message ColumnStats { repeated ColumnStat column_stats = 1; } + +// Note: This message is something different to VectorIndexDefinition. +// VectorIndexDefinition defines an index, comes from table DDL. +// It includes information about how index should be constructed, +// for example, it contains HNSW's 'efConstruct' parameter. +// However, VectorIndexFileProps provides information for read out the index, +// for example, very basic information about what the index is, and how it is stored. +message VectorIndexFileProps { + optional string index_kind = 1; // The value is tipb.VectorIndexKind + optional string distance_metric = 2; // The value is tipb.VectorDistanceMetric + optional uint64 dimensions = 3; +} diff --git a/dbms/src/Storages/DeltaMerge/Index/RSIndex.h b/dbms/src/Storages/DeltaMerge/Index/RSIndex.h index b171e146286..39dc3b7bdb7 100644 --- a/dbms/src/Storages/DeltaMerge/Index/RSIndex.h +++ b/dbms/src/Storages/DeltaMerge/Index/RSIndex.h @@ -15,7 +15,6 @@ #pragma once #include -#include namespace DB { @@ -36,18 +35,12 @@ struct RSIndex DataTypePtr type; MinMaxIndexPtr minmax; EqualIndexPtr equal; - VectorIndexPtr vector; // TODO: Actually this is not a rough index. We put it here for convenience. RSIndex(const DataTypePtr & type_, const MinMaxIndexPtr & minmax_) : type(type_) , minmax(minmax_) {} - RSIndex(const DataTypePtr & type_, const VectorIndexPtr & vector_) - : type(type_) - , vector(vector_) - {} - RSIndex(const DataTypePtr & type_, const MinMaxIndexPtr & minmax_, const EqualIndexPtr & equal_) : type(type_) , minmax(minmax_) diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp index 652f5168d9d..ed53c33d3a5 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.cpp @@ -15,8 +15,10 @@ #include #include #include +#include #include #include +#include namespace DB::ErrorCodes { @@ -26,7 +28,7 @@ extern const int INCORRECT_QUERY; namespace DB::DM { -bool VectorIndex::isSupportedType(const IDataType & type) +bool VectorIndexBuilder::isSupportedType(const IDataType & type) { const auto * nullable = checkAndGetDataType(&type); if (nullable) @@ -35,53 +37,40 @@ bool VectorIndex::isSupportedType(const IDataType & type) return checkDataTypeArray(&type); } -VectorIndexPtr VectorIndex::create(const TiDB::VectorIndexInfo & index_info) +VectorIndexBuilderPtr VectorIndexBuilder::create(const TiDB::VectorIndexDefinitionPtr & definition) { - RUNTIME_CHECK(index_info.dimension > 0); - RUNTIME_CHECK(index_info.dimension <= std::numeric_limits::max()); + RUNTIME_CHECK(definition->dimension > 0); + RUNTIME_CHECK(definition->dimension <= std::numeric_limits::max()); - switch (index_info.kind) + switch (definition->kind) { - case TiDB::VectorIndexKind::HNSW: - switch (index_info.distance_metric) - { - case TiDB::DistanceMetric::L2: - return std::make_shared>(index_info.dimension); - case TiDB::DistanceMetric::COSINE: - return std::make_shared>(index_info.dimension); - default: - throw Exception( - ErrorCodes::INCORRECT_QUERY, - "Unsupported vector index distance metric {}", - index_info.distance_metric); - } + case tipb::VectorIndexKind::HNSW: + return std::make_shared(definition); default: - throw Exception(ErrorCodes::INCORRECT_QUERY, "Unsupported vector index {}", index_info.kind); + throw Exception( // + ErrorCodes::INCORRECT_QUERY, + "Unsupported vector index {}", + tipb::VectorIndexKind_Name(definition->kind)); } } -VectorIndexPtr VectorIndex::load(TiDB::VectorIndexKind kind, TiDB::DistanceMetric distance_metric, ReadBuffer & istr) +VectorIndexViewerPtr VectorIndexViewer::view(const dtpb::VectorIndexFileProps & file_props, std::string_view path) { - RUNTIME_CHECK(kind != TiDB::VectorIndexKind::INVALID); - RUNTIME_CHECK(distance_metric != TiDB::DistanceMetric::INVALID); + RUNTIME_CHECK(file_props.dimensions() > 0); + RUNTIME_CHECK(file_props.dimensions() <= std::numeric_limits::max()); + + tipb::VectorIndexKind kind; + RUNTIME_CHECK(tipb::VectorIndexKind_Parse(file_props.index_kind(), &kind)); switch (kind) { - case TiDB::VectorIndexKind::HNSW: - switch (distance_metric) - { - case TiDB::DistanceMetric::L2: - return VectorIndexHNSW::deserializeBinary(istr); - case TiDB::DistanceMetric::COSINE: - return VectorIndexHNSW::deserializeBinary(istr); - default: - throw Exception( - ErrorCodes::INCORRECT_QUERY, - "Unsupported vector index distance metric {}", - distance_metric); - } + case tipb::VectorIndexKind::HNSW: + return VectorIndexHNSWViewer::view(file_props, path); default: - throw Exception(ErrorCodes::INCORRECT_QUERY, "Unsupported vector index {}", kind); + throw Exception( // + ErrorCodes::INCORRECT_QUERY, + "Unsupported vector index {}", + file_props.index_kind()); } } diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h index d249d7a91e7..1d61d06a604 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex.h @@ -16,91 +16,75 @@ #include #include -#include #include #include #include #include +#include #include #include - -namespace DB -{ -namespace DM +namespace DB::DM { -class VectorIndex +/// Builds a VectorIndex in memory. +class VectorIndexBuilder { public: /// The key is the row's offset in the DMFile. using Key = UInt32; - /// True bit means the row is valid and should be kept in the search result. - /// False bit lets the row filtered out and will search for more results. - using RowFilter = BitmapFilterView; - - struct SearchStatistics - { - size_t visited_nodes = 0; - size_t discarded_nodes = 0; // Rows filtered out by MVCC - }; +public: + static VectorIndexBuilderPtr create(const TiDB::VectorIndexDefinitionPtr & definition); static bool isSupportedType(const IDataType & type); - static VectorIndexPtr create(const TiDB::VectorIndexInfo & index_info); - - static VectorIndexPtr load(TiDB::VectorIndexKind kind, TiDB::DistanceMetric distance_metric, ReadBuffer & istr); - - VectorIndex(TiDB::VectorIndexKind kind_, TiDB::DistanceMetric distance_metric_) - : kind(kind_) - , distance_metric(distance_metric_) +public: + explicit VectorIndexBuilder(const TiDB::VectorIndexDefinitionPtr & definition_) + : definition(definition_) {} - virtual ~VectorIndex() = default; + virtual ~VectorIndexBuilder() = default; virtual void addBlock(const IColumn & column, const ColumnVector * del_mark) = 0; - virtual void serializeBinary(WriteBuffer & ostr) const = 0; - - virtual size_t memoryUsage() const = 0; - - virtual std::vector search( // - const ANNQueryInfoPtr & queryInfo, - const RowFilter & valid_rows, - SearchStatistics & statistics) const - = 0; - - // Get the value (i.e. vector content) of a Key. - virtual void get(Key key, std::vector & out) const = 0; + virtual void save(std::string_view path) const = 0; public: - const TiDB::VectorIndexKind kind; - const TiDB::DistanceMetric distance_metric; + const TiDB::VectorIndexDefinitionPtr definition; }; -struct VectorIndexWeightFunction +/// Views a VectorIndex file. +/// It may nor may not read the whole content of the file into memory. +class VectorIndexViewer { - size_t operator()(const String &, const VectorIndex & index) const { return index.memoryUsage(); } -}; +public: + /// The key is the row's offset in the DMFile. + using Key = VectorIndexBuilder::Key; -class VectorIndexCache : public LRUCache, VectorIndexWeightFunction> -{ -private: - using Base = LRUCache, VectorIndexWeightFunction>; + /// True bit means the row is valid and should be kept in the search result. + /// False bit lets the row filtered out and will search for more results. + using RowFilter = BitmapFilterView; + +public: + static VectorIndexViewerPtr view(const dtpb::VectorIndexFileProps & file_props, std::string_view path); public: - explicit VectorIndexCache(size_t max_size_in_bytes) - : Base(max_size_in_bytes) + explicit VectorIndexViewer(const dtpb::VectorIndexFileProps & file_props_) + : file_props(file_props_) {} - template - MappedPtr getOrSet(const Key & key, LoadFunc && load) - { - auto result = Base::getOrSet(key, load); - return result.first; - } -}; + virtual ~VectorIndexViewer() = default; -} // namespace DM + virtual std::vector search( // + const ANNQueryInfoPtr & queryInfo, + const RowFilter & valid_rows) const + = 0; + + // Get the value (i.e. vector content) of a Key. + virtual void get(Key key, std::vector & out) const = 0; + +public: + const dtpb::VectorIndexFileProps file_props; +}; -} // namespace DB +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.cpp new file mode 100644 index 00000000000..55350df8642 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.cpp @@ -0,0 +1,100 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include + +namespace DB::DM +{ + +size_t VectorIndexCache::cleanOutdatedCacheEntries() +{ + size_t cleaned = 0; + + std::unordered_set files; + { + // Copy out the list to avoid occupying lock for too long. + // The complexity is just O(N) which is fine. + std::shared_lock lock(mu); + files = files_to_check; + } + + for (const auto & file_path : files) + { + if (is_shutting_down) + break; + + if (!cache.contains(file_path)) + { + // It is evicted from LRU cache + std::unique_lock lock(mu); + files_to_check.erase(file_path); + } + else if (!Poco::File(file_path).exists()) + { + LOG_INFO(log, "Dropping in-memory Vector Index cache because on-disk file is dropped, file={}", file_path); + { + std::unique_lock lock(mu); + files_to_check.erase(file_path); + } + cache.remove(file_path); + cleaned++; + } + } + + LOG_DEBUG(log, "Cleaned {} outdated Vector Index cache entries", cleaned); + + return cleaned; +} + +void VectorIndexCache::cleanOutdatedLoop() +{ + while (true) + { + { + std::unique_lock lock(shutdown_mu); + shutdown_cv.wait_for(lock, std::chrono::minutes(1), [this] { return is_shutting_down.load(); }); + } + + if (is_shutting_down) + break; + + try + { + cleanOutdatedCacheEntries(); + } + catch (...) + { + tryLogCurrentException(__PRETTY_FUNCTION__); + } + } +} + +VectorIndexCache::VectorIndexCache(size_t max_entities) + : cache(max_entities) + , log(Logger::get()) +{ + cleaner_thread = std::thread([this] { cleanOutdatedLoop(); }); +} + +VectorIndexCache::~VectorIndexCache() +{ + is_shutting_down = true; + shutdown_cv.notify_all(); + cleaner_thread.join(); +} + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.h new file mode 100644 index 00000000000..2cf51b73812 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexCache.h @@ -0,0 +1,79 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace DB::DM +{ + +class VectorIndexCache +{ +private: + using Cache = LRUCache; + + Cache cache; + LoggerPtr log; + + // Note: Key exists if cache does internal eviction. However it is fine, because + // we will remove them periodically. + std::unordered_set files_to_check; + std::shared_mutex mu; + + std::atomic is_shutting_down = false; + std::condition_variable shutdown_cv; + std::mutex shutdown_mu; + +#ifdef DBMS_PUBLIC_GTEST +public: +#else +private: +#endif + + // Drop the in-memory Vector Index if the on-disk file is deleted. + // mmaped file could be unmmaped so that disk space can be reclaimed. + size_t cleanOutdatedCacheEntries(); + + void cleanOutdatedLoop(); + + std::thread cleaner_thread; + +public: + explicit VectorIndexCache(size_t max_entities); + + ~VectorIndexCache(); + + template + Cache::MappedPtr getOrSet(const Cache::Key & file_path, LoadFunc && load) + { + { + std::scoped_lock lock(mu); + files_to_check.insert(file_path); + } + + auto result = cache.getOrSet(file_path, load); + return result.first; + } +}; + +} // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp index a58539def4b..9c051e9934f 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp @@ -16,9 +16,14 @@ #include #include #include +#include #include +#include +#include #include +#include +#include namespace DB::ErrorCodes { @@ -30,69 +35,36 @@ extern const int CANNOT_ALLOCATE_MEMORY; namespace DB::DM { -template -USearchIndexWithSerialization::USearchIndexWithSerialization(size_t dimensions) - : Base(Base::make(unum::usearch::metric_punned_t(dimensions, Metric))) -{} - -template -void USearchIndexWithSerialization::serialize(WriteBuffer & ostr) const -{ - auto callback = [&ostr](void * from, size_t n) { - ostr.write(reinterpret_cast(from), n); - return true; - }; - Base::save_to_stream(callback); -} - -template -void USearchIndexWithSerialization::deserialize(ReadBuffer & istr) -{ - auto callback = [&istr](void * from, size_t n) { - istr.readStrict(reinterpret_cast(from), n); - return true; - }; - Base::load_from_stream(callback); -} - -template class USearchIndexWithSerialization; -template class USearchIndexWithSerialization; - -constexpr TiDB::DistanceMetric toTiDBDistanceMetric(unum::usearch::metric_kind_t metric) +unum::usearch::metric_kind_t getUSearchMetricKind(tipb::VectorDistanceMetric d) { - switch (metric) + switch (d) { - case unum::usearch::metric_kind_t::l2sq_k: - return TiDB::DistanceMetric::L2; - case unum::usearch::metric_kind_t::cos_k: - return TiDB::DistanceMetric::COSINE; + case tipb::VectorDistanceMetric::INNER_PRODUCT: + return unum::usearch::metric_kind_t::ip_k; + case tipb::VectorDistanceMetric::COSINE: + return unum::usearch::metric_kind_t::cos_k; + case tipb::VectorDistanceMetric::L2: + return unum::usearch::metric_kind_t::l2sq_k; default: - return TiDB::DistanceMetric::INVALID; + // Specifically, L1 is currently unsupported by usearch. + + RUNTIME_CHECK_MSG( // + false, + "Unsupported vector distance {}", + tipb::VectorDistanceMetric_Name(d)); } } -constexpr tipb::VectorDistanceMetric toTiDBQueryDistanceMetric(unum::usearch::metric_kind_t metric) +VectorIndexHNSWBuilder::VectorIndexHNSWBuilder(const TiDB::VectorIndexDefinitionPtr & definition_) + : VectorIndexBuilder(definition_) + , index(USearchImplType::make(unum::usearch::metric_punned_t( // + definition_->dimension, + getUSearchMetricKind(definition->distance_metric)))) { - switch (metric) - { - case unum::usearch::metric_kind_t::l2sq_k: - return tipb::VectorDistanceMetric::L2; - case unum::usearch::metric_kind_t::cos_k: - return tipb::VectorDistanceMetric::COSINE; - default: - return tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; - } + RUNTIME_CHECK(definition_->kind == tipb::VectorIndexKind::HNSW); } -template -VectorIndexHNSW::VectorIndexHNSW(UInt32 dimensions_) - : VectorIndex(TiDB::VectorIndexKind::HNSW, toTiDBDistanceMetric(Metric)) - , dimensions(dimensions_) - , index(std::make_shared>(static_cast(dimensions_))) -{} - -template -void VectorIndexHNSW::addBlock(const IColumn & column, const ColumnVector * del_mark) +void VectorIndexHNSWBuilder::addBlock(const IColumn & column, const ColumnVector * del_mark) { // Note: column may be nullable. const ColumnArray * col_array; @@ -106,7 +78,7 @@ void VectorIndexHNSW::addBlock(const IColumn & column, const ColumnVecto const auto * del_mark_data = (!del_mark) ? nullptr : &(del_mark->getData()); - if (!index->reserve(unum::usearch::ceil2(index->size() + column.size()))) + if (!index.reserve(unum::usearch::ceil2(index.size() + column.size()))) { throw Exception(ErrorCodes::CANNOT_ALLOCATE_MEMORY, "Could not reserve memory for HNSW index"); } @@ -125,86 +97,84 @@ void VectorIndexHNSW::addBlock(const IColumn & column, const ColumnVecto continue; // Expect all data to have matching dimensions. - RUNTIME_CHECK(col_array->sizeAt(i) == dimensions); + RUNTIME_CHECK(col_array->sizeAt(i) == definition->dimension); auto data = col_array->getDataAt(i); - RUNTIME_CHECK(data.size == dimensions * sizeof(Float32)); + RUNTIME_CHECK(data.size == definition->dimension * sizeof(Float32)); - if (auto rc = index->add(row_offset, reinterpret_cast(data.data)); !rc) + if (auto rc = index.add(row_offset, reinterpret_cast(data.data)); !rc) throw Exception(ErrorCodes::INCORRECT_DATA, rc.error.release()); } } -template -void VectorIndexHNSW::serializeBinary(WriteBuffer & ostr) const +void VectorIndexHNSWBuilder::save(std::string_view path) const { - writeStringBinary(magic_enum::enum_name(kind), ostr); - writeStringBinary(magic_enum::enum_name(distance_metric), ostr); - writeIntBinary(dimensions, ostr); - index->serialize(ostr); + auto result = index.save(unum::usearch::output_file_t(path.data())); + RUNTIME_CHECK_MSG(result, "Failed to save vector index: {}", result.error.what()); } -template -VectorIndexPtr VectorIndexHNSW::deserializeBinary(ReadBuffer & istr) +VectorIndexViewerPtr VectorIndexHNSWViewer::view(const dtpb::VectorIndexFileProps & file_props, std::string_view path) { - String kind; - readStringBinary(kind, istr); - RUNTIME_CHECK(magic_enum::enum_cast(kind) == TiDB::VectorIndexKind::HNSW); + RUNTIME_CHECK(file_props.index_kind() == tipb::VectorIndexKind_Name(tipb::VectorIndexKind::HNSW)); - String distance_metric; - readStringBinary(distance_metric, istr); - RUNTIME_CHECK(magic_enum::enum_cast(distance_metric) == toTiDBDistanceMetric(Metric)); + tipb::VectorDistanceMetric metric; + RUNTIME_CHECK(tipb::VectorDistanceMetric_Parse(file_props.distance_metric(), &metric)); + RUNTIME_CHECK(metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); - UInt32 dimensions; - readIntBinary(dimensions, istr); + auto vi = std::make_shared(file_props); + vi->index = USearchImplType::make(unum::usearch::metric_punned_t( // + file_props.dimensions(), + getUSearchMetricKind(metric))); + auto result = vi->index.view(unum::usearch::memory_mapped_file_t(path.data())); + RUNTIME_CHECK_MSG(result, "Failed to load vector index: {}", result.error.what()); - auto vi = std::make_shared>(dimensions); - vi->index->deserialize(istr); return vi; } -template -std::vector VectorIndexHNSW::search( +std::vector VectorIndexHNSWViewer::search( const ANNQueryInfoPtr & queryInfo, - const RowFilter & valid_rows, - SearchStatistics & statistics) const + const RowFilter & valid_rows) const { RUNTIME_CHECK(queryInfo->ref_vec_f32().size() >= sizeof(UInt32)); auto query_vec_size = readLittleEndian(queryInfo->ref_vec_f32().data()); - if (query_vec_size != dimensions) + if (query_vec_size != file_props.dimensions()) throw Exception( ErrorCodes::INCORRECT_QUERY, "Query vector size {} does not match index dimensions {}", query_vec_size, - dimensions); + file_props.dimensions()); RUNTIME_CHECK(queryInfo->ref_vec_f32().size() >= sizeof(UInt32) + query_vec_size * sizeof(Float32)); - if (queryInfo->distance_metric() != toTiDBQueryDistanceMetric(Metric)) + if (tipb::VectorDistanceMetric_Name(queryInfo->distance_metric()) != file_props.distance_metric()) throw Exception( ErrorCodes::INCORRECT_QUERY, "Query distance metric {} does not match index distance metric {}", tipb::VectorDistanceMetric_Name(queryInfo->distance_metric()), - tipb::VectorDistanceMetric_Name(toTiDBQueryDistanceMetric(Metric))); + file_props.distance_metric()); - RUNTIME_CHECK(index != nullptr); + std::atomic visited_nodes = 0; + std::atomic discarded_nodes = 0; - auto predicate - = [&valid_rows, &statistics](typename USearchIndexWithSerialization::member_cref_t const & member) { - statistics.visited_nodes++; - if (!valid_rows[member.key]) - statistics.discarded_nodes++; - return valid_rows[member.key]; - }; + auto predicate = [&](typename USearchImplType::member_cref_t const & member) { + // Note: We don't increase the thread_local perf, to be compatible with future multi-thread change. + visited_nodes++; + if (!valid_rows[member.key]) + discarded_nodes++; + return valid_rows[member.key]; + }; // TODO: Support efSearch. - auto result = index->search( // + auto result = index.search( // reinterpret_cast(queryInfo->ref_vec_f32().data() + sizeof(UInt32)), queryInfo->top_k(), predicate); std::vector keys(result.size()); result.dump_to(keys.data()); + PerfContext::vector_search.visited_nodes += visited_nodes; + PerfContext::vector_search.discarded_nodes += discarded_nodes; + // For some reason usearch does not always do the predicate for all search results. // So we need to filter again. keys.erase( @@ -214,14 +184,10 @@ std::vector VectorIndexHNSW::search( return keys; } -template -void VectorIndexHNSW::get(Key key, std::vector & out) const +void VectorIndexHNSWViewer::get(Key key, std::vector & out) const { - out.resize(dimensions); - index->get(key, out.data()); + out.resize(file_props.dimensions()); + index.get(key, out.data()); } -template class VectorIndexHNSW; -template class VectorIndexHNSW; - } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h index a548d1b8e42..9a5804f9970 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.h @@ -14,6 +14,7 @@ #pragma once +#include #include #if __clang__ @@ -26,48 +27,40 @@ namespace DB::DM { -using USearchImplType - = unum::usearch::index_dense_gt; +using USearchImplType = unum::usearch:: + index_dense_gt; -template -class USearchIndexWithSerialization : public USearchImplType +class VectorIndexHNSWBuilder : public VectorIndexBuilder { - using Base = USearchImplType; - public: - explicit USearchIndexWithSerialization(size_t dimensions); - void serialize(WriteBuffer & ostr) const; - void deserialize(ReadBuffer & istr); -}; + explicit VectorIndexHNSWBuilder(const TiDB::VectorIndexDefinitionPtr & definition_); -template -using USearchIndexWithSerializationPtr = std::shared_ptr>; + void addBlock(const IColumn & column, const ColumnVector * del_mark) override; -template -class VectorIndexHNSW : public VectorIndex -{ -public: - explicit VectorIndexHNSW(UInt32 dimensions_); + void save(std::string_view path) const override; - void addBlock(const IColumn & column, const ColumnVector * del_mark) override; +private: + USearchImplType index; + UInt64 added_rows = 0; // Includes nulls and deletes. Used as the index key. +}; - void serializeBinary(WriteBuffer & ostr) const override; - static VectorIndexPtr deserializeBinary(ReadBuffer & istr); +class VectorIndexHNSWViewer : public VectorIndexViewer +{ +public: + static VectorIndexViewerPtr view(const dtpb::VectorIndexFileProps & props, std::string_view path); - size_t memoryUsage() const override { return index->memory_usage(); } + explicit VectorIndexHNSWViewer(const dtpb::VectorIndexFileProps & props) + : VectorIndexViewer(props) + {} std::vector search( // const ANNQueryInfoPtr & queryInfo, - const RowFilter & valid_rows, - SearchStatistics & statistics) const override; + const RowFilter & valid_rows) const override; void get(Key key, std::vector & out) const override; private: - const UInt32 dimensions; - const USearchIndexWithSerializationPtr index; - - UInt64 added_rows = 0; // Includes nulls and deletes. Used as the index key. + USearchImplType index; }; } // namespace DB::DM diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h b/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h index abf481411f7..131715302e5 100644 --- a/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h +++ b/dbms/src/Storages/DeltaMerge/Index/VectorIndex_fwd.h @@ -21,8 +21,11 @@ namespace DB::DM using ANNQueryInfoPtr = std::shared_ptr; -class VectorIndex; -using VectorIndexPtr = std::shared_ptr; +class VectorIndexBuilder; +using VectorIndexBuilderPtr = std::shared_ptr; + +class VectorIndexViewer; +using VectorIndexViewerPtr = std::shared_ptr; class VectorIndexCache; using VectorIndexCachePtr = std::shared_ptr; diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.cpp b/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.cpp new file mode 100644 index 00000000000..a7cca6be6a6 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.cpp @@ -0,0 +1,22 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB::PerfContext +{ + +thread_local VectorSearchPerfContext vector_search = {}; + +} diff --git a/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.h b/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.h new file mode 100644 index 00000000000..6fb3f1a7405 --- /dev/null +++ b/dbms/src/Storages/DeltaMerge/Index/VectorSearchPerf.h @@ -0,0 +1,37 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/// Remove the population of thread_local from Poco +#ifdef thread_local +#undef thread_local +#endif + +namespace DB::PerfContext +{ + +struct VectorSearchPerfContext +{ + size_t visited_nodes = 0; + size_t discarded_nodes = 0; // Rows filtered out by MVCC + + void reset() { *this = {}; } +}; + +extern thread_local VectorSearchPerfContext vector_search; + +} // namespace DB::PerfContext diff --git a/dbms/src/Storages/DeltaMerge/ScanContext.h b/dbms/src/Storages/DeltaMerge/ScanContext.h index 7659e04cdfc..fe46891b090 100644 --- a/dbms/src/Storages/DeltaMerge/ScanContext.h +++ b/dbms/src/Storages/DeltaMerge/ScanContext.h @@ -25,6 +25,10 @@ #include +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif namespace DB::DM { @@ -84,6 +88,7 @@ class ScanContext // Building bitmap std::atomic build_bitmap_time_ns{0}; + std::atomic total_vector_idx_load_from_s3{0}; std::atomic total_vector_idx_load_from_disk{0}; std::atomic total_vector_idx_load_from_cache{0}; std::atomic total_vector_idx_load_time_ms{0}; @@ -139,6 +144,7 @@ class ScanContext deserializeRegionNumberOfInstance(tiflash_scan_context_pb); + total_vector_idx_load_from_s3 = tiflash_scan_context_pb.total_vector_idx_load_from_s3(); total_vector_idx_load_from_disk = tiflash_scan_context_pb.total_vector_idx_load_from_disk(); total_vector_idx_load_from_cache = tiflash_scan_context_pb.total_vector_idx_load_from_cache(); total_vector_idx_load_time_ms = tiflash_scan_context_pb.total_vector_idx_load_time_ms(); @@ -189,6 +195,7 @@ class ScanContext serializeRegionNumOfInstance(tiflash_scan_context_pb); + tiflash_scan_context_pb.set_total_vector_idx_load_from_s3(total_vector_idx_load_from_s3); tiflash_scan_context_pb.set_total_vector_idx_load_from_disk(total_vector_idx_load_from_disk); tiflash_scan_context_pb.set_total_vector_idx_load_from_cache(total_vector_idx_load_from_cache); tiflash_scan_context_pb.set_total_vector_idx_load_time_ms(total_vector_idx_load_time_ms); @@ -245,6 +252,7 @@ class ScanContext mergeRegionNumberOfInstance(other); + total_vector_idx_load_from_s3 += other.total_vector_idx_load_from_s3; total_vector_idx_load_from_disk += other.total_vector_idx_load_from_disk; total_vector_idx_load_from_cache += other.total_vector_idx_load_from_cache; total_vector_idx_load_time_ms += other.total_vector_idx_load_time_ms; @@ -268,6 +276,7 @@ class ScanContext create_snapshot_time_ns += other.total_build_snapshot_ms() * 1000000; total_local_region_num += other.local_regions(); total_remote_region_num += other.remote_regions(); + user_read_bytes += other.user_read_bytes(); learner_read_ns += other.total_learner_read_ms() * 1000000; disagg_read_cache_hit_size += other.disagg_read_cache_hit_bytes(); @@ -297,6 +306,7 @@ class ScanContext disagg_read_cache_hit_size += other.disagg_read_cache_hit_bytes(); disagg_read_cache_miss_size += other.disagg_read_cache_miss_bytes(); + total_vector_idx_load_from_s3 += other.total_vector_idx_load_from_s3(); total_vector_idx_load_from_disk += other.total_vector_idx_load_from_disk(); total_vector_idx_load_from_cache += other.total_vector_idx_load_from_cache(); total_vector_idx_load_time_ms += other.total_vector_idx_load_time_ms(); diff --git a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp index ba9aa5fab11..64e811cd63b 100644 --- a/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp +++ b/dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp @@ -12,20 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include #include #include +#include +#include +#include +#include #include #include #include +#include #include +#include #include #include #include #include #include +#include + +#include "Storages/S3/FileCachePerf.h" + +namespace CurrentMetrics +{ +extern const Metric DT_SnapshotOfRead; +} // namespace CurrentMetrics + +namespace DB::FailPoints +{ +extern const char force_use_dmfile_format_v3[]; +} // namespace DB::FailPoints + namespace DB::DM::tests { @@ -65,6 +87,12 @@ class VectorIndexTestUtils EncodeVectorFloat32(arr, wb); return wb.str(); } + + ColumnDefine cdVec() + { + // When used in read, no need to assign vector_index. + return ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); + } }; class VectorIndexDMFileTest @@ -172,10 +200,10 @@ try { auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); - vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ - .kind = TiDB::VectorIndexKind::HNSW, + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, .dimension = 3, - .distance_metric = TiDB::DistanceMetric::L2, + .distance_metric = tipb::VectorDistanceMetric::L2, }); cols->emplace_back(vec_cd); @@ -433,7 +461,7 @@ try } catch (const DB::Exception & ex) { - ASSERT_STREQ("Query distance metric Cosine does not match index distance metric L2", ex.message().c_str()); + ASSERT_STREQ("Query distance metric COSINE does not match index distance metric L2", ex.message().c_str()); } catch (...) { @@ -475,10 +503,10 @@ try { auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); - vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ - .kind = TiDB::VectorIndexKind::HNSW, + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, .dimension = 3, - .distance_metric = TiDB::DistanceMetric::L2, + .distance_metric = tipb::VectorDistanceMetric::L2, }); cols->emplace_back(vec_cd); @@ -615,10 +643,10 @@ try { auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true); auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); - vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ - .kind = TiDB::VectorIndexKind::HNSW, + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, - .distance_metric = TiDB::DistanceMetric::L2, + .distance_metric = tipb::VectorDistanceMetric::L2, }); cols->emplace_back(vec_cd); @@ -773,12 +801,6 @@ class VectorIndexSegmentTestBase ColumnDefine cdPK() { return getExtraHandleColumnDefine(options.is_common_handle); } - ColumnDefine cdVec() - { - // When used in read, no need to assign vector_index. - return ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); - } - protected: Block prepareWriteBlockImpl(Int64 start_key, Int64 end_key, bool is_deleted) override { @@ -790,10 +812,10 @@ class VectorIndexSegmentTestBase void prepareColumns(const ColumnDefinesPtr & columns) override { auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)")); - vec_cd.vector_index = std::make_shared(TiDB::VectorIndexInfo{ - .kind = TiDB::VectorIndexKind::HNSW, + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, .dimension = 1, - .distance_metric = TiDB::DistanceMetric::L2, + .distance_metric = tipb::VectorDistanceMetric::L2, }); columns->emplace_back(vec_cd); } @@ -1108,4 +1130,686 @@ try } CATCH +class VectorIndexSegmentOnS3Test + : public VectorIndexTestUtils + , public DB::base::TiFlashStorageTestBasic +{ +public: + void SetUp() override + { + FailPointHelper::enableFailPoint(FailPoints::force_use_dmfile_format_v3); + + DB::tests::TiFlashTestEnv::enableS3Config(); + auto s3_client = S3::ClientFactory::instance().sharedTiFlashClient(); + ASSERT_TRUE(::DB::tests::TiFlashTestEnv::createBucketIfNotExist(*s3_client)); + TiFlashStorageTestBasic::SetUp(); + + auto & global_context = TiFlashTestEnv::getGlobalContext(); + + global_context.getSharedContextDisagg()->initRemoteDataStore( + global_context.getFileProvider(), + /*s3_enabled*/ true); + ASSERT_TRUE(global_context.getSharedContextDisagg()->remote_data_store != nullptr); + + orig_mode = global_context.getPageStorageRunMode(); + global_context.setPageStorageRunMode(PageStorageRunMode::UNI_PS); + global_context.tryReleaseWriteNodePageStorageForTest(); + global_context.initializeWriteNodePageStorageIfNeed(global_context.getPathPool()); + + global_context.setVectorIndexCache(1000); + + auto kvstore = db_context->getTMTContext().getKVStore(); + { + auto meta_store = metapb::Store{}; + meta_store.set_id(100); + kvstore->setStore(meta_store); + } + + TiFlashStorageTestBasic::reload(DB::Settings()); + storage_path_pool = std::make_shared(db_context->getPathPool().withTable("test", "t1", false)); + page_id_allocator = std::make_shared(); + storage_pool = std::make_shared( + *db_context, + NullspaceID, + ns_id, + *storage_path_pool, + page_id_allocator, + "test.t1"); + storage_pool->restore(); + + StorageRemoteCacheConfig file_cache_config{ + .dir = fmt::format("{}/fs_cache", getTemporaryPath()), + .capacity = 1 * 1000 * 1000 * 1000, + }; + FileCache::initialize(global_context.getPathCapacity(), file_cache_config); + + auto cols = DMTestEnv::getDefaultColumns(); + auto vec_cd = cdVec(); + vec_cd.vector_index = std::make_shared(TiDB::VectorIndexDefinition{ + .kind = tipb::VectorIndexKind::HNSW, + .dimension = 1, + .distance_metric = tipb::VectorDistanceMetric::L2, + }); + cols->emplace_back(vec_cd); + setColumns(cols); + + auto dm_context = dmContext(); + wn_segment = Segment::newSegment( + Logger::get(), + *dm_context, + table_columns, + RowKeyRange::newAll(false, 1), + DELTA_MERGE_FIRST_SEGMENT_ID, + 0); + ASSERT_EQ(wn_segment->segmentId(), DELTA_MERGE_FIRST_SEGMENT_ID); + } + + void TearDown() override + { + FailPointHelper::disableFailPoint(FailPoints::force_use_dmfile_format_v3); + + FileCache::shutdown(); + + auto & global_context = TiFlashTestEnv::getGlobalContext(); + global_context.dropVectorIndexCache(); + global_context.getSharedContextDisagg()->remote_data_store = nullptr; + global_context.setPageStorageRunMode(orig_mode); + + auto s3_client = S3::ClientFactory::instance().sharedTiFlashClient(); + ::DB::tests::TiFlashTestEnv::deleteBucket(*s3_client); + DB::tests::TiFlashTestEnv::disableS3Config(); + } + + static ColumnDefine cdPK() { return getExtraHandleColumnDefine(false); } + + BlockInputStreamPtr createComputeNodeStream( + const SegmentPtr & write_node_segment, + const ColumnDefines & columns_to_read, + const PushDownFilterPtr & filter, + const ScanContextPtr & read_scan_context = nullptr) + { + auto write_dm_context = dmContext(); + auto snap = write_node_segment->createSnapshot(*write_dm_context, false, CurrentMetrics::DT_SnapshotOfRead); + auto snap_proto = Remote::Serializer::serializeTo( + snap, + write_node_segment->segmentId(), + 0, + write_node_segment->rowkey_range, + {write_node_segment->rowkey_range}, + dummy_mem_tracker); + + auto cn_segment = std::make_shared( + Logger::get(), + /*epoch*/ 0, + write_node_segment->getRowKeyRange(), + write_node_segment->segmentId(), + /*next_segment_id*/ 0, + nullptr, + nullptr); + + auto read_dm_context = dmContext(read_scan_context); + auto cn_segment_snap = Remote::Serializer::deserializeSegmentSnapshotFrom( + *read_dm_context, + /* store_id */ 100, + 0, + /* table_id */ 100, + snap_proto); + + auto stream = cn_segment->getInputStream( + ReadMode::Bitmap, + *read_dm_context, + columns_to_read, + cn_segment_snap, + {write_node_segment->getRowKeyRange()}, + filter, + std::numeric_limits::max(), + DEFAULT_BLOCK_SIZE); + + return stream; + } + + static void removeAllFileCache() + { + auto * file_cache = FileCache::instance(); + auto file_segments = file_cache->getAll(); + for (const auto & file_seg : file_cache->getAll()) + file_cache->remove(file_cache->toS3Key(file_seg->getLocalFileName()), true); + + RUNTIME_CHECK(file_cache->getAll().empty()); + } + + void prepareWriteNodeStable() + { + auto dm_context = dmContext(); + Block block = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 100); + block.insert(colVecFloat32("[0, 100)", vec_column_name, vec_column_id)); + wn_segment->write(*dm_context, std::move(block), true); + wn_segment = wn_segment->mergeDelta(*dm_context, tableColumns()); + + // Let's just make sure we are later indeed reading from S3 + RUNTIME_CHECK(wn_segment->stable->getDMFiles()[0]->path().rfind("s3://") == 0); + } + + BlockInputStreamPtr computeNodeTableScan() + { + return createComputeNodeStream(wn_segment, {cdPK(), cdVec()}, nullptr); + } + + BlockInputStreamPtr computeNodeANNQuery( + const std::vector ref_vec, + UInt32 top_k = 1, + const ScanContextPtr & read_scan_context = nullptr) + { + auto ann_query_info = std::make_shared(); + ann_query_info->set_column_id(vec_column_id); + ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2); + ann_query_info->set_top_k(top_k); + ann_query_info->set_ref_vec_f32(encodeVectorFloat32(ref_vec)); + + auto stream = createComputeNodeStream( + wn_segment, + {cdPK(), cdVec()}, + std::make_shared(wrapWithANNQueryInfo(nullptr, ann_query_info)), + read_scan_context); + return stream; + } + +protected: + // setColumns should update dm_context at the same time + void setColumns(const ColumnDefinesPtr & columns) { table_columns = columns; } + + const ColumnDefinesPtr & tableColumns() const { return table_columns; } + + DMContextPtr dmContext(const ScanContextPtr & scan_context = nullptr) + { + return std::make_unique( + *db_context, + storage_path_pool, + storage_pool, + /*min_version_*/ 0, + NullspaceID, + /*physical_table_id*/ 100, + false, + 1, + db_context->getSettingsRef(), + scan_context); + } + +protected: + /// all these var lives as ref in dm_context + GlobalPageIdAllocatorPtr page_id_allocator; + std::shared_ptr storage_path_pool; + std::shared_ptr storage_pool; + ColumnDefinesPtr table_columns; + DM::DeltaMergeStore::Settings settings; + + NamespaceID ns_id = 100; + + // the segment we are going to test + SegmentPtr wn_segment; + + DB::PageStorageRunMode orig_mode = PageStorageRunMode::ONLY_V3; + + // MemoryTrackerPtr memory_tracker; + MemTrackerWrapper dummy_mem_tracker = MemTrackerWrapper(0, root_of_query_mem_trackers.get()); +}; + +TEST_F(VectorIndexSegmentOnS3Test, FileCacheNotEnabled) +try +{ + prepareWriteNodeStable(); + + FileCache::shutdown(); + auto stream = computeNodeANNQuery({5.0}); + + try + { + stream->readPrefix(); + stream->read(); + FAIL(); + } + catch (const DB::Exception & ex) + { + ASSERT_STREQ("Check file_cache failed: Must enable S3 file cache to use vector index", ex.message().c_str()); + } + catch (...) + { + FAIL(); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, ReadWithoutIndex) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto stream = computeNodeTableScan(); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[0, 100)"), + colVecFloat32("[0, 100)"), + })); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, ReadFromIndex) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + } + { + // Read again, we should be reading from memory cache. + + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 1); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 0); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, FileCacheEvict) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + } + { + // Simulate cache evict. + removeAllFileCache(); + } + { + // Check whether on-disk file is successfully unlinked when there is a memory + // cache. + auto * file_cache = FileCache::instance(); + ASSERT_TRUE(std::filesystem::is_empty(file_cache->cache_dir)); + } + { + // When cache is evicted (but memory cache exists), the query should be fine. + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + // Read again, we should be reading from memory cache. + + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 1); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 0); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, FileCacheEvictAndVectorCacheDrop) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + } + { + // Simulate cache evict. + removeAllFileCache(); + } + { + // Check whether on-disk file is successfully unlinked when there is a memory + // cache. + auto * file_cache = FileCache::instance(); + ASSERT_TRUE(std::filesystem::is_empty(file_cache->cache_dir)); + } + { + // We should be able to clear something from the vector index cache. + auto vec_cache = TiFlashTestEnv::getGlobalContext().getVectorIndexCache(); + ASSERT_EQ(1, vec_cache->cleanOutdatedCacheEntries()); + } + { + // When cache is evicted (and memory cache is dropped), the query should be fine. + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + // Read again, we should be reading from memory cache. + + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 1); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 0); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, FileCacheDeleted) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + + // Simulate cache file is deleted by user. + std::filesystem::remove_all(file_cache->cache_dir); + } + { + // Query should be fine. + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + // Read again, we should be reading from memory cache. + + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 1); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 0); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, FileCacheDeletedAndVectorCacheDrop) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + auto * file_cache = FileCache::instance(); + ASSERT_FALSE(file_cache->getAll().empty()); + ASSERT_FALSE(std::filesystem::is_empty(file_cache->cache_dir)); + + // Simulate cache file is deleted by user. + std::filesystem::remove_all(file_cache->cache_dir); + } + { + // We should be able to clear something from the vector index cache. + auto vec_cache = TiFlashTestEnv::getGlobalContext().getVectorIndexCache(); + ASSERT_EQ(1, vec_cache->cleanOutdatedCacheEntries()); + } + { + // Query should be fine. + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + } + { + // Read again, we should be reading from memory cache. + + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 1); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 0); + } +} +CATCH + +TEST_F(VectorIndexSegmentOnS3Test, ConcurrentDownloadFromS3) +try +{ + prepareWriteNodeStable(); + { + auto * file_cache = FileCache::instance(); + ASSERT_EQ(0, file_cache->getAll().size()); + } + + auto sp_s3_fg_download = SyncPointCtl::enableInScope("FileCache::fgDownload"); + auto sp_wait_other_s3 = SyncPointCtl::enableInScope("before_FileSegment::waitForNotEmpty_wait"); + + auto th_1 = std::async([&]() { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({5.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[5, 6)"), + colVecFloat32("[5, 6)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + + ASSERT_EQ(PerfContext::file_cache.fg_download_from_s3, 1); + ASSERT_EQ(PerfContext::file_cache.fg_wait_download_from_s3, 0); + }); + + // th_1 should be blocked when downloading from s3. + sp_s3_fg_download.waitAndPause(); + + auto th_2 = std::async([&]() { + auto scan_context = std::make_shared(); + auto stream = computeNodeANNQuery({7.0}, 1, scan_context); + ASSERT_INPUTSTREAM_COLS_UR( + stream, + Strings({DMTestEnv::pk_name, vec_column_name}), + createColumns({ + colInt64("[7, 8)"), + colVecFloat32("[7, 8)"), + })); + + ASSERT_EQ(scan_context->total_vector_idx_load_from_cache, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_disk, 0); + ASSERT_EQ(scan_context->total_vector_idx_load_from_s3, 1); + + ASSERT_EQ(PerfContext::file_cache.fg_download_from_s3, 0); + ASSERT_EQ(PerfContext::file_cache.fg_wait_download_from_s3, 1); + }); + + // th_2 should be blocked by waiting th_1 to finish downloading from s3. + sp_wait_other_s3.waitAndNext(); + + // Let th_1 finish downloading from s3. + sp_s3_fg_download.next(); + + // Both th_1 and th_2 should be able to finish without hitting sync points again. + // e.g. th_2 should not ever try to fgDownload. + th_1.get(); + th_2.get(); +} +CATCH + + } // namespace DB::DM::tests diff --git a/dbms/src/Storages/S3/FileCache.cpp b/dbms/src/Storages/S3/FileCache.cpp index 1646217df7e..8712c40fa56 100644 --- a/dbms/src/Storages/S3/FileCache.cpp +++ b/dbms/src/Storages/S3/FileCache.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -25,6 +26,7 @@ #include #include #include +#include #include #include @@ -56,6 +58,47 @@ namespace DB { using FileType = FileSegment::FileType; +FileSegment::Status FileSegment::waitForNotEmpty() +{ + std::unique_lock lock(mtx); + + if (status != Status::Empty) + return status; + + PerfContext::file_cache.fg_wait_download_from_s3++; + + Stopwatch watch; + + while (true) + { + SYNC_FOR("before_FileSegment::waitForNotEmpty_wait"); // just before actual waiting... + + auto is_done = cv_ready.wait_for(lock, std::chrono::seconds(30), [&] { return status != Status::Empty; }); + if (is_done) + break; + + double elapsed_secs = watch.elapsedSeconds(); + LOG_WARNING( + Logger::get(), + "FileCache is still waiting FileSegment ready, file={} elapsed={}s", + local_fname, + elapsed_secs); + + // Snapshot time is 300s + if (elapsed_secs > 300) + { + throw Exception( + ErrorCodes::S3_ERROR, + "Failed to wait until S3 file {} is ready after {}s", + local_fname, + elapsed_secs); + return status; + } + } + + return status; +} + FileCache::FileCache(PathCapacityMetricsPtr capacity_metrics_, const StorageRemoteCacheConfig & config_) : capacity_metrics(capacity_metrics_) , cache_dir(config_.getDTFileCacheDir()) @@ -107,6 +150,24 @@ RandomAccessFilePtr FileCache::getRandomAccessFile( } } +FileSegmentPtr FileCache::downloadFileForLocalRead( + const S3::S3FilenameView & s3_fname, + const std::optional & filesize) +{ + auto file_seg = getOrWait(s3_fname, filesize); + if (!file_seg) + return nullptr; + + auto path = file_seg->getLocalFileName(); + if likely (Poco::File(path).exists()) + return file_seg; + + // Normally, this would not happen. But if someone removes cache files manually, the status of memory and filesystem are inconsistent. + // We can handle this situation by remove it from FileCache. + remove(s3_fname.toFullKey(), /*force*/ true); + return nullptr; +} + FileSegmentPtr FileCache::get(const S3::S3FilenameView & s3_fname, const std::optional & filesize) { auto s3_key = s3_fname.toFullKey(); @@ -165,6 +226,72 @@ FileSegmentPtr FileCache::get(const S3::S3FilenameView & s3_fname, const std::op return nullptr; } +FileSegmentPtr FileCache::getOrWait(const S3::S3FilenameView & s3_fname, const std::optional & filesize) +{ + auto s3_key = s3_fname.toFullKey(); + auto file_type = getFileType(s3_key); + auto & table = tables[static_cast(file_type)]; + + std::unique_lock lock(mtx); + + auto f = table.get(s3_key); + if (f != nullptr) + { + lock.unlock(); + f->setLastAccessTime(std::chrono::system_clock::now()); + auto status = f->waitForNotEmpty(); + if (status == FileSegment::Status::Complete) + { + GET_METRIC(tiflash_storage_remote_cache, type_dtfile_hit).Increment(); + return f; + } + else + { + // On-going download failed, let the caller retry. + return nullptr; + } + } + + GET_METRIC(tiflash_storage_remote_cache, type_dtfile_miss).Increment(); + + auto estimzted_size = filesize ? *filesize : getEstimatedSizeOfFileType(file_type); + if (!reserveSpaceImpl(file_type, estimzted_size, /*try_evict*/ true)) + { + // Space not enough. + GET_METRIC(tiflash_storage_remote_cache, type_dtfile_full).Increment(); + LOG_DEBUG( + log, + "s3_key={} space not enough(capacity={} used={} estimzted_size={}), skip cache", + s3_key, + cache_capacity, + cache_used, + estimzted_size); + + // Just throw, no need to let the caller retry. + throw Exception( // + ErrorCodes::S3_ERROR, + "Cannot reserve {} space for object {}", + estimzted_size, + s3_key); + return nullptr; + } + + auto file_seg + = std::make_shared(toLocalFilename(s3_key), FileSegment::Status::Empty, estimzted_size, file_type); + table.set(s3_key, file_seg); + lock.unlock(); + + PerfContext::file_cache.fg_download_from_s3++; + fgDownload(lock, s3_key, file_seg); + if (!file_seg->isReadyToRead()) + throw Exception( // + ErrorCodes::S3_ERROR, + "Download object {} failed", + s3_key); + + return file_seg; +} + // Remove `local_fname` from disk and remove parent directory if parent directory is empty. void FileCache::removeDiskFile(const String & local_fname, bool update_fsize_metrics) const { @@ -202,20 +329,22 @@ void FileCache::removeDiskFile(const String & local_fname, bool update_fsize_met } } -void FileCache::remove(const String & s3_key, bool force) +void FileCache::remove(std::unique_lock &, const String & s3_key, bool force) { auto file_type = getFileType(s3_key); auto & table = tables[static_cast(file_type)]; - - std::lock_guard lock(mtx); auto f = table.get(s3_key, /*update_lru*/ false); if (f == nullptr) - { return; - } std::ignore = removeImpl(table, s3_key, f, force); } +void FileCache::remove(const String & s3_key, bool force) +{ + std::unique_lock lock(mtx); + remove(lock, s3_key, force); +} + std::pair::iterator> FileCache::removeImpl( LRUFileTable & table, const String & s3_key, @@ -514,6 +643,7 @@ void FileCache::download(const String & s3_key, FileSegmentPtr & file_seg) if (!file_seg->isReadyToRead()) { + file_seg->setStatus(FileSegment::Status::Failed); GET_METRIC(tiflash_storage_remote_cache, type_dtfile_download_failed).Increment(); bg_download_fail_count.fetch_add(1, std::memory_order_relaxed); file_seg.reset(); @@ -543,6 +673,35 @@ void FileCache::bgDownload(const String & s3_key, FileSegmentPtr & file_seg) [this, s3_key = s3_key, file_seg = file_seg]() mutable { download(s3_key, file_seg); }); } +void FileCache::fgDownload(std::unique_lock & cache_lock, const String & s3_key, FileSegmentPtr & file_seg) +{ + SYNC_FOR("FileCache::fgDownload"); // simulate long s3 download + + try + { + GET_METRIC(tiflash_storage_remote_cache, type_dtfile_download).Increment(); + downloadImpl(s3_key, file_seg); + } + catch (...) + { + tryLogCurrentException(log, fmt::format("Download s3_key={} failed", s3_key)); + } + + if (!file_seg->isReadyToRead()) + { + file_seg->setStatus(FileSegment::Status::Failed); + GET_METRIC(tiflash_storage_remote_cache, type_dtfile_download_failed).Increment(); + file_seg.reset(); + remove(cache_lock, s3_key); + } + + LOG_DEBUG( + log, + "foreground downloading count {} => s3_key {} finished", + bg_downloading_count.load(std::memory_order_relaxed), + s3_key); +} + bool FileCache::isS3Filename(const String & fname) { return S3::S3FilenameView::fromKey(fname).isValid(); diff --git a/dbms/src/Storages/S3/FileCache.h b/dbms/src/Storages/S3/FileCache.h index c4d8d66aada..7e5a1498c59 100644 --- a/dbms/src/Storages/S3/FileCache.h +++ b/dbms/src/Storages/S3/FileCache.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -71,6 +72,8 @@ class FileSegment return status == Status::Complete; } + Status waitForNotEmpty(); + void setSize(UInt64 size_) { std::lock_guard lock(mtx); @@ -81,6 +84,8 @@ class FileSegment { std::lock_guard lock(mtx); status = s; + if (status != Status::Empty) + cv_ready.notify_all(); } UInt64 getSize() const @@ -126,6 +131,7 @@ class FileSegment UInt64 size; const FileType file_type; std::chrono::time_point last_access_time; + std::condition_variable cv_ready; }; using FileSegmentPtr = std::shared_ptr; @@ -218,6 +224,13 @@ class FileCache const S3::S3FilenameView & s3_fname, const std::optional & filesize); + /// Download the file if it is not in the local cache and returns the + /// file guard of the local cache file. When file guard is alive, + /// local file will not be evicted. + FileSegmentPtr downloadFileForLocalRead( + const S3::S3FilenameView & s3_fname, + const std::optional & filesize); + void updateConfig(const Settings & settings); #ifndef DBMS_PUBLIC_GTEST @@ -232,8 +245,14 @@ class FileCache DISALLOW_COPY_AND_MOVE(FileCache); FileSegmentPtr get(const S3::S3FilenameView & s3_fname, const std::optional & filesize = std::nullopt); + /// Try best to wait until the file is available in cache. If the file is not in cache, it will download the file in foreground. + /// It may return nullptr after wait. In this case the caller could retry. + FileSegmentPtr getOrWait( + const S3::S3FilenameView & s3_fname, + const std::optional & filesize = std::nullopt); void bgDownload(const String & s3_key, FileSegmentPtr & file_seg); + void fgDownload(std::unique_lock & cache_lock, const String & s3_key, FileSegmentPtr & file_seg); void download(const String & s3_key, FileSegmentPtr & file_seg); void downloadImpl(const String & s3_key, FileSegmentPtr & file_seg); @@ -250,6 +269,7 @@ class FileCache void restoreTable(const std::filesystem::directory_entry & table_entry); void restoreDMFile(const std::filesystem::directory_entry & dmfile_entry); + void remove(std::unique_lock & cache_lock, const String & s3_key, bool force = false); void remove(const String & s3_key, bool force = false); std::pair::iterator> removeImpl( LRUFileTable & table, diff --git a/dbms/src/Storages/S3/FileCachePerf.cpp b/dbms/src/Storages/S3/FileCachePerf.cpp new file mode 100644 index 00000000000..937dd3ff2ea --- /dev/null +++ b/dbms/src/Storages/S3/FileCachePerf.cpp @@ -0,0 +1,22 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB::PerfContext +{ + +thread_local FileCachePerfContext file_cache = {}; + +} diff --git a/dbms/src/Storages/S3/FileCachePerf.h b/dbms/src/Storages/S3/FileCachePerf.h new file mode 100644 index 00000000000..e206de87f68 --- /dev/null +++ b/dbms/src/Storages/S3/FileCachePerf.h @@ -0,0 +1,37 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +/// Remove the population of thread_local from Poco +#ifdef thread_local +#undef thread_local +#endif + +namespace DB::PerfContext +{ + +struct FileCachePerfContext +{ + size_t fg_download_from_s3 = 0; + size_t fg_wait_download_from_s3 = 0; + + void reset() { *this = {}; } +}; + +extern thread_local FileCachePerfContext file_cache; + +} // namespace DB::PerfContext diff --git a/dbms/src/TiDB/Schema/TiDB.cpp b/dbms/src/TiDB/Schema/TiDB.cpp index 35645fd30fb..cc6058e7139 100644 --- a/dbms/src/TiDB/Schema/TiDB.cpp +++ b/dbms/src/TiDB/Schema/TiDB.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -406,13 +407,13 @@ try if (vector_index) { - RUNTIME_CHECK(vector_index->kind != VectorIndexKind::INVALID); - RUNTIME_CHECK(vector_index->distance_metric != DistanceMetric::INVALID); + RUNTIME_CHECK(vector_index->kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); + RUNTIME_CHECK(vector_index->distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); Poco::JSON::Object::Ptr vector_index_json = new Poco::JSON::Object(); - vector_index_json->set("kind", String(magic_enum::enum_name(vector_index->kind))); + vector_index_json->set("kind", tipb::VectorIndexKind_Name(vector_index->kind)); vector_index_json->set("dimension", vector_index->dimension); - vector_index_json->set("distance_metric", String(magic_enum::enum_name(vector_index->distance_metric))); + vector_index_json->set("distance_metric", tipb::VectorDistanceMetric_Name(vector_index->distance_metric)); json->set("vector_index", vector_index_json); } @@ -470,22 +471,29 @@ try auto vector_index_json = json->getObject("vector_index"); if (vector_index_json) { - vector_index = std::make_shared(); - - auto vector_kind = magic_enum::enum_cast(vector_index_json->getValue("kind")); - RUNTIME_CHECK(vector_kind.has_value()); - RUNTIME_CHECK(vector_kind.value() != VectorIndexKind::INVALID); - vector_index->kind = vector_kind.value(); - - vector_index->dimension = vector_index_json->getValue("dimension"); - RUNTIME_CHECK(vector_index->dimension > 0); - RUNTIME_CHECK(vector_index->dimension <= 16000); // Just a protection - - auto distance_metric - = magic_enum::enum_cast(vector_index_json->getValue("distance_metric")); - RUNTIME_CHECK(distance_metric.has_value()); - RUNTIME_CHECK(distance_metric.value() != DistanceMetric::INVALID); - vector_index->distance_metric = distance_metric.value(); + tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND; + auto ok = tipb::VectorIndexKind_Parse( // + vector_index_json->getValue("kind"), + &kind); + RUNTIME_CHECK(ok); + RUNTIME_CHECK(kind != tipb::VectorIndexKind::INVALID_INDEX_KIND); + + auto dimension = vector_index_json->getValue("dimension"); + RUNTIME_CHECK(dimension > 0); + RUNTIME_CHECK(dimension <= 16000); // Just a protection + + tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; + ok = tipb::VectorDistanceMetric_Parse( // + vector_index_json->getValue("distance_metric"), + &distance_metric); + RUNTIME_CHECK(ok); + RUNTIME_CHECK(distance_metric != tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC); + + vector_index = std::make_shared(VectorIndexDefinition{ + .kind = kind, + .dimension = dimension, + .distance_metric = distance_metric, + }); } } catch (const Poco::Exception & e) diff --git a/dbms/src/TiDB/Schema/TiDB.h b/dbms/src/TiDB/Schema/TiDB.h index 49abf0ed44a..5d30b950b92 100644 --- a/dbms/src/TiDB/Schema/TiDB.h +++ b/dbms/src/TiDB/Schema/TiDB.h @@ -201,7 +201,7 @@ struct ColumnInfo SchemaState state = StateNone; String comment; - VectorIndexInfoPtr vector_index = nullptr; + VectorIndexDefinitionPtr vector_index = nullptr; #ifdef M #error "Please undefine macro M first." diff --git a/dbms/src/TiDB/Schema/VectorIndex.h b/dbms/src/TiDB/Schema/VectorIndex.h index 8ba8b0a0d98..9d5901ad695 100644 --- a/dbms/src/TiDB/Schema/VectorIndex.h +++ b/dbms/src/TiDB/Schema/VectorIndex.h @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -23,80 +24,46 @@ namespace TiDB { -enum class VectorIndexKind +// Constructed from table definition. +struct VectorIndexDefinition { - INVALID = 0, - - // Note: Field names must match TiDB's enum definition. - HNSW, -}; - -enum class DistanceMetric -{ - INVALID = 0, - - // Note: Field names must match TiDB's enum definition. - L1, - L2, - COSINE, - INNER_PRODUCT, -}; - - -struct VectorIndexInfo -{ - VectorIndexKind kind = VectorIndexKind::INVALID; + tipb::VectorIndexKind kind = tipb::VectorIndexKind::INVALID_INDEX_KIND; UInt64 dimension = 0; - DistanceMetric distance_metric = DistanceMetric::INVALID; -}; - -using VectorIndexInfoPtr = std::shared_ptr; - -} // namespace TiDB - -template <> -struct fmt::formatter -{ - static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } + tipb::VectorDistanceMetric distance_metric = tipb::VectorDistanceMetric::INVALID_DISTANCE_METRIC; - template - auto format(const TiDB::VectorIndexKind & v, FormatContext & ctx) const -> decltype(ctx.out()) - { - return format_to(ctx.out(), "{}", magic_enum::enum_name(v)); - } + // TODO: There are possibly more fields, like efConstruct. + // Will be added later. }; -template <> -struct fmt::formatter -{ - static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } +// As this is constructed from TiDB's table definition, we should not +// ever try to modify it anyway. +using VectorIndexDefinitionPtr = std::shared_ptr; - template - auto format(const TiDB::DistanceMetric & d, FormatContext & ctx) const -> decltype(ctx.out()) - { - return format_to(ctx.out(), "{}", magic_enum::enum_name(d)); - } -}; +} // namespace TiDB template <> -struct fmt::formatter +struct fmt::formatter { static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } template - auto format(const TiDB::VectorIndexInfo & vi, FormatContext & ctx) const -> decltype(ctx.out()) + auto format(const TiDB::VectorIndexDefinition & vi, FormatContext & ctx) const -> decltype(ctx.out()) { - return format_to(ctx.out(), "{}:{}", vi.kind, vi.distance_metric); + return format_to( + ctx.out(), // + "{}:{}", + tipb::VectorIndexKind_Name(vi.kind), + tipb::VectorDistanceMetric_Name(vi.distance_metric)); } }; template <> -struct fmt::formatter +struct fmt::formatter { static constexpr auto parse(format_parse_context & ctx) { return ctx.begin(); } template - auto format(const TiDB::VectorIndexInfoPtr & vi, FormatContext & ctx) const -> decltype(ctx.out()) + auto format(const TiDB::VectorIndexDefinitionPtr & vi, FormatContext & ctx) const -> decltype(ctx.out()) { if (!vi) return format_to(ctx.out(), "");