Skip to content

Commit

Permalink
Integration with base OpenSearch 2.0 (#328)
Browse files Browse the repository at this point in the history
* Initial integration with OS 2.0 alpha1 snapshot

Signed-off-by: Martin Gaievski <gaievski@amazon.com>
  • Loading branch information
martin-gaievski authored Mar 28, 2022
1 parent 8203524 commit 6121c27
Show file tree
Hide file tree
Showing 29 changed files with 544 additions and 252 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Run build
run: |
./gradlew build -Dopensearch.version=2.0.0-SNAPSHOT
./gradlew build -Dopensearch.version=2.0.0-alpha1-SNAPSHOT
- name: Run k-NN Backwards Compatibility Tests
run: |
Expand Down
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ buildscript {
ext {
// build.version_qualifier parameter applies to knn plugin artifacts only. OpenSearch version must be set
// explicitly as 'opensearch.version' property, for instance opensearch.version=2.0.0-alpha1-SNAPSHOT
opensearch_version = System.getProperty("opensearch.version", "2.0.0-SNAPSHOT")
opensearch_version = System.getProperty("opensearch.version", "2.0.0-alpha1-SNAPSHOT")
knn_bwc_version = System.getProperty("bwc.version", "1.2.0.0-SNAPSHOT")
version_qualifier = System.getProperty("build.version_qualifier", "")
version_qualifier = System.getProperty("build.version_qualifier", "alpha1")
opensearch_bwc_version = "${knn_bwc_version}" - ".0-SNAPSHOT"
opensearch_group = "org.opensearch"
}
Expand Down
13 changes: 10 additions & 3 deletions src/main/java/org/opensearch/knn/index/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

Expand Down Expand Up @@ -41,7 +42,9 @@ public int getK() {
return this.k;
}

public String getIndexName() { return this.indexName; }
public String getIndexName() {
return this.indexName;
}

/**
* Constructs Weight implementation for this query
Expand All @@ -59,6 +62,11 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return new KNNWeight(this, boost);
}

@Override
public void visit(QueryVisitor visitor) {

}

@Override
public String toString(String field) {
return field;
Expand All @@ -71,8 +79,7 @@ public int hashCode() {

@Override
public boolean equals(Object other) {
return sameClassAs(other) &&
equalsTo(getClass().cast(other));
return sameClassAs(other) && equalsTo(getClass().cast(other));
}

private boolean equalsTo(KNNQuery other) {
Expand Down
216 changes: 107 additions & 109 deletions src/main/java/org/opensearch/knn/index/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.opensearch.knn.index;

import com.google.common.collect.ImmutableMap;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
Expand All @@ -19,7 +18,6 @@
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SegmentReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.Scorer;
Expand All @@ -36,10 +34,8 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -77,114 +73,118 @@ public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
}

@Override
public void extractTerms(Set<Term> terms) {
}

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(),
reader.getSegmentName());
return null;
}

KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName());
knnEngine = KNNEngine.getEngine(engineName);
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
spaceType = SpaceType.getSpace(spaceTypeName);
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
*/
String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION : knnEngine.getExtension();
String engineSuffix = knnQuery.getField() + engineExtension;
List<String> engineFiles = reader.getSegmentInfo().files().stream()
.filter(fileName -> fileName.endsWith(engineSuffix))
.collect(Collectors.toList());

if(engineFiles.isEmpty()) {
logger.debug("[KNN] No engine index found for field {} for segment {}",
knnQuery.getField(), reader.getSegmentName());
return null;
}

Path indexPath = PathUtils.get(directory, engineFiles.get(0));
final KNNQueryResult[] results;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
try {
indexAllocation = nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
indexPath.toString(),
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()),
knnQuery.getIndexName()
), true);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), knnQuery.getQueryVector(), knnQuery.getK(), knnEngine.getName());
} catch (Exception e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
} finally {
indexAllocation.readUnlock();
SegmentReader reader = (SegmentReader) FilterLeafReader.unwrap(context.reader());
String directory = ((FSDirectory) FilterDirectory.unwrap(reader.directory())).getDirectory().toString();

FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField());

if (fieldInfo == null) {
logger.debug("[KNN] Field info not found for {}:{}", knnQuery.getField(), reader.getSegmentName());
return null;
}

KNNEngine knnEngine;
SpaceType spaceType;

// Check if a modelId exists. If so, the space type and engine will need to be picked up from the model's
// metadata.
String modelId = fieldInfo.getAttribute(MODEL_ID);
if (modelId != null) {
ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
if (modelMetadata == null) {
throw new RuntimeException("Model \"" + modelId + "\" does not exist.");
}

/*
* Scores represent the distance of the documents with respect to given query vector.
* Lesser the score, the closer the document is to the query vector.
* Since by default results are retrieved in the descending order of scores, to get the nearest
* neighbors we are inverting the scores.
*/
if (results.length == 0) {
logger.debug("[KNN] Query yielded 0 results");
return null;
knnEngine = modelMetadata.getKnnEngine();
spaceType = modelMetadata.getSpaceType();
} else {
String engineName = fieldInfo.attributes().getOrDefault(KNN_ENGINE, KNNEngine.NMSLIB.getName());
knnEngine = KNNEngine.getEngine(engineName);
String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue());
spaceType = SpaceType.getSpace(spaceTypeName);
}

