diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java index dde9ce76ac3d..f88eabe0fcac 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractVectorSimilarityQuery.java @@ -25,6 +25,7 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; @@ -248,9 +249,12 @@ public int nextDoc() { @Override public int advance(int target) { + assert index >= -1 : "index must >= -1 but got " + index; index = - Arrays.binarySearch( + ArrayUtil.exponentialSearch( scoreDocs, + Math.min(index + 1, scoreDocs.length), + scoreDocs.length, new ScoreDoc(target, 0), Comparator.comparingInt(scoreDoc -> scoreDoc.doc)); if (index < 0) { diff --git a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java index 2cd3cb63cfe5..ca88e275593c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java @@ -801,4 +801,44 @@ public static int compareUnsigned4(byte[] a, int aOffset, byte[] b, int bOffset) return Integer.compareUnsigned( (int) BitUtil.VH_BE_INT.get(a, aOffset), (int) BitUtil.VH_BE_INT.get(b, bOffset)); } + + /** + * Run an exponential search for the target in an array + * + * @param arr the array + * @param fromIndex the start index of the search (inclusive) + * @param toIndex the end index of the search (exclusive) + * @param target the target to search for + * @return index of the search key, if it is contained in the array; otherwise, (-(insertion + * point) - 1) + */ + public static int exponentialSearch(int[] arr, int fromIndex, int toIndex, int target) { + int bound = 1; + while (fromIndex + bound < toIndex && arr[fromIndex + bound] < target) { + bound *= 2; + } + return Arrays.binarySearch( + arr, fromIndex + bound / 2, Math.min(fromIndex + bound + 1, toIndex), target); + } + + /** + * Run an exponential search for the target in an array + * + * @param arr the array + * @param fromIndex the start index of the search (inclusive) + * @param toIndex the end index of the search (exclusive) + * @param target the target to search for + * @param comp the comparator + * @return index of the search key, if it is contained in the array; otherwise, (-(insertion + * point) - 1) + */ + public static int exponentialSearch( + T[] arr, int fromIndex, int toIndex, T target, Comparator comp) { + int bound = 1; + while (fromIndex + bound < toIndex && comp.compare(arr[fromIndex + bound], target) < 0) { + bound *= 2; + } + return Arrays.binarySearch( + arr, fromIndex + bound / 2, Math.min(fromIndex + bound + 1, toIndex), target, comp); + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java index d44cc7839233..c1ac33d08b88 100644 --- a/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java +++ b/lucene/core/src/java/org/apache/lucene/util/IntArrayDocIdSet.java @@ -83,12 +83,7 @@ public int nextDoc() throws IOException { @Override public int advance(int target) throws IOException { - int bound = 1; - // given that we use this for small arrays only, this is very unlikely to overflow - while (i + bound < length && docs[i + bound] < target) { - bound *= 2; - } - i = Arrays.binarySearch(docs, i + bound / 2, Math.min(i + bound + 1, length), target); + i = ArrayUtil.exponentialSearch(docs, i, length, target); if (i < 0) { i = -1 - i; } diff --git a/lucene/core/src/test/org/apache/lucene/util/TestArrayUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestArrayUtil.java index 41f320ba0071..e0e17de10bcd 100644 --- a/lucene/core/src/test/org/apache/lucene/util/TestArrayUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/TestArrayUtil.java @@ -508,4 +508,58 @@ public void testCompareUnsigned8() { assertEquals(0, ArrayUtil.compareUnsigned8(a, aOffset, b, bOffset)); } + + public void testExponentialSearchForIntArray() { + final Random rnd = random(); + final int[] arr = new int[rnd.nextInt(2000) + 1]; + int last = 0; + for (int i = 0; i < arr.length; i++) { + arr[i] = last; + last += rnd.nextInt(1, 10); + } + + // case 1: random number, may not be in the array + int target = random().nextInt(arr[arr.length - 1]); + int idx = ArrayUtil.exponentialSearch(arr, 0, arr.length, target); + assertEquals(Arrays.binarySearch(arr, 0, arr.length, target), idx); + + // case 2: search for a number in the array + assertExponentialSearch(arr, random().nextInt(arr.length)); + assertExponentialSearch(arr, 0); + assertExponentialSearch(arr, arr.length - 1); + } + + private static void assertExponentialSearch(int[] arr, int expectedIndex) { + int idx = ArrayUtil.exponentialSearch(arr, 0, arr.length, arr[expectedIndex]); + assertEquals(expectedIndex, idx); + } + + public void testExponentialSearchForObjectArray() { + final Random rnd = random(); + final Integer[] arr = new Integer[rnd.nextInt(2000) + 1]; + int last = 0; + for (int i = 0; i < arr.length; i++) { + arr[i] = last; + last += rnd.nextInt(1, 10); + } + + // case 1: random number, may not be in the array + int target = random().nextInt(arr[arr.length - 1]); + int idx = + ArrayUtil.exponentialSearch( + arr, 0, arr.length, target, Comparator.comparingInt(Integer::intValue)); + assertEquals(Arrays.binarySearch(arr, 0, arr.length, target), idx); + + // case 2: search for a number in the array + assertExponentialSearch(arr, random().nextInt(arr.length)); + assertExponentialSearch(arr, 0); + assertExponentialSearch(arr, arr.length - 1); + } + + private static void assertExponentialSearch(Integer[] arr, int expectedIndex) { + int idx = + ArrayUtil.exponentialSearch( + arr, 0, arr.length, arr[expectedIndex], Comparator.comparingInt(Integer::intValue)); + assertEquals(expectedIndex, idx); + } }