diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java index b3c4e60d56949..3e6c4fef6a559 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java @@ -20,10 +20,10 @@ import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; import java.util.ArrayList; -import java.util.Comparator; -import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; /** * A simple implementation of field collapsing on search responses. Note that this is not going to work as well as @@ -35,12 +35,12 @@ public class CollapseResponseProcessor extends AbstractProcessor implements Sear * Key to reference this processor type from a search pipeline. */ public static final String TYPE = "collapse"; - private static final String COLLAPSE_FIELD = "field"; + static final String COLLAPSE_FIELD = "field"; private final String collapseField; private CollapseResponseProcessor(String tag, String description, boolean ignoreFailure, String collapseField) { super(tag, description, ignoreFailure); - this.collapseField = collapseField; + this.collapseField = Objects.requireNonNull(collapseField); } @Override @@ -49,34 +49,47 @@ public String getType() { } @Override - public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { if (response.getHits() != null) { - Map collapsedHits = new HashMap<>(); + if (response.getHits().getCollapseField() != null) { + throw new IllegalStateException( + "Cannot collapse on " + collapseField + ". Results already collapsed on " + response.getHits().getCollapseField() + ); + } + Map collapsedHits = new LinkedHashMap<>(); + List collapseValues = new ArrayList<>(); for (SearchHit hit : response.getHits()) { - String fieldValue = ""; + Object fieldValue = null; DocumentField docField = hit.getFields().get(collapseField); if (docField != null) { if (docField.getValues().size() > 1) { - throw new IllegalStateException("Document " + hit.getId() + " has multiple values for field " + collapseField); - } - fieldValue = docField.getValues().get(0).toString(); - } else if (hit.hasSource()) { - Object val = hit.getSourceAsMap().get(collapseField); - if (val != null) { - fieldValue = val.toString(); + throw new IllegalStateException( + "Failed to collapse " + hit.getId() + ": doc has multiple values for field " + collapseField + ); } + fieldValue = docField.getValues().get(0); + } else if (hit.getSourceAsMap() != null) { + fieldValue = hit.getSourceAsMap().get(collapseField); } - SearchHit previousHit = collapsedHits.get(fieldValue); - // TODO - Support the sort used in the request, rather than just score - if (previousHit == null || hit.getScore() > previousHit.getScore()) { - collapsedHits.put(fieldValue, hit); + String fieldValueString; + if (fieldValue == null) { + fieldValueString = "__missing__"; + } else { + fieldValueString = fieldValue.toString(); } + + // Results are already sorted by sort criterion. Only keep the first hit for each field. + if (collapsedHits.containsKey(fieldValueString) == false) { + collapsedHits.put(fieldValueString, hit); + collapseValues.add(fieldValue); + } + } + SearchHit[] newHits = new SearchHit[collapsedHits.size()]; + int i = 0; + for (SearchHit collapsedHit : collapsedHits.values()) { + newHits[i++] = collapsedHit; } - List hitsToReturn = new ArrayList<>(collapsedHits.values()); - hitsToReturn.sort(Comparator.comparingDouble(SearchHit::getScore).reversed()); - SearchHit[] newHits = hitsToReturn.toArray(new SearchHit[0]); - List collapseValues = new ArrayList<>(collapsedHits.keySet()); SearchHits searchHits = new SearchHits( newHits, response.getHits().getTotalHits(), diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java new file mode 100644 index 0000000000000..cda011f24fea1 --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.document.DocumentField; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class CollapseResponseProcessorTests extends OpenSearchTestCase { + public void testWithDocumentFields() { + testProcessor(true); + } + + public void testWithSourceField() { + testProcessor(false); + } + + private void testProcessor(boolean includeDocField) { + Map config = new HashMap<>(Map.of(CollapseResponseProcessor.COLLAPSE_FIELD, "groupid")); + CollapseResponseProcessor processor = new CollapseResponseProcessor.Factory().create( + Collections.emptyMap(), + null, + null, + false, + config, + null + ); + int numHits = randomIntBetween(1, 100); + SearchResponse inputResponse = generateResponse(numHits, includeDocField); + + SearchResponse processedResponse = processor.processResponse(new SearchRequest(), inputResponse); + if (numHits % 2 == 0) { + assertEquals(numHits / 2, processedResponse.getHits().getHits().length); + } else { + assertEquals(numHits / 2 + 1, processedResponse.getHits().getHits().length); + } + for (SearchHit collapsedHit : processedResponse.getHits()) { + assertEquals(0, collapsedHit.docId() % 2); + } + assertEquals("groupid", processedResponse.getHits().getCollapseField()); + assertEquals(processedResponse.getHits().getHits().length, processedResponse.getHits().getCollapseValues().length); + for (int i = 0; i < processedResponse.getHits().getHits().length; i++) { + assertEquals(i, processedResponse.getHits().getCollapseValues()[i]); + } + } + + private static SearchResponse generateResponse(int numHits, boolean includeDocField) { + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + Map docFields; + int groupValue = i / 2; + if (includeDocField) { + docFields = Map.of("groupid", new DocumentField("groupid", List.of(groupValue))); + } else { + docFields = Collections.emptyMap(); + } + SearchHit hit = new SearchHit(i, Integer.toString(i), docFields, Collections.emptyMap()); + hit.sourceRef(new BytesArray("{\"groupid\": " + groupValue + "}")); + hitsArray[i] = hit; + } + SearchHits searchHits = new SearchHits( + hitsArray, + new TotalHits(Math.max(numHits, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + 1.0f + ); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, false, 0); + return new SearchResponse(internalSearchResponse, null, 1, 1, 0, 10, null, null); + } +}