Skip to content

Commit

Permalink
[Search Pipelines] Add request-scoped state shared between processors
Browse files Browse the repository at this point in the history
To handle cases where multiple search pipeline processors need to share
information, we will allocate a Map<String, Object> for the lifetime of
the request and pass it to each processor to get/set values.

Signed-off-by: Michael Froh <froh@amazon.com>
  • Loading branch information
msfroh committed Aug 16, 2023
1 parent 7ffcd65 commit 39d7d25
Show file tree
Hide file tree
Showing 14 changed files with 516 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.document.DocumentField;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* A simple implementation of field collapsing on search responses. Note that this is not going to work as well as
* field collapsing at the shard level, as implemented with the "collapse" parameter in a search request. Mostly
* just using this to demo the oversample / truncate_hits processors.
*/
public class CollapseResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {
/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "collapse";
private static final String COLLAPSE_FIELD = "field";
private final String collapseField;

private CollapseResponseProcessor(String tag, String description, boolean ignoreFailure, String collapseField) {
super(tag, description, ignoreFailure);
this.collapseField = collapseField;
}

@Override
public String getType() {
return TYPE;
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {

if (response.getHits() != null) {
Map<String, SearchHit> collapsedHits = new HashMap<>();
for (SearchHit hit : response.getHits()) {
String fieldValue = "<missing>";
DocumentField docField = hit.getFields().get(collapseField);
if (docField != null) {
if (docField.getValues().size() > 1) {
throw new IllegalStateException("Document " + hit.getId() + " has multiple values for field " + collapseField);
}
fieldValue = docField.getValues().get(0).toString();
} else if (hit.hasSource()) {
Object val = hit.getSourceAsMap().get(collapseField);
if (val != null) {
fieldValue = val.toString();
}
}
SearchHit previousHit = collapsedHits.get(fieldValue);
// TODO - Support the sort used in the request, rather than just score
if (previousHit == null || hit.getScore() > previousHit.getScore()) {
collapsedHits.put(fieldValue, hit);
}
}
List<SearchHit> hitsToReturn = new ArrayList<>(collapsedHits.values());
hitsToReturn.sort(Comparator.comparingDouble(SearchHit::getScore).reversed());
SearchHit[] newHits = hitsToReturn.toArray(new SearchHit[0]);
List<String> collapseValues = new ArrayList<>(collapsedHits.keySet());
SearchHits searchHits = new SearchHits(
newHits,
response.getHits().getTotalHits(),
response.getHits().getMaxScore(),
response.getHits().getSortFields(),
collapseField,
collapseValues.toArray()
);
return SearchResponseUtil.replaceHits(searchHits, response);
}
return response;
}

static class Factory implements Processor.Factory<SearchResponseProcessor> {

@Override
public CollapseResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) {
String collapseField = ConfigurationUtils.readStringProperty(TYPE, tag, config, COLLAPSE_FIELD);
return new CollapseResponseProcessor(tag, description, ignoreFailure, collapseField);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.StatefulSearchRequestProcessor;

import java.util.Map;

/**
* Multiplies the "size" parameter on the {@link SearchRequest} by the given scaling factor, storing the original value
* in the request context as "original_size".
*/
public class OversampleRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor {

/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "oversample";
private static final String SAMPLE_FACTOR = "sample_factor";
static final String ORIGINAL_SIZE = "original_size";
private final double sampleFactor;

private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor) {
super(tag, description, ignoreFailure);
this.sampleFactor = sampleFactor;
}

@Override
public SearchRequest processRequest(SearchRequest request, Map<String, Object> requestContext) {
if (request.source() != null) {
int originalSize = request.source().size();
requestContext.put(ORIGINAL_SIZE, originalSize);
int newSize = (int) Math.ceil(originalSize * sampleFactor);
request.source().size(newSize);
}
return request;
}

@Override
public String getType() {
return TYPE;
}

static class Factory implements Processor.Factory<SearchRequestProcessor> {

@Override
public OversampleRequestProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) {
double sampleFactor = ConfigurationUtils.readDoubleProperty(TYPE, tag, config, SAMPLE_FACTOR);
if (sampleFactor < 1.0) {
throw ConfigurationUtils.newConfigurationException(TYPE, tag, SAMPLE_FACTOR, "Value must be >= 1.0");
}
return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,21 @@ public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcesso
FilterQueryRequestProcessor.TYPE,
new FilterQueryRequestProcessor.Factory(parameters.namedXContentRegistry),
ScriptRequestProcessor.TYPE,
new ScriptRequestProcessor.Factory(parameters.scriptService)
new ScriptRequestProcessor.Factory(parameters.scriptService),
OversampleRequestProcessor.TYPE,
new OversampleRequestProcessor.Factory()
);
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(RenameFieldResponseProcessor.TYPE, new RenameFieldResponseProcessor.Factory());
return Map.of(
RenameFieldResponseProcessor.TYPE,
new RenameFieldResponseProcessor.Factory(),
TruncateHitsResponseProcessor.TYPE,
new TruncateHitsResponseProcessor.Factory(),
CollapseResponseProcessor.TYPE,
new CollapseResponseProcessor.Factory()
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.pipeline.StatefulSearchResponseProcessor;
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil;

import java.util.Map;

/**
* Truncates the returned search hits from the {@link SearchResponse}. If no target size is specified in the pipeline, then
* we try using the "original_size" value from the request context, which may have been set by {@link OversampleRequestProcessor}.
*/
public class TruncateHitsResponseProcessor extends AbstractProcessor implements StatefulSearchResponseProcessor {
/**
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "truncate_hits";
private static final String TARGET_SIZE = "target_size";
private final int targetSize;

@Override
public String getType() {
return TYPE;
}

private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize) {
super(tag, description, ignoreFailure);
this.targetSize = targetSize;
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, Map<String, Object> requestContext) {

int size;
if (targetSize < 0) {
size = (int) requestContext.get(OversampleRequestProcessor.ORIGINAL_SIZE);
} else {
size = targetSize;
}
if (response.getHits() != null && response.getHits().getHits().length > size) {
SearchHit[] newHits = new SearchHit[size];
System.arraycopy(response.getHits().getHits(), 0, newHits, 0, size);
SearchHits searchHits = new SearchHits(
newHits,
response.getHits().getTotalHits(),
response.getHits().getMaxScore(),
response.getHits().getSortFields(),
response.getHits().getCollapseField(),
response.getHits().getCollapseValues()
);
return SearchResponseUtil.replaceHits(searchHits, response);
}
return response;
}

static class Factory implements Processor.Factory<SearchResponseProcessor> {

@Override
public SearchResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
int targetSize = ConfigurationUtils.readIntProperty(TYPE, tag, config, TARGET_SIZE, -1);
return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize);
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common.helpers;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.search.SearchHits;
import org.opensearch.search.profile.SearchProfileShardResults;

/**
* Helper methods for manipulating {@link SearchResponse}.
*/
public final class SearchResponseUtil {
private SearchResponseUtil() {

}

/**
* Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}.
* @param newHits new search hits
* @param response the existing search response
* @return a new search response where the search hits have been replaced
*/
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) {
return new SearchResponse(
new SearchResponseSections(
newHits,
response.getAggregations(),
response.getSuggest(),
response.isTimedOut(),
response.isTerminatedEarly(),
new SearchProfileShardResults(response.getProfileResults()),
response.getNumReducePhases()
),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getTook().millis(),
response.getShardFailures(),
response.getClusters(),
response.pointInTimeId()
);
}
}
Loading

0 comments on commit 39d7d25

Please sign in to comment.