diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 0f26f6577860f..faeac9dc1853f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -56,11 +56,11 @@ public record SemanticTextField(String fieldName, List originalValues, I ToXContentObject { static final String TEXT_FIELD = "text"; - static final String INFERENCE_FIELD = "inference"; + public static final String INFERENCE_FIELD = "inference"; static final String INFERENCE_ID_FIELD = "inference_id"; static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; - static final String CHUNKS_FIELD = "chunks"; - static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; + public static final String CHUNKS_FIELD = "chunks"; + public static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; public static final String CHUNKED_TEXT_FIELD = "text"; static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String TASK_TYPE_FIELD = "task_type"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java index 016ae6dc83fde..705ff12ac7c7e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticMatchQueryRewriteInterceptor.java @@ -18,7 +18,7 @@ import java.util.List; -import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField; public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor { @@ -33,33 +33,33 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde assert (queryBuilder instanceof MatchQueryBuilder); MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder; QueryBuilder rewritten = queryBuilder; - SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( + InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( matchQueryBuilder.fieldName(), context.getResolvedIndices() ); - if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) { - // No semantic text fields, return original query + if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) { + // No inference fields, return original query return rewritten; - } else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) { - // Combined semantic_text and other text fields + } else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) { + // Combined inference and non inference fields BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); boolQueryBuilder.should( createSemanticSubQuery( - semanticTextIndexInformationForField.semanticMappedIndices(), + inferenceIndexInformationForField.inferenceIndices(), matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value() ) ); boolQueryBuilder.should( SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.otherIndices(), + inferenceIndexInformationForField.nonInferenceIndices(), matchQueryBuilder ) ); rewritten = boolQueryBuilder; } else { - // Only semantic_text fields + // Only inference fields rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java index df685ce858b70..8922c7bdf5ce4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticQueryInterceptionUtils.java @@ -23,7 +23,7 @@ public class SemanticQueryInterceptionUtils { private SemanticQueryInterceptionUtils() {} - public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { + public static InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) { if (resolvedIndices != null) { Collection indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values(); List inferenceIndices = new ArrayList<>(); @@ -38,7 +38,7 @@ public static SemanticTextIndexInformationForField resolveIndicesForField(String } } - return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices); + return new InferenceIndexInformationForField(inferenceIndices, nonInferenceIndices); } return null; } @@ -50,5 +50,5 @@ public static QueryBuilder createSubQueryForIndices(List indices, QueryB return boolQueryBuilder; } - public record SemanticTextIndexInformationForField(List semanticMappedIndices, List otherIndices) {} + public record InferenceIndexInformationForField(List inferenceIndices, List nonInferenceIndices) {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java index 613394d25d7b8..6ca71decc0e6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/queries/SemanticSparseVectorQueryRewriteInterceptor.java @@ -15,8 +15,9 @@ import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor; import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; -import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField; +import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField; public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor { @@ -24,9 +25,6 @@ public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewrite "search.semantic_sparse_vector_query_rewrite_interception_supported" ); - private static final String NESTED_FIELD_PATH = ".inference.chunks"; - private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings"; - public SemanticSparseVectorQueryRewriteInterceptor() {} @Override @@ -34,36 +32,34 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde assert (queryBuilder instanceof SparseVectorQueryBuilder); SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder; QueryBuilder rewritten = queryBuilder; - SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( + InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField( sparseVectorQueryBuilder.getFieldName(), context.getResolvedIndices() ); - if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) { - // No semantic text fields, return original query + if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) { + // No inference fields, return original query return rewritten; - } else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) { - // Combined semantic and sparse vector fields + } else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) { + // Combined inference and non inference fields BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); - // sparse_vector fields should be passed in as their own clause boolQueryBuilder.should( SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.otherIndices(), + inferenceIndexInformationForField.nonInferenceIndices(), SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.otherIndices(), + inferenceIndexInformationForField.nonInferenceIndices(), sparseVectorQueryBuilder ) ) ); - // semantic text fields should be passed in as nested sub queries + // We always perform nested subqueries on semantic_text fields, to support + // sparse_vector queries using query vectors boolQueryBuilder.should( SemanticQueryInterceptionUtils.createSubQueryForIndices( - semanticTextIndexInformationForField.semanticMappedIndices(), + inferenceIndexInformationForField.inferenceIndices(), buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder) ) - ); - rewritten = boolQueryBuilder; } else { // Only semantic text fields @@ -89,11 +85,11 @@ private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuil } private static String getNestedFieldPath(String fieldName) { - return fieldName + NESTED_FIELD_PATH; + return fieldName + SemanticTextField.INFERENCE_FIELD + SemanticTextField.CHUNKS_FIELD; } private static String getNestedEmbeddingsField(String fieldName) { - return fieldName + NESTED_EMBEDDINGS_FIELD; + return getNestedFieldPath(fieldName) + SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; } @Override