diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index b5506e3e519e..19d2288fe8f3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.Bag; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -122,26 +123,24 @@ public String toString() { private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { private final FloatVectorValues vectorValues; private final VectorSimilarityFunction similarityFunction; - private final FloatVectorValues.Floats queryVectors; - private final FloatVectorValues.Floats targetVectors; + private final Bag pool = new Bag<>(); private FloatScoringSupplier( FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { this.vectorValues = vectorValues; this.similarityFunction = similarityFunction; - this.queryVectors = vectorValues.vectors(); - this.targetVectors = vectorValues.vectors(); } @Override - public RandomVectorScorer scorer(int ord) { - return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) { - @Override - public float score(int node) throws IOException { - return similarityFunction.compare(queryVectors.get(ord), targetVectors.get(node)); - } - }; + public RandomVectorScorer scorer(int ord) throws IOException { + FloatVectorScorer scorer = (FloatVectorScorer) pool.poll(); + if (scorer != null) { + scorer.setQuery(ord); + } else { + scorer = new FloatVectorScorer(vectorValues, ord, similarityFunction, pool); + } + return scorer; } @Override @@ -152,22 +151,40 @@ public String toString() { /** A {@link RandomVectorScorer} for float vectors. */ private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final float[] query; + private final FloatVectorValues.Floats vectors, queryVectors; private final VectorSimilarityFunction similarityFunction; - private final FloatVectorValues.Floats targetVectors; + private float[] query; + + FloatVectorScorer( + FloatVectorValues vectorValues, + int ord, + VectorSimilarityFunction similarityFunction, + Bag pool) + throws IOException { + super(vectorValues, pool); + this.similarityFunction = similarityFunction; + vectors = vectorValues.vectors(); + queryVectors = vectorValues.vectors(); + query = queryVectors.get(ord); + } public FloatVectorScorer( FloatVectorValues vectorValues, float[] query, VectorSimilarityFunction similarityFunction) throws IOException { super(vectorValues); - this.query = query; this.similarityFunction = similarityFunction; - this.targetVectors = vectorValues.vectors(); + vectors = vectorValues.vectors(); + queryVectors = null; + this.query = query; + } + + private void setQuery(int ord) throws IOException { + query = queryVectors.get(ord); } @Override public float score(int node) throws IOException { - return similarityFunction.compare(query, targetVectors.get(node)); + return similarityFunction.compare(query, vectors.get(node)); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 736a552ff04e..c7af020dc90b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -28,12 +28,13 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.Bag; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice { - + private final Bag pool = new Bag<>(); protected final int dimension; protected final int size; protected final IndexInput slice; @@ -73,6 +74,10 @@ public IndexInput getSlice() { @Override public Floats vectors() { + Floats floats = pool.poll(); + if (floats != null) { + return floats; + } IndexInput sliceCopy = slice.clone(); float[] value = new float[dimension]; return new Floats() { @@ -88,6 +93,11 @@ public float[] get(int targetOrd) throws IOException { lastOrd = targetOrd; return value; } + + @Override + public void close() throws IOException { + pool.offer(this); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 86225d69e557..34f6a2b4f9eb 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -428,6 +428,11 @@ public Floats vectors() throws IOException { public float[] get(int ord) throws IOException { return rawVectors.get(ord); } + + @Override + public void close() throws IOException { + rawVectors.close(); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index fc401c935292..209218707d38 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -1216,11 +1216,9 @@ public DocIndexIterator iterator() throws IOException { static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues vectorValues; - private final Floats floats; public NormalizedFloatVectorValues(FloatVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; - floats = vectorValues.vectors(); } @Override @@ -1239,15 +1237,22 @@ public int ordToDoc(int ord) { } @Override - public Floats vectors() { + public Floats vectors() throws IOException { float[] normalizedVector = new float[vectorValues.dimension()]; return new Floats() { + Floats delegate = vectorValues.vectors(); + @Override public float[] get(int ord) throws IOException { - System.arraycopy(floats.get(ord), 0, normalizedVector, 0, normalizedVector.length); + System.arraycopy(delegate.get(ord), 0, normalizedVector, 0, normalizedVector.length); VectorUtil.l2normalize(normalizedVector); return normalizedVector; } + + @Override + public void close() throws IOException { + delegate.close(); + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 6a3ebe4a1afc..38fd4a2d8cf2 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -33,7 +33,7 @@ public abstract class FloatVectorValues extends KnnVectorValues { protected FloatVectorValues() {} /** A random access (lookup by ord) provider of the vector values */ - public abstract static class Floats { + public abstract static class Floats implements AutoCloseable { /** * Return the vector value for the given vector ordinal which must be in [0, size() - 1], * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. @@ -42,6 +42,11 @@ public abstract static class Floats { */ public abstract float[] get(int ord) throws IOException; + @Override + public void close() throws IOException { + // by default do nothing. Some implementations do more interesting resource management. + } + /** A Floats containing no vectors. Throws UnsupportedOperationException if get() is called. */ public static final Floats EMPTY = new Floats() { @@ -118,6 +123,9 @@ public Floats vectors() { public float[] get(int ord) throws IOException { return vectors.get(ord); } + + @Override + public void close() {} }; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/Bag.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/Bag.java new file mode 100644 index 000000000000..df322093f532 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/Bag.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.util.hnsw; + +/** + * A collection of objects that is threadsafe, providing offer(T) that tries to add an element and + * poll() that removes and returns an element or null. The storage will never grow. There are no + * guarantees about which object will be returned from poll(), just that it will be one that was + * added by offer(). + */ +public class Bag { + private static final int DEFAULT_CAPACITY = 64; + + private final Object[] elements; + private int writeTo; + private int readFrom; + + public Bag() { + this(DEFAULT_CAPACITY); + } + + public Bag(int capacity) { + elements = new Object[capacity]; + } + + public synchronized boolean offer(T element) { + if (full()) { + return false; + } + elements[writeTo] = element; + writeTo = (writeTo + 1) % elements.length; + return true; + } + + @SuppressWarnings("unchecked") + public synchronized T poll() { + if (empty()) { + return null; + } + T result = (T) elements[readFrom]; + readFrom = (readFrom + 1) % elements.length; + return result; + } + + private boolean full() { + int headroom = readFrom - 1 - writeTo; + if (headroom < 0) { + headroom += elements.length; + } + return headroom == 0; + } + + private boolean empty() { + return readFrom == writeTo; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index bed1480e9262..0cbadeaa7a56 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -233,74 +233,75 @@ to the newly introduced levels (repeating step 2,3 for new levels) and again try if (frozen) { throw new IllegalStateException("Graph builder is already frozen"); } - RandomVectorScorer scorer = scorerSupplier.scorer(node); - final int nodeLevel = getRandomGraphLevel(ml, random); - // first add nodes to all levels - for (int level = nodeLevel; level >= 0; level--) { - hnsw.addNode(level, node); - } - // then promote itself as entry node if entry node is not set - if (hnsw.trySetNewEntryNode(node, nodeLevel)) { - return; - } - // if the entry node is already set, then we have to do all connections first before we can - // promote ourselves as entry node - - int lowestUnsetLevel = 0; - int curMaxLevel; - do { - curMaxLevel = hnsw.numLevels() - 1; - // NOTE: the entry node and max level may not be paired, but because we get the level first - // we ensure that the entry node we get later will always exist on the curMaxLevel - int[] eps = new int[] {hnsw.entryNode()}; - - // we first do the search from top to bottom - // for levels > nodeLevel search with topk = 1 - GraphBuilderKnnCollector candidates = entryCandidates; - for (int level = curMaxLevel; level > nodeLevel; level--) { - candidates.clear(); - graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); - eps[0] = candidates.popNode(); - } - - // for levels <= nodeLevel search with topk = beamWidth, and add connections - candidates = beamCandidates; - NeighborArray[] scratchPerLevel = - new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1]; - for (int i = scratchPerLevel.length - 1; i >= 0; i--) { - int level = i + lowestUnsetLevel; - candidates.clear(); - graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); - eps = candidates.popUntilNearestKNodes(); - scratchPerLevel[i] = new NeighborArray(Math.max(beamCandidates.k(), M + 1), false); - popToScratch(candidates, scratchPerLevel[i]); + try (RandomVectorScorer scorer = scorerSupplier.scorer(node)) { + final int nodeLevel = getRandomGraphLevel(ml, random); + // first add nodes to all levels + for (int level = nodeLevel; level >= 0; level--) { + hnsw.addNode(level, node); } - - // then do connections from bottom up - for (int i = 0; i < scratchPerLevel.length; i++) { - addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); - } - lowestUnsetLevel += scratchPerLevel.length; - assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; - if (lowestUnsetLevel > nodeLevel) { + // then promote itself as entry node if entry node is not set + if (hnsw.trySetNewEntryNode(node, nodeLevel)) { return; } - assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel; - if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) { - return; - } - if (hnsw.numLevels() == curMaxLevel + 1) { - // This should never happen if all the calculations are correct - throw new IllegalStateException( - "We're not able to promote node " - + node - + " at level " - + nodeLevel - + " as entry node. But the max graph level " - + curMaxLevel - + " has not changed while we are inserting the node."); - } - } while (true); + // if the entry node is already set, then we have to do all connections first before we can + // promote ourselves as entry node + + int lowestUnsetLevel = 0; + int curMaxLevel; + do { + curMaxLevel = hnsw.numLevels() - 1; + // NOTE: the entry node and max level may not be paired, but because we get the level first + // we ensure that the entry node we get later will always exist on the curMaxLevel + int[] eps = new int[] {hnsw.entryNode()}; + + // we first do the search from top to bottom + // for levels > nodeLevel search with topk = 1 + GraphBuilderKnnCollector candidates = entryCandidates; + for (int level = curMaxLevel; level > nodeLevel; level--) { + candidates.clear(); + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); + eps[0] = candidates.popNode(); + } + + // for levels <= nodeLevel search with topk = beamWidth, and add connections + candidates = beamCandidates; + NeighborArray[] scratchPerLevel = + new NeighborArray[Math.min(nodeLevel, curMaxLevel) - lowestUnsetLevel + 1]; + for (int i = scratchPerLevel.length - 1; i >= 0; i--) { + int level = i + lowestUnsetLevel; + candidates.clear(); + graphSearcher.searchLevel(candidates, scorer, level, eps, hnsw, null); + eps = candidates.popUntilNearestKNodes(); + scratchPerLevel[i] = new NeighborArray(Math.max(beamCandidates.k(), M + 1), false); + popToScratch(candidates, scratchPerLevel[i]); + } + + // then do connections from bottom up + for (int i = 0; i < scratchPerLevel.length; i++) { + addDiverseNeighbors(i + lowestUnsetLevel, node, scratchPerLevel[i]); + } + lowestUnsetLevel += scratchPerLevel.length; + assert lowestUnsetLevel == Math.min(nodeLevel, curMaxLevel) + 1; + if (lowestUnsetLevel > nodeLevel) { + return; + } + assert lowestUnsetLevel == curMaxLevel + 1 && nodeLevel > curMaxLevel; + if (hnsw.tryPromoteNewEntryNode(node, nodeLevel, curMaxLevel)) { + return; + } + if (hnsw.numLevels() == curMaxLevel + 1) { + // This should never happen if all the calculations are correct + throw new IllegalStateException( + "We're not able to promote node " + + node + + " at level " + + nodeLevel + + " as entry node. But the max graph level " + + curMaxLevel + + " has not changed while we are inserting the node."); + } + } while (true); + } } private long printGraphBuildStatus(int node, long start, long t) { @@ -393,11 +394,12 @@ private static void popToScratch(GraphBuilderKnnCollector candidates, NeighborAr */ private boolean diversityCheck(int candidate, float score, NeighborArray neighbors) throws IOException { - RandomVectorScorer scorer = scorerSupplier.scorer(candidate); - for (int i = 0; i < neighbors.size(); i++) { - float neighborSimilarity = scorer.score(neighbors.nodes()[i]); - if (neighborSimilarity >= score) { - return false; + try (RandomVectorScorer scorer = scorerSupplier.scorer(candidate)) { + for (int i = 0; i < neighbors.size(); i++) { + float neighborSimilarity = scorer.score(neighbors.nodes()[i]); + if (neighborSimilarity >= score) { + return false; + } } } return true; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java index 716364a39dc2..592aedc57e6b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/NeighborArray.java @@ -237,8 +237,10 @@ private int descSortFindRightMostInsertionPoint(float newScore, int bound) { */ private int findWorstNonDiverse(int nodeOrd, RandomVectorScorerSupplier scorerSupplier) throws IOException { - RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd); - int[] uncheckedIndexes = sort(scorer); + int[] uncheckedIndexes; + try (RandomVectorScorer scorer = scorerSupplier.scorer(nodeOrd)) { + uncheckedIndexes = sort(scorer); + } assert uncheckedIndexes != null : "We will always have something unchecked"; int uncheckedCursor = uncheckedIndexes.length - 1; for (int i = size - 1; i > 0; i--) { @@ -263,25 +265,26 @@ private boolean isWorstNonDiverse( RandomVectorScorerSupplier scorerSupplier) throws IOException { float minAcceptedSimilarity = scores[candidateIndex]; - RandomVectorScorer scorer = scorerSupplier.scorer(nodes[candidateIndex]); - if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { - // the candidate itself is unchecked - for (int i = candidateIndex - 1; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes[i]); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + try (RandomVectorScorer scorer = scorerSupplier.scorer(nodes[candidateIndex])) { + if (candidateIndex == uncheckedIndexes[uncheckedCursor]) { + // the candidate itself is unchecked + for (int i = candidateIndex - 1; i >= 0; i--) { + float neighborSimilarity = scorer.score(nodes[i]); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } - } - } else { - // else we just need to make sure candidate does not violate diversity with the (newly - // inserted) unchecked nodes - assert candidateIndex > uncheckedIndexes[uncheckedCursor]; - for (int i = uncheckedCursor; i >= 0; i--) { - float neighborSimilarity = scorer.score(nodes[uncheckedIndexes[i]]); - // candidate node is too similar to node i given its score relative to the base node - if (neighborSimilarity >= minAcceptedSimilarity) { - return true; + } else { + // else we just need to make sure candidate does not violate diversity with the (newly + // inserted) unchecked nodes + assert candidateIndex > uncheckedIndexes[uncheckedCursor]; + for (int i = uncheckedCursor; i >= 0; i--) { + float neighborSimilarity = scorer.score(nodes[uncheckedIndexes[i]]); + // candidate node is too similar to node i given its score relative to the base node + if (neighborSimilarity >= minAcceptedSimilarity) { + return true; + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index a135df436991..7d3b6c1b7f9b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -22,7 +22,7 @@ import org.apache.lucene.util.Bits; /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ -public interface RandomVectorScorer { +public interface RandomVectorScorer extends AutoCloseable { /** * Returns the score between the query and the provided node. * @@ -56,9 +56,13 @@ default Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + @Override + default void close() {} + /** Creates a default scorer for random access vectors. */ abstract class AbstractRandomVectorScorer implements RandomVectorScorer { private final KnnVectorValues values; + private final Bag scorerPool; /** * Creates a new scorer for the given vector values. @@ -66,7 +70,12 @@ abstract class AbstractRandomVectorScorer implements RandomVectorScorer { * @param values the vector values */ public AbstractRandomVectorScorer(KnnVectorValues values) { + this(values, null); + } + + public AbstractRandomVectorScorer(KnnVectorValues values, Bag scorerPool) { this.values = values; + this.scorerPool = scorerPool; } @Override @@ -83,5 +92,12 @@ public int ordToDoc(int ord) { public Bits getAcceptOrds(Bits acceptDocs) { return values.getAcceptOrds(acceptDocs); } + + @Override + public void close() { + if (scorerPool != null) { + scorerPool.offer(this); + } + } } }