Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 12, 2024
1 parent ef1b8f1 commit 7652f5c
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ public record SemanticTextField(String fieldName, List<String> originalValues, I
ToXContentObject {

static final String TEXT_FIELD = "text";
static final String INFERENCE_FIELD = "inference";
public static final String INFERENCE_FIELD = "inference";
static final String INFERENCE_ID_FIELD = "inference_id";
static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
static final String CHUNKS_FIELD = "chunks";
static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
public static final String CHUNKS_FIELD = "chunks";
public static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings";
public static final String CHUNKED_TEXT_FIELD = "text";
static final String MODEL_SETTINGS_FIELD = "model_settings";
static final String TASK_TYPE_FIELD = "task_type";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import java.util.List;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField;
import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField;

public class SemanticMatchQueryRewriteInterceptor implements QueryRewriteInterceptor {

Expand All @@ -33,33 +33,33 @@ public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilde
assert (queryBuilder instanceof MatchQueryBuilder);
MatchQueryBuilder matchQueryBuilder = (MatchQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
matchQueryBuilder.fieldName(),
context.getResolvedIndices()
);

if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) {
// No inference fields, return original query
return rewritten;
} else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) {
// Combined semantic_text and other text fields
} else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) {
// Combined inference and non inference fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
boolQueryBuilder.should(
createSemanticSubQuery(
semanticTextIndexInformationForField.semanticMappedIndices(),
inferenceIndexInformationForField.inferenceIndices(),
matchQueryBuilder.fieldName(),
(String) matchQueryBuilder.value()
)
);
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
inferenceIndexInformationForField.nonInferenceIndices(),
matchQueryBuilder
)
);
rewritten = boolQueryBuilder;
} else {
// Only semantic_text fields
// Only inference fields
rewritten = new SemanticQueryBuilder(matchQueryBuilder.fieldName(), (String) matchQueryBuilder.value(), false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class SemanticQueryInterceptionUtils {

private SemanticQueryInterceptionUtils() {}

public static SemanticTextIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
public static InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
List<String> inferenceIndices = new ArrayList<>();
Expand All @@ -38,7 +38,7 @@ public static SemanticTextIndexInformationForField resolveIndicesForField(String
}
}

return new SemanticTextIndexInformationForField(inferenceIndices, nonInferenceIndices);
return new InferenceIndexInformationForField(inferenceIndices, nonInferenceIndices);
}
return null;
}
Expand All @@ -50,5 +50,5 @@ public static QueryBuilder createSubQueryForIndices(List<String> indices, QueryB
return boolQueryBuilder;
}

public record SemanticTextIndexInformationForField(List<String> semanticMappedIndices, List<String> otherIndices) {}
public record InferenceIndexInformationForField(List<String> inferenceIndices, List<String> nonInferenceIndices) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,55 +15,51 @@
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.internal.rewriter.QueryRewriteInterceptor;
import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder;
import org.elasticsearch.xpack.inference.mapper.SemanticTextField;

import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.SemanticTextIndexInformationForField;
import static org.elasticsearch.xpack.inference.queries.SemanticQueryInterceptionUtils.InferenceIndexInformationForField;

public class SemanticSparseVectorQueryRewriteInterceptor implements QueryRewriteInterceptor {

public static final NodeFeature SEMANTIC_SPARSE_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED = new NodeFeature(
"search.semantic_sparse_vector_query_rewrite_interception_supported"
);

private static final String NESTED_FIELD_PATH = ".inference.chunks";
private static final String NESTED_EMBEDDINGS_FIELD = NESTED_FIELD_PATH + ".embeddings";

public SemanticSparseVectorQueryRewriteInterceptor() {}

@Override
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
assert (queryBuilder instanceof SparseVectorQueryBuilder);
SparseVectorQueryBuilder sparseVectorQueryBuilder = (SparseVectorQueryBuilder) queryBuilder;
QueryBuilder rewritten = queryBuilder;
SemanticTextIndexInformationForField semanticTextIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
InferenceIndexInformationForField inferenceIndexInformationForField = SemanticQueryInterceptionUtils.resolveIndicesForField(
sparseVectorQueryBuilder.getFieldName(),
context.getResolvedIndices()
);

if (semanticTextIndexInformationForField == null || semanticTextIndexInformationForField.semanticMappedIndices().isEmpty()) {
// No semantic text fields, return original query
if (inferenceIndexInformationForField == null || inferenceIndexInformationForField.inferenceIndices().isEmpty()) {
// No inference fields, return original query
return rewritten;
} else if (semanticTextIndexInformationForField.otherIndices().isEmpty() == false) {
// Combined semantic and sparse vector fields
} else if (inferenceIndexInformationForField.nonInferenceIndices().isEmpty() == false) {
// Combined inference and non inference fields
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
// sparse_vector fields should be passed in as their own clause
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
inferenceIndexInformationForField.nonInferenceIndices(),
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.otherIndices(),
inferenceIndexInformationForField.nonInferenceIndices(),
sparseVectorQueryBuilder
)
)
);
// semantic text fields should be passed in as nested sub queries
// We always perform nested subqueries on semantic_text fields, to support
// sparse_vector queries using query vectors
boolQueryBuilder.should(
SemanticQueryInterceptionUtils.createSubQueryForIndices(
semanticTextIndexInformationForField.semanticMappedIndices(),
inferenceIndexInformationForField.inferenceIndices(),
buildNestedQueryFromSparseVectorQuery(sparseVectorQueryBuilder)
)

);

rewritten = boolQueryBuilder;
} else {
// Only semantic text fields
Expand All @@ -89,11 +85,11 @@ private QueryBuilder buildNestedQueryFromSparseVectorQuery(SparseVectorQueryBuil
}

private static String getNestedFieldPath(String fieldName) {
return fieldName + NESTED_FIELD_PATH;
return fieldName + SemanticTextField.INFERENCE_FIELD + SemanticTextField.CHUNKS_FIELD;
}

private static String getNestedEmbeddingsField(String fieldName) {
return fieldName + NESTED_EMBEDDINGS_FIELD;
return getNestedFieldPath(fieldName) + SemanticTextField.CHUNKED_EMBEDDINGS_FIELD;
}

@Override
Expand Down

0 comments on commit 7652f5c

Please sign in to comment.