Skip to content

Commit

Permalink
reuse Floats and RandomVectorScorers
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Sokolov committed Nov 8, 2024
1 parent f1e0007 commit ef13bad
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.hnsw.Bag;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;

Expand Down Expand Up @@ -122,26 +123,24 @@ public String toString() {
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final FloatVectorValues vectorValues;
private final VectorSimilarityFunction similarityFunction;
private final FloatVectorValues.Floats queryVectors;
private final FloatVectorValues.Floats targetVectors;
private final Bag<RandomVectorScorer> pool = new Bag<>();

private FloatScoringSupplier(
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectorValues = vectorValues;
this.similarityFunction = similarityFunction;
this.queryVectors = vectorValues.vectors();
this.targetVectors = vectorValues.vectors();
}

@Override
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(queryVectors.get(ord), targetVectors.get(node));
}
};
public RandomVectorScorer scorer(int ord) throws IOException {
FloatVectorScorer scorer = (FloatVectorScorer) pool.poll();
if (scorer != null) {
scorer.setQuery(ord);
} else {
scorer = new FloatVectorScorer(vectorValues, ord, similarityFunction, pool);
}
return scorer;
}

@Override
Expand All @@ -152,22 +151,40 @@ public String toString() {

/** A {@link RandomVectorScorer} for float vectors. */
private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final float[] query;
private final FloatVectorValues.Floats vectors, queryVectors;
private final VectorSimilarityFunction similarityFunction;
private final FloatVectorValues.Floats targetVectors;
private float[] query;

FloatVectorScorer(
FloatVectorValues vectorValues,
int ord,
VectorSimilarityFunction similarityFunction,
Bag<RandomVectorScorer> pool)
throws IOException {
super(vectorValues, pool);
this.similarityFunction = similarityFunction;
vectors = vectorValues.vectors();
queryVectors = vectorValues.vectors();
query = queryVectors.get(ord);
}

public FloatVectorScorer(
FloatVectorValues vectorValues, float[] query, VectorSimilarityFunction similarityFunction)
throws IOException {
super(vectorValues);
this.query = query;
this.similarityFunction = similarityFunction;
this.targetVectors = vectorValues.vectors();
vectors = vectorValues.vectors();
queryVectors = null;
this.query = query;
}

private void setQuery(int ord) throws IOException {
query = queryVectors.get(ord);
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, targetVectors.get(node));
return similarityFunction.compare(query, vectors.get(node));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.Bag;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;

/** Read the vector values from the index input. This supports both iterated and random access. */
public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice {

private final Bag<Floats> pool = new Bag<>();
protected final int dimension;
protected final int size;
protected final IndexInput slice;
Expand Down Expand Up @@ -73,6 +74,10 @@ public IndexInput getSlice() {

@Override
public Floats vectors() {
Floats floats = pool.poll();
if (floats != null) {
return floats;
}
IndexInput sliceCopy = slice.clone();
float[] value = new float[dimension];
return new Floats() {
Expand All @@ -88,6 +93,11 @@ public float[] get(int targetOrd) throws IOException {
lastOrd = targetOrd;
return value;
}

@Override
public void close() throws IOException {
pool.offer(this);
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ public Floats vectors() throws IOException {
public float[] get(int ord) throws IOException {
return rawVectors.get(ord);
}

@Override
public void close() throws IOException {
rawVectors.close();
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1216,11 +1216,9 @@ public DocIndexIterator iterator() throws IOException {

static final class NormalizedFloatVectorValues extends FloatVectorValues {
private final FloatVectorValues vectorValues;
private final Floats floats;

public NormalizedFloatVectorValues(FloatVectorValues vectorValues) throws IOException {
this.vectorValues = vectorValues;
floats = vectorValues.vectors();
}

@Override
Expand All @@ -1239,15 +1237,22 @@ public int ordToDoc(int ord) {
}

@Override
public Floats vectors() {
public Floats vectors() throws IOException {
float[] normalizedVector = new float[vectorValues.dimension()];
return new Floats() {
Floats delegate = vectorValues.vectors();

@Override
public float[] get(int ord) throws IOException {
System.arraycopy(floats.get(ord), 0, normalizedVector, 0, normalizedVector.length);
System.arraycopy(delegate.get(ord), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
return normalizedVector;
}

@Override
public void close() throws IOException {
delegate.close();
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class FloatVectorValues extends KnnVectorValues {
protected FloatVectorValues() {}

/** A random access (lookup by ord) provider of the vector values */
public abstract static class Floats {
public abstract static class Floats implements AutoCloseable {
/**
* Return the vector value for the given vector ordinal which must be in [0, size() - 1],
* otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls.
Expand All @@ -42,6 +42,11 @@ public abstract static class Floats {
*/
public abstract float[] get(int ord) throws IOException;

@Override
public void close() throws IOException {
// by default do nothing. Some implementations do more interesting resource management.
}

/** A Floats containing no vectors. Throws UnsupportedOperationException if get() is called. */
public static final Floats EMPTY =
new Floats() {
Expand Down Expand Up @@ -118,6 +123,9 @@ public Floats vectors() {
public float[] get(int ord) throws IOException {
return vectors.get(ord);
}

@Override
public void close() {}
};
}

Expand Down
71 changes: 71 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/hnsw/Bag.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.lucene.util.hnsw;

/**
* A collection of objects that is threadsafe, providing offer(T) that tries to add an element and
* poll() that removes and returns an element or null. The storage will never grow. There are no
* guarantees about which object will be returned from poll(), just that it will be one that was
* added by offer().
*/
public class Bag<T> {
private static final int DEFAULT_CAPACITY = 64;

private final Object[] elements;
private int writeTo;
private int readFrom;

public Bag() {
this(DEFAULT_CAPACITY);
}

public Bag(int capacity) {
elements = new Object[capacity];
}

public synchronized boolean offer(T element) {
if (full()) {
return false;
}
elements[writeTo] = element;
writeTo = (writeTo + 1) % elements.length;
return true;
}

@SuppressWarnings("unchecked")
public synchronized T poll() {
if (empty()) {
return null;
}
T result = (T) elements[readFrom];
readFrom = (readFrom + 1) % elements.length;
return result;
}

private boolean full() {
int headroom = readFrom - 1 - writeTo;
if (headroom < 0) {
headroom += elements.length;
}
return headroom == 0;
}

private boolean empty() {
return readFrom == writeTo;
}
}
Loading

0 comments on commit ef13bad

Please sign in to comment.