Skip to content

Commit

Permalink
Refactor match rewriting & cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 12, 2024
1 parent 1d3151a commit c427143
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.features.FeatureSpecification;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor;
import org.elasticsearch.xpack.inference.queries.SemanticQueryBuilder;
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ public record SemanticTextField(String fieldName, List<String> originalValues, I
ToXContentObject {

static final String TEXT_FIELD = "text";
static final String INFERENCE_FIELD = "inference";
public static final String INFERENCE_FIELD = "inference";
static final String INFERENCE_ID_FIELD = "inference_id";
static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
static final String CHUNKS_FIELD = "chunks";
static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
public static final String CHUNKS_FIELD = "chunks";
public static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
public static final String CHUNKED_TEXT_FIELD = "text";
static final String MODEL_SETTINGS_FIELD = "model_settings";
static final String TASK_TYPE_FIELD = "task_type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,19 @@

package org.elasticsearch.xpack.inference.queries;

import org.elasticsearch.action.ResolvedIndices;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.mapper.IndexFieldMapper;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.MatchQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.TermQueryBuilder;
import org.elasticsearch.index.query.TermsQueryBuilder;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField;

public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
Expand All @@ -37,58 +33,47 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(matchQueryBuilder.fieldName());
if (inferenceFieldMetadata != null) {
inferenceIndices.add(indexName);
} else {
nonInferenceIndices.add(indexName);
}
}
InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
matchQueryBuilder.fieldName(),
context.getResolvedIndices()
);

if (inferenceIndices.isEmpty()) {
return rewritten;
} else if (nonInferenceIndices.isEmpty() == false) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
for (String inferenceIndexName : inferenceIndices) {
// Add a separate clause for each semantic query, because they may be using different inference endpoints
// TODO - consolidate this to a single clause once the semantic query supports multiple inference endpoints
boolQueryBuilder.should(
createSemanticSubQuery(inferenceIndexName, matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value())
);
}
boolQueryBuilder.should(createMatchSubQuery(nonInferenceIndices, matchQueryBuilder));
rewritten = boolQueryBuilder;
} else {
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}
if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) {
// No inference fields, return original query
return rewritten;
} else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) {
// Combined inference and non inference fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
inferenceIndexInformationForField.inferenceIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
inferenceIndexInformationForField.nonInferenceIndices(),
matchQueryBuilder
)
);
rewritten = boolQueryBuilder;
} else {
// Only inference fields
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}

return rewritten;

}

@Override
public String getQueryName() {
return MatchQueryBuilder.NAME;
}

private QueryBuilder createSemanticSubQuery(String indexName, String fieldName, String value) {
private QueryBuilder createSemanticSubQuery(List<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermQueryBuilder(IndexFieldMapper.NAME, indexName));
return boolQueryBuilder;
}

private QueryBuilder createMatchSubQuery(List<String> indices, MatchQueryBuilder matchQueryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(matchQueryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@

public class SemanticQueryInterceptionUtils {


private SemanticQueryInterceptionUtils() {}

public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
public static InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
Expand All @@ -39,25 +38,17 @@ public static SemanticTextIndexInformationForField resolveIndicesForField(String
}
}

return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices);
return new InferenceIndexInformationForField(inferenceIndices, nonInferenceIndices);
}
return null;
}

public static QueryBuilder createSemanticSubQueryForIndices(List<String> indices, String fieldName, String value) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(new SemanticQueryBuilder(fieldName, value, true));
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

public static QueryBuilder createSubQueryForIndices(List<String> indices, QueryBuilder queryBuilder) {
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.must(queryBuilder);
boolQueryBuilder.filter(new TermsQueryBuilder(IndexFieldMapper.NAME, indices));
return boolQueryBuilder;
}

public record SemanticTextIndexInformationForField(List<String> semanticMappedIndices, List<String> otherIndices) {}

public record InferenceIndexInformationForField(List<String> inferenceIndices, List<String> nonInferenceIndices) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,51 +15,51 @@
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField;

public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_sparse_vector_query_rewrite_interception_supported"
);

private static final String NESTED_FIELD_PATH = ".inference.chunks";
private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings";

public SemanticSparseVectorQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField semanticTextIndexInformationForField =
SemanticQueryInterceptionUtils.resolveIndicesForField(sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices());
InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
sparseVectorQueryBuilder.getFieldName(),
context.getResolvedIndices()
);

if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) {
// No inference fields, return original query
return rewritten;
} else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) {
// Combined semantic and sparse vector fields
} else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) {
// Combined inference and non inference fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
// sparse_vector fields should be passed in as their own clause
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
inferenceIndexInformationForField.nonInferenceIndices(),
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
inferenceIndexInformationForField.nonInferenceIndices(),
sparseVectorQueryBuilder
)
)
);
// semantic text fields should be passed in as nested sub queries
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.semanticMappedIndices(),
inferenceIndexInformationForField.inferenceIndices(),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder)
)

);

rewritten = boolQueryBuilder;
} else {
// Only semantic text fields
Expand All @@ -85,11 +85,11 @@ private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuil
}

private static String getNestedFieldPath(String fieldName) {
return fieldName + NESTED_FIELD_PATH;
return fieldName + SemanticTextField.INFERENCE_FIELD + SemanticTextField.CHUNKS_FIELD;
}

private static String getNestedEmbeddingsField(String fieldName) {
return fieldName + NESTED_EMBEDDINGS_FIELD;
return getNestedFieldPath(fieldName) + SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
}

@Override
Expand Down

0 comments on commit c427143

Please sign in to comment.