From 604475023300e3947b28de5cb514d1c913313e29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 13 Dec 2023 17:23:20 +0100 Subject: [PATCH 1/3] Fix the angular distance --- src/distance/angular.rs | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/distance/angular.rs b/src/distance/angular.rs index 1616a7f7..c610b5bd 100644 --- a/src/distance/angular.rs +++ b/src/distance/angular.rs @@ -34,19 +34,25 @@ impl Distance for Angular { } fn built_distance(p: &Leaf, q: &Leaf) -> f32 { - let pp = p.header.norm; - let qq = q.header.norm; + let pn = p.header.norm; + let qn = q.header.norm; let pq = dot_product(&p.vector, &q.vector); - let ppqq = pp * qq; - if ppqq >= f32::MIN_POSITIVE { - 2.0 - 2.0 * pq / ppqq.sqrt() - } else { - 2.0 // cos is 0 - } + let pnqn = pn * qn; + let cos = pq / pnqn; + + // cos is [-1; 1] + // cos = 0. -> 0.5 + // cos = -1. -> 1.0 + // cos = 1. -> 0.0 + (1.0 - cos) / 2.0 + } + + fn normalized_distance(d: f32) -> f32 { + d } fn init(node: &mut Leaf) { - node.header.norm = dot_product(&node.vector, &node.vector); + node.header.norm = dot_product(&node.vector, &node.vector).sqrt(); } fn create_split( From 7cada5918bdb5cac8c1192e6bca776c7c077adef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 13 Dec 2023 17:28:39 +0100 Subject: [PATCH 2/3] Fix the windows CI --- .github/workflows/rust.yml | 2 +- src/parallel.rs | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index e04bae8d..02ea113f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -11,7 +11,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-latest, macos-latest-xlarge] + os: [ubuntu-latest, macos-latest-xlarge, windows-latest] rust: - stable - beta diff --git a/src/parallel.rs b/src/parallel.rs index 3b722073..c6698688 100644 --- a/src/parallel.rs +++ b/src/parallel.rs @@ -7,7 +7,7 @@ use std::sync::atomic::{AtomicU64, Ordering}; use heed::types::Bytes; use heed::{BytesDecode, BytesEncode, RoTxn}; -use memmap2::{Advice, Mmap}; +use memmap2::Mmap; use rand::seq::index; use rand::Rng; use roaring::RoaringBitmap; @@ -64,7 +64,8 @@ impl<'a, DE: BytesEncode<'a>> TmpNodes { pub fn into_bytes_reader(self) -> Result { let file = self.file.into_inner().map_err(|iie| iie.into_error())?; let mmap = unsafe { Mmap::map(&file)? }; - mmap.advise(Advice::Sequential)?; + #[cfg(unix)] + mmap.advise(memmap2::Advice::Sequential)?; Ok(TmpNodesReader { mmap, ids: self.ids, bounds: self.bounds }) } } From a73cf33a3f1dfd4043862345474df6bdd7309c83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 13 Dec 2023 17:56:15 +0100 Subject: [PATCH 3/3] Make sure not to divide by zero --- src/distance/angular.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/distance/angular.rs b/src/distance/angular.rs index c610b5bd..b0d93d8f 100644 --- a/src/distance/angular.rs +++ b/src/distance/angular.rs @@ -38,13 +38,16 @@ impl Distance for Angular { let qn = q.header.norm; let pq = dot_product(&p.vector, &q.vector); let pnqn = pn * qn; - let cos = pq / pnqn; - - // cos is [-1; 1] - // cos = 0. -> 0.5 - // cos = -1. -> 1.0 - // cos = 1. -> 0.0 - (1.0 - cos) / 2.0 + if pnqn != 0.0 { + let cos = pq / pnqn; + // cos is [-1; 1] + // cos = 0. -> 0.5 + // cos = -1. -> 1.0 + // cos = 1. -> 0.0 + (1.0 - cos) / 2.0 + } else { + 0.0 + } } fn normalized_distance(d: f32) -> f32 {