Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support int8 for hnsw #65

Merged
merged 12 commits into from
Oct 23, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
modify
Signed-off-by: jinjiabao.jjb <jinjiabao.jjb@antgroup.com>
  • Loading branch information
jinjiabao.jjb committed Oct 16, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit 318c637e4be7168bb762b3661769d6f14bdda828
2 changes: 1 addition & 1 deletion src/index/hnsw.cpp
Original file line number Diff line number Diff line change
@@ -819,7 +819,7 @@ HNSW::get_vectors(const vsag::DatasetPtr& base, void*& vectors, size_t& data_siz
vectors = (void*)base->GetInt8Vectors();
data_size = dim_ * sizeof(int8_t);
} else {
throw std::invalid_argument("fail to support this metric");
throw std::invalid_argument("fail to support this data type");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the exception message need include unknown type id

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i removed this choice because the part will be completed during the parameter checking phase

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this function, it is still recommended to add a check for the parameter type_.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}
}

1 change: 1 addition & 0 deletions src/index/hnsw_zparameters.cpp
Original file line number Diff line number Diff line change
@@ -140,6 +140,7 @@ CreateFreshHnswParameters::FromJson(const std::string& json_string) {
obj.space = parrent_obj.space;
obj.use_static = false;
obj.normalize = parrent_obj.normalize;
obj.type = parrent_obj.type;

// set obj.use_reversed_edges
obj.use_reversed_edges = true;
136 changes: 73 additions & 63 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
@@ -782,8 +782,13 @@ TEST_CASE("remove vectors from the index", "[ft][index]") {
vsag::Options::Instance().logger()->SetLevel(vsag::Logger::Level::kDEBUG);
int64_t num_vectors = 1000;
int64_t dim = 64;
auto index_name = "fresh_hnsw";
auto index_name = GENERATE("fresh_hnsw", "diskann");
auto metric_type = GENERATE("cosine", "ip", "l2");

if (index_name == std::string("diskann") and metric_type == std::string("cosine")) {
return; // TODO: support cosine for diskann
}

bool need_normalize = metric_type != std::string("cosine");
auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim, need_normalize);
auto index = fixtures::generate_index(index_name, metric_type, num_vectors, dim, ids, vectors);
@@ -802,82 +807,87 @@ TEST_CASE("remove vectors from the index", "[ft][index]") {
}
)";

// remove half data
int correct = 0;
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);
if (index_name != std::string("diskann")) { // index that supports remove
// remove half data

int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
int correct = 0;
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);

int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
}
}
}
float recall_before = ((float)correct) / num_vectors;
float recall_before = ((float)correct) / num_vectors;

for (int i = 0; i < num_vectors / 2; ++i) {
auto result = index->Remove(ids[i]);
REQUIRE(result.has_value());
REQUIRE(result.value());
}
auto wrong_result = index->Remove(-1);
REQUIRE(wrong_result.has_value());
REQUIRE_FALSE(wrong_result.value());
for (int i = 0; i < num_vectors / 2; ++i) {
auto result = index->Remove(ids[i]);
REQUIRE(result.has_value());
REQUIRE(result.value());
}
auto wrong_result = index->Remove(-1);
REQUIRE(wrong_result.has_value());
REQUIRE_FALSE(wrong_result.value());

REQUIRE(index->GetNumElements() == num_vectors / 2);
REQUIRE(index->GetNumElements() == num_vectors / 2);

// test recall for half data
correct = 0;
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);
// test recall for half data
correct = 0;
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);

int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (i < num_vectors / 2) {
REQUIRE(result.value()->GetIds()[0] != ids[i]);
} else {
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (i < num_vectors / 2) {
REQUIRE(result.value()->GetIds()[0] != ids[i]);
} else {
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
}
}
}
}
float recall = ((float)correct) / (num_vectors / 2);
REQUIRE(recall >= 0.98);
float recall = ((float)correct) / (num_vectors / 2);
REQUIRE(recall >= 0.98);

// remove all data
for (int i = num_vectors / 2; i < num_vectors; ++i) {
auto result = index->Remove(i);
REQUIRE(result.has_value());
REQUIRE(result.value());
}
// remove all data
for (int i = num_vectors / 2; i < num_vectors; ++i) {
auto result = index->Remove(i);
REQUIRE(result.has_value());
REQUIRE(result.value());
}

// add data into index again
correct = 0;
auto dataset = vsag::Dataset::Make();
dataset->NumElements(num_vectors)
->Dim(dim)
->Float32Vectors(vectors.data())
->Ids(ids.data())
->Owner(false);
auto result = index->Add(dataset);
// add data into index again
correct = 0;
auto dataset = vsag::Dataset::Make();
dataset->NumElements(num_vectors)
->Dim(dim)
->Float32Vectors(vectors.data())
->Ids(ids.data())
->Owner(false);
auto result = index->Add(dataset);

for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(vectors.data() + i * dim)->Owner(false);

int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
int64_t k = 10;
auto result = index->KnnSearch(query, k, search_parameters);
REQUIRE(result.has_value());
if (result.value()->GetIds()[0] == ids[i]) {
correct += 1;
}
}
float recall_after = ((float)correct) / num_vectors;
REQUIRE(std::abs(recall_before - recall_after) < 0.001);
} else { // index that does not supports remove
REQUIRE_THROWS(index->Remove(-1));
}
float recall_after = ((float)correct) / num_vectors;
REQUIRE(abs(recall_before - recall_after) < 0.001);
}

TEST_CASE("index with bsa", "[ft][index]") {