Skip to content

Commit

Permalink
update flat vectors scorer to use only two vector dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisHegarty committed Nov 8, 2024
1 parent 4284360 commit f1e0007
Showing 1 changed file with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,24 @@ public String toString() {
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
private final ByteVectorValues vectorValues;
private final VectorSimilarityFunction similarityFunction;
private final ByteVectorValues.Bytes queryVectors;
private final ByteVectorValues.Bytes targetVectors;

private ByteScoringSupplier(
ByteVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectorValues = vectorValues;
this.similarityFunction = similarityFunction;
this.queryVectors = vectorValues.vectors();
this.targetVectors = vectorValues.vectors();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
byte[] query = vectorValues.vectors().get(ord);
ByteVectorValues.Bytes vectors = vectorValues.vectors();
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.get(node));
return similarityFunction.compare(queryVectors.get(ord), targetVectors.get(node));
}
};
}
Expand All @@ -120,22 +122,24 @@ public String toString() {
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
private final FloatVectorValues vectorValues;
private final VectorSimilarityFunction similarityFunction;
private final FloatVectorValues.Floats queryVectors;
private final FloatVectorValues.Floats targetVectors;

private FloatScoringSupplier(
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
this.vectorValues = vectorValues;
this.similarityFunction = similarityFunction;
this.queryVectors = vectorValues.vectors();
this.targetVectors = vectorValues.vectors();
}

@Override
public RandomVectorScorer scorer(int ord) throws IOException {
float[] query = vectorValues.vectors().get(ord);
FloatVectorValues.Floats vectors = vectorValues.vectors();
public RandomVectorScorer scorer(int ord) {
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorValues) {
@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.get(node));
return similarityFunction.compare(queryVectors.get(ord), targetVectors.get(node));
}
};
}
Expand All @@ -148,43 +152,43 @@ public String toString() {

/** A {@link RandomVectorScorer} for float vectors. */
private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final FloatVectorValues.Floats vectors;
private final float[] query;
private final VectorSimilarityFunction similarityFunction;
private final FloatVectorValues.Floats targetVectors;

public FloatVectorScorer(
FloatVectorValues vectorValues, float[] query, VectorSimilarityFunction similarityFunction)
throws IOException {
super(vectorValues);
this.vectors = vectorValues.vectors();
this.query = query;
this.similarityFunction = similarityFunction;
this.targetVectors = vectorValues.vectors();
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.get(node));
return similarityFunction.compare(query, targetVectors.get(node));
}
}

/** A {@link RandomVectorScorer} for byte vectors. */
private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
private final ByteVectorValues.Bytes vectors;
private final byte[] query;
private final VectorSimilarityFunction similarityFunction;
private final ByteVectorValues.Bytes targetVectors;

public ByteVectorScorer(
ByteVectorValues vectorValues, byte[] query, VectorSimilarityFunction similarityFunction)
throws IOException {
super(vectorValues);
vectors = vectorValues.vectors();
this.query = query;
this.similarityFunction = similarityFunction;
targetVectors = vectorValues.vectors();
}

@Override
public float score(int node) throws IOException {
return similarityFunction.compare(query, vectors.get(node));
return similarityFunction.compare(query, targetVectors.get(node));
}
}
}

0 comments on commit f1e0007

Please sign in to comment.