Skip to content

Commit

Permalink
storage: Fix possible check failure when there are duplicated results (
Browse files Browse the repository at this point in the history
…#166)

Signed-off-by: Wish <breezewish@outlook.com>
  • Loading branch information
breezewish authored and JaySon-Huang committed Aug 6, 2024
1 parent 864106e commit 5b2725b
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ void DMFileWithVectorIndexBlockInputStream::loadVectorSearchResult()

auto perf_begin = PerfContext::vector_search;

RUNTIME_CHECK(valid_rows.size() >= dmfile->getRows(), valid_rows.size(), dmfile->getRows());
auto results_rowid = vec_index->search(ann_query_info, valid_rows);

auto discarded_nodes = PerfContext::vector_search.discarded_nodes - perf_begin.discarded_nodes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ MutableColumnPtr VectorColumnFromIndexReader::calcResultsByPack(

// results must be in ascending order.
std::sort(results.begin(), results.end());
// results must not contain duplicates. Usually there should be no duplicates.
results.erase(std::unique(results.begin(), results.end()), results.end());

std::vector<UInt32> offsets_in_pack;
size_t results_it = 0;
Expand Down
25 changes: 20 additions & 5 deletions dbms/src/Storages/DeltaMerge/Index/VectorIndexHNSW/Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,20 +155,35 @@ std::vector<VectorIndexBuilder::Key> VectorIndexHNSWViewer::search(

std::atomic<size_t> visited_nodes = 0;
std::atomic<size_t> discarded_nodes = 0;
std::atomic<bool> has_exception_in_search = false;

auto predicate = [&](typename USearchImplType::member_cref_t const & member) {
// Note: We don't increase the thread_local perf, to be compatible with future multi-thread change.
visited_nodes++;
if (!valid_rows[member.key])
discarded_nodes++;
return valid_rows[member.key];
// Must catch exceptions in the predicate, because search runs on other threads.
try
{
// Note: We don't increase the thread_local perf, because search runs on other threads.
visited_nodes++;
if (!valid_rows[member.key])
discarded_nodes++;
return valid_rows[member.key];
}
catch (...)
{
tryLogCurrentException(__PRETTY_FUNCTION__);
has_exception_in_search = true;
return false;
}
};

// TODO: Support efSearch.
auto result = index.search( //
reinterpret_cast<const Float32 *>(queryInfo->ref_vec_f32().data() + sizeof(UInt32)),
queryInfo->top_k(),
predicate);

if (has_exception_in_search)
throw Exception(ErrorCodes::INCORRECT_QUERY, "Exception happened occurred during search");

std::vector<Key> keys(result.size());
result.dump_to(keys.data());

Expand Down
67 changes: 67 additions & 0 deletions dbms/src/Storages/DeltaMerge/tests/gtest_dm_vector_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,73 @@ try
}
CATCH

TEST_P(VectorIndexDMFileTest, OnePackWithDuplicateVectors)
try
{
auto cols = DMTestEnv::getDefaultColumns(DMTestEnv::PkType::HiddenTiDBRowID, /*add_nullable*/ true);
auto vec_cd = ColumnDefine(vec_column_id, vec_column_name, tests::typeFromString("Array(Float32)"));
vec_cd.vector_index = std::make_shared<TiDB::VectorIndexDefinition>(TiDB::VectorIndexDefinition{
.kind = tipb::VectorIndexKind::HNSW,
.dimension = 3,
.distance_metric = tipb::VectorDistanceMetric::L2,
});
cols->emplace_back(vec_cd);

ColumnDefines read_cols = *cols;
if (test_only_vec_column)
read_cols = {vec_cd};

// Prepare DMFile
{
Block block = DMTestEnv::prepareSimpleWriteBlockWithNullable(0, 5);
block.insert(createVecFloat32Column<Array>(
{//
{1.0, 2.0, 3.0},
{1.0, 2.0, 3.0},
{0.0, 0.0, 0.0},
{1.0, 2.0, 3.0},
{1.0, 2.0, 3.5}},
vec_cd.name,
vec_cd.id));
auto stream = std::make_shared<DMFileBlockOutputStream>(dbContext(), dm_file, *cols);
stream->writePrefix();
stream->write(block, DMFileBlockOutputStream::BlockProperty{0, 0, 0, 0});
stream->writeSuffix();
}

dm_file = restoreDMFile();

{
auto ann_query_info = std::make_shared<tipb::ANNQueryInfo>();
ann_query_info->set_column_id(vec_cd.id);
ann_query_info->set_distance_metric(tipb::VectorDistanceMetric::L2);
ann_query_info->set_top_k(4);
ann_query_info->set_ref_vec_f32(encodeVectorFloat32({1.0, 2.0, 3.5}));

DMFileBlockInputStreamBuilder builder(dbContext());
auto stream = builder.setRSOperator(wrapWithANNQueryInfo(nullptr, ann_query_info))
.setBitmapFilter(BitmapFilterView(std::make_shared<BitmapFilter>(5, true), 0, 5))
.build2(
dm_file,
read_cols,
RowKeyRanges{RowKeyRange::newAll(false, 1)},
std::make_shared<ScanContext>());

ASSERT_INPUTSTREAM_COLS_UR(
stream,
createColumnNames(),
createColumnData({
createColumn<Int64>({0, 1, 3, 4}),
createVecFloat32Column<Array>({//
{1.0, 2.0, 3.0},
{1.0, 2.0, 3.0},
{1.0, 2.0, 3.0},
{1.0, 2.0, 3.5}}),
}));
}
}
CATCH

TEST_P(VectorIndexDMFileTest, MultiPacks)
try
{
Expand Down

0 comments on commit 5b2725b

Please sign in to comment.