-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Search Pipelines] Add request-scoped state shared between processors
To handle cases where multiple search pipeline processors need to share information, we will allocate a Map<String, Object> for the lifetime of the request and pass it to each processor to get/set values. Signed-off-by: Michael Froh <froh@amazon.com>
- Loading branch information
Showing
14 changed files
with
516 additions
and
11 deletions.
There are no files selected for viewing
109 changes: 109 additions & 0 deletions
109
...common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.search.pipeline.common; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.common.document.DocumentField; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Comparator; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* A simple implementation of field collapsing on search responses. Note that this is not going to work as well as | ||
* field collapsing at the shard level, as implemented with the "collapse" parameter in a search request. Mostly | ||
* just using this to demo the oversample / truncate_hits processors. | ||
*/ | ||
public class CollapseResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { | ||
/** | ||
* Key to reference this processor type from a search pipeline. | ||
*/ | ||
public static final String TYPE = "collapse"; | ||
private static final String COLLAPSE_FIELD = "field"; | ||
private final String collapseField; | ||
|
||
private CollapseResponseProcessor(String tag, String description, boolean ignoreFailure, String collapseField) { | ||
super(tag, description, ignoreFailure); | ||
this.collapseField = collapseField; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
@Override | ||
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { | ||
|
||
if (response.getHits() != null) { | ||
Map<String, SearchHit> collapsedHits = new HashMap<>(); | ||
for (SearchHit hit : response.getHits()) { | ||
String fieldValue = "<missing>"; | ||
DocumentField docField = hit.getFields().get(collapseField); | ||
if (docField != null) { | ||
if (docField.getValues().size() > 1) { | ||
throw new IllegalStateException("Document " + hit.getId() + " has multiple values for field " + collapseField); | ||
} | ||
fieldValue = docField.getValues().get(0).toString(); | ||
} else if (hit.hasSource()) { | ||
Object val = hit.getSourceAsMap().get(collapseField); | ||
if (val != null) { | ||
fieldValue = val.toString(); | ||
} | ||
} | ||
SearchHit previousHit = collapsedHits.get(fieldValue); | ||
// TODO - Support the sort used in the request, rather than just score | ||
if (previousHit == null || hit.getScore() > previousHit.getScore()) { | ||
collapsedHits.put(fieldValue, hit); | ||
} | ||
} | ||
List<SearchHit> hitsToReturn = new ArrayList<>(collapsedHits.values()); | ||
hitsToReturn.sort(Comparator.comparingDouble(SearchHit::getScore).reversed()); | ||
SearchHit[] newHits = hitsToReturn.toArray(new SearchHit[0]); | ||
List<String> collapseValues = new ArrayList<>(collapsedHits.keySet()); | ||
SearchHits searchHits = new SearchHits( | ||
newHits, | ||
response.getHits().getTotalHits(), | ||
response.getHits().getMaxScore(), | ||
response.getHits().getSortFields(), | ||
collapseField, | ||
collapseValues.toArray() | ||
); | ||
return SearchResponseUtil.replaceHits(searchHits, response); | ||
} | ||
return response; | ||
} | ||
|
||
static class Factory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
@Override | ||
public CollapseResponseProcessor create( | ||
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, | ||
String tag, | ||
String description, | ||
boolean ignoreFailure, | ||
Map<String, Object> config, | ||
PipelineContext pipelineContext | ||
) { | ||
String collapseField = ConfigurationUtils.readStringProperty(TYPE, tag, config, COLLAPSE_FIELD); | ||
return new CollapseResponseProcessor(tag, description, ignoreFailure, collapseField); | ||
} | ||
} | ||
|
||
} |
73 changes: 73 additions & 0 deletions
73
...ommon/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.search.pipeline.common; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchRequestProcessor; | ||
import org.opensearch.search.pipeline.StatefulSearchRequestProcessor; | ||
|
||
import java.util.Map; | ||
|
||
/** | ||
* Multiplies the "size" parameter on the {@link SearchRequest} by the given scaling factor, storing the original value | ||
* in the request context as "original_size". | ||
*/ | ||
public class OversampleRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { | ||
|
||
/** | ||
* Key to reference this processor type from a search pipeline. | ||
*/ | ||
public static final String TYPE = "oversample"; | ||
private static final String SAMPLE_FACTOR = "sample_factor"; | ||
static final String ORIGINAL_SIZE = "original_size"; | ||
private final double sampleFactor; | ||
|
||
private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor) { | ||
super(tag, description, ignoreFailure); | ||
this.sampleFactor = sampleFactor; | ||
} | ||
|
||
@Override | ||
public SearchRequest processRequest(SearchRequest request, Map<String, Object> requestContext) { | ||
if (request.source() != null) { | ||
int originalSize = request.source().size(); | ||
requestContext.put(ORIGINAL_SIZE, originalSize); | ||
int newSize = (int) Math.ceil(originalSize * sampleFactor); | ||
request.source().size(newSize); | ||
} | ||
return request; | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
static class Factory implements Processor.Factory<SearchRequestProcessor> { | ||
|
||
@Override | ||
public OversampleRequestProcessor create( | ||
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, | ||
String tag, | ||
String description, | ||
boolean ignoreFailure, | ||
Map<String, Object> config, | ||
PipelineContext pipelineContext | ||
) { | ||
double sampleFactor = ConfigurationUtils.readDoubleProperty(TYPE, tag, config, SAMPLE_FACTOR); | ||
if (sampleFactor < 1.0) { | ||
throw ConfigurationUtils.newConfigurationException(TYPE, tag, SAMPLE_FACTOR, "Value must be >= 1.0"); | ||
} | ||
return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
...on/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.search.pipeline.common; | ||
|
||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
import org.opensearch.search.pipeline.StatefulSearchResponseProcessor; | ||
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; | ||
|
||
import java.util.Map; | ||
|
||
/** | ||
* Truncates the returned search hits from the {@link SearchResponse}. If no target size is specified in the pipeline, then | ||
* we try using the "original_size" value from the request context, which may have been set by {@link OversampleRequestProcessor}. | ||
*/ | ||
public class TruncateHitsResponseProcessor extends AbstractProcessor implements StatefulSearchResponseProcessor { | ||
/** | ||
* Key to reference this processor type from a search pipeline. | ||
*/ | ||
public static final String TYPE = "truncate_hits"; | ||
private static final String TARGET_SIZE = "target_size"; | ||
private final int targetSize; | ||
|
||
@Override | ||
public String getType() { | ||
return TYPE; | ||
} | ||
|
||
private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize) { | ||
super(tag, description, ignoreFailure); | ||
this.targetSize = targetSize; | ||
} | ||
|
||
@Override | ||
public SearchResponse processResponse(SearchRequest request, SearchResponse response, Map<String, Object> requestContext) { | ||
|
||
int size; | ||
if (targetSize < 0) { | ||
size = (int) requestContext.get(OversampleRequestProcessor.ORIGINAL_SIZE); | ||
} else { | ||
size = targetSize; | ||
} | ||
if (response.getHits() != null && response.getHits().getHits().length > size) { | ||
SearchHit[] newHits = new SearchHit[size]; | ||
System.arraycopy(response.getHits().getHits(), 0, newHits, 0, size); | ||
SearchHits searchHits = new SearchHits( | ||
newHits, | ||
response.getHits().getTotalHits(), | ||
response.getHits().getMaxScore(), | ||
response.getHits().getSortFields(), | ||
response.getHits().getCollapseField(), | ||
response.getHits().getCollapseValues() | ||
); | ||
return SearchResponseUtil.replaceHits(searchHits, response); | ||
} | ||
return response; | ||
} | ||
|
||
static class Factory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
@Override | ||
public SearchResponseProcessor create( | ||
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, | ||
String tag, | ||
String description, | ||
boolean ignoreFailure, | ||
Map<String, Object> config, | ||
PipelineContext pipelineContext | ||
) throws Exception { | ||
int targetSize = ConfigurationUtils.readIntProperty(TYPE, tag, config, TARGET_SIZE, -1); | ||
return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize); | ||
} | ||
} | ||
|
||
} |
51 changes: 51 additions & 0 deletions
51
...ommon/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
|
||
package org.opensearch.search.pipeline.common.helpers; | ||
|
||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.action.search.SearchResponseSections; | ||
import org.opensearch.search.SearchHits; | ||
import org.opensearch.search.profile.SearchProfileShardResults; | ||
|
||
/** | ||
* Helper methods for manipulating {@link SearchResponse}. | ||
*/ | ||
public final class SearchResponseUtil { | ||
private SearchResponseUtil() { | ||
|
||
} | ||
|
||
/** | ||
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}. | ||
* @param newHits new search hits | ||
* @param response the existing search response | ||
* @return a new search response where the search hits have been replaced | ||
*/ | ||
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) { | ||
return new SearchResponse( | ||
new SearchResponseSections( | ||
newHits, | ||
response.getAggregations(), | ||
response.getSuggest(), | ||
response.isTimedOut(), | ||
response.isTerminatedEarly(), | ||
new SearchProfileShardResults(response.getProfileResults()), | ||
response.getNumReducePhases() | ||
), | ||
response.getScrollId(), | ||
response.getTotalShards(), | ||
response.getSuccessfulShards(), | ||
response.getSkippedShards(), | ||
response.getTook().millis(), | ||
response.getShardFailures(), | ||
response.getClusters(), | ||
response.pointInTimeId() | ||
); | ||
} | ||
} |
Oops, something went wrong.