diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 4ea6265d7ae1..3ba407b2b4de 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -254,6 +254,8 @@ New Features * GITHUB#13268: Add ability for UnifiedHighlighter to highlight a field based on combined matches from multiple fields. (Mayya Sharipova, Jim Ferenczi) +* GITHUB#13288: Make HNSW and Flat storage vector formats easier to extend with new FlatVectorScorer interface. Add + new Hnsw format for binary quantized vectors. (Ben Trent) Improvements --------------------- diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java index 6719639b67d9..f43d33ca2747 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java @@ -29,7 +29,7 @@ * * @lucene.experimental */ -public class Word2VecModel implements RandomAccessVectorValues { +public class Word2VecModel implements RandomAccessVectorValues.Floats { private final int dictionarySize; private final int vectorDimension; @@ -88,7 +88,7 @@ public int size() { } @Override - public RandomAccessVectorValues copy() throws IOException { + public Word2VecModel copy() throws IOException { return new Word2VecModel( this.dictionarySize, this.vectorDimension, this.termsAndVectors, this.word2Vec); } diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java index d7a37bbba9b8..68ea5ec555ea 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecSynonymProvider.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.util.LinkedList; import java.util.List; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopDocs; @@ -44,6 +45,7 @@ public class Word2VecSynonymProvider { VectorSimilarityFunction.DOT_PRODUCT; private final Word2VecModel word2VecModel; private final OnHeapHnswGraph hnswGraph; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); /** * Word2VecSynonymProvider constructor @@ -53,7 +55,7 @@ public class Word2VecSynonymProvider { public Word2VecSynonymProvider(Word2VecModel model) throws IOException { this.word2VecModel = model; RandomVectorScorerSupplier scorerSupplier = - RandomVectorScorerSupplier.createFloats(word2VecModel, SIMILARITY_FUNCTION); + defaultFlatVectorScorer.getRandomVectorScorerSupplier(SIMILARITY_FUNCTION, word2VecModel); HnswGraphBuilder builder = HnswGraphBuilder.create( scorerSupplier, @@ -75,7 +77,7 @@ public List getSynonyms( float[] query = word2VecModel.vectorValue(term); if (query != null) { RandomVectorScorer scorer = - RandomVectorScorer.createFloats(word2VecModel, SIMILARITY_FUNCTION, query); + defaultFlatVectorScorer.getRandomVectorScorer(SIMILARITY_FUNCTION, word2VecModel, query); KnnCollector synonyms = HnswGraphSearcher.search( scorer, diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 196e3b24b795..52972e9dcda4 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder { private final Lucene90NeighborArray scratch; private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues vectorValues; + private final RandomAccessVectorValues.Floats vectorValues; private final SplittableRandom random; private final Lucene90BoundsChecker bound; final Lucene90OnHeapHnswGraph hnsw; @@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private final RandomAccessVectorValues buildVectors; + private final RandomAccessVectorValues.Floats buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene90HnswGraphBuilder( - RandomAccessVectorValues vectors, + RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -104,8 +104,7 @@ public Lucene90HnswGraphBuilder( * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) - throws IOException { + public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -231,7 +230,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene90NeighborArray neighbors, - RandomAccessVectorValues vectorValues) + RandomAccessVectorValues.Floats vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 69ff0608f097..ea63c926ac1d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -350,7 +350,7 @@ int size() { /** Read the vector values from the index input. This supports both iterated and random access. */ static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { final int dimension; final int[] ordToDoc; @@ -419,7 +419,7 @@ public int advance(int target) { } @Override - public RandomAccessVectorValues copy() { + public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues(dimension, ordToDoc, dataIn.clone()); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index a8bab2c14755..52f2146e836b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -74,7 +74,7 @@ public static NeighborQueue search( float[] query, int topK, int numSeed, - RandomAccessVectorValues vectors, + RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction, HnswGraph graphValues, Bits acceptOrds, diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 9ff575624d34..f29c04ed10ca 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -26,6 +26,7 @@ import java.util.function.IntUnaryOperator; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; @@ -56,6 +57,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader { private final Map fields = new HashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); Lucene91HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); @@ -233,7 +235,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits OffHeapFloatVectorValues vectorValues = getOffHeapVectorValues(fieldEntry); RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), @@ -387,7 +390,7 @@ int ordToDoc(int ord) { /** Read the vector values from the index input. This supports both iterated and random access. */ static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { private final int dimension; private final int size; @@ -464,7 +467,7 @@ public int advance(int target) { } @Override - public RandomAccessVectorValues copy() { + public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues(dimension, size, ordToDoc, dataIn.clone()); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 76c554eef446..833efdf80259 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -25,6 +25,7 @@ import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; @@ -55,6 +56,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader { private final Map fields = new HashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); Lucene92HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); @@ -232,7 +234,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 0a9e29695930..ec0cbf7379a4 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -28,7 +28,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { protected final int dimension; protected final int size; @@ -114,7 +114,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone()); } @@ -173,7 +173,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapFloatVectorValues copy() throws IOException { return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone()); } @@ -240,7 +240,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index b7bacf7da672..a948ab7bee3f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -25,6 +25,7 @@ import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; @@ -56,6 +57,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader { private final Map fields = new HashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); Lucene94HnswVectorsReader(SegmentReadState state) throws IOException { int versionMeta = readMetadata(state); @@ -269,7 +271,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits OffHeapFloatVectorValues vectorValues = OffHeapFloatVectorValues.load(fieldEntry, vectorData); RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), @@ -288,7 +291,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData); RandomVectorScorer scorer = - RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index a75e5b287889..b961bafabb16 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -30,7 +30,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Bytes { protected final int dimension; protected final int size; @@ -124,7 +124,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapByteVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } @@ -186,7 +186,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapByteVectorValues copy() throws IOException { return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); } @@ -253,7 +253,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapByteVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 7c1f5afc07d0..95abedf2d872 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -28,7 +28,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { protected final int dimension; protected final int size; @@ -120,7 +120,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapFloatVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } @@ -182,7 +182,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapFloatVectorValues copy() throws IOException { return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); } @@ -249,7 +249,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 6137d9762e88..72b7cfc82f21 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -24,8 +24,9 @@ import java.util.HashMap; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; @@ -65,6 +66,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements private final Map fields = new HashMap<>(); private final IndexInput vectorData; private final IndexInput vectorIndex; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); Lucene95HnswVectorsReader(SegmentReadState state) throws IOException { this.fieldInfos = state.fieldInfos; @@ -300,7 +302,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits fieldEntry.vectorDataLength, vectorData); RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), @@ -328,7 +331,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits fieldEntry.vectorDataLength, vectorData); RandomVectorScorer scorer = - RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target); + defaultFlatVectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, vectorValues, target); HnswGraphSearcher.search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index b90fc03253d8..ce07b2548495 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -231,7 +231,7 @@ private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOExcepti private void writeGraph( IndexOutput graphData, - RandomAccessVectorValues vectorValues, + RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction, long graphDataOffset, long[] offsets, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index 99dd1c1ebe6d..dbb9a71b4218 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -24,6 +24,7 @@ import java.util.Objects; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; @@ -54,8 +55,9 @@ public final class Lucene91HnswGraphBuilder { private final double ml; private final Lucene91NeighborArray scratch; + private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues vectorValues; + private final RandomAccessVectorValues.Floats vectorValues; private final SplittableRandom random; private final Lucene91BoundsChecker bound; private final HnswGraphSearcher graphSearcher; @@ -66,7 +68,7 @@ public final class Lucene91HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private RandomAccessVectorValues buildVectors; + private RandomAccessVectorValues.Floats buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -81,7 +83,7 @@ public final class Lucene91HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene91HnswGraphBuilder( - RandomAccessVectorValues vectors, + RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -118,8 +120,7 @@ public Lucene91HnswGraphBuilder( * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) - throws IOException { + public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -146,7 +147,7 @@ public void setInfoStream(InfoStream infoStream) { /** Inserts a doc with vector value to the graph */ void addGraphNode(int node, float[] value) throws IOException { RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, similarityFunction, value); + defaultFlatVectorScorer.getRandomVectorScorer(similarityFunction, vectorValues, value); HnswGraphBuilder.GraphBuilderKnnCollector candidates; final int nodeLevel = getRandomGraphLevel(ml, random); int curMaxLevel = hnsw.numLevels() - 1; @@ -253,7 +254,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene91NeighborArray neighbors, - RandomAccessVectorValues vectorValues) + RandomAccessVectorValues.Floats vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 22234c3f9ec7..b58b2e21a4f4 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -239,7 +239,7 @@ private void writeMeta( } private Lucene91OnHeapHnswGraph writeGraph( - RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction) + RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { // build graph diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 09ab197c6d5c..4dd8f1f3054d 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -26,6 +26,7 @@ import java.util.Arrays; import org.apache.lucene.codecs.BufferingKnnVectorsWriter; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; @@ -273,12 +274,12 @@ private void writeMeta( } private OnHeapHnswGraph writeGraph( - RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction) + RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { - + DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); // build graph RandomVectorScorerSupplier scorerSupplier = - RandomVectorScorerSupplier.createFloats(vectorValues, similarityFunction); + defaultFlatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectorValues); HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index bea2a22343e2..d8cdb1739f7c 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -29,6 +29,7 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; @@ -409,6 +410,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE // TODO: separate random access vector values from DocIdSetIterator? int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize; OnHeapHnswGraph graph = null; + DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); if (docsWithField.cardinality() != 0) { // build graph graph = @@ -421,8 +423,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); RandomVectorScorerSupplier scorerSupplier = - RandomVectorScorerSupplier.createBytes( - vectorValues, fieldInfo.getVectorSimilarityFunction()); + defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), vectorValues); HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder.create( scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); @@ -437,8 +439,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE vectorDataInput, byteSize); RandomVectorScorerSupplier scorerSupplier = - RandomVectorScorerSupplier.createFloats( - vectorValues, fieldInfo.getVectorSimilarityFunction()); + defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), vectorValues); HnswGraphBuilder hnswGraphBuilder = HnswGraphBuilder.create( scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); @@ -656,15 +658,15 @@ public float[] copyValue(float[] value) { this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - RandomAccessVectorValues raVectors = new RAVectorValues<>(vectors, dim); + DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); + case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromBytes((List) vectors, dim)); + case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); @@ -708,34 +710,4 @@ public long ramBytesUsed() { + hnswGraphBuilder.getGraph().ramBytesUsed(); } } - - private static class RAVectorValues implements RandomAccessVectorValues { - private final List vectors; - private final int dim; - - RAVectorValues(List vectors, int dim) { - this.vectors = vectors; - this.dim = dim; - } - - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public T vectorValue(int targetOrd) throws IOException { - return vectors.get(targetOrd); - } - - @Override - public RAVectorValues copy() throws IOException { - return this; - } - } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 32e2d1a426a0..c4e315a72e22 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -29,6 +29,7 @@ import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; @@ -436,27 +437,28 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE OnHeapHnswGraph graph = null; int[][] vectorIndexNodeOffsets = null; if (docsWithField.cardinality() != 0) { + DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); final RandomVectorScorerSupplier scorerSupplier; switch (fieldInfo.getVectorEncoding()) { case BYTE: scorerSupplier = - RandomVectorScorerSupplier.createBytes( + defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), new OffHeapByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize), - fieldInfo.getVectorSimilarityFunction()); + byteSize)); break; case FLOAT32: scorerSupplier = - RandomVectorScorerSupplier.createFloats( + defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), vectorDataInput, - byteSize), - fieldInfo.getVectorSimilarityFunction()); + byteSize)); break; default: throw new IllegalArgumentException( @@ -695,15 +697,15 @@ public float[] copyValue(float[] value) { this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - RAVectorValues raVectors = new RAVectorValues<>(vectors, dim); + DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); + case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromBytes((List) vectors, dim)); + case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); @@ -746,34 +748,4 @@ public long ramBytesUsed() { + hnswGraphBuilder.getGraph().ramBytesUsed(); } } - - private static class RAVectorValues implements RandomAccessVectorValues { - private final List vectors; - private final int dim; - - RAVectorValues(List vectors, int dim) { - this.vectors = vectors; - this.dim = dim; - } - - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public T vectorValue(int targetOrd) throws IOException { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues copy() throws IOException { - return this; - } - } } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWHnswScalarQuantizationVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWHnswScalarQuantizationVectorsFormat.java index e5b7b91a4eaa..a1d7648d0775 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWHnswScalarQuantizationVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene99/Lucene99RWHnswScalarQuantizationVectorsFormat.java @@ -18,9 +18,11 @@ package org.apache.lucene.backward_codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.FlatVectorsFormat; -import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; @@ -56,12 +58,16 @@ public int getMaxDimensions(String fieldName) { } static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat { - private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(); + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene99ScalarQuantizedVectorsWriter( - state, null, rawVectorFormat.fieldsWriter(state)); + state, + null, + rawVectorFormat.fieldsWriter(state), + new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer())); } } } diff --git a/lucene/codecs/src/java/module-info.java b/lucene/codecs/src/java/module-info.java index 73f53fbf96b9..8c8c2e83b94a 100644 --- a/lucene/codecs/src/java/module-info.java +++ b/lucene/codecs/src/java/module-info.java @@ -15,10 +15,13 @@ * limitations under the License. */ +import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat; + /** Lucene codecs and postings formats */ module org.apache.lucene.codecs { requires org.apache.lucene.core; + exports org.apache.lucene.codecs.bitvectors; exports org.apache.lucene.codecs.blockterms; exports org.apache.lucene.codecs.blocktreeords; exports org.apache.lucene.codecs.bloom; @@ -27,6 +30,8 @@ exports org.apache.lucene.codecs.uniformsplit; exports org.apache.lucene.codecs.uniformsplit.sharedterms; + provides org.apache.lucene.codecs.KnnVectorsFormat with + HnswBitVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat, org.apache.lucene.codecs.bloom.BloomFilteringPostingsFormat, diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java new file mode 100644 index 000000000000..b8ff37c2654a --- /dev/null +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.bitvectors; + +import java.io.IOException; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** A bit vector scorer for scoring byte vectors. */ +public class FlatBitVectorsScorer implements FlatVectorsScorer { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + assert vectorValues instanceof RandomAccessVectorValues.Bytes; + if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + return new BitRandomVectorScorerSupplier(byteVectorValues); + } + throw new IllegalArgumentException( + "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + throw new IllegalArgumentException("bit vectors do not support float[] targets"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + assert vectorValues instanceof RandomAccessVectorValues.Bytes; + if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + return new BitRandomVectorScorer(byteVectorValues, target); + } + throw new IllegalArgumentException( + "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + } + + static class BitRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues.Bytes vectorValues; + private final int bitDimensions; + private final byte[] query; + + BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + this.query = query; + this.bitDimensions = vectorValues.dimension() * Byte.SIZE; + this.vectorValues = vectorValues; + } + + @Override + public float score(int node) throws IOException { + return (bitDimensions - VectorUtil.xorBitCount(query, vectorValues.vectorValue(node))) + / (float) bitDimensions; + } + + @Override + public int maxOrd() { + return vectorValues.size(); + } + + @Override + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return vectorValues.getAcceptOrds(acceptDocs); + } + } + + static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier { + protected final RandomAccessVectorValues.Bytes vectorValues; + protected final RandomAccessVectorValues.Bytes vectorValues1; + protected final RandomAccessVectorValues.Bytes vectorValues2; + + public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) + throws IOException { + this.vectorValues = vectorValues; + this.vectorValues1 = vectorValues.copy(); + this.vectorValues2 = vectorValues.copy(); + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + byte[] query = vectorValues1.vectorValue(ord); + return new BitRandomVectorScorer(vectorValues2, query); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new BitRandomVectorScorerSupplier(vectorValues.copy()); + } + } + + @Override + public String toString() { + return "FlatBitVectorsScorer()"; + } +} diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/HnswBitVectorsFormat.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/HnswBitVectorsFormat.java new file mode 100644 index 000000000000..f5888c0a03c1 --- /dev/null +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/HnswBitVectorsFormat.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.bitvectors; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.TaskExecutor; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * Encodes bit vector values into an associated graph connecting the documents having values. The + * graph is used to power HNSW search. The format consists of two files, and uses {@link + * Lucene99FlatVectorsFormat} to store the actual vectors, but with a custom scorer implementation: + * For details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}. + * + * @lucene.experimental + */ +public final class HnswBitVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "HnswBitVectorsFormat"; + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link HnswGraph} + * for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private final FlatVectorsFormat flatVectorsFormat; + + private final int numMergeWorkers; + private final TaskExecutor mergeExec; + + /** Constructs a format using default graph construction parameters */ + public HnswBitVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public HnswBitVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public HnswBitVectorsFormat( + int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { + super(NAME); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to " + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to " + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + if (mergeExec != null) { + this.mergeExec = new TaskExecutor(mergeExec); + } else { + this.mergeExec = null; + } + this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer()); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new FlatBitVectorsWriter( + new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec)); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "HnswBitVectorsFormat(name=HnswBitVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } + + private static class FlatBitVectorsWriter extends KnnVectorsWriter { + private final KnnVectorsWriter delegate; + + public FlatBitVectorsWriter(KnnVectorsWriter delegate) { + this.delegate = delegate; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + delegate.mergeOneField(fieldInfo, mergeState); + } + + @Override + public void finish() throws IOException { + delegate.finish(); + } + + @Override + public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) { + throw new IllegalArgumentException("HnswBitVectorsFormat only supports BYTE encoding"); + } + return delegate.addField(fieldInfo); + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + delegate.flush(maxDoc, sortMap); + } + + @Override + public void close() throws IOException { + delegate.close(); + } + + @Override + public long ramBytesUsed() { + return delegate.ramBytesUsed(); + } + } +} diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/package-info.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/package-info.java new file mode 100644 index 000000000000..b1788f1ec7ff --- /dev/null +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A simple bit-vector format that supports hamming distance and storing vectors in an HNSW graph + */ +package org.apache.lucene.codecs.bitvectors; diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 6afa02e071bc..8ea9b22b35a3 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -47,7 +47,6 @@ import org.apache.lucene.util.BytesRefBuilder; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.StringHelper; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Reads vector values from a simple text format. All vectors are read up front and cached in RAM in @@ -282,8 +281,7 @@ int size() { } } - private static class SimpleTextFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + private static class SimpleTextFloatVectorValues extends FloatVectorValues { private final BytesRefBuilder scratch = new BytesRefBuilder(); private final FieldEntry entry; @@ -315,11 +313,6 @@ public float[] vectorValue() { return values[curOrd]; } - @Override - public RandomAccessVectorValues copy() { - return this; - } - @Override public int docID() { if (curOrd == -1) { @@ -364,15 +357,9 @@ private void readVector(float[] value) throws IOException { value[i] = Float.parseFloat(floatStrings[i]); } } - - @Override - public float[] vectorValue(int targetOrd) throws IOException { - return values[targetOrd]; - } } - private static class SimpleTextByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues { + private static class SimpleTextByteVectorValues extends ByteVectorValues { private final BytesRefBuilder scratch = new BytesRefBuilder(); private final FieldEntry entry; @@ -408,11 +395,6 @@ public byte[] vectorValue() { return binaryValue.bytes; } - @Override - public RandomAccessVectorValues copy() { - return this; - } - @Override public int docID() { if (curOrd == -1) { @@ -457,12 +439,6 @@ private void readVector(byte[] value) throws IOException { value[i] = (byte) Float.parseFloat(floatStrings[i]); } } - - @Override - public BytesRef vectorValue(int targetOrd) throws IOException { - binaryValue.bytes = values[curOrd]; - return binaryValue; - } } private int readInt(IndexInput in, BytesRef field) throws IOException { diff --git a/lucene/codecs/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/codecs/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat new file mode 100644 index 000000000000..27f66d2fc1e5 --- /dev/null +++ b/lucene/codecs/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java new file mode 100644 index 000000000000..71456b9cf8bd --- /dev/null +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.bitvectors; + +import static org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase.randomVector8; + +import java.io.IOException; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99Codec; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseIndexFileFormatTestCase; + +public class TestHnswBitVectorsFormat extends BaseIndexFileFormatTestCase { + @Override + protected Codec getCodec() { + return new Lucene99Codec() { + @Override + public KnnVectorsFormat getKnnVectorsFormatForField(String field) { + return new HnswBitVectorsFormat(); + } + }; + } + + @Override + protected void addRandomFields(Document doc) { + doc.add(new KnnByteVectorField("v2", randomVector8(30), VectorSimilarityFunction.DOT_PRODUCT)); + } + + public void testFloatVectorFails() throws IOException { + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); + IllegalArgumentException e = + expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc)); + e.getMessage().contains("HnswBitVectorsFormat only supports BYTE encoding"); + } + } + + public void testIndexAndSearchBitVectors() throws IOException { + byte[][] vectors = + new byte[][] { + new byte[] {(byte) 0b10101110, (byte) 0b01010111}, + new byte[] {(byte) 0b11110000, (byte) 0b00001111}, + new byte[] {(byte) 0b11001100, (byte) 0b00110011}, + new byte[] {(byte) 0b11111111, (byte) 0b00000000}, + new byte[] {(byte) 0b00000000, (byte) 0b00000000} + }; + try (Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + int id = 0; + for (byte[] vector : vectors) { + Document doc = new Document(); + doc.add(new KnnByteVectorField("v1", vector, VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new StringField("id", Integer.toString(id++), Field.Store.YES)); + w.addDocument(doc); + } + w.commit(); + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE); + r.searchNearestVectors("v1", vectors[0], collector, null); + TopDocs topDocs = collector.topDocs(); + assertEquals(3, topDocs.scoreDocs.length); + + StoredFields fields = r.storedFields(); + assertEquals("0", fields.document(topDocs.scoreDocs[0].doc).get("id")); + assertEquals(1.0, topDocs.scoreDocs[0].score, 1e-12); + assertEquals("2", fields.document(topDocs.scoreDocs[1].doc).get("id")); + assertEquals(0.625, topDocs.scoreDocs[1].score, 1e-12); + assertEquals("1", fields.document(topDocs.scoreDocs[2].doc).get("id")); + assertEquals(0.5, topDocs.scoreDocs[2].score, 1e-12); + } + } + } + + public void testToString() { + FilterCodec customCodec = + new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new HnswBitVectorsFormat(10, 20); + } + }; + String expectedString = + "HnswBitVectorsFormat(name=HnswBitVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=FlatBitVectorsScorer()))"; + assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(-1, 20)); + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(0, 20)); + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, 0)); + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, -1)); + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(512 + 1, 20)); + expectThrows(IllegalArgumentException.class, () -> new HnswBitVectorsFormat(20, 3201)); + } +} diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 04ee07275411..94ff818c499d 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -63,6 +63,7 @@ org.apache.lucene.test_framework; exports org.apache.lucene.util.quantization; + exports org.apache.lucene.codecs.hnsw; provides org.apache.lucene.analysis.TokenizerFactory with org.apache.lucene.analysis.standard.StandardTokenizerFactory; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java new file mode 100644 index 000000000000..e5496f3e10e0 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** + * Default implementation of {@link FlatVectorsScorer}. + * + * @lucene.experimental + */ +public class DefaultFlatVectorScorer implements FlatVectorsScorer { + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) { + return new FloatScoringSupplier(floatVectorValues, similarityFunction); + } else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + return new ByteScoringSupplier(byteVectorValues, similarityFunction); + } + throw new IllegalArgumentException( + "vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes"); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + assert vectorValues instanceof RandomAccessVectorValues.Floats; + if (target.length != vectorValues.dimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + target.length + + " differs from field dimension: " + + vectorValues.dimension()); + } + return new FloatVectorScorer( + (RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + assert vectorValues instanceof RandomAccessVectorValues.Bytes; + if (target.length != vectorValues.dimension()) { + throw new IllegalArgumentException( + "vector query dimension: " + + target.length + + " differs from field dimension: " + + vectorValues.dimension()); + } + return new ByteVectorScorer( + (RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction); + } + + @Override + public String toString() { + return "DefaultFlatVectorScorer()"; + } + + /** RandomVectorScorerSupplier for bytes vector */ + private static final class ByteScoringSupplier implements RandomVectorScorerSupplier { + private final RandomAccessVectorValues.Bytes vectors; + private final RandomAccessVectorValues.Bytes vectors1; + private final RandomAccessVectorValues.Bytes vectors2; + private final VectorSimilarityFunction similarityFunction; + + private ByteScoringSupplier( + RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction) + throws IOException { + this.vectors = vectors; + vectors1 = vectors.copy(); + vectors2 = vectors.copy(); + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return new ByteVectorScorer(vectors2, vectors1.vectorValue(ord), similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new ByteScoringSupplier(vectors, similarityFunction); + } + } + + /** RandomVectorScorerSupplier for Float vector */ + private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { + private final RandomAccessVectorValues.Floats vectors; + private final RandomAccessVectorValues.Floats vectors1; + private final RandomAccessVectorValues.Floats vectors2; + private final VectorSimilarityFunction similarityFunction; + + private FloatScoringSupplier( + RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction) + throws IOException { + this.vectors = vectors; + vectors1 = vectors.copy(); + vectors2 = vectors.copy(); + this.similarityFunction = similarityFunction; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return new FloatVectorScorer(vectors2, vectors1.vectorValue(ord), similarityFunction); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return new FloatScoringSupplier(vectors, similarityFunction); + } + } + + /** A {@link RandomVectorScorer} for float vectors. */ + private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final RandomAccessVectorValues.Floats values; + private final float[] query; + private final VectorSimilarityFunction similarityFunction; + + public FloatVectorScorer( + RandomAccessVectorValues.Floats values, + float[] query, + VectorSimilarityFunction similarityFunction) { + super(values); + this.values = values; + this.query = query; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int node) throws IOException { + return similarityFunction.compare(query, values.vectorValue(node)); + } + } + + /** A {@link RandomVectorScorer} for byte vectors. */ + private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { + private final RandomAccessVectorValues.Bytes values; + private final byte[] query; + private final VectorSimilarityFunction similarityFunction; + + public ByteVectorScorer( + RandomAccessVectorValues.Bytes values, + byte[] query, + VectorSimilarityFunction similarityFunction) { + super(values); + this.values = values; + this.query = query; + this.similarityFunction = similarityFunction; + } + + @Override + public float score(int node) throws IOException { + return similarityFunction.compare(query, values.vectorValue(node)); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java similarity index 94% rename from lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java rename to lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java index 679b3d3af2e2..313ccccd4eb8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatFieldVectorsWriter.java @@ -15,7 +15,9 @@ * limitations under the License. */ -package org.apache.lucene.codecs; +package org.apache.lucene.codecs.hnsw; + +import org.apache.lucene.codecs.KnnFieldVectorsWriter; /** * Vectors' writer for a field diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsFormat.java similarity index 88% rename from lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java rename to lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsFormat.java index 3bfb19ced57c..39d4bf01c6ee 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsFormat.java @@ -15,14 +15,15 @@ * limitations under the License. */ -package org.apache.lucene.codecs; +package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; /** - * Encodes/decodes per-document vectors + * Encodes/decodes per-document vectors and provides a scoring interface for the flat stored vectors * * @lucene.experimental */ diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java similarity index 89% rename from lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java rename to lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java index eca0fc97209a..04e379c10fe5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.lucene.codecs; +package org.apache.lucene.codecs.hnsw; import java.io.Closeable; import java.io.IOException; @@ -41,8 +41,20 @@ */ public abstract class FlatVectorsReader implements Closeable, Accountable { + /** Scorer for flat vectors */ + protected final FlatVectorsScorer vectorScorer; + /** Sole constructor */ - protected FlatVectorsReader() {} + protected FlatVectorsReader(FlatVectorsScorer vectorsScorer) { + this.vectorScorer = vectorsScorer; + } + + /** + * @return the {@link FlatVectorsScorer} for this reader. + */ + public FlatVectorsScorer getFlatVectorScorer() { + return vectorScorer; + } /** * Returns a {@link RandomVectorScorer} for the given field and target vector. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java new file mode 100644 index 000000000000..17430c24f276 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** + * Provides mechanisms to score vectors that are stored in a flat file The purpose of this class is + * for providing flexibility to the codec utilizing the vectors + * + * @lucene.experimental + */ +public interface FlatVectorsScorer { + + /** + * Returns a {@link RandomVectorScorerSupplier} that can be used to score vectors + * + * @param similarityFunction the similarity function to use + * @param vectorValues the vector values to score + * @return a {@link RandomVectorScorerSupplier} that can be used to score vectors + * @throws IOException if an I/O error occurs + */ + RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException; + + /** + * Returns a {@link RandomVectorScorer} for the given set of vectors and target vector. + * + * @param similarityFunction the similarity function to use + * @param vectorValues the vector values to score + * @param target the target vector + * @return a {@link RandomVectorScorer} for the given field and target vector. + * @throws IOException if an I/O error occurs when reading from the index. + */ + RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException; + + /** + * Returns a {@link RandomVectorScorer} for the given set of vectors and target vector. + * + * @param similarityFunction the similarity function to use + * @param vectorValues the vector values to score + * @param target the target vector + * @return a {@link RandomVectorScorer} for the given field and target vector. + * @throws IOException if an I/O error occurs when reading from the index. + */ + RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java similarity index 87% rename from lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java rename to lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java index 07cca250c0f7..96af676762fe 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsWriter.java @@ -15,10 +15,11 @@ * limitations under the License. */ -package org.apache.lucene.codecs; +package org.apache.lucene.codecs.hnsw; import java.io.Closeable; import java.io.IOException; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; @@ -32,9 +33,20 @@ * @lucene.experimental */ public abstract class FlatVectorsWriter implements Accountable, Closeable { + /** Scorer for flat vectors */ + protected final FlatVectorsScorer vectorsScorer; /** Sole constructor */ - protected FlatVectorsWriter() {} + protected FlatVectorsWriter(FlatVectorsScorer vectorsScorer) { + this.vectorsScorer = vectorsScorer; + } + + /** + * @return the {@link FlatVectorsScorer} for this reader. + */ + public FlatVectorsScorer getFlatVectorScorer() { + return vectorsScorer; + } /** * Add a new field for indexing, allowing the user to provide a writer that the flat vectors diff --git a/lucene/core/src/java/org/apache/lucene/codecs/HnswGraphProvider.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/HnswGraphProvider.java similarity index 97% rename from lucene/core/src/java/org/apache/lucene/codecs/HnswGraphProvider.java rename to lucene/core/src/java/org/apache/lucene/codecs/hnsw/HnswGraphProvider.java index 000432aff3f7..91658fe763c6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/HnswGraphProvider.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/HnswGraphProvider.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.lucene.codecs; +package org.apache.lucene.codecs.hnsw; import java.io.IOException; import org.apache.lucene.util.hnsw.HnswGraph; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java new file mode 100644 index 000000000000..47dc18afc0af --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.hnsw; + +import java.io.IOException; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer; +import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier; +import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; +import org.apache.lucene.util.quantization.ScalarQuantizer; + +/** + * Default scalar quantized implementation of {@link FlatVectorsScorer}. + * + * @lucene.experimental + */ +public class ScalarQuantizedVectorScorer implements FlatVectorsScorer { + + private final FlatVectorsScorer nonQuantizedDelegate; + + public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { + nonQuantizedDelegate = flatVectorsScorer; + } + + @Override + public RandomVectorScorerSupplier getRandomVectorScorerSupplier( + VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + return new ScalarQuantizedRandomVectorScorerSupplier( + similarityFunction, + quantizedByteVectorValues.getScalarQuantizer(), + quantizedByteVectorValues); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorerSupplier(similarityFunction, vectorValues); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + float[] target) + throws IOException { + if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); + byte[] targetBytes = new byte[target.length]; + float offsetCorrection = + ScalarQuantizedRandomVectorScorer.quantizeQuery( + target, targetBytes, similarityFunction, scalarQuantizer); + ScalarQuantizedVectorSimilarity scalarQuantizedVectorSimilarity = + ScalarQuantizedVectorSimilarity.fromVectorSimilarity( + similarityFunction, + scalarQuantizer.getConstantMultiplier(), + scalarQuantizer.getBits()); + return new ScalarQuantizedRandomVectorScorer( + scalarQuantizedVectorSimilarity, + quantizedByteVectorValues, + targetBytes, + offsetCorrection); + } + // It is possible to get to this branch during initial indexing and flush + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer( + VectorSimilarityFunction similarityFunction, + RandomAccessVectorValues vectorValues, + byte[] target) + throws IOException { + return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); + } + + @Override + public String toString() { + return "ScalarQuantizedVectorScorer(" + "nonQuantizedDelegate=" + nonQuantizedDelegate + ')'; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/package-info.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/package-info.java new file mode 100644 index 000000000000..4907ad67cd40 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/package-info.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * HNSW vector helper classes. The classes in this package provide a scoring and storing mechanism + * for vectors stored in a flat file. This allows for HNSW formats to be extended with other flat + * storage formats or scoring without significant changes to the HNSW code. Some examples for + * scoring include {@link org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer} and {@link + * org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer}. Some examples for storing include {@link + * org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat} and {@link + * org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat}. + */ +package org.apache.lucene.codecs.hnsw; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index c11ed70f0b87..da11df1e2518 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -30,7 +30,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ public abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Bytes { protected final int dimension; protected final int size; @@ -85,12 +85,11 @@ public static OffHeapByteVectorValues load( return new EmptyOffHeapVectorValues(dimension); } IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); - int byteSize = dimension; if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize); + return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, dimension); } else { return new SparseOffHeapVectorValues( - configuration, vectorData, bytesSlice, dimension, byteSize); + configuration, vectorData, bytesSlice, dimension, dimension); } } @@ -131,7 +130,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } @@ -194,7 +193,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( configuration, dataIn, slice.clone(), dimension, byteSize); } @@ -262,7 +261,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public EmptyOffHeapVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 93cca6262d17..a13c9f55bb12 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -29,7 +29,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ public abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { protected final int dimension; protected final int size; @@ -125,7 +125,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); } @@ -188,7 +188,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( configuration, dataIn, slice.clone(), dimension, byteSize); } @@ -256,7 +256,7 @@ public int advance(int target) throws IOException { } @Override - public RandomAccessVectorValues copy() throws IOException { + public EmptyOffHeapVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java index 39fddb33653b..dc6fe4e7178f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java @@ -18,9 +18,10 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.FlatVectorsFormat; -import org.apache.lucene.codecs.FlatVectorsReader; -import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -75,24 +76,25 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat { public static final int VERSION_CURRENT = VERSION_START; static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + private final FlatVectorsScorer vectorsScorer; /** Constructs a format */ - public Lucene99FlatVectorsFormat() { - super(); + public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) { + this.vectorsScorer = vectorsScorer; } @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new Lucene99FlatVectorsWriter(state); + return new Lucene99FlatVectorsWriter(state, vectorsScorer); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new Lucene99FlatVectorsReader(state); + return new Lucene99FlatVectorsReader(state, vectorsScorer); } @Override public String toString() { - return "Lucene99FlatVectorsFormat()"; + return "Lucene99FlatVectorsFormat(" + "vectorsScorer=" + vectorsScorer + ')'; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java index 63b5f2c4fe64..0311a3b0cf85 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -24,7 +24,8 @@ import java.util.HashMap; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; @@ -59,7 +60,9 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader { private final Map fields = new HashMap<>(); private final IndexInput vectorData; - public Lucene99FlatVectorsReader(SegmentReadState state) throws IOException { + public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer) + throws IOException { + super(scorer); int versionMeta = readMetadata(state); boolean success = false; try { @@ -217,7 +220,8 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { return null; } - return RandomVectorScorer.createFloats( + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, OffHeapFloatVectorValues.load( fieldEntry.ordToDoc, fieldEntry.vectorEncoding, @@ -225,7 +229,6 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, vectorData), - fieldEntry.similarityFunction, target); } @@ -235,7 +238,8 @@ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) thr if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { return null; } - return RandomVectorScorer.createBytes( + return vectorScorer.getRandomVectorScorer( + fieldEntry.similarityFunction, OffHeapByteVectorValues.load( fieldEntry.ordToDoc, fieldEntry.vectorEncoding, @@ -243,7 +247,6 @@ public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) thr fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, vectorData), - fieldEntry.similarityFunction, target); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 0491507aaf49..0f4a4114e708 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -27,10 +27,11 @@ import java.util.ArrayList; import java.util.List; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatFieldVectorsWriter; -import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; @@ -71,7 +72,9 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter { private final List> fields = new ArrayList<>(); private boolean finished; - public Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException { + public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer) + throws IOException { + super(scorer); segmentWriteState = state; String metaFileName = IndexFileNames.segmentFileName( @@ -305,20 +308,20 @@ public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( final IndexInput finalVectorDataInput = vectorDataInput; final RandomVectorScorerSupplier randomVectorScorerSupplier = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( + case BYTE -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), new OffHeapByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Byte.BYTES), - fieldInfo.getVectorSimilarityFunction()); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( + fieldInfo.getVectorDimension() * Byte.BYTES)); + case FLOAT32 -> vectorsScorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), new OffHeapFloatVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), finalVectorDataInput, - fieldInfo.getVectorDimension() * Float.BYTES), - fieldInfo.getVectorSimilarityFunction()); + fieldInfo.getVectorDimension() * Float.BYTES)); }; return new FlatCloseableRandomVectorScorerSupplier( () -> { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java index b9705295c60c..ff41ac01a961 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -25,10 +25,10 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; -import org.apache.lucene.codecs.FlatVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.search.TaskExecutor; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index bcea40170865..8c78a0cb0a06 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -19,10 +19,11 @@ import java.io.IOException; import java.util.concurrent.ExecutorService; -import org.apache.lucene.codecs.FlatVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.index.MergePolicy; import org.apache.lucene.index.MergeScheduler; @@ -101,7 +102,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { *

NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large * numbers here will use an inordinate amount of heap */ - static final int MAXIMUM_MAX_CONN = 512; + public static final int MAXIMUM_MAX_CONN = 512; /** Default number of maximum connections per node */ public static final int DEFAULT_MAX_CONN = 16; @@ -111,7 +112,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { * maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 = * 3200` */ - static final int MAXIMUM_BEAM_WIDTH = 3200; + public static final int MAXIMUM_BEAM_WIDTH = 3200; /** * Default number of the size of the queue maintained while searching during a graph construction. @@ -137,7 +138,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { private final int beamWidth; /** The format for storing, reading, merging vectors on disk */ - private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(); + private static final FlatVectorsFormat flatVectorsFormat = + new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); private final int numMergeWorkers; private final TaskExecutor mergeExec; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 6aaa3bb0611e..6323870367f2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -25,9 +25,9 @@ import java.util.List; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatVectorsReader; -import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index 1bd2c53b6d63..8f715993a2b0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -25,9 +25,10 @@ import java.util.Arrays; import java.util.List; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; @@ -127,7 +128,12 @@ public Lucene99HnswVectorsWriter( @Override public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { FieldWriter newField = - FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream); + FieldWriter.create( + flatVectorWriter.getFlatVectorScorer(), + fieldInfo, + M, + beamWidth, + segmentWriteState.infoStream); fields.add(newField); return flatVectorWriter.addField(fieldInfo, newField); } @@ -542,29 +548,32 @@ private static class FieldWriter extends KnnFieldVectorsWriter { private int lastDocID = -1; private int node = 0; - static FieldWriter create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) + static FieldWriter create( + FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { return switch (fieldInfo.getVectorEncoding()) { - case BYTE -> new FieldWriter(fieldInfo, M, beamWidth, infoStream); - case FLOAT32 -> new FieldWriter(fieldInfo, M, beamWidth, infoStream); + case BYTE -> new FieldWriter(scorer, fieldInfo, M, beamWidth, infoStream); + case FLOAT32 -> new FieldWriter(scorer, fieldInfo, M, beamWidth, infoStream); }; } @SuppressWarnings("unchecked") - FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) + FieldWriter( + FlatVectorsScorer scorer, FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; this.docsWithField = new DocsWithFieldSet(); vectors = new ArrayList<>(); - RAVectorValues raVectors = new RAVectorValues<>(vectors, fieldInfo.getVectorDimension()); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( - (RandomAccessVectorValues) raVectors, - fieldInfo.getVectorSimilarityFunction()); + case BYTE -> scorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromBytes( + (List) vectors, fieldInfo.getVectorDimension())); + case FLOAT32 -> scorer.getRandomVectorScorerSupplier( + fieldInfo.getVectorSimilarityFunction(), + RandomAccessVectorValues.fromFloats( + (List) vectors, fieldInfo.getVectorDimension())); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); @@ -609,34 +618,4 @@ public long ramBytesUsed() { + hnswGraphBuilder.getGraph().ramBytesUsed(); } } - - private static class RAVectorValues implements RandomAccessVectorValues { - private final List vectors; - private final int dim; - - RAVectorValues(List vectors, int dim) { - this.vectors = vectors; - this.dim = dim; - } - - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public T vectorValue(int targetOrd) throws IOException { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues copy() throws IOException { - return this; - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 8e36fad52555..a3d894e64e99 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -18,9 +18,11 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; -import org.apache.lucene.codecs.FlatVectorsFormat; -import org.apache.lucene.codecs.FlatVectorsReader; -import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; @@ -46,7 +48,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { static final String META_EXTENSION = "vemq"; static final String VECTOR_DATA_EXTENSION = "veq"; - private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(); + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer()); /** The minimum confidence interval */ private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f; @@ -62,6 +65,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { final byte bits; final boolean compress; + final ScalarQuantizedVectorScorer flatVectorScorer; /** Constructs a format using default graph construction parameters */ public Lucene99ScalarQuantizedVectorsFormat() { @@ -98,6 +102,7 @@ public Lucene99ScalarQuantizedVectorsFormat( this.bits = (byte) bits; this.confidenceInterval = confidenceInterval; this.compress = compress; + this.flatVectorScorer = new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); } public static float calculateDefaultConfidenceInterval(int vectorDimension) { @@ -115,6 +120,8 @@ public String toString() { + bits + ", compress=" + compress + + ", flatVectorScorer=" + + flatVectorScorer + ", rawVectorFormat=" + rawVectorFormat + ")"; @@ -123,11 +130,17 @@ public String toString() { @Override public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene99ScalarQuantizedVectorsWriter( - state, confidenceInterval, bits, compress, rawVectorFormat.fieldsWriter(state)); + state, + confidenceInterval, + bits, + compress, + rawVectorFormat.fieldsWriter(state), + flatVectorScorer); } @Override public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state)); + return new Lucene99ScalarQuantizedVectorsReader( + state, rawVectorFormat.fieldsReader(state), flatVectorScorer); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 6fce634cfeea..8aaa2cca7b5b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -24,7 +24,8 @@ import java.util.HashMap; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; @@ -45,7 +46,6 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorer; import org.apache.lucene.util.quantization.ScalarQuantizer; /** @@ -64,7 +64,9 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade private final FlatVectorsReader rawVectorsReader; public Lucene99ScalarQuantizedVectorsReader( - SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer) + throws IOException { + super(scorer); this.rawVectorsReader = rawVectorsReader; int versionMeta = -1; String metaFileName = @@ -224,13 +226,12 @@ public RandomVectorScorer getRandomVectorScorer(String field, float[] target) th fieldEntry.ordToDoc, fieldEntry.dimension, fieldEntry.size, - fieldEntry.bits, + fieldEntry.scalarQuantizer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, quantizedVectorData); - return new ScalarQuantizedRandomVectorScorer( - fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target); + return vectorScorer.getRandomVectorScorer(fieldEntry.similarityFunction, vectorValues, target); } @Override @@ -280,7 +281,7 @@ public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) thro fieldEntry.ordToDoc, fieldEntry.dimension, fieldEntry.size, - fieldEntry.bits, + fieldEntry.scalarQuantizer, fieldEntry.compress, fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 9064841dc335..c4067b7fd78d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -30,11 +30,12 @@ import java.util.ArrayList; import java.util.List; import org.apache.lucene.codecs.CodecUtil; -import org.apache.lucene.codecs.FlatFieldVectorsWriter; -import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocIDMerger; @@ -59,7 +60,6 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.QuantizedVectorsReader; -import org.apache.lucene.util.quantization.ScalarQuantizedRandomVectorScorerSupplier; import org.apache.lucene.util.quantization.ScalarQuantizer; /** @@ -102,7 +102,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite private boolean finished; public Lucene99ScalarQuantizedVectorsWriter( - SegmentWriteState state, Float confidenceInterval, FlatVectorsWriter rawVectorDelegate) + SegmentWriteState state, + Float confidenceInterval, + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer) throws IOException { this( state, @@ -110,7 +113,8 @@ public Lucene99ScalarQuantizedVectorsWriter( confidenceInterval, (byte) 7, false, - rawVectorDelegate); + rawVectorDelegate, + scorer); } public Lucene99ScalarQuantizedVectorsWriter( @@ -118,7 +122,8 @@ public Lucene99ScalarQuantizedVectorsWriter( Float confidenceInterval, byte bits, boolean compress, - FlatVectorsWriter rawVectorDelegate) + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer) throws IOException { this( state, @@ -126,7 +131,8 @@ public Lucene99ScalarQuantizedVectorsWriter( confidenceInterval, bits, compress, - rawVectorDelegate); + rawVectorDelegate, + scorer); } private Lucene99ScalarQuantizedVectorsWriter( @@ -135,8 +141,10 @@ private Lucene99ScalarQuantizedVectorsWriter( Float confidenceInterval, byte bits, boolean compress, - FlatVectorsWriter rawVectorDelegate) + FlatVectorsWriter rawVectorDelegate, + FlatVectorsScorer scorer) throws IOException { + super(scorer); this.confidenceInterval = confidenceInterval; this.bits = bits; this.compress = compress; @@ -511,13 +519,12 @@ private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); }, docsWithField.cardinality(), - new ScalarQuantizedRandomVectorScorerSupplier( + vectorsScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - mergedQuantizationState, new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues( fieldInfo.getVectorDimension(), docsWithField.cardinality(), - bits, + mergedQuantizationState, compress, quantizationDataInput))); } finally { @@ -1091,12 +1098,12 @@ private void quantize() throws IOException { static final class ScalarQuantizedCloseableRandomVectorScorerSupplier implements CloseableRandomVectorScorerSupplier { - private final ScalarQuantizedRandomVectorScorerSupplier supplier; + private final RandomVectorScorerSupplier supplier; private final Closeable onClose; private final int numVectors; ScalarQuantizedCloseableRandomVectorScorerSupplier( - Closeable onClose, int numVectors, ScalarQuantizedRandomVectorScorerSupplier supplier) { + Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { this.onClose = onClose; this.supplier = supplier; this.numVectors = numVectors; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 872dc4fb23ff..08e666d51ff8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -26,6 +26,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.ScalarQuantizer; /** * Read the quantized vector values and their score correction values from the index input. This @@ -37,7 +38,7 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect protected final int dimension; protected final int size; protected final int numBytes; - protected final byte bits; + protected final ScalarQuantizer scalarQuantizer; protected final boolean compress; protected final IndexInput slice; @@ -81,13 +82,17 @@ static void compressBytes(byte[] raw, byte[] compressed) { } OffHeapQuantizedByteVectorValues( - int dimension, int size, byte bits, boolean compress, IndexInput slice) { + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + boolean compress, + IndexInput slice) { this.dimension = dimension; this.size = size; this.slice = slice; - this.bits = bits; + this.scalarQuantizer = scalarQuantizer; this.compress = compress; - if (bits <= 4 && compress) { + if (scalarQuantizer.getBits() <= 4 && compress) { this.numBytes = (dimension + 1) >> 1; } else { this.numBytes = dimension; @@ -97,6 +102,11 @@ static void compressBytes(byte[] raw, byte[] compressed) { binaryValue = byteBuffer.array(); } + @Override + public ScalarQuantizer getScalarQuantizer() { + return scalarQuantizer; + } + @Override public int dimension() { return dimension; @@ -129,7 +139,7 @@ public static OffHeapQuantizedByteVectorValues load( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, - byte bits, + ScalarQuantizer scalarQuantizer, boolean compress, long quantizedVectorDataOffset, long quantizedVectorDataLength, @@ -142,10 +152,10 @@ public static OffHeapQuantizedByteVectorValues load( vectorData.slice( "quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength); if (configuration.isDense()) { - return new DenseOffHeapVectorValues(dimension, size, bits, compress, bytesSlice); + return new DenseOffHeapVectorValues(dimension, size, scalarQuantizer, compress, bytesSlice); } else { return new SparseOffHeapVectorValues( - configuration, dimension, size, bits, compress, vectorData, bytesSlice); + configuration, dimension, size, scalarQuantizer, compress, vectorData, bytesSlice); } } @@ -158,8 +168,12 @@ public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorV private int doc = -1; public DenseOffHeapVectorValues( - int dimension, int size, byte bits, boolean compress, IndexInput slice) { - super(dimension, size, bits, compress, slice); + int dimension, + int size, + ScalarQuantizer scalarQuantizer, + boolean compress, + IndexInput slice) { + super(dimension, size, scalarQuantizer, compress, slice); } @Override @@ -188,7 +202,8 @@ public int advance(int target) throws IOException { @Override public DenseOffHeapVectorValues copy() throws IOException { - return new DenseOffHeapVectorValues(dimension, size, bits, compress, slice.clone()); + return new DenseOffHeapVectorValues( + dimension, size, scalarQuantizer, compress, slice.clone()); } @Override @@ -208,12 +223,12 @@ public SparseOffHeapVectorValues( OrdToDocDISIReaderConfiguration configuration, int dimension, int size, - byte bits, + ScalarQuantizer scalarQuantizer, boolean compress, IndexInput dataIn, IndexInput slice) throws IOException { - super(dimension, size, bits, compress, slice); + super(dimension, size, scalarQuantizer, compress, slice); this.configuration = configuration; this.dataIn = dataIn; this.ordToDoc = configuration.getDirectMonotonicReader(dataIn); @@ -244,7 +259,7 @@ public int advance(int target) throws IOException { @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( - configuration, dimension, size, bits, compress, dataIn, slice.clone()); + configuration, dimension, size, scalarQuantizer, compress, dataIn, slice.clone()); } @Override @@ -274,7 +289,7 @@ public int length() { private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { public EmptyOffHeapVectorValues(int dimension) { - super(dimension, 0, (byte) 7, false, null); + super(dimension, 0, new ScalarQuantizer(-1, 1, (byte) 7), false, null); } private int doc = -1; diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 97f8c0473835..7409a1de4cff 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -182,6 +182,30 @@ public static int int4DotProduct(byte[] a, byte[] b) { return IMPL.int4DotProduct(a, b); } + /** + * XOR bit count computed over signed bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the XOR bit count of the two vectors + */ + public static int xorBitCount(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + int distance = 0, i = 0; + for (final int upperBound = a.length & ~(Long.BYTES - 1); i < upperBound; i += Long.BYTES) { + distance += + Long.bitCount( + (long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + /** * Dot product score computed over signed bytes, scaled to be in [0, 1]. * diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index 38ecc38467b4..392d83fa262c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -17,7 +17,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; -import org.apache.lucene.codecs.HnswGraphProvider; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.TaskExecutor; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index 9c547a8b7037..1909af420156 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -19,8 +19,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index 5fc531d9f64f..ecf5339cd21a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -18,6 +18,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; +import java.util.List; import org.apache.lucene.util.Bits; /** @@ -26,7 +27,7 @@ * * @lucene.experimental */ -public interface RandomAccessVectorValues { +public interface RandomAccessVectorValues { /** Return the number of vector values */ int size(); @@ -34,19 +35,11 @@ public interface RandomAccessVectorValues { /** Return the dimension of the returned vector values */ int dimension(); - /** - * Return the vector value indexed at the given ordinal. - * - * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. - */ - T vectorValue(int targetOrd) throws IOException; - /** * Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to - * access different values at once, to avoid overwriting the underlying float vector returned by - * {@link RandomAccessVectorValues#vectorValue}. + * access different values at once, to avoid overwriting the underlying vector returned. */ - RandomAccessVectorValues copy() throws IOException; + RandomAccessVectorValues copy() throws IOException; /** * Translates vector ordinal to the correct document ID. By default, this is an identity function. @@ -67,4 +60,92 @@ default int ordToDoc(int ord) { default Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + + /** Float vector values. */ + interface Floats extends RandomAccessVectorValues { + @Override + RandomAccessVectorValues.Floats copy() throws IOException; + + /** + * Return the vector value indexed at the given ordinal. + * + * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. + */ + float[] vectorValue(int targetOrd) throws IOException; + } + + /** Byte vector values. */ + interface Bytes extends RandomAccessVectorValues { + @Override + RandomAccessVectorValues.Bytes copy() throws IOException; + + /** + * Return the vector value indexed at the given ordinal. + * + * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. + */ + byte[] vectorValue(int targetOrd) throws IOException; + } + + /** + * Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays. + * + * @param vectors the list of float arrays + * @param dim the dimension of the vectors + * @return a {@link RandomAccessVectorValues.Floats} instance + */ + static RandomAccessVectorValues.Floats fromFloats(List vectors, int dim) { + return new RandomAccessVectorValues.Floats() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public float[] vectorValue(int targetOrd) throws IOException { + return vectors.get(targetOrd); + } + + @Override + public RandomAccessVectorValues.Floats copy() throws IOException { + return this; + } + }; + } + + /** + * Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays. + * + * @param vectors the list of byte arrays + * @param dim the dimension of the vectors + * @return a {@link RandomAccessVectorValues.Bytes} instance + */ + static RandomAccessVectorValues.Bytes fromBytes(List vectors, int dim) { + return new RandomAccessVectorValues.Bytes() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public byte[] vectorValue(int targetOrd) throws IOException { + return vectors.get(targetOrd); + } + + @Override + public RandomAccessVectorValues.Bytes copy() throws IOException { + return this; + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index 76c7bb910f73..fc8ed3d004a1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -18,7 +18,6 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ @@ -56,82 +55,16 @@ default Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } - /** - * Creates a default scorer for float vectors. - * - *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid - * using it after calling this function. If you plan to use it again outside the returned {@link - * RandomVectorScorer}, think about passing a copied version ({@link - * RandomAccessVectorValues#copy}). - * - * @param vectors the underlying storage for vectors - * @param similarityFunction the similarity function to score vectors - * @param query the actual query - */ - static RandomVectorScorer createFloats( - final RandomAccessVectorValues vectors, - final VectorSimilarityFunction similarityFunction, - final float[] query) { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - return new AbstractRandomVectorScorer<>(vectors) { - @Override - public float score(int node) throws IOException { - return similarityFunction.compare(query, vectors.vectorValue(node)); - } - }; - } - - /** - * Creates a default scorer for byte vectors. - * - *

WARNING: The {@link RandomAccessVectorValues} given can contain stateful buffers. Avoid - * using it after calling this function. If you plan to use it again outside the returned {@link - * RandomVectorScorer}, think about passing a copied version ({@link - * RandomAccessVectorValues#copy}). - * - * @param vectors the underlying storage for vectors - * @param similarityFunction the similarity function to use to score vectors - * @param query the actual query - */ - static RandomVectorScorer createBytes( - final RandomAccessVectorValues vectors, - final VectorSimilarityFunction similarityFunction, - final byte[] query) { - if (query.length != vectors.dimension()) { - throw new IllegalArgumentException( - "vector query dimension: " - + query.length - + " differs from field dimension: " - + vectors.dimension()); - } - return new AbstractRandomVectorScorer<>(vectors) { - @Override - public float score(int node) throws IOException { - return similarityFunction.compare(query, vectors.vectorValue(node)); - } - }; - } - - /** - * Creates a default scorer for random access vectors. - * - * @param the type of the vector values - */ - abstract class AbstractRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues values; + /** Creates a default scorer for random access vectors. */ + abstract class AbstractRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues values; /** * Creates a new scorer for the given vector values. * * @param values the vector values */ - public AbstractRandomVectorScorer(RandomAccessVectorValues values) { + public AbstractRandomVectorScorer(RandomAccessVectorValues values) { this.values = values; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java index 2b8099675523..f8436f061d6a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java @@ -18,7 +18,6 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; -import org.apache.lucene.index.VectorSimilarityFunction; /** A supplier that creates {@link RandomVectorScorer} from an ordinal. */ public interface RandomVectorScorerSupplier { @@ -36,100 +35,4 @@ public interface RandomVectorScorerSupplier { * be used in other threads. */ RandomVectorScorerSupplier copy() throws IOException; - - /** - * Creates a {@link RandomVectorScorerSupplier} to compare float vectors. The vectorValues passed - * in will be copied and the original copy will not be used. - * - * @param vectors the underlying storage for vectors - * @param similarityFunction the similarity function to score vectors - */ - static RandomVectorScorerSupplier createFloats( - final RandomAccessVectorValues vectors, - final VectorSimilarityFunction similarityFunction) - throws IOException { - // We copy the provided random accessor just once during the supplier's initialization - // and then reuse it consistently across all scorers for conducting vector comparisons. - return new FloatScoringSupplier(vectors, similarityFunction); - } - - /** - * Creates a {@link RandomVectorScorerSupplier} to compare byte vectors. The vectorValues passed - * in will be copied and the original copy will not be used. - * - * @param vectors the underlying storage for vectors - * @param similarityFunction the similarity function to score vectors - */ - static RandomVectorScorerSupplier createBytes( - final RandomAccessVectorValues vectors, - final VectorSimilarityFunction similarityFunction) - throws IOException { - // We copy the provided random accessor only during the supplier's initialization - // and then reuse it consistently across all scorers for conducting vector comparisons. - return new ByteScoringSupplier(vectors, similarityFunction); - } - - /** RandomVectorScorerSupplier for bytes vector */ - final class ByteScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues vectors; - private final RandomAccessVectorValues vectors1; - private final RandomAccessVectorValues vectors2; - private final VectorSimilarityFunction similarityFunction; - - private ByteScoringSupplier( - RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction) - throws IOException { - this.vectors = vectors; - vectors1 = vectors.copy(); - vectors2 = vectors.copy(); - this.similarityFunction = similarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) { - @Override - public float score(int cand) throws IOException { - return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); - } - }; - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new ByteScoringSupplier(vectors, similarityFunction); - } - } - - /** RandomVectorScorerSupplier for Float vector */ - final class FloatScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues vectors; - private final RandomAccessVectorValues vectors1; - private final RandomAccessVectorValues vectors2; - private final VectorSimilarityFunction similarityFunction; - - private FloatScoringSupplier( - RandomAccessVectorValues vectors, VectorSimilarityFunction similarityFunction) - throws IOException { - this.vectors = vectors; - vectors1 = vectors.copy(); - vectors2 = vectors.copy(); - this.similarityFunction = similarityFunction; - } - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) { - @Override - public float score(int cand) throws IOException { - return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); - } - }; - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - return new FloatScoringSupplier(vectors, similarityFunction); - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java index 1c8e3d15a672..08b0b6e5a7ae 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java @@ -25,7 +25,10 @@ * * @lucene.experimental */ -public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues { +public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes { + + ScalarQuantizer getScalarQuantizer(); + float getScoreCorrectionConstant(); @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java index 0c07e1e971c6..41fcc3f97d01 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorer.java @@ -28,7 +28,7 @@ * @lucene.experimental */ public class ScalarQuantizedRandomVectorScorer - extends RandomVectorScorer.AbstractRandomVectorScorer { + extends RandomVectorScorer.AbstractRandomVectorScorer { public static float quantizeQuery( float[] query, @@ -64,22 +64,6 @@ public ScalarQuantizedRandomVectorScorer( this.values = values; } - public ScalarQuantizedRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - ScalarQuantizer scalarQuantizer, - RandomAccessQuantizedByteVectorValues values, - float[] query) { - super(values); - byte[] quantizedQuery = new byte[query.length]; - float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer); - this.quantizedQuery = quantizedQuery; - this.queryOffset = correction; - this.similarity = - ScalarQuantizedVectorSimilarity.fromVectorSimilarity( - similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); - this.values = values; - } - @Override public float score(int node) throws IOException { byte[] storedVectorValue = values.vectorValue(node); diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java index b3b1d4adee37..baf89df326db 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizedRandomVectorScorerSupplier.java @@ -30,6 +30,7 @@ public class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorSc private final RandomAccessQuantizedByteVectorValues values; private final ScalarQuantizedVectorSimilarity similarity; + private final VectorSimilarityFunction vectorSimilarityFunction; public ScalarQuantizedRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, @@ -39,12 +40,16 @@ public ScalarQuantizedRandomVectorScorerSupplier( ScalarQuantizedVectorSimilarity.fromVectorSimilarity( similarityFunction, scalarQuantizer.getConstantMultiplier(), scalarQuantizer.getBits()); this.values = values; + this.vectorSimilarityFunction = similarityFunction; } private ScalarQuantizedRandomVectorScorerSupplier( - ScalarQuantizedVectorSimilarity similarity, RandomAccessQuantizedByteVectorValues values) { + ScalarQuantizedVectorSimilarity similarity, + VectorSimilarityFunction vectorSimilarityFunction, + RandomAccessQuantizedByteVectorValues values) { this.similarity = similarity; this.values = values; + this.vectorSimilarityFunction = vectorSimilarityFunction; } @Override @@ -57,6 +62,7 @@ public RandomVectorScorer scorer(int ord) throws IOException { @Override public RandomVectorScorerSupplier copy() throws IOException { - return new ScalarQuantizedRandomVectorScorerSupplier(similarity, values.copy()); + return new ScalarQuantizedRandomVectorScorerSupplier( + similarity, vectorSimilarityFunction, values.copy()); } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 65ef03d2ef7c..fb8ffe369f4b 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -182,7 +182,7 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedString = - "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, rawVectorFormat=Lucene99FlatVectorsFormat()))"; + "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer())))"; assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java index 493b2cd5b921..0f84f8ab4aeb 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java @@ -38,7 +38,7 @@ public KnnVectorsFormat knnVectorsFormat() { } }; String expectedString = - "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat())"; + "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=DefaultFlatVectorScorer()))"; assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 2546115ff4fe..a37bb4a4dc05 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -42,6 +42,7 @@ import org.apache.lucene.index.QueryTimeout; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.TopKnnCollectorManager; @@ -1031,8 +1032,8 @@ public void testSameFieldDifferentFormats() throws IOException { try (Directory directory = newDirectory()) { MockAnalyzer mockAnalyzer = new MockAnalyzer(random()); IndexWriterConfig iwc = newIndexWriterConfig(mockAnalyzer); - KnnVectorsFormat format1 = randomVectorFormat(); - KnnVectorsFormat format2 = randomVectorFormat(); + KnnVectorsFormat format1 = randomVectorFormat(VectorEncoding.FLOAT32); + KnnVectorsFormat format2 = randomVectorFormat(VectorEncoding.FLOAT32); iwc.setCodec( new AssertingCodec() { @Override diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java index 166ee00dcf7c..54de3919b516 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java @@ -22,7 +22,7 @@ import java.io.IOException; import org.apache.lucene.util.BytesRef; -abstract class AbstractMockVectorValues implements RandomAccessVectorValues { +abstract class AbstractMockVectorValues implements RandomAccessVectorValues { protected final int dimension; protected final T[] denseValues; @@ -52,7 +52,6 @@ public int dimension() { return dimension; } - @Override public T vectorValue(int targetOrd) { return denseValues[targetOrd]; } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 1b25f74ffad9..4770bdf98abf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -40,6 +40,7 @@ import java.util.stream.Collectors; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; @@ -87,6 +88,7 @@ abstract class HnswGraphTestCase extends LuceneTestCase { VectorSimilarityFunction similarityFunction; + DefaultFlatVectorScorer flatVectorScorer = new DefaultFlatVectorScorer(); abstract VectorEncoding getVectorEncoding(); @@ -109,30 +111,23 @@ abstract AbstractMockVectorValues vectorValues( abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); - abstract RandomAccessVectorValues circularVectorValues(int nDoc); + abstract RandomAccessVectorValues circularVectorValues(int nDoc); abstract T getTargetVector(); - @SuppressWarnings("unchecked") - protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors) + protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors) throws IOException { - return switch (getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( - (RandomAccessVectorValues) vectors, similarityFunction); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( - (RandomAccessVectorValues) vectors, similarityFunction); - }; + return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors); } - @SuppressWarnings("unchecked") - protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query) + protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query) throws IOException { - RandomAccessVectorValues vectorsCopy = vectors.copy(); + RandomAccessVectorValues vectorsCopy = vectors.copy(); return switch (getVectorEncoding()) { - case BYTE -> RandomVectorScorer.createBytes( - (RandomAccessVectorValues) vectorsCopy, similarityFunction, (byte[]) query); - case FLOAT32 -> RandomVectorScorer.createFloats( - (RandomAccessVectorValues) vectorsCopy, similarityFunction, (float[]) query); + case BYTE -> flatVectorScorer.getRandomVectorScorer( + similarityFunction, vectorsCopy, (byte[]) query); + case FLOAT32 -> flatVectorScorer.getRandomVectorScorer( + similarityFunction, vectorsCopy, (float[]) query); }; } @@ -464,7 +459,7 @@ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { public void testAknnDiverse() throws IOException { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + RandomAccessVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -496,7 +491,7 @@ public void testAknnDiverse() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -521,7 +516,7 @@ public void testSearchWithAcceptOrds() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithSelectiveAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + RandomAccessVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -714,7 +709,7 @@ private int[] createOffsetOrdinalMap( public void testVisitedLimit() throws IOException { int nDoc = 500; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + RandomAccessVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -749,7 +744,7 @@ public void testRamUsageEstimate() throws IOException { int M = randomIntBetween(4, 96); similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); - RandomAccessVectorValues vectors = vectorValues(size, dim); + RandomAccessVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = @@ -1078,7 +1073,7 @@ private int computeOverlap(int[] a, int[] b) { /** Returns vectors evenly distributed around the upper unit semicircle. */ static class CircularFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Floats { private final int size; private final float[] value; @@ -1137,7 +1132,7 @@ public float[] vectorValue(int ord) { /** Returns vectors evenly distributed around the upper unit semicircle. */ static class CircularByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues { + implements RandomAccessVectorValues.Bytes { private final int size; private final float[] value; private final byte[] bValue; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index 453948924ed4..fdc2566c022e 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -20,7 +20,8 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockByteVectorValues extends AbstractMockVectorValues { +class MockByteVectorValues extends AbstractMockVectorValues + implements RandomAccessVectorValues.Bytes { private final byte[] scratch; static MockByteVectorValues fromValues(byte[][] values) { @@ -55,6 +56,11 @@ public MockByteVectorValues copy() { numVectors); } + @Override + public byte[] vectorValue(int ord) { + return values[ord]; + } + @Override public byte[] vectorValue() { if (LuceneTestCase.random().nextBoolean()) { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index 0892195d5eb2..c10f80e20a85 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -20,7 +20,8 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockVectorValues extends AbstractMockVectorValues { +class MockVectorValues extends AbstractMockVectorValues + implements RandomAccessVectorValues.Floats { private final float[] scratch; static MockVectorValues fromValues(float[][] values) { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 3a2d92ff92a6..649bc1a64519 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -132,7 +132,7 @@ Field knnVectorField(String name, byte[] vector, VectorSimilarityFunction simila } @Override - RandomAccessVectorValues circularVectorValues(int nDoc) { + CircularByteVectorValues circularVectorValues(int nDoc) { return new CircularByteVectorValues(nDoc); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 5cab4cf3256b..5621edc4b35e 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -117,7 +117,7 @@ Field knnVectorField(String name, float[] vector, VectorSimilarityFunction simil } @Override - RandomAccessVectorValues circularVectorValues(int nDoc) { + CircularFloatVectorValues circularVectorValues(int nDoc) { return new CircularFloatVectorValues(nDoc); } @@ -129,7 +129,7 @@ float[] getTargetVector() { public void testSearchWithSkewedAcceptOrds() throws IOException { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 73ec496b5629..501e2e5616f0 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -18,11 +18,11 @@ package org.apache.lucene.tests.codecs.asserting; import java.io.IOException; -import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseIndexFileFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseIndexFileFormatTestCase.java index 2dcb713128e6..10e748363d73 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseIndexFileFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseIndexFileFormatTestCase.java @@ -101,7 +101,7 @@ import org.apache.lucene.util.Version; /** Common tests to all index formats. */ -abstract class BaseIndexFileFormatTestCase extends LuceneTestCase { +public abstract class BaseIndexFileFormatTestCase extends LuceneTestCase { private static final IndexWriterAccess INDEX_WRITER_ACCESS = TestSecrets.getIndexWriterAccess(); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index b6e3da77224f..4fb5e95247d5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1244,7 +1244,7 @@ private void add( iw.updateDocument(idTerm, doc); } - protected float[] randomVector(int dim) { + public static float[] randomVector(int dim) { assert dim > 0; float[] v = new float[dim]; double squareSum = 0.0; @@ -1259,13 +1259,13 @@ protected float[] randomVector(int dim) { return v; } - protected float[] randomNormalizedVector(int dim) { + public static float[] randomNormalizedVector(int dim) { float[] v = randomVector(dim); VectorUtil.l2normalize(v); return v; } - protected byte[] randomVector8(int dim) { + public static byte[] randomVector8(int dim) { assert dim > 0; float[] v = randomNormalizedVector(dim); byte[] b = new byte[dim]; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java index 9b96eb1ca037..594a19cc66a9 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/LuceneTestCase.java @@ -102,6 +102,7 @@ import junit.framework.AssertionFailedError; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; @@ -152,6 +153,7 @@ import org.apache.lucene.index.TermsEnum; import org.apache.lucene.index.TermsEnum.SeekStatus; import org.apache.lucene.index.TieredMergePolicy; +import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.internal.tests.IndexPackageAccess; import org.apache.lucene.internal.tests.TestSecrets; import org.apache.lucene.search.DocIdSetIterator; @@ -3213,11 +3215,17 @@ public static BytesRef newBytesRef(byte[] bytesIn, int offset, int length) { return it; } - protected KnnVectorsFormat randomVectorFormat() { + protected KnnVectorsFormat randomVectorFormat(VectorEncoding vectorEncoding) { ServiceLoader formats = java.util.ServiceLoader.load(KnnVectorsFormat.class); List availableFormats = new ArrayList<>(); for (KnnVectorsFormat f : formats) { - availableFormats.add(f); + if (f.getName().equals(HnswBitVectorsFormat.NAME)) { + if (vectorEncoding.equals(VectorEncoding.BYTE)) { + availableFormats.add(f); + } + } else { + availableFormats.add(f); + } } return RandomPicks.randomFrom(random(), availableFormats); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java index ddd85a68d562..955665544bcc 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/util/hnsw/HnswTestUtil.java @@ -23,7 +23,7 @@ import java.util.ArrayList; import java.util.Deque; import java.util.List; -import org.apache.lucene.codecs.HnswGraphProvider; +import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.FilterLeafReader;