Skip to content

Commit

Permalink
Add SemanticSparseVectorQueryRewriteInterceptor
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 12, 2024
1 parent f6f57bb commit f676481
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ public List<WeightedToken> getQueryVectors() {
return queryVectors;
}

public String getInferenceId() {
return inferenceId;
}

public String getQuery() {
return query;
}

public boolean shouldPruneTokens() {
return shouldPruneTokens;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.queries.SemanticSparseVectorQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.rank.random.RandomRankBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankBuilder;
Expand Down Expand Up @@ -432,7 +433,7 @@ public List<QuerySpec<?>> getQueries() {

@Override
public List<QueryRewriteInterceptor> getQueryRewriteInterceptors() {
return List.of(new SemanticMatchQueryRewriteInterceptor());
return List.of(new SemanticMatchQueryRewriteInterceptor(), new SemanticSparseVectorQueryRewriteInterceptor());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

public class SemanticQueryInterceptionUtils {


private SemanticQueryInterceptionUtils() {}

public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}

return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices);
}
return null;
}

public static QueryBuilder createSemanticSubQueryForIndices(List<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

public static QueryBuilder createSubQueryForIndices(List<String> indices, QueryBuilder queryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

public record SemanticTextIndexInformationForField(List<String> semanticMappedIndices, List<String> otherIndices) {}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.queries;

import org.apache.lucene.search.join.ScoreMode;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;

public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_sparse_vector_query_rewrite_interception_supported"
);

private static final String NESTED_FIELD_PATH = ".inference.chunks";
private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings";

public SemanticSparseVectorQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField semanticTextIndexInformationForField =
SemanticQueryInterceptionUtils.resolveIndicesForField(sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices());

if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
return rewritten;
} else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) {
// Combined semantic and sparse vector fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
// sparse_vector fields should be passed in as their own clause
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
sparseVectorQueryBuilder
)
)
);
// semantic text fields should be passed in as nested sub queries
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.semanticMappedIndices(),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder)
)

);

rewritten = boolQueryBuilder;
} else {
// Only semantic text fields
rewritten = buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder);
}

return rewritten;
}

private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuilder sparseVectorQueryBuilder) {
return QueryBuilders.nestedQuery(
getNestedFieldPath(sparseVectorQueryBuilder.getFieldName()),
new SparseVectorQueryBuilder(
getNestedEmbeddingsField(sparseVectorQueryBuilder.getFieldName()),
sparseVectorQueryBuilder.getQueryVectors(),
sparseVectorQueryBuilder.getInferenceId(),
sparseVectorQueryBuilder.getQuery(),
sparseVectorQueryBuilder.shouldPruneTokens(),
sparseVectorQueryBuilder.getTokenPruningConfig()
),
ScoreMode.Max
);
}

private static String getNestedFieldPath(String fieldName) {
return fieldName + NESTED_FIELD_PATH;
}

private static String getNestedEmbeddingsField(String fieldName) {
return fieldName + NESTED_EMBEDDINGS_FIELD;
}

@Override
public String getQueryName() {
return SparseVectorQueryBuilder.NAME;
}
}

0 comments on commit f676481

Please sign in to comment.