/*
* In case of compound file, extension would be <engine-extension> + c otherwise <engine-extension>
*/
String engineExtension = reader.getSegmentInfo().info.getUseCompoundFile()
? knnEngine.getExtension() + KNNConstants.COMPOUND_EXTENSION
: knnEngine.getExtension();
String engineSuffix = knnQuery.getField() + engineExtension;
List<String> engineFiles = reader.getSegmentInfo()
.files()
.stream()
.filter(fileName -> fileName.endsWith(engineSuffix))
.collect(Collectors.toList());

if (engineFiles.isEmpty()) {
logger.debug("[KNN] No engine index found for field {} for segment {}", knnQuery.getField(), reader.getSegmentName());
return null;
}

Path indexPath = PathUtils.get(directory, engineFiles.get(0));
final KNNQueryResult[] results;
KNNCounter.GRAPH_QUERY_REQUESTS.increment();

// We need to first get index allocation
NativeMemoryAllocation indexAllocation;
try {
indexAllocation = nativeMemoryCacheManager.get(
new NativeMemoryEntryContext.IndexEntryContext(
indexPath.toString(),
NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(),
getParametersAtLoading(spaceType, knnEngine, knnQuery.getIndexName()),
knnQuery.getIndexName()
),
true
);
} catch (ExecutionException e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
}

// Now that we have the allocation, we need to readLock it
indexAllocation.readLock();

try {
if (indexAllocation.isClosed()) {
throw new RuntimeException("Index has already been closed");
}

Map<Integer, Float> scores = Arrays.stream(results).collect(
Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
results = JNIService.queryIndex(
indexAllocation.getMemoryAddress(),
knnQuery.getQueryVector(),
knnQuery.getK(),
knnEngine.getName()
);
} catch (Exception e) {
GRAPH_QUERY_ERRORS.increment();
throw new RuntimeException(e);
} finally {
indexAllocation.readUnlock();
}

/*
* Scores represent the distance of the documents with respect to given query vector.
* Lesser the score, the closer the document is to the query vector.
* Since by default results are retrieved in the descending order of scores, to get the nearest
* neighbors we are inverting the scores.
*/
if (results.length == 0) {
logger.debug("[KNN] Query yielded 0 results");
return null;
}

Map<Integer, Float> scores = Arrays.stream(results)
.collect(Collectors.toMap(KNNQueryResult::getId, result -> knnEngine.score(result.getScore(), spaceType)));
int maxDoc = Collections.max(scores.keySet()) + 1;
DocIdSetBuilder docIdSetBuilder = new DocIdSetBuilder(maxDoc);
DocIdSetBuilder.BulkAdder setAdder = docIdSetBuilder.grow(maxDoc);
Arrays.stream(results).forEach(result -> setAdder.add(result.getId()));
DocIdSetIterator docIdSetIter = docIdSetBuilder.build().iterator();
return new KNNScorer(this, docIdSetIter, scores, boost);
}

@Override
Expand All @@ -193,9 +193,7 @@ public boolean isCacheable(LeafReaderContext context) {
}

public static float normalizeScore(float score) {
if (score >= 0)
return 1 / (1 + score);
if (score >= 0) return 1 / (1 + score);
return -score + 1;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.lucene.codecs.CompoundFormat;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.FieldInfosFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.LiveDocsFormat;
import org.apache.lucene.codecs.NormsFormat;
import org.apache.lucene.codecs.PointsFormat;
Expand Down Expand Up @@ -52,8 +53,7 @@ public DocValuesFormat getDocValuesFormatForField(String field) {
* This function returns the Lucene80 Codec.
*/
public Codec getDelegatee() {
if (lucene80Codec == null)
lucene80Codec = Codec.forName(LUCENE_80);
if (lucene80Codec == null) lucene80Codec = Codec.forName(LUCENE_80);
return lucene80Codec;
}

Expand Down Expand Up @@ -112,4 +112,9 @@ public CompoundFormat compoundFormat() {
public PointsFormat pointsFormat() {
return getDelegatee().pointsFormat();
}

@Override
public KnnVectorsFormat knnVectorsFormat() {
throw new UnsupportedOperationException("Codec does not support knn vector format");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.KNN80Codec;

import org.apache.lucene.codecs.lucene50.Lucene50CompoundFormat;
import org.apache.lucene.backward_codecs.lucene50.Lucene50CompoundFormat;
import org.opensearch.knn.common.KNNConstants;
import org.apache.lucene.codecs.CompoundDirectory;
import org.apache.lucene.codecs.CompoundFormat;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesFormat;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.lucene80.Lucene80DocValuesFormat;
import org.apache.lucene.backward_codecs.lucene80.Lucene80DocValuesFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;

Expand Down
Loading

0 comments on commit 6121c27

Please sign in to comment.