Skip to content

Commit

Permalink
Fix some test errors, and do some cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
kderusso committed Dec 13, 2024
1 parent 78cafbb commit 0bab375
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,6 @@ public SparseVectorQueryBuilder(
: (this.shouldPruneTokens ? new TokenPruningConfig() : null));
this.weightedTokensSupplier = null;

if (queryVectors == null ^ query == null == false) {
throw new IllegalArgumentException(
"[" + NAME + "] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + QUERY_FIELD.getPreferredName() + "]"
);
}
if (inferenceId != null && query == null) {
throw new IllegalArgumentException(
"["
Expand All @@ -106,6 +101,12 @@ public SparseVectorQueryBuilder(
+ "] is specified"
);
}

if (queryVectors == null ^ query == null == false) {
throw new IllegalArgumentException(
"[" + NAME + "] requires one of [" + QUERY_VECTOR_FIELD.getPreferredName() + "] or [" + QUERY_FIELD.getPreferredName() + "]"
);
}
}

public SparseVectorQueryBuilder(StreamInput in) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ public void testIllegalValues() {
{
IllegalArgumentException e = expectThrows(
IllegalArgumentException.class,
() -> new SparseVectorQueryBuilder("field name", null, "model id")
() -> new SparseVectorQueryBuilder("field name", null, null)
);
assertEquals("[sparse_vector] requires one of [query_vector] or [inference_id]", e.getMessage());
assertEquals("[sparse_vector] requires one of [query_vector] or [query]", e.getMessage());
}
{
IllegalArgumentException e = expectThrows(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ public SemanticQueryRewriteInterceptor() {}
public QueryBuilder interceptAndRewrite(QueryRewriteContext context, QueryBuilder queryBuilder) {
QueryBuilder rewritten = queryBuilder;
String fieldName = getFieldName(queryBuilder);
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, context.getResolvedIndices());

if (indexInformation == null || indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
return rewritten;
} else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
// Combined case where the field name requested by this query contains both
// semantic_text and non-inference fields, so we have to combine queries per index
// containing each field type.
rewritten = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
} else {
// The only fields we've identified are inference fields (e.g. semantic_text),
// so rewrite the entire query to work on a semantic_text field.
rewritten = buildInferenceQuery(queryBuilder, indexInformation);
ResolvedIndices resolvedIndices = context.getResolvedIndices();
if (resolvedIndices != null) {
InferenceIndexInformationForField indexInformation = resolveIndicesForField(fieldName, resolvedIndices);
if (indexInformation.getInferenceIndices().isEmpty()) {
// No inference fields were identified, so return the original query.
return rewritten;
} else if (indexInformation.nonInferenceIndices().isEmpty() == false) {
// Combined case where the field name requested by this query contains both
// semantic_text and non-inference fields, so we have to combine queries per index
// containing each field type.
rewritten = buildCombinedInferenceAndNonInferenceQuery(queryBuilder, indexInformation);
} else {
// The only fields we've identified are inference fields (e.g. semantic_text),
// so rewrite the entire query to work on a semantic_text field.
rewritten = buildInferenceQuery(queryBuilder, indexInformation);
}
}

return rewritten;
Expand Down Expand Up @@ -87,23 +89,20 @@ protected abstract QueryBuilder buildCombinedInferenceAndNonInferenceQuery(
);

private InferenceIndexInformationForField resolveIndicesForField(String fieldName, ResolvedIndices resolvedIndices) {
if (resolvedIndices != null) {
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
} else {
nonInferenceIndices.add(indexName);
}
Collection<IndexMetadata> indexMetadataCollection = resolvedIndices.getConcreteLocalIndicesMetadata().values();
Map<String, InferenceFieldMetadata> inferenceIndicesMetadata = new HashMap<>();
List<String> nonInferenceIndices = new ArrayList<>();
for (IndexMetadata indexMetadata : indexMetadataCollection) {
String indexName = indexMetadata.getIndex().getName();
InferenceFieldMetadata inferenceFieldMetadata = indexMetadata.getInferenceFields().get(fieldName);
if (inferenceFieldMetadata != null) {
inferenceIndicesMetadata.put(indexName, inferenceFieldMetadata);
} else {
nonInferenceIndices.add(indexName);
}

return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
}
return null;

return new InferenceIndexInformationForField(fieldName, inferenceIndicesMetadata, nonInferenceIndices);
}

protected QueryBuilder createSubQueryForIndices(Collection<String> indices, QueryBuilder queryBuilder) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ setup:

---
"Test sparse_vector only allows one of query or query_vector":
- requires:
cluster_features: [ "search.semantic_sparse_vector_query_rewrite_interception_supported" ]
reason: "sparse vector inference checks updated in 8.18 to support sparse_vector on semantic_text fields"
- do:
catch: /\[sparse_vector\] requires one of \[query_vector\] or \[query\]/
search:
Expand Down

0 comments on commit 0bab375

Please sign in to comment.