Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support cosine in diskann #98

Merged
merged 9 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extern/diskann/DiskANN/include/disk_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,5 @@ DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std
const std::string reorder_data_file = std::string(""));
template <typename T>
void create_disk_layout(const T *data, uint32_t npts, uint32_t ndims, const std::vector<size_t>& skip_locs, std::stringstream &vamana_reader,
std::stringstream &diskann_writer, size_t sector_len, const std::string reorder_data_file);
std::stringstream &diskann_writer, size_t sector_len, diskann::Metric metric);
} // namespace diskann
3 changes: 2 additions & 1 deletion extern/diskann/DiskANN/include/in_mem_data_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace diskann
template <typename data_t> class InMemDataStore : public AbstractDataStore<data_t>
{
public:
InMemDataStore(const location_t capacity, const size_t dim, std::shared_ptr<Distance<data_t>> distance_fn);
InMemDataStore(const location_t capacity, const size_t dim, std::shared_ptr<Distance<data_t>> distance_fn, bool compute_norm = false);
virtual ~InMemDataStore();

virtual location_t load(const std::string &filename) override;
Expand Down Expand Up @@ -79,6 +79,7 @@ template <typename data_t> class InMemDataStore : public AbstractDataStore<data_
std::shared_ptr<Distance<data_t>> _distance_fn;

// in case we need to save vector norms for optimization
bool _compute_norms = false;
std::shared_ptr<float[]> _pre_computed_norms;

bool _use_data_reference = false;
Expand Down
35 changes: 7 additions & 28 deletions extern/diskann/DiskANN/src/disk_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ void create_disk_layout(const std::string base_file, const std::string mem_index

template <typename T>
void create_disk_layout(const T *data, uint32_t npts, uint32_t ndims, const std::vector<size_t>& skip_locs, std::stringstream &vamana_reader,
std::stringstream &diskann_writer, size_t sector_len, const std::string reorder_data_file)
std::stringstream &diskann_writer, size_t sector_len, diskann::Metric metric)
{
// amount to read or write in one shot
size_t read_blk_size = 64 * 1024 * 1024;
Expand All @@ -1039,29 +1039,6 @@ void create_disk_layout(const T *data, uint32_t npts, uint32_t ndims, const std:
std::ifstream reorder_data_reader;

uint32_t npts_reorder_file = 0, ndims_reorder_file = 0;
if (reorder_data_file != std::string(""))
jiaweizone marked this conversation as resolved.
Show resolved Hide resolved
{
append_reorder_data = true;
size_t reorder_data_file_size = get_file_size(reorder_data_file);
reorder_data_reader.exceptions(std::ofstream::failbit | std::ofstream::badbit);

try
{
reorder_data_reader.open(reorder_data_file, std::ios::binary);
reorder_data_reader.read((char *)&npts_reorder_file, sizeof(uint32_t));
reorder_data_reader.read((char *)&ndims_reorder_file, sizeof(uint32_t));
if (npts_reorder_file != npts)
throw ANNException("Mismatch in num_points between reorder "
"data file and base file",
-1, __FUNCSIG__, __FILE__, __LINE__);
if (reorder_data_file_size != 8 + sizeof(float) * (size_t)npts_reorder_file * (size_t)ndims_reorder_file)
throw ANNException("Discrepancy in reorder data file size ", -1, __FUNCSIG__, __FILE__, __LINE__);
}
catch (std::system_error &e)
{
throw FileException(reorder_data_file, e, __FUNCSIG__, __FILE__, __LINE__);
}
}

// create cached reader + writer
size_t actual_file_size = vamana_reader.str().size();
Expand Down Expand Up @@ -1166,7 +1143,9 @@ void create_disk_layout(const T *data, uint32_t npts, uint32_t ndims, const std:

// write coords of node first
memcpy(node_buf.get(), data + ((uint64_t) ndims_64 * cur_node_id), ndims_64 * sizeof(T));

if (diskann::Metric::COSINE == metric) {
normalize((float *)node_buf.get(), ndims);
}
// write nnbrs
*(uint32_t *)(node_buf.get() + ndims_64 * sizeof(T)) = (std::min)(nnbrs, width_u32);

Expand Down Expand Up @@ -1462,11 +1441,11 @@ template DISKANN_DLLEXPORT void create_disk_layout<float>(const std::string base


template DISKANN_DLLEXPORT void create_disk_layout<int8_t>(const int8_t *data, uint32_t npts, uint32_t ndims, const std::vector<size_t>& skip_locs, std::stringstream &vamana_reader, std::stringstream &diskann_writer,
size_t sector_len, const std::string reorder_data_file);
size_t sector_len, diskann::Metric metric);
template DISKANN_DLLEXPORT void create_disk_layout<uint8_t>(const uint8_t *data, uint32_t npts, uint32_t ndims, const std::vector<size_t>& skip_locs, std::stringstream &vamana_reader, std::stringstream &diskann_writer,
size_t sector_len, const std::string reorder_data_file);
size_t sector_len, diskann::Metric metric);
template DISKANN_DLLEXPORT void create_disk_layout<float>(const float *data, uint32_t npts, uint32_t ndims, const std::vector<size_t>& skip_locs, std::stringstream &vamana_reader, std::stringstream &diskann_writer,
size_t sector_len, const std::string reorder_data_file);
size_t sector_len, diskann::Metric metric);
template DISKANN_DLLEXPORT int8_t *load_warmup<int8_t>(const std::string &cache_warmup_file, uint64_t &warmup_num,
uint64_t warmup_dim, uint64_t warmup_aligned_dim);
template DISKANN_DLLEXPORT uint8_t *load_warmup<uint8_t>(const std::string &cache_warmup_file, uint64_t &warmup_num,
Expand Down
30 changes: 25 additions & 5 deletions extern/diskann/DiskANN/src/in_mem_data_store.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace diskann

template <typename data_t>
InMemDataStore<data_t>::InMemDataStore(const location_t num_points, const size_t dim,
std::shared_ptr<Distance<data_t>> distance_fn)
: AbstractDataStore<data_t>(num_points, dim), _distance_fn(distance_fn)
std::shared_ptr<Distance<data_t>> distance_fn, bool compute_norms)
: AbstractDataStore<data_t>(num_points, dim), _distance_fn(distance_fn), _compute_norms(compute_norms)
{
_aligned_dim = ROUND_UP(dim, _distance_fn->get_required_alignment());
}
Expand Down Expand Up @@ -201,13 +201,19 @@ template <typename data_t> void InMemDataStore<data_t>::link_data(const data_t *
}
_data = const_cast<data_t *>(vectors);
_loc_to_memory_index.resize(num_pts);
if (_compute_norms) {
_pre_computed_norms.reset(new float[num_pts]);
}
_use_data_reference = true;
this->_capacity = num_pts;
int64_t cur = 0;
for (location_t i = 0; i < mask.size(); i++)
{
if (!mask.test(i)) continue;
_loc_to_memory_index[cur] = i;
if (_compute_norms) {
_pre_computed_norms[cur] = get_norm(vectors + i * this->_dim, this->_dim);
}
cur ++;
}
assert(num_pts == cur);
Expand Down Expand Up @@ -258,6 +264,12 @@ template <typename data_t> void InMemDataStore<data_t>::get_vector(location_t lo
{
loc = _loc_to_memory_index[loc];
memcpy(dest, _data + loc * this->_dim, this->_dim * sizeof(data_t));
if (_compute_norms) {
for (int i = 0; i < this->_dim; ++i)
{
dest[i] /= _pre_computed_norms[loc];
}
}
return;
}
memcpy(dest, _data + loc * _aligned_dim, this->_dim * sizeof(data_t));
Expand Down Expand Up @@ -290,7 +302,11 @@ template <typename data_t> float InMemDataStore<data_t>::get_distance(const data
if (_use_data_reference)
{
loc = _loc_to_memory_index[loc];
return _distance_fn->compare(query, _data + this->_dim * loc, (uint32_t)this->_dim);
auto distance = _distance_fn->compare(query, _data + this->_dim * loc, (uint32_t)this->_dim);
if (_compute_norms) {
distance /= _pre_computed_norms[loc];
}
return distance;
}
return _distance_fn->compare(query, _data + _aligned_dim * loc, (uint32_t)_aligned_dim);
}
Expand All @@ -303,8 +319,12 @@ float InMemDataStore<data_t>::get_distance(location_t loc1, location_t loc2) con
{
loc1 = _loc_to_memory_index[loc1];
loc2 = _loc_to_memory_index[loc2];
return _distance_fn->compare(_data + loc1 * this->_dim, _data + loc2 * this->_dim,
(uint32_t)this->_dim);
auto distance = _distance_fn->compare(_data + loc1 * this->_dim, _data + loc2 * this->_dim,
(uint32_t)this->_dim);
if (_compute_norms) {
distance /= (_pre_computed_norms[loc1] * _pre_computed_norms[loc2]);
}
return distance;
}
return _distance_fn->compare(_data + loc1 * _aligned_dim, _data + loc2 * _aligned_dim,
(uint32_t)this->_aligned_dim);
Expand Down
8 changes: 3 additions & 5 deletions extern/diskann/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,17 @@ Index<T, TagT, LabelT>::Index(Metric m, const size_t dim, const size_t max_point
else if (m == diskann::Metric::COSINE && std::is_floating_point<T>::value)
{
// This is safe because T is float inside the if block.
this->_distance.reset((Distance<T> *)new AVXNormalizedCosineDistanceFloat());
this->_distance.reset((Distance<T> *)new VsagDistanceInnerProductFloat(dim));
this->_normalize_vecs = true;
diskann::cout << "Normalizing vectors and using L2 for cosine "
"AVXNormalizedCosineDistanceFloat()."
<< std::endl;
}
else
{
this->_distance.reset((Distance<T> *)get_distance_function<T>(m));
}
// Note: moved this to factory, keeping this for backward compatibility.
_data_store =
std::make_unique<diskann::InMemDataStore<T>>((location_t)total_internal_points, _dim, this->_distance);
std::make_unique<diskann::InMemDataStore<T>>((location_t)total_internal_points, _dim,
this->_distance, this->_normalize_vecs);
}

_locks = std::vector<non_recursive_mutex>(total_internal_points);
Expand Down
18 changes: 16 additions & 2 deletions extern/diskann/DiskANN/src/pq.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// block size for reading/processing large files and matrices in blocks
#define BLOCK_SIZE 5000000
#define MIN_SAMPLE_NUM 1000
#define MAX_SAMPLE_NUM 131072
namespace vsag {

extern PQDistanceFunc
Expand Down Expand Up @@ -1904,17 +1905,30 @@ void generate_disk_quantized_data(const T* train_data, size_t train_size, size_t
// instantiates train_data with random sample updates train_size
size_t sample_size = std::min(train_size, (size_t)(train_size * p_val));
sample_size = std::max(sample_size, std::min(train_size, (size_t)MIN_SAMPLE_NUM));
sample_size = std::min(sample_size, (size_t)MAX_SAMPLE_NUM);
auto sample_data = train_data;
std::shared_ptr<T[]> new_train_data;
if (compare_metric == diskann::Metric::COSINE) {
new_train_data.reset(new T[train_dim * sample_size]);
memcpy(new_train_data.get(), train_data, train_dim * sample_size * sizeof(T));
for (int i = 0; i < sample_size; ++i)
{
normalize(new_train_data.get() + i * train_dim, train_dim);
}
sample_data = new_train_data.get();
}

// diskann::cout << "Training data with " << sample_size << " samples loaded." << std::endl;
if (disk_pq_dims > train_dim)
disk_pq_dims = train_dim;

// diskann::cout << "Compressing base for disk-PQ into " << disk_pq_dims << " chunks " << std::endl;
std::shared_ptr<float[]> rotate;
if (use_opq) {
generate_opq_pivots((const float*)train_data, sample_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims,
generate_opq_pivots((const float*)sample_data, sample_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims,
disk_pq_pivots, rotate, false);
} else {
generate_pq_pivots((const float*)train_data, sample_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ,
generate_pq_pivots((const float*)sample_data, sample_size, (uint32_t)train_dim, 256, (uint32_t)disk_pq_dims, NUM_KMEANS_REPS_PQ,
disk_pq_pivots, false);
}

Expand Down
17 changes: 14 additions & 3 deletions extern/diskann/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,9 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint6
{
aligned_query_T[i] = (float) query1[i];
}
if (diskann::Metric::COSINE == metric) {
normalize(aligned_query_T.get(), this->data_dim);
}

// FIXME: alternative instruction on aarch64
#if defined(__i386__) || defined(__x86_64__)
Expand Down Expand Up @@ -1681,7 +1684,7 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search(const T *query1, const uint6
if (distances != nullptr)
{
distances[result_size] = full_retset[i].distance;
if (metric == diskann::Metric::INNER_PRODUCT)
if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE)
{
// When using L2 distance to calculate IP distance, the L2
// distance is exactly twice the IP distance.
Expand Down Expand Up @@ -1762,6 +1765,10 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search_memory(const T *query, const
aligned_query_T[i] = (float) query[i];
}

if (diskann::Metric::COSINE == metric) {
normalize(aligned_query_T.get(), this->data_dim);
}

// FIXME: alternative instruction on aarch64
#if defined(__i386__) || defined(__x86_64__)
_mm_prefetch((char *)aligned_query_T.get(), _MM_HINT_T1);
Expand Down Expand Up @@ -1888,7 +1895,7 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search_memory(const T *query, const
char *node_disk_buf = OFFSET_TO_NODE(sorted_read_reqs[j].buf, id);
T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf);
float exact_dist;
exact_dist = dist_cmp_float->compare((float *)query, (float *)node_fp_coords, (uint32_t)data_dim);
exact_dist = dist_cmp_float->compare(aligned_query_T.get(), (float *)node_fp_coords, (uint32_t)data_dim);
reorder_retset.push_back(Neighbor(id, exact_dist));
distance_ranks.push(exact_dist);
if (distance_ranks.size() > k_search) {
Expand All @@ -1914,7 +1921,7 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search_memory(const T *query, const
if (distances != nullptr)
{
distances[result_size] = full_retset[i].distance;
if (metric == diskann::Metric::INNER_PRODUCT)
if (metric == diskann::Metric::INNER_PRODUCT || metric == diskann::Metric::COSINE)
{
// When using L2 distance to calculate IP distance, the L2
// distance is exactly twice the IP distance.
Expand Down Expand Up @@ -1944,6 +1951,10 @@ int64_t PQFlashIndex<T, LabelT>::cached_beam_search_async(const T *query, const
aligned_query_T[i] = (float) query[i];
}

if (diskann::Metric::COSINE == metric) {
normalize(aligned_query_T.get(), this->data_dim);
}

// FIXME: alternative instruction on aarch64
#if defined(__i386__) || defined(__x86_64__)
_mm_prefetch((char *)aligned_query_T.get(), _MM_HINT_T1);
Expand Down
4 changes: 2 additions & 2 deletions src/index/diskann.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ DiskANN::build(const DatasetPtr& base) {
graph_stream_,
disk_layout_stream_,
sector_len_,
"");
metric_);
}

std::vector<int64_t> failed_ids;
Expand Down Expand Up @@ -932,7 +932,7 @@ DiskANN::continue_build(const DatasetPtr& base, const BinarySet& binary_set) {
graph_stream_,
disk_layout_stream_,
sector_len_,
"");
metric_);
load_disk_index(binary_set);
build_status = BuildStatus::FINISH;
status_ = IndexStatus::MEMORY;
Expand Down
5 changes: 4 additions & 1 deletion src/index/diskann_zparameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ CreateDiskannParameters::FromJson(const std::string& json_string) {
obj.metric = diskann::Metric::L2;
} else if (index_common_param.metric_ == MetricType::METRIC_TYPE_IP) {
obj.metric = diskann::Metric::INNER_PRODUCT;
} else if (params[PARAMETER_METRIC_TYPE] == METRIC_COSINE) {
obj.metric = diskann::Metric::COSINE;
} else {
std::string metric = params[PARAMETER_METRIC_TYPE];
throw std::invalid_argument(fmt::format("parameters[{}] must in [{}, {}], now is {}",
throw std::invalid_argument(fmt::format("parameters[{}] must in [{}, {}, {}], now is {}",
PARAMETER_METRIC_TYPE,
METRIC_L2,
METRIC_IP,
METRIC_COSINE,
metric));
}

Expand Down
15 changes: 0 additions & 15 deletions tests/test_index_old.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ TEST_CASE("index search distance", "[ft][index]") {
auto metric_type = GENERATE("ip", "cosine", "l2");
auto algorithm = GENERATE("hnsw", "diskann");

if (algorithm == std::string("diskann") and metric_type == std::string("cosine")) {
return; // TODO: support cosine for diskann
}

bool need_normalize = metric_type != std::string("cosine");
auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim, need_normalize);

Expand Down Expand Up @@ -368,10 +364,6 @@ TEST_CASE("serialize/deserialize with file stream", "[ft][index]") {
// deserialize from file stream
std::filesystem::resize_file(dir.path + "index.bin", size - 10);
std::fstream in_file(dir.path + "index.bin", std::ios::in | std::ios::binary);

if (metric_type == std::string("cosine")) {
return;
}
auto new_index =
vsag::Factory::CreateIndex(
"diskann",
Expand Down Expand Up @@ -484,9 +476,6 @@ TEST_CASE("search on a deserialized empty index", "[ft][index]") {
auto index_name = GENERATE("hnsw", "diskann");
auto metric_type = GENERATE("l2", "ip", "cosine");

if (index_name == std::string("diskann") and metric_type == std::string("cosine")) {
return; // TODO: support cosine for diskann
}
auto index =
vsag::Factory::CreateIndex(
index_name, vsag::generate_build_parameters(metric_type, num_vectors, dim).value())
Expand Down Expand Up @@ -547,10 +536,6 @@ TEST_CASE("remove vectors from the index", "[ft][index]") {
auto index_name = GENERATE("fresh_hnsw", "diskann");
auto metric_type = GENERATE("cosine", "ip", "l2");

if (index_name == std::string("diskann") and metric_type == std::string("cosine")) {
return; // TODO: support cosine for diskann
}

bool need_normalize = metric_type != std::string("cosine");
auto [ids, vectors] = fixtures::generate_ids_and_vectors(num_vectors, dim, need_normalize);
auto index = fixtures::generate_index(index_name, metric_type, num_vectors, dim, ids, vectors);
Expand Down