Skip to content

Commit

Permalink
Merge pull request #166 from nmslib/develop
Browse files Browse the repository at this point in the history
Merge develop into master
  • Loading branch information
yurymalkov authored Nov 11, 2019
2 parents bbddf19 + 38482db commit 0dcfb91
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 15 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ Index methods:

* `get_ids_list()` - returns a list of all elements' ids.

* `get_max_elements()` - returns the current capacity of the index

* `get_current_count()` - returns the current number of element stored in the index


Expand Down
29 changes: 27 additions & 2 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <unordered_map>
#include <fstream>
#include <mutex>
#include <algorithm>

namespace hnswlib {
template<typename dist_t>
Expand All @@ -21,6 +22,8 @@ namespace hnswlib {
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxElements * size_per_element_);
if (data_ == nullptr)
std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
cur_element_count = 0;
}

Expand All @@ -40,7 +43,7 @@ namespace hnswlib {

std::unordered_map<labeltype,size_t > dict_external_to_internal;

void addPoint(void *datapoint, labeltype label) {
void addPoint(const void *datapoint, labeltype label) {

int idx;
{
Expand Down Expand Up @@ -84,8 +87,10 @@ namespace hnswlib {
}


std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
Expand All @@ -106,6 +111,24 @@ namespace hnswlib {
return topResults;
};

template <typename Comp>
std::vector<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, Comp comp) {
std::vector<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;

auto ret = searchKnn(query_data, k);

while (!ret.empty()) {
result.push_back(ret.top());
ret.pop();
}

std::sort(result.begin(), result.end(), comp);

return result;
}

void saveIndex(const std::string &location) {
std::ofstream output(location, std::ios::binary);
std::streampos position;
Expand Down Expand Up @@ -134,6 +157,8 @@ namespace hnswlib {
dist_func_param_ = s->get_dist_func_param();
size_per_element_ = data_size_ + sizeof(labeltype);
data_ = (char *) malloc(maxelements_ * size_per_element_);
if (data_ == nullptr)
std::runtime_error("Not enough memory: loadIndex failed to allocate data");

input.read(data_, maxelements_ * size_per_element_);

Expand Down
55 changes: 45 additions & 10 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ namespace hnswlib {
maxlevel_ = -1;

linkLists_ = (char **) malloc(sizeof(void *) * max_elements_);
if (linkLists_ == nullptr)
throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists");
size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint);
mult_ = 1 / log(1.0 * M_);
revSize_ = 1.0 / mult_;
Expand Down Expand Up @@ -150,7 +152,7 @@ namespace hnswlib {
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayer(tableint ep_id, void *data_point, int layer) {
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand Down Expand Up @@ -371,7 +373,7 @@ namespace hnswlib {
return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_);
};

void mutuallyConnectNewElement(void *data_point, tableint cur_c,
void mutuallyConnectNewElement(const void *data_point, tableint cur_c,
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates,
int level) {

Expand Down Expand Up @@ -484,6 +486,8 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, tableint>> searchKnnInternal(void *query_data, int k) {
std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
if (cur_element_count == 0) return top_candidates;
tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

Expand All @@ -510,8 +514,6 @@ namespace hnswlib {
}
}


std::priority_queue<std::pair<dist_t, tableint >> top_candidates;
if (has_deletions_) {
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
ef_);
Expand Down Expand Up @@ -546,12 +548,16 @@ namespace hnswlib {

// Reallocate base layer
char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_);
if (data_level0_memory_new == nullptr)
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer");
memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_);
free(data_level0_memory_);
data_level0_memory_=data_level0_memory_new;

// Reallocate all other layers
char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements);
if (linkLists_new == nullptr)
throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers");
memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *));
free(linkLists_);
linkLists_=linkLists_new;
Expand Down Expand Up @@ -659,6 +665,8 @@ namespace hnswlib {


data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_);
if (data_level0_memory_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0");
input.read(data_level0_memory_, cur_element_count * size_data_per_element_);


Expand All @@ -675,6 +683,8 @@ namespace hnswlib {


linkLists_ = (char **) malloc(sizeof(void *) * max_elements);
if (linkLists_ == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists");
element_levels_ = std::vector<int>(max_elements);
revSize_ = 1.0 / mult_;
ef_ = 10;
Expand All @@ -689,6 +699,8 @@ namespace hnswlib {
} else {
element_levels_[i] = linkListSize / size_links_per_element_;
linkLists_[i] = (char *) malloc(linkListSize);
if (linkLists_[i] == nullptr)
throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist");
input.read(linkLists_[i], linkListSize);
}
}
Expand Down Expand Up @@ -779,11 +791,11 @@ namespace hnswlib {
*((unsigned short int*)(ptr))=*((unsigned short int *)&size);
}

void addPoint(void *data_point, labeltype label) {
void addPoint(const void *data_point, labeltype label) {
addPoint(data_point, label,-1);
}

tableint addPoint(void *data_point, labeltype label, int level) {
tableint addPoint(const void *data_point, labeltype label, int level) {
tableint cur_c = 0;
{
std::unique_lock <std::mutex> lock(cur_element_count_guard_);
Expand All @@ -797,6 +809,7 @@ namespace hnswlib {
auto search = label_lookup_.find(label);
if (search != label_lookup_.end()) {
std::unique_lock <std::mutex> lock_el(link_list_locks_[search->second]);
has_deletions_ = true;
markDeletedInternal(search->second);
}
label_lookup_[label] = cur_c;
Expand Down Expand Up @@ -827,6 +840,8 @@ namespace hnswlib {

if (curlevel) {
linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1);
if (linkLists_[cur_c] == nullptr)
throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist");
memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1);
}

Expand Down Expand Up @@ -895,7 +910,11 @@ namespace hnswlib {
return cur_c;
};

std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *query_data, size_t k) const {
std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

tableint currObj = enterpoint_node_;
dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_);

Expand Down Expand Up @@ -934,18 +953,34 @@ namespace hnswlib {
currObj, query_data, std::max(ef_, k));
top_candidates.swap(top_candidates1);
}
std::priority_queue<std::pair<dist_t, labeltype >> results;
while (top_candidates.size() > k) {
top_candidates.pop();
}
while (top_candidates.size() > 0) {
std::pair<dist_t, tableint> rez = top_candidates.top();
results.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
result.push(std::pair<dist_t, labeltype>(rez.first, getExternalLabel(rez.second)));
top_candidates.pop();
}
return results;
return result;
};

template <typename Comp>
std::vector<std::pair<dist_t, labeltype>>
searchKnn(const void* query_data, size_t k, Comp comp) {
std::vector<std::pair<dist_t, labeltype>> result;
if (cur_element_count == 0) return result;

auto ret = searchKnn(query_data, k);

while (!ret.empty()) {
result.push_back(ret.top());
ret.pop();
}

std::sort(result.begin(), result.end(), comp);

return result;
}

};

Expand Down
14 changes: 13 additions & 1 deletion hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,21 @@
#endif

#include <queue>
#include <vector>

#include <string.h>

namespace hnswlib {
typedef size_t labeltype;

template <typename T>
class pairGreater {
public:
bool operator()(const T& p1, const T& p2) {
return p1.first > p2.first;
}
};

template<typename T>
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
out.write((char *) &podRef, sizeof(T));
Expand Down Expand Up @@ -60,8 +69,11 @@ namespace hnswlib {
template<typename dist_t>
class AlgorithmInterface {
public:
virtual void addPoint(void *datapoint, labeltype label)=0;
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;
template <typename Comp>
std::vector<std::pair<dist_t, labeltype>> searchKnn(const void*, size_t, Comp) {
}
virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
}
Expand Down
10 changes: 10 additions & 0 deletions python_bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,14 @@ class Index {
appr_alg->resizeIndex(new_size);
}

size_t getMaxElements() const {
return appr_alg->max_elements_;
}

size_t getCurrentCount() const {
return appr_alg->cur_element_count;
}

std::string space_name;
int dim;

Expand Down Expand Up @@ -397,6 +405,8 @@ PYBIND11_PLUGIN(hnswlib) {
.def("load_index", &Index<float>::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0)
.def("mark_deleted", &Index<float>::markDeleted, py::arg("label"))
.def("resize_index", &Index<float>::resizeIndex, py::arg("new_size"))
.def("get_max_elements", &Index<float>::getMaxElements)
.def("get_current_count", &Index<float>::getCurrentCount)
.def("__repr__",
[](const Index<float> &a) {
return "<HNSW-lib index>";
Expand Down
4 changes: 2 additions & 2 deletions python_bindings/tests/bindings_test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def testRandomSelf(self):
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))),1.0,3)

# Check that the returned element data is correct:
diff_with_gt_labels=np.max(np.abs(data1-items))
diff_with_gt_labels=np.mean(np.abs(data1-items))
self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-4)

# Serializing and deleting the index.
Expand Down Expand Up @@ -83,7 +83,7 @@ def testRandomSelf(self):
self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))),1.0,3)

# Check that the returned element data is correct:
diff_with_gt_labels=np.max(np.abs(data-items))
diff_with_gt_labels=np.mean(np.abs(data-items))
self.assertAlmostEqual(diff_with_gt_labels, 0, delta = 1e-4) # deleting index.

# Checking that all labels are returned correctly:
Expand Down

0 comments on commit 0dcfb91

Please sign in to comment.