Skip to content

Commit

Permalink
fix case where index is reordered
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Sokolov committed Sep 12, 2024
1 parent c2ae86b commit ff7a317
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
}

/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
private static class SortingFloatVectorValues extends FloatVectorValues {
private static class SortingFloatVectorValues extends FloatVectorValues {
private final BufferedFloatVectorValues delegate;
private final DocIndexIterator iterator;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,7 @@ public long cost() {

@Override
public float[] vectorValue(int ord) throws IOException {
assert ord == iterator.index();
return current.values.vectorValue(current.values.iterator().index());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ public void checkIntegrity() throws IOException {
}
}

private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart) {}
private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart, int ordStart) {}

private static class MergedDocIterator<T extends KnnVectorValues>
extends KnnVectorValues.DocIndexIterator {
Expand All @@ -301,7 +301,6 @@ private static class MergedDocIterator<T extends KnnVectorValues>
DocValuesSub<T> current;
int ord = -1;
int doc = -1;
int currentOrdBase = 0;

MergedDocIterator(List<DocValuesSub<T>> subs) {
long cost = 0;
Expand Down Expand Up @@ -336,10 +335,11 @@ public int nextDoc() throws IOException {
}
}
if (it.hasNext() == false) {
ord = NO_MORE_DOCS;
return doc = NO_MORE_DOCS;
}
current = it.next();
currentOrdBase = ord;
ord = current.ordStart - 1;
}
}

Expand Down Expand Up @@ -817,41 +817,65 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {
int size = 0;
for (CodecReader reader : codecReaders) {
FloatVectorValues values = reader.getFloatVectorValues(field);
subs.add(new DocValuesSub<>(values, docStarts[i], size));
if (values != null) {
if (dimension == -1) {
dimension = values.dimension();
}
size += values.size();
}
subs.add(new DocValuesSub<>(values, docStarts[i]));
i++;
}
final int finalDimension = dimension;
final int finalSize = size;
return new FloatVectorValues() {

final MergedDocIterator<FloatVectorValues> iter = new MergedDocIterator<>(subs);
return new MergedFloatVectorValues(dimension, size, subs);
}

@Override
public MergedDocIterator<FloatVectorValues> iterator() {
return iter;
class MergedFloatVectorValues extends FloatVectorValues {
final int dimension;
final int size;
final DocValuesSub<?>[] subs;
final MergedDocIterator<FloatVectorValues> iter;
final int[] starts;
int lastSubIndex;

MergedFloatVectorValues(int dimension, int size, List<DocValuesSub<FloatVectorValues>> subs) {
this.dimension = dimension;
this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]);
iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds
starts = new int[subs.size() + 1];
for (int i = 0; i < subs.size(); i++) {
starts[i] = subs.get(i).ordStart;
}
starts[starts.length - 1] = size;
}

@Override
public int dimension() {
return finalDimension;
}
@Override
public MergedDocIterator<FloatVectorValues> iterator() {
return iter;
}

@Override
public int size() {
return finalSize;
}
@Override
public int dimension() {
return dimension;
}

@Override
public float[] vectorValue(int ord) throws IOException {
return iter.current.sub.vectorValue(ord - iter.currentOrdBase);
}
};
@Override
public int size() {
return size;
}

@Override
public float[] vectorValue(int ord) throws IOException {
assert ord >= 0 && ord < size;
// We need to implement fully random-access API here in order to support callers like
// SortingCodecReader that
// rely on it.
lastSubIndex = findSub(ord, lastSubIndex, starts);
return ((FloatVectorValues) subs[lastSubIndex].sub)
.vectorValue(ord - subs[lastSubIndex].ordStart);
}
}

@Override
Expand All @@ -862,46 +886,86 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
int size = 0;
for (CodecReader reader : codecReaders) {
ByteVectorValues values = reader.getByteVectorValues(field);
subs.add(new DocValuesSub<>(values, docStarts[i], size));
if (values != null) {
if (dimension == -1) {
dimension = values.dimension();
}
size += values.size();
}
subs.add(new DocValuesSub<>(values, docStarts[i]));
i++;
}
final int finalDimension = dimension;
final int finalSize = size;
return new ByteVectorValues() {

final MergedDocIterator<ByteVectorValues> iter = new MergedDocIterator<>(subs);
return new MergedByteVectorValues(dimension, size, subs);
}

