From ff7a317a81df18ac944f07d66ef770cfd4ab3b06 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 12 Sep 2024 12:56:46 -0400 Subject: [PATCH] fix case where index is reordered --- .../codecs/BufferingKnnVectorsWriter.java | 2 +- .../lucene/codecs/KnnVectorsWriter.java | 1 + .../SlowCompositeCodecReaderWrapper.java | 164 ++++++++++++------ .../lucene/index/SortingCodecReader.java | 15 +- 4 files changed, 128 insertions(+), 54 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index e9b777fb24e8..95cb1bdf533d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -108,7 +108,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ - private static class SortingFloatVectorValues extends FloatVectorValues { + private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; private final DocIndexIterator iterator; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index dc4544e11637..f0d6639388e3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -338,6 +338,7 @@ public long cost() { @Override public float[] vectorValue(int ord) throws IOException { + assert ord == iterator.index(); return current.values.vectorValue(current.values.iterator().index()); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index fcfd9da83fba..9256f9876022 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -291,7 +291,7 @@ public void checkIntegrity() throws IOException { } } - private record DocValuesSub(T sub, int docStart) {} + private record DocValuesSub(T sub, int docStart, int ordStart) {} private static class MergedDocIterator extends KnnVectorValues.DocIndexIterator { @@ -301,7 +301,6 @@ private static class MergedDocIterator DocValuesSub current; int ord = -1; int doc = -1; - int currentOrdBase = 0; MergedDocIterator(List> subs) { long cost = 0; @@ -336,10 +335,11 @@ public int nextDoc() throws IOException { } } if (it.hasNext() == false) { + ord = NO_MORE_DOCS; return doc = NO_MORE_DOCS; } current = it.next(); - currentOrdBase = ord; + ord = current.ordStart - 1; } } @@ -817,41 +817,65 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { FloatVectorValues values = reader.getFloatVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - return new FloatVectorValues() { - - final MergedDocIterator iter = new MergedDocIterator<>(subs); + return new MergedFloatVectorValues(dimension, size, subs); + } - @Override - public MergedDocIterator iterator() { - return iter; + class MergedFloatVectorValues extends FloatVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedFloatVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public int dimension() { - return finalDimension; - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public int size() { - return finalSize; - } + @Override + public int dimension() { + return dimension; + } - @Override - public float[] vectorValue(int ord) throws IOException { - return iter.current.sub.vectorValue(ord - iter.currentOrdBase); - } - }; + @Override + public int size() { + return size; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that + // rely on it. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((FloatVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } } @Override @@ -862,46 +886,86 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { ByteVectorValues values = reader.getByteVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - return new ByteVectorValues() { - - final MergedDocIterator iter = new MergedDocIterator<>(subs); + return new MergedByteVectorValues(dimension, size, subs); + } - @Override - public MergedDocIterator iterator() { - return iter; + class MergedByteVectorValues extends ByteVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedByteVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public int dimension() { - return finalDimension; - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public int size() { - return finalSize; - } + @Override + public int dimension() { + return dimension; + } - @Override - public byte[] vectorValue(int ord) throws IOException { - return iter.current.sub.vectorValue(ord - iter.currentOrdBase); - } + @Override + public int size() { + return size; + } - @Override - protected DocIndexIterator createIterator() { - return new MergedDocIterator(subs); + @Override + public byte[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some + // repetition. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((ByteVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } + } + + private static int findSub(int ord, int lastSubIndex, int[] starts) { + if (ord >= starts[lastSubIndex]) { + if (ord >= starts[lastSubIndex + 1]) { + return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length); } - }; + } else { + return binarySearchStarts(starts, ord, 0, lastSubIndex); + } + return lastSubIndex; + } + + private static int binarySearchStarts(int[] starts, int ord, int from, int to) { + int pos = Arrays.binarySearch(starts, from, to, ord); + // also subtract one since starts[] is shifted by one + if (pos < 0) { + return -2 - pos; + } else { + return pos - 1; + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index c1699388de1a..31b79f6b0235 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -212,6 +212,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue * index() may skip around, not increasing monotonically as iteration proceeds. */ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator { + private final FixedBitSet docBits; private final DocIdSetIterator docsWithValues; private final int[] docToOrd; private final int size; @@ -222,8 +223,10 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterat public SortingValuesIterator(KnnVectorValues.DocIndexIterator iter, Sorter.DocMap docMap) throws IOException { docToOrd = new int[docMap.size()]; - FixedBitSet docBits = new FixedBitSet(docMap.size()); + docBits = new FixedBitSet(docMap.size()); int count = 0; + // Note: docToOrd will contain zero for docids that have no vector. This is OK though + // because the iterator cannot be positioned on such docs for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { int newDocId = docMap.oldToNew(doc); if (newDocId != -1) { @@ -243,6 +246,7 @@ public int docID() { @Override public int index() { + assert docBits.get(doc); return docToOrd[doc]; } @@ -266,12 +270,15 @@ private static class SortingFloatVectorValues extends FloatVectorValues { SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override public float[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(iterator.index()); + // ords are interpreted in the delegate's ord-space. + assert ord == iterator.index(); + return delegate.vectorValue(ord); } @Override @@ -300,12 +307,14 @@ private static class SortingByteVectorValues extends ByteVectorValues { SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override public byte[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(iterator().index()); + assert ord == iterator.index(); + return delegate.vectorValue(ord); } @Override