Skip to content

Commit

Permalink
Refactor match rewriting
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 12, 2024
1 parent 1d3151a commit ef1b8f1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField;

public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -37,58 +33,47 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}
SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
matchQueryBuilder.fieldName(),
context.getResolvedIndices()
);

if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
return rewritten;
} else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) {
// Combined semantic_text and other text fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
semanticTextIndexInformationForField.semanticMappedIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
matchQueryBuilder
)
);
rewritten = boolQueryBuilder;
} else {
// Only semantic_text fields
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}

return rewritten;

}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
private QueryBuilder createSemanticSubQuery(List<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
return boolQueryBuilder;
}

private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

public class SemanticQueryInterceptionUtils {


private SemanticQueryInterceptionUtils() {}

public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
Expand All @@ -44,13 +43,6 @@ public static SemanticTextIndexInformationForField resolveIndicesForField(String
return null;
}

public static QueryBuilder createSemanticSubQueryForIndices(List<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

public static QueryBuilder createSubQueryForIndices(List<String> indices, QueryBuilder queryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilder);
Expand All @@ -59,5 +51,4 @@ public static QueryBuilder createSubQueryForIndices(List<String> indices, QueryB
}

public record SemanticTextIndexInformationForField(List<String> semanticMappedIndices, List<String> otherIndices) {}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField;

public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -32,8 +34,10 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField semanticTextIndexInformationForField =
SemanticQueryInterceptionUtils.resolveIndicesForField(sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices());
SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
sparseVectorQueryBuilder.getFieldName(),
context.getResolvedIndices()
);

if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
Expand Down

0 comments on commit ef1b8f1

Please sign in to comment.