@Override
public MergedDocIterator<ByteVectorValues> iterator() {
return iter;
class MergedByteVectorValues extends ByteVectorValues {
final int dimension;
final int size;
final DocValuesSub<?>[] subs;
final MergedDocIterator<ByteVectorValues> iter;
final int[] starts;
int lastSubIndex;

MergedByteVectorValues(int dimension, int size, List<DocValuesSub<ByteVectorValues>> subs) {
this.dimension = dimension;
this.size = size;
this.subs = subs.toArray(new DocValuesSub<?>[0]);
iter = new MergedDocIterator<>(subs);
// [0, start(1), ..., size] - we want the extra element
// to avoid checking for out-of-array bounds
starts = new int[subs.size() + 1];
for (int i = 0; i < subs.size(); i++) {
starts[i] = subs.get(i).ordStart;
}
starts[starts.length - 1] = size;
}

@Override
public int dimension() {
return finalDimension;
}
@Override
public MergedDocIterator<ByteVectorValues> iterator() {
return iter;
}

@Override
public int size() {
return finalSize;
}
@Override
public int dimension() {
return dimension;
}

@Override
public byte[] vectorValue(int ord) throws IOException {
return iter.current.sub.vectorValue(ord - iter.currentOrdBase);
}
@Override
public int size() {
return size;
}

@Override
protected DocIndexIterator createIterator() {
return new MergedDocIterator<ByteVectorValues>(subs);
@Override
public byte[] vectorValue(int ord) throws IOException {
assert ord >= 0 && ord < size;
// We need to implement fully random-access API here in order to support callers like
// SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some
// repetition.
lastSubIndex = findSub(ord, lastSubIndex, starts);
return ((ByteVectorValues) subs[lastSubIndex].sub)
.vectorValue(ord - subs[lastSubIndex].ordStart);
}
}

private static int findSub(int ord, int lastSubIndex, int[] starts) {
if (ord >= starts[lastSubIndex]) {
if (ord >= starts[lastSubIndex + 1]) {
return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length);
}
};
} else {
return binarySearchStarts(starts, ord, 0, lastSubIndex);
}
return lastSubIndex;
}

private static int binarySearchStarts(int[] starts, int ord, int from, int to) {
int pos = Arrays.binarySearch(starts, from, to, ord);
// also subtract one since starts[] is shifted by one
if (pos < 0) {
return -2 - pos;
} else {
return pos - 1;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue
* index() may skip around, not increasing monotonically as iteration proceeds.
*/
public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator {
private final FixedBitSet docBits;
private final DocIdSetIterator docsWithValues;
private final int[] docToOrd;
private final int size;
Expand All @@ -222,8 +223,10 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterat
public SortingValuesIterator(KnnVectorValues.DocIndexIterator iter, Sorter.DocMap docMap)
throws IOException {
docToOrd = new int[docMap.size()];
FixedBitSet docBits = new FixedBitSet(docMap.size());
docBits = new FixedBitSet(docMap.size());
int count = 0;
// Note: docToOrd will contain zero for docids that have no vector. This is OK though
// because the iterator cannot be positioned on such docs
for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) {
int newDocId = docMap.oldToNew(doc);
if (newDocId != -1) {
Expand All @@ -243,6 +246,7 @@ public int docID() {

@Override
public int index() {
assert docBits.get(doc);
return docToOrd[doc];
}

Expand All @@ -266,12 +270,15 @@ private static class SortingFloatVectorValues extends FloatVectorValues {

SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
iterator = new SortingValuesIterator(delegate.iterator(), sortMap);
}

@Override
public float[] vectorValue(int ord) throws IOException {
return delegate.vectorValue(iterator.index());
// ords are interpreted in the delegate's ord-space.
assert ord == iterator.index();
return delegate.vectorValue(ord);
}

@Override
Expand Down Expand Up @@ -300,12 +307,14 @@ private static class SortingByteVectorValues extends ByteVectorValues {

SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
iterator = new SortingValuesIterator(delegate.iterator(), sortMap);
}

@Override
public byte[] vectorValue(int ord) throws IOException {
return delegate.vectorValue(iterator().index());
assert ord == iterator.index();
return delegate.vectorValue(ord);
}

@Override
Expand Down

0 comments on commit ff7a317

Please sign in to comment.