Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Load vector data directly from the memory segment #12703

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ public int size() {
return dictionarySize;
}

@Override
public int byteSize() {
return vectorDimension * Float.BYTES;
}

@Override
public RandomAccessVectorValues<float[]> copy() throws IOException {
return new Word2VecModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ public int size() {
return 0;
}

@Override
public int byteSize() {
return 0;
}

@Override
public byte[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,8 @@ public float[] copyValue(float[] value) {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RandomAccessVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RandomAccessVectorValues<T> raVectors =
new RAVectorValues<>(vectors, dim, fieldInfo.getVectorEncoding().byteSize);
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
Expand Down Expand Up @@ -712,17 +713,24 @@ public long ramBytesUsed() {
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
private final int byteSize;

RAVectorValues(List<T> vectors, int dim) {
RAVectorValues(List<T> vectors, int dim, int elementSize) {
this.vectors = vectors;
this.dim = dim;
this.byteSize = dim * elementSize;
}

@Override
public int size() {
return vectors.size();
}

@Override
public int byteSize() {
return byteSize;
}

@Override
public int dimension() {
return dim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,20 @@
*/
package org.apache.lucene.benchmark.jmh;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.openjdk.jmh.annotations.*;

@BenchmarkMode(Mode.Throughput)
Expand All @@ -33,11 +44,15 @@ public class VectorUtilBenchmark {
private float[] floatsA;
private float[] floatsB;

private Directory dir;
private RandomAccessVectorValues<float[]> floatValuesA;
private RandomAccessVectorValues<float[]> floatValuesB;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;

@Setup(Level.Trial)
public void init() {
public void init() throws IOException {
ThreadLocalRandom random = ThreadLocalRandom.current();

// random byte arrays for binary methods
Expand All @@ -53,6 +68,14 @@ public void init() {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}

dir = new MMapDirectory(Files.createTempDirectory("benchmark-floats"));
var aIndex = inputForFloats(floatsA, 0, 1, dir, "vector-a");
var bIndex = inputForFloats(floatsB, 0, 3, dir, "vector-b");
floatValuesA =
newDenseOffHeapFloatVectorValues(floatsA.length, 1, aIndex, floatsA.length * Float.BYTES);
floatValuesB =
newDenseOffHeapFloatVectorValues(floatsB.length, 1, bIndex, floatsB.length * Float.BYTES);
}

@Benchmark
Expand Down Expand Up @@ -121,8 +144,25 @@ public float floatDotProductScalar() {
@Fork(
value = 1,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatDotProductVector() {
return VectorUtil.dotProduct(floatsA, floatsB);
public float floatDotProductVector() throws IOException {
// REVERT/REMOVE - this change is required to ensure fair comparison with MS1 and MS2
return VectorUtil.dotProduct(floatValuesA.vectorValue(0), floatValuesB.vectorValue(0));
}

@Benchmark
@Fork(
value = 1,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatDotProductVectorMS1() throws IOException {
return VectorUtil.dotProduct(floatValuesA.vectorValue(0), floatValuesB, 0);
}

@Benchmark
@Fork(
value = 1,
jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public float floatDotProductVectorMS2() throws IOException {
return VectorUtil.dotProduct(floatValuesA, 0, floatValuesB, 0);
}

@Benchmark
Expand All @@ -138,4 +178,35 @@ public float floatSquareScalar() {
public float floatSquareVector() {
return VectorUtil.squareDistance(floatsA, floatsB);
}

// ---

public static RandomAccessVectorValues<float[]> newDenseOffHeapFloatVectorValues(
int dimension, int size, IndexInput slice, int byteSize) {
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dimension, size, slice, byteSize);
}

static IndexInput inputForFloats(
float[] vector, int vectorPosition, int initialBytes, Directory dir, String name)
throws IOException {
int vectorLenBytes = vector.length * Float.BYTES;
try (var out = dir.createOutput(name + ".data", IOContext.DEFAULT)) {
if (initialBytes != 0) {
out.writeBytes(new byte[initialBytes], initialBytes);
}
if (vectorPosition != 0) {
out.writeBytes(new byte[vectorPosition * vectorLenBytes], vectorPosition * vectorLenBytes);
}
writeFloat32(vector, out);
}
var in = dir.openInput(name + ".data", IOContext.DEFAULT);
return in.slice(name, initialBytes, (long) (vectorPosition + 1) * vectorLenBytes);
}

static void writeFloat32(float[] arr, IndexOutput out) throws IOException {
int lenBytes = arr.length * Float.BYTES;
final ByteBuffer buffer = ByteBuffer.allocate(lenBytes).order(ByteOrder.LITTLE_ENDIAN);
buffer.asFloatBuffer().put(arr);
out.writeBytes(buffer.array(), lenBytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,8 @@ public float[] copyValue(float[] value) {
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>();
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RAVectorValues<T> raVectors =
new RAVectorValues<>(vectors, dim, fieldInfo.getVectorEncoding().byteSize);
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
Expand Down Expand Up @@ -731,17 +732,24 @@ public long ramBytesUsed() {
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
private final int byteSize;

RAVectorValues(List<T> vectors, int dim) {
RAVectorValues(List<T> vectors, int dim, int elementSize) {
this.vectors = vectors;
this.dim = dim;
this.byteSize = dim * elementSize;
}

@Override
public int size() {
return vectors.size();
}

@Override
public int byteSize() {
return byteSize;
}

@Override
public int dimension() {
return dim;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public byte[] vectorValue(int targetOrd) throws IOException {
return binaryValue;
}

@Override
public IndexInput getIndexInput() {
return slice;
}

private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,21 @@ public int size() {

@Override
public float[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return value;
}
// REVERT/REMOVE - hack for bench, to ensure fair comparison with MS1 and MS2
// if (lastOrd == targetOrd) {
// return value;
// }
slice.seek((long) targetOrd * byteSize);
slice.readFloats(value, 0, value.length);
lastOrd = targetOrd;
return value;
}

@Override
public IndexInput getIndexInput() {
return slice;
}

public static OffHeapFloatVectorValues load(
OrdToDocDISIReaderConfiguration configuration,
VectorEncoding vectorEncoding,
Expand All @@ -90,7 +96,9 @@ public static OffHeapFloatVectorValues load(

abstract Bits getAcceptOrds(Bits acceptDocs);

static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
// REVERT/REMOVE - hack for bench, to ensure fair comparison with MS1 and MS2
/** Stub doc to keep build happy. */
public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {

private int doc = -1;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ protected ByteVectorValues() {}
*/
public abstract int size();

/** Return the number of bytes per individual vector. */
public int byteSize() {
return dimension() * Byte.BYTES;
}

@Override
public final long cost() {
return size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ protected FloatVectorValues() {}
/** Return the dimension of the vectors */
public abstract int dimension();

/** Return the number of bytes per individual vector. */
public int byteSize() {
return dimension() * Float.BYTES;
}

/**
* Return the number of vectors for this field.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore;
import static org.apache.lucene.util.VectorUtil.squareDistance;

import java.io.IOException;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Vector similarity function; used in search to return top K most similar vectors to a target
* vector. This is a label describing the method used during indexing and searching of the vectors
Expand Down Expand Up @@ -55,6 +58,22 @@ public float compare(float[] v1, float[] v2) {
return (1 + dotProduct(v1, v2)) / 2;
}

@Override
public float compare(float[] v1, RandomAccessVectorValues<float[]> v2, int v2TargetOrd)
throws IOException {
return (1 + dotProduct(v1, v2, v2TargetOrd)) / 2;
}

@Override
public float compare(
RandomAccessVectorValues<float[]> v1,
int v1TargetOrd,
RandomAccessVectorValues<float[]> v2,
int v2TargetOrd)
throws IOException {
return (1 + dotProduct(v1, v1TargetOrd, v2, v2TargetOrd)) / 2;
}

@Override
public float compare(byte[] v1, byte[] v2) {
return dotProductScore(v1, v2);
Expand Down Expand Up @@ -106,6 +125,24 @@ public float compare(byte[] v1, byte[] v2) {
*/
public abstract float compare(float[] v1, float[] v2);

/** Calculates a similarity score between ... TODO */
public float compare(float[] v1, RandomAccessVectorValues<float[]> v2, int v2TargetOrd)
throws IOException {
return compare(v1, v2.vectorValue(v2TargetOrd));
}

/** Calculates a similarity score between ... TODO */
public float compare(
RandomAccessVectorValues<float[]> v1,
int v1TargetOrd,
RandomAccessVectorValues<float[]> v2,
int v2TargetOrd)
throws IOException {
return compare(v1.vectorValue(v1TargetOrd), v2.vectorValue(v2TargetOrd));
}

// ^^^ TODO do the same with byte

/**
* Calculates a similarity score between the two vectors with a specified function. Higher
* similarity scores correspond to closer vectors. Each (signed) byte represents a vector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.lucene.internal.vectorization;

import java.io.IOException;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

final class DefaultVectorUtilSupport implements VectorUtilSupport {

DefaultVectorUtilSupport() {}
Expand Down Expand Up @@ -87,6 +90,22 @@ public float dotProduct(float[] a, float[] b) {
return res;
}

@Override
public float dotProduct(float[] a, RandomAccessVectorValues<float[]> b, int bOffset)
throws IOException {
return dotProduct(a, b.vectorValue(bOffset));
}

@Override
public float dotProduct(
RandomAccessVectorValues<float[]> a,
int aOffset,
RandomAccessVectorValues<float[]> b,
int bOffset)
throws IOException {
return dotProduct(a.vectorValue(aOffset), b.vectorValue(bOffset));
}

@Override
public float cosine(float[] a, float[] b) {
float sum = 0.0f;
Expand Down
Loading