Skip to content

Commit

Permalink
Clean up and test CollapseResponseProcessor
Browse files Browse the repository at this point in the history
After realizing that we just need to keep the first hit for each group
(since results are already sorted by the sort criteria), I think
CollapseResponseProcessor might be worth including.

Combining it with the oversample + truncate processors, it can provide a
workaround for the lack of support for collapse + rescore.

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh committed Aug 25, 2023
1 parent 923b12d commit b995556
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* A simple implementation of field collapsing on search responses. Note that this is not going to work as well as
Expand All @@ -35,12 +35,12 @@ public class CollapseResponseProcessor extends AbstractProcessor implements Sear
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "collapse";
private static final String COLLAPSE_FIELD = "field";
static final String COLLAPSE_FIELD = "field";
private final String collapseField;

private CollapseResponseProcessor(String tag, String description, boolean ignoreFailure, String collapseField) {
super(tag, description, ignoreFailure);
this.collapseField = collapseField;
this.collapseField = Objects.requireNonNull(collapseField);
}

@Override
Expand All @@ -49,34 +49,43 @@ public String getType() {
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
public SearchResponse processResponse(SearchRequest request, SearchResponse response) {

if (response.getHits() != null) {
Map<String, SearchHit> collapsedHits = new HashMap<>();
if (response.getHits().getCollapseField() != null) {
throw new IllegalStateException("Cannot collapse on " + collapseField + ". Results already collapsed on " + response.getHits().getCollapseField());
}
Map<String, SearchHit> collapsedHits = new LinkedHashMap<>();
List<Object> collapseValues = new ArrayList<>();
for (SearchHit hit : response.getHits()) {
String fieldValue = "<missing>";
Object fieldValue = null;
DocumentField docField = hit.getFields().get(collapseField);
if (docField != null) {
if (docField.getValues().size() > 1) {
throw new IllegalStateException("Document " + hit.getId() + " has multiple values for field " + collapseField);
}
fieldValue = docField.getValues().get(0).toString();
} else if (hit.hasSource()) {
Object val = hit.getSourceAsMap().get(collapseField);
if (val != null) {
fieldValue = val.toString();
throw new IllegalStateException("Failed to collapse " + hit.getId() + ": doc has multiple values for field " + collapseField);
}
fieldValue = docField.getValues().get(0);
} else if (hit.getSourceAsMap() != null) {
fieldValue = hit.getSourceAsMap().get(collapseField);
}
SearchHit previousHit = collapsedHits.get(fieldValue);
// TODO - Support the sort used in the request, rather than just score
if (previousHit == null || hit.getScore() > previousHit.getScore()) {
collapsedHits.put(fieldValue, hit);
String fieldValueString;
if (fieldValue == null) {
fieldValueString = "__missing__";
} else {
fieldValueString = fieldValue.toString();
}

// Results are already sorted by sort criterion. Only keep the first hit for each field.
if (collapsedHits.containsKey(fieldValueString) == false) {
collapsedHits.put(fieldValueString, hit);
collapseValues.add(fieldValue);
}
}
SearchHit[] newHits = new SearchHit[collapsedHits.size()];
int i = 0;
for (SearchHit collapsedHit : collapsedHits.values()) {
newHits[i++] = collapsedHit;
}
List<SearchHit> hitsToReturn = new ArrayList<>(collapsedHits.values());
hitsToReturn.sort(Comparator.comparingDouble(SearchHit::getScore).reversed());
SearchHit[] newHits = hitsToReturn.toArray(new SearchHit[0]);
List<String> collapseValues = new ArrayList<>(collapsedHits.keySet());
SearchHits searchHits = new SearchHits(
newHits,
response.getHits().getTotalHits(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.apache.lucene.search.TotalHits;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.document.DocumentField;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class CollapseResponseProcessorTests extends OpenSearchTestCase {
public void testWithDocumentFields() {
testProcessor(true);
}

public void testWithSourceField() {
testProcessor(false);
}

private void testProcessor(boolean includeDocField) {
Map<String, Object> config = new HashMap<>(Map.of(CollapseResponseProcessor.COLLAPSE_FIELD, "groupid"));
CollapseResponseProcessor processor = new CollapseResponseProcessor.Factory().create(Collections.emptyMap(), null, null, false, config, null);
int numHits = randomIntBetween(1, 100);
SearchResponse inputResponse = generateResponse(numHits, includeDocField);

SearchResponse processedResponse = processor.processResponse(new SearchRequest(), inputResponse);
if (numHits % 2 == 0) {
assertEquals(numHits / 2, processedResponse.getHits().getHits().length);
} else {
assertEquals(numHits / 2 + 1, processedResponse.getHits().getHits().length);
}
for (SearchHit collapsedHit : processedResponse.getHits()) {
assertEquals(0, collapsedHit.docId() % 2);
}
assertEquals("groupid", processedResponse.getHits().getCollapseField());
assertEquals(processedResponse.getHits().getHits().length, processedResponse.getHits().getCollapseValues().length);
for (int i = 0; i < processedResponse.getHits().getHits().length; i++) {
assertEquals(i, processedResponse.getHits().getCollapseValues()[i]);
}
}

private static SearchResponse generateResponse(int numHits, boolean includeDocField) {
SearchHit[] hitsArray = new SearchHit[numHits];
for (int i = 0; i < numHits; i++) {
Map<String, DocumentField> docFields;
int groupValue = i / 2;
if (includeDocField) {
docFields = Map.of("groupid", new DocumentField("groupid", List.of(groupValue)));
} else {
docFields = Collections.emptyMap();
}
SearchHit hit = new SearchHit(i, Integer.toString(i), docFields, Collections.emptyMap());
hit.sourceRef(new BytesArray("{\"groupid\": "+ groupValue + "}"));
hitsArray[i] = hit;
}
SearchHits searchHits = new SearchHits(
hitsArray,
new TotalHits(Math.max(numHits, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO),
1.0f
);
InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, false, 0);
return new SearchResponse(internalSearchResponse, null, 1, 1, 0, 10, null, null);
}
}

0 comments on commit b995556

Please sign in to comment.