From 14303de4d85da55a7d61f9a87e8e3808918badbf Mon Sep 17 00:00:00 2001 From: Mstrutov <41866740+Mstrutov@users.noreply.github.com> Date: Mon, 5 Jun 2023 20:46:18 +0300 Subject: [PATCH] fix: add several missing fields in MultisearchBody.Builder (#506) (#516) * fix: add several missing fields in MultisearchBody.Builder (#506) - add minScore, postFilter, searchAfter, sort, trackScores to MultisearchBody Signed-off-by: Maksim Strutovskii * update CHANGELOG.md Signed-off-by: Maksim Strutovskii --------- Signed-off-by: Maksim Strutovskii --- CHANGELOG.md | 2 +- .../core/msearch/MultisearchBody.java | 200 ++++++++++++++++++ .../AbstractMultiSearchRequestIT.java | 160 +++++++++++++- 3 files changed, 352 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5779a013e2..9fd7523f53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -61,7 +61,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fix search failure with missing required property HitsMetadata.total when trackTotalHits is disabled ([#372](https://github.com/opensearch-project/opensearch-java/pull/372)) - Fix failure when deserialing response for tasks API ([#463](https://github.com/opensearch-project/opensearch-java/pull/463)) - Fix failure when deserializing boolean types for enums ([#463](https://github.com/opensearch-project/opensearch-java/pull/482)) - +- Fix missing minScore, postFilter, searchAfter, sort, trackScores in the MultisearchBody ([#516](https://github.com/opensearch-project/opensearch-java/pull/516)) ### Security ## [2.4.0] - 04/11/2023 diff --git a/java-client/src/main/java/org/opensearch/client/opensearch/core/msearch/MultisearchBody.java b/java-client/src/main/java/org/opensearch/client/opensearch/core/msearch/MultisearchBody.java index 8fb418a153..aaee0642ea 100644 --- a/java-client/src/main/java/org/opensearch/client/opensearch/core/msearch/MultisearchBody.java +++ b/java-client/src/main/java/org/opensearch/client/opensearch/core/msearch/MultisearchBody.java @@ -36,6 +36,7 @@ package org.opensearch.client.opensearch.core.msearch; +import org.opensearch.client.opensearch._types.SortOptions; import org.opensearch.client.opensearch._types.aggregations.Aggregation; import org.opensearch.client.opensearch._types.query_dsl.Query; import org.opensearch.client.opensearch.core.search.Highlight; @@ -53,6 +54,7 @@ import org.opensearch.client.util.ObjectBuilderBase; import jakarta.json.stream.JsonGenerator; +import java.util.List; import java.util.Map; import java.util.function.Function; import javax.annotation.Nullable; @@ -70,9 +72,22 @@ public class MultisearchBody implements JsonpSerializable { @Nullable private final Integer from; + @Nullable + private final Double minScore; + + @Nullable + private final Query postFilter; + + private final List searchAfter; + @Nullable private final Integer size; + private final List sort; + + @Nullable + private final Boolean trackScores; + @Nullable private final TrackHits trackTotalHits; @@ -92,7 +107,12 @@ private MultisearchBody(Builder builder) { this.aggregations = ApiTypeHelper.unmodifiable(builder.aggregations); this.query = builder.query; this.from = builder.from; + this.minScore = builder.minScore; + this.postFilter = builder.postFilter; + this.searchAfter = ApiTypeHelper.unmodifiable(builder.searchAfter); this.size = builder.size; + this.sort = ApiTypeHelper.unmodifiable(builder.sort); + this.trackScores = builder.trackScores; this.trackTotalHits = builder.trackTotalHits; this.suggest = builder.suggest; this.highlight = builder.highlight; @@ -127,6 +147,29 @@ public final Integer from() { return this.from; } + /** + * API name: {@code from} + */ + @Nullable + public final Double minScore() { + return this.minScore; + } + + /** + * API name: {@code post_filter} + */ + @Nullable + public final Query postFilter() { + return this.postFilter; + } + + /** + * API name: {@code search_after} + */ + public final List searchAfter() { + return this.searchAfter; + } + /** * API name: {@code size} */ @@ -135,6 +178,21 @@ public final Integer size() { return this.size; } + /** + * API name: {@code sort} + */ + public final List sort() { + return this.sort; + } + + /** + * API name: {@code track_scores} + */ + @Nullable + public final Boolean trackScores() { + return this.trackScores; + } + /** * API name: {@code track_total_hits} */ @@ -198,11 +256,46 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) { generator.writeKey("from"); generator.write(this.from); + } + if (this.minScore != null) { + generator.writeKey("min_score"); + generator.write(this.minScore); + + } + if (this.postFilter != null) { + generator.writeKey("post_filter"); + this.postFilter.serialize(generator, mapper); + + } + if (ApiTypeHelper.isDefined(this.searchAfter)) { + generator.writeKey("search_after"); + generator.writeStartArray(); + for (String item0 : this.searchAfter) { + generator.write(item0); + + } + generator.writeEnd(); + } if (this.size != null) { generator.writeKey("size"); generator.write(this.size); + } + if (ApiTypeHelper.isDefined(this.sort)) { + generator.writeKey("sort"); + generator.writeStartArray(); + for (SortOptions item0 : this.sort) { + item0.serialize(generator, mapper); + + } + generator.writeEnd(); + + } + if (this.trackScores != null) { + generator.writeKey("track_scores"); + generator.write(this.trackScores); + } if (this.trackTotalHits != null) { generator.writeKey("track_total_hits"); @@ -245,9 +338,24 @@ public static class Builder extends ObjectBuilderBase implements ObjectBuilder searchAfter; + @Nullable private Integer size; + @Nullable + private List sort; + + @Nullable + private Boolean trackScores; + @Nullable private TrackHits trackTotalHits; @@ -313,6 +421,52 @@ public final Builder from(@Nullable Integer value) { return this; } + /** + * Minimum _score for matching documents. Documents with a lower _score are not + * included in the search results. + *

+ * API name: {@code min_score} + */ + public final Builder minScore(@Nullable Double value) { + this.minScore = value; + return this; + } + + /** + * API name: {@code post_filter} + */ + public final Builder postFilter(@Nullable Query value) { + this.postFilter = value; + return this; + } + + /** + * API name: {@code post_filter} + */ + public final Builder postFilter(Function> fn) { + return this.postFilter(fn.apply(new Query.Builder()).build()); + } + + /** + * API name: {@code search_after} + *

+ * Adds all elements of list to searchAfter. + */ + public final Builder searchAfter(List list) { + this.searchAfter = _listAddAll(this.searchAfter, list); + return this; + } + + /** + * API name: {@code search_after} + *

+ * Adds one or more values to searchAfter. + */ + public final Builder searchAfter(String value, String... values) { + this.searchAfter = _listAdd(this.searchAfter, value, values); + return this; + } + /** * API name: {@code size} */ @@ -321,6 +475,46 @@ public final Builder size(@Nullable Integer value) { return this; } + /** + * API name: {@code sort} + *

+ * Adds all elements of list to sort. + */ + public final Builder sort(List list) { + this.sort = _listAddAll(this.sort, list); + return this; + } + + /** + * API name: {@code sort} + *

+ * Adds one or more values to sort. + */ + public final Builder sort(SortOptions value, SortOptions... values) { + this.sort = _listAdd(this.sort, value, values); + return this; + } + + /** + * API name: {@code sort} + *

+ * Adds a value to sort using a builder lambda. + */ + public final Builder sort(Function> fn) { + return sort(fn.apply(new SortOptions.Builder()).build()); + } + + /** + * If true, calculate and return document scores, even if the scores are not + * used for sorting. + *

+ * API name: {@code track_scores} + */ + public final Builder trackScores(@Nullable Boolean value) { + this.trackScores = value; + return this; + } + /** * API name: {@code track_total_hits} */ @@ -408,7 +602,13 @@ protected static void setupMultisearchBodyDeserializer(ObjectDeserializer response = sendMSearchRequest(index, List.of(sortedItemsQuery)); + assertEquals(1, response.responses().size()); + var hits = response.responses().get(0).result().hits().hits(); + assertEquals(3, hits.size()); + + assertEquals("hammer", hits.get(2).source().getName()); + } + + @Test + public void shouldReturnMultiSearchesTrackingScores() throws Exception { + String index = "multiple_searches_request_track_scores"; + createTestDocuments(index); + + RequestItem sortedItemsQuery = createMSearchSortedFuzzyRequest(); + + MsearchResponse response = sendMSearchRequest(index, List.of(sortedItemsQuery)); + assertEquals(1, response.responses().size()); + var hits = response.responses().get(0).result().hits().hits(); + assertEquals(3, hits.size()); + assertNull(hits.get(0).score()); + assertNull(hits.get(1).score()); + assertNull(hits.get(2).score()); + + RequestItem trackScoreItemsQuery = createMSearchSortedFuzzyRequest(b -> b.trackScores(true)); + + MsearchResponse responseTrackingScore = sendMSearchRequest(index, List.of(trackScoreItemsQuery)); + assertEquals(1, responseTrackingScore.responses().size()); + var hitsTrackingScore = responseTrackingScore.responses().get(0).result().hits().hits(); + assertEquals(3, hitsTrackingScore.size()); + assertNotNull(hitsTrackingScore.get(0).score()); + assertNotNull(hitsTrackingScore.get(1).score()); + assertNotNull(hitsTrackingScore.get(2).score()); + } + + @Test + public void shouldReturnMultiSearchesAboveMinScore() throws Exception { + String index = "multiple_searches_request_min_score"; + createTestDocuments(index); + + RequestItem sortedItemsQuery = createMSearchFuzzyRequest(); + + MsearchResponse response = sendMSearchRequest(index, List.of(sortedItemsQuery)); + assertEquals(1, response.responses().size()); + var hits = response.responses().get(0).result().hits().hits(); + assertEquals(3, hits.size()); + + double minScore = hits.get(2).score(); + double scoreBetweenFirstAndSecondLowest = (hits.get(1).score() + minScore) / 2; + + RequestItem minScoredItemsQuery = createMSearchFuzzyRequest(b -> b.minScore(scoreBetweenFirstAndSecondLowest)); + + MsearchResponse responseAboveMinScore = sendMSearchRequest(index, List.of(minScoredItemsQuery)); + assertEquals(1, responseAboveMinScore.responses().size()); + assertEquals(2, responseAboveMinScore.responses().get(0).result().hits().hits().size()); + } + + @Test + public void shouldReturnMultiSearchesApplyingPostFilter() throws Exception { + String index = "multiple_searches_request_post_filter"; + createTestDocuments(index); + + RequestItem filteredItemsQuery = createMSearchFuzzyRequest(b -> b.postFilter(createItemSizeSearchQuery("large"))); + + MsearchResponse response = sendMSearchRequest(index, List.of(filteredItemsQuery)); + assertEquals(1, response.responses().size()); + assertEquals(1, response.responses().get(0).result().hits().hits().size()); + } + + @Test + public void shouldReturnMultiSearchesSearchAfter() throws Exception { + String index = "multiple_searches_request_search_after"; + createTestDocuments(index); + + RequestItem sortedItemsQuery = createMSearchSortedFuzzyRequest(); + + MsearchResponse response = sendMSearchRequest(index, List.of(sortedItemsQuery)); + assertEquals(1, response.responses().size()); + assertEquals(3, response.responses().get(0).result().hits().hits().size()); + + List sorts = response.responses().get(0).result().hits().hits().get(1).sort(); + RequestItem sortedAfterItemsQuery = createMSearchSortedFuzzyRequest(b -> b.searchAfter(sorts)); + + MsearchResponse response2 = sendMSearchRequest(index, List.of(sortedAfterItemsQuery)); + assertEquals(1, response2.responses().size()); + assertEquals(1, response2.responses().get(0).result().hits().hits().size()); + } + + private void assertResponseSources(MultiSearchResponseItem response) { List> hitsWithHighlights = response.result().hits().hits(); assertEquals(2, hitsWithHighlights.size()); @@ -113,6 +215,34 @@ private RequestItem createMSearchQuery(String itemSize, String fieldName, List b); + } + + private RequestItem createMSearchSortedFuzzyRequest(Function> additional) { + return createMSearchFuzzyRequest(b -> additional.apply(b + .sort(SortOptions.of(sort -> sort.field(FieldSort.of(f -> f.field("quantity").order(SortOrder.Asc))))))); + } + + private RequestItem createMSearchFuzzyRequest() { + return createMSearchFuzzyRequest(b -> b); + } + + private RequestItem createMSearchFuzzyRequest(Function> additional) { + return RequestItem.of(item -> item.header(header -> header) + .body(b -> additional.apply(b.query(createNameSearchFuzzyQuery()))) + ); + } + + private Query createNameSearchFuzzyQuery() { + return Query.of(filter -> filter.fuzzy( + FuzzyQuery.of(term -> term.field("name") + .value(FieldValue.of("rammer")) + ) + ) + ); + } + private SourceConfig createSourcesConfig(List sources) { return sources.isEmpty() ? null : SourceConfig.of(builder -> builder.filter(filter -> filter.includes(sources))); } @@ -143,30 +273,34 @@ private MsearchResponse sendMSearchRequest(String index, List _1.index(index).id("1").document(createItem("hammer", "large", "yes")).refresh(Refresh.True)); - javaClient().create(_1 -> _1.index(index).id("2").document(createItem("drill", "large", "yes")).refresh(Refresh.True)); - javaClient().create(_1 -> _1.index(index).id("3").document(createItem("jack", "medium", "yes")).refresh(Refresh.True)); - javaClient().create(_1 -> _1.index(index).id("4").document(createItem("wrench", "medium", "no")).refresh(Refresh.True)); - javaClient().create(_1 -> _1.index(index).id("5").document(createItem("screws", "small", "no")).refresh(Refresh.True)); - javaClient().create(_1 -> _1.index(index).id("6").document(createItem("nuts", "small", "no")).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("1").document(createItem("hummer", "huge", "yes", 2)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("2").document(createItem("jammer", "huge", "yes", 1)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("3").document(createItem("hammer", "large", "yes", 3)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("4").document(createItem("drill", "large", "yes", 3)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("5").document(createItem("jack", "medium", "yes", 2)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("6").document(createItem("wrench", "medium", "no", 3)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("7").document(createItem("screws", "small", "no", 1)).refresh(Refresh.True)); + javaClient().create(_1 -> _1.index(index).id("8").document(createItem("nuts", "small", "no", 2)).refresh(Refresh.True)); } - private ShopItem createItem(String name, String size, String company) { - return new ShopItem(name, size, company); + private ShopItem createItem(String name, String size, String company, int quantity) { + return new ShopItem(name, size, company, quantity); } public static class ShopItem { private String name; private String size; private String company; + private int quantity; public ShopItem() { } - public ShopItem(String name, String size, String company) { + public ShopItem(String name, String size, String company, int quantity) { this.name = name; this.size = size; this.company = company; + this.quantity = quantity; } public String getName() { @@ -192,5 +326,13 @@ public String getCompany() { public void setCompany(String company) { this.company = company; } + + public int getQuantity() { + return quantity; + } + + public void setQuantity(int quantity) { + this.quantity = quantity; + } } }