Skip to content

Commit

Permalink
[ML] explicitly disallow partial results in datafeed extractors (#55537)
Browse files Browse the repository at this point in the history
Instead of doing our own checks against REST status, shard counts, and shard failures, this commit changes all our extractor search requests to set `.setAllowPartialSearchResults(false)`.

- Scrolls are automatically cleared when a search failure occurs with `.setAllowPartialSearchResults(false)` set.
- Code error handling is simplified

closes #40793
  • Loading branch information
benwtrent authored Apr 22, 2020
1 parent ed57adb commit 5074f2c
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 217 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,20 @@
*/
package org.elasticsearch.xpack.core.ml.datafeed.extractor;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.ShardSearchFailure;
import org.elasticsearch.common.Rounding;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.RangeQueryBuilder;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.histogram.DateHistogramAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.histogram.HistogramAggregationBuilder;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.time.ZoneOffset;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.TimeUnit;

Expand All @@ -34,7 +27,6 @@
*/
public final class ExtractorUtils {

private static final Logger LOGGER = LogManager.getLogger(ExtractorUtils.class);
private static final String EPOCH_MILLIS = "epoch_millis";

private ExtractorUtils() {}
Expand All @@ -47,25 +39,6 @@ public static QueryBuilder wrapInTimeRangeQuery(QueryBuilder userQuery, String t
return new BoolQueryBuilder().filter(userQuery).filter(timeQuery);
}

/**
* Checks that a {@link SearchResponse} has an OK status code and no shard failures
*/
public static void checkSearchWasSuccessful(String jobId, SearchResponse searchResponse) throws IOException {
if (searchResponse.status() != RestStatus.OK) {
throw new IOException("[" + jobId + "] Search request returned status code: " + searchResponse.status()
+ ". Response was:\n" + searchResponse.toString());
}
ShardSearchFailure[] shardFailures = searchResponse.getShardFailures();
if (shardFailures != null && shardFailures.length > 0) {
LOGGER.error("[{}] Search request returned shard failures: {}", jobId, Arrays.toString(shardFailures));
throw new IOException(ExceptionsHelper.shardFailuresToErrorMsg(jobId, shardFailures));
}
int unavailableShards = searchResponse.getTotalShards() - searchResponse.getSuccessfulShards();
if (unavailableShards > 0) {
throw new IOException("[" + jobId + "] Search request encountered [" + unavailableShards + "] unavailable shards");
}
}

/**
* Find the (date) histogram in {@code aggFactory} and extract its interval.
* Throws if there is no (date) histogram or if the histogram has sibling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,13 @@ public Optional<InputStream> next() throws IOException {
return Optional.ofNullable(processNextBatch());
}

private Aggregations search() throws IOException {
private Aggregations search() {
LOGGER.debug("[{}] Executing aggregated search", context.jobId);
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest(buildBaseSearchSource()));
T searchRequest = buildSearchRequest(buildBaseSearchSource());
assert searchRequest.request().allowPartialSearchResults() == false;
SearchResponse searchResponse = executeSearchRequest(searchRequest);
LOGGER.debug("[{}] Search response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
return validateAggs(searchResponse.getAggregations());
}

Expand Down Expand Up @@ -166,10 +167,6 @@ private InputStream processNextBatch() throws IOException {
return new ByteArrayInputStream(outputStream.toByteArray());
}

protected long getHistogramInterval() {
return ExtractorUtils.getHistogramIntervalMillis(context.aggs);
}

public AggregationDataExtractorContext getContext() {
return context;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ protected SearchRequestBuilder buildSearchRequest(SearchSourceBuilder searchSour
return new SearchRequestBuilder(client, SearchAction.INSTANCE)
.setSource(searchSourceBuilder)
.setIndicesOptions(context.indicesOptions)
.setAllowPartialSearchResults(false)
.setIndices(context.indices);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ class AggregationToJsonProcessor {
* @param includeDocCount whether to include the doc_count
* @param startTime buckets with a timestamp before this time are discarded
*/
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime)
throws IOException {
AggregationToJsonProcessor(String timeField, Set<String> fields, boolean includeDocCount, long startTime) {
this.timeField = Objects.requireNonNull(timeField);
this.fields = Objects.requireNonNull(fields);
this.includeDocCount = includeDocCount;
Expand Down Expand Up @@ -279,7 +278,7 @@ private void processBucket(MultiBucketsAggregation bucketAgg, boolean addField)
* Adds a leaf key-value. It returns {@code true} if the key added or {@code false} when nothing was added.
* Non-finite metric values are not added.
*/
private boolean processLeaf(Aggregation agg) throws IOException {
private boolean processLeaf(Aggregation agg) {
if (agg instanceof NumericMetricsAggregation.SingleValue) {
return processSingleValue((NumericMetricsAggregation.SingleValue) agg);
} else if (agg instanceof Percentiles) {
Expand All @@ -291,7 +290,7 @@ private boolean processLeaf(Aggregation agg) throws IOException {
}
}

private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) throws IOException {
private boolean processSingleValue(NumericMetricsAggregation.SingleValue singleValue) {
return addMetricIfFinite(singleValue.getName(), singleValue.value());
}

Expand All @@ -311,7 +310,7 @@ private boolean processGeoCentroid(GeoCentroid agg) {
return false;
}

private boolean processPercentiles(Percentiles percentiles) throws IOException {
private boolean processPercentiles(Percentiles percentiles) {
Iterator<Percentile> percentileIterator = percentiles.iterator();
boolean aggregationAdded = addMetricIfFinite(percentiles.getName(), percentileIterator.next().getValue());
if (percentileIterator.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class RollupDataExtractor extends AbstractAggregationDataExtractor<RollupSearchA
protected RollupSearchAction.RequestBuilder buildSearchRequest(SearchSourceBuilder searchSourceBuilder) {
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
.indicesOptions(context.indicesOptions)
.allowPartialSearchResults(false)
.source(searchSourceBuilder);

return new RollupSearchAction.RequestBuilder(client, searchRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public Optional<InputStream> next() throws IOException {
return getNextStream();
}

private void setUpChunkedSearch() throws IOException {
private void setUpChunkedSearch() {
DataSummary dataSummary = dataSummaryFactory.buildDataSummary();
if (dataSummary.hasData()) {
currentStart = context.timeAligner.alignToFloor(dataSummary.earliestTime());
Expand Down Expand Up @@ -196,21 +196,18 @@ private class DataSummaryFactory {
* So, if we need to gather an appropriate chunked time for aggregations, we can utilize the AggregatedDataSummary
*
* @return DataSummary object
* @throws IOException when timefield range search fails
*/
private DataSummary buildDataSummary() throws IOException {
private DataSummary buildDataSummary() {
return context.hasAggregations ? newAggregatedDataSummary() : newScrolledDataSummary();
}

private DataSummary newScrolledDataSummary() throws IOException {
private DataSummary newScrolledDataSummary() {
SearchRequestBuilder searchRequestBuilder = rangeSearchRequest();

SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
LOGGER.debug("[{}] Scrolling Data summary response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());

ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);

Aggregations aggregations = searchResponse.getAggregations();
long earliestTime = 0;
long latestTime = 0;
Expand All @@ -224,16 +221,14 @@ private DataSummary newScrolledDataSummary() throws IOException {
return new ScrolledDataSummary(earliestTime, latestTime, totalHits);
}

private DataSummary newAggregatedDataSummary() throws IOException {
private DataSummary newAggregatedDataSummary() {
// TODO: once RollupSearchAction is changed from indices:admin* to indices:data/read/* this branch is not needed
ActionRequestBuilder<SearchRequest, SearchResponse> searchRequestBuilder =
dataExtractorFactory instanceof RollupDataExtractorFactory ? rollupRangeSearchRequest() : rangeSearchRequest();
SearchResponse searchResponse = executeSearchRequest(searchRequestBuilder);
LOGGER.debug("[{}] Aggregating Data summary response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());

ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);

Aggregations aggregations = searchResponse.getAggregations();
Min min = aggregations.get(EARLIEST_TIME);
Max max = aggregations.get(LATEST_TIME);
Expand All @@ -253,12 +248,14 @@ private SearchRequestBuilder rangeSearchRequest() {
.setIndices(context.indices)
.setIndicesOptions(context.indicesOptions)
.setSource(rangeSearchBuilder())
.setAllowPartialSearchResults(false)
.setTrackTotalHits(true);
}

private RollupSearchAction.RequestBuilder rollupRangeSearchRequest() {
SearchRequest searchRequest = new SearchRequest().indices(context.indices)
.indicesOptions(context.indicesOptions)
.allowPartialSearchResults(false)
.source(rangeSearchBuilder());
return new RollupSearchAction.RequestBuilder(client, searchRequest);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,13 @@ private Optional<InputStream> tryNextStream() throws IOException {
return scrollId == null ?
Optional.ofNullable(initScroll(context.start)) : Optional.ofNullable(continueScroll());
} catch (Exception e) {
// In case of error make sure we clear the scroll context
clearScroll();
throw e;
scrollId = null;
if (searchHasShardFailure) {
throw e;
}
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
markScrollAsErrored();
return Optional.ofNullable(initScroll(lastTimestamp == null ? context.start : lastTimestamp));
}
}

Expand All @@ -127,6 +131,7 @@ private SearchRequestBuilder buildSearchRequest(long start) {
.setIndices(context.indices)
.setIndicesOptions(context.indicesOptions)
.setSize(context.scrollSize)
.setAllowPartialSearchResults(false)
.setQuery(ExtractorUtils.wrapInTimeRangeQuery(
context.query, context.extractedFields.timeField(), start, context.end));

Expand All @@ -147,14 +152,6 @@ private SearchRequestBuilder buildSearchRequest(long start) {
private InputStream processSearchResponse(SearchResponse searchResponse) throws IOException {

scrollId = searchResponse.getScrollId();

if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
markScrollAsErrored();
return initScroll(lastTimestamp == null ? context.start : lastTimestamp);
}

ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
if (searchResponse.getHits().getHits().length == 0) {
hasNext = false;
clearScroll();
Expand Down Expand Up @@ -190,24 +187,23 @@ private InputStream continueScroll() throws IOException {
try {
searchResponse = executeSearchScrollRequest(scrollId);
} catch (SearchPhaseExecutionException searchExecutionException) {
if (searchHasShardFailure == false) {
LOGGER.debug("[{}] Reinitializing scroll due to SearchPhaseExecutionException", context.jobId);
markScrollAsErrored();
searchResponse =
executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
} else {
if (searchHasShardFailure) {
throw searchExecutionException;
}
LOGGER.debug("[{}] search failed due to SearchPhaseExecutionException. Will attempt again with new scroll",
context.jobId);
markScrollAsErrored();
searchResponse = executeSearchRequest(buildSearchRequest(lastTimestamp == null ? context.start : lastTimestamp));
}
LOGGER.debug("[{}] Search response was obtained", context.jobId);
timingStatsReporter.reportSearchDuration(searchResponse.getTook());
return processSearchResponse(searchResponse);
}

private void markScrollAsErrored() {
void markScrollAsErrored() {
// This could be a transient error with the scroll Id.
// Reinitialise the scroll and try again but only once.
clearScroll();
scrollId = null;
if (lastTimestamp != null) {
lastTimestamp++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
package org.elasticsearch.xpack.ml.datafeed.extractor.aggregation;

import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
Expand Down Expand Up @@ -64,6 +65,7 @@ public class AggregationDataExtractorTests extends ESTestCase {
private class TestDataExtractor extends AggregationDataExtractor {

private SearchResponse nextResponse;
private SearchPhaseExecutionException ex;

TestDataExtractor(long start, long end) {
super(testClient, createContext(start, end), timingStatsReporter);
Expand All @@ -72,12 +74,19 @@ private class TestDataExtractor extends AggregationDataExtractor {
@Override
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
capturedSearchRequests.add(searchRequestBuilder);
if (ex != null) {
throw ex;
}
return nextResponse;
}

void setNextResponse(SearchResponse searchResponse) {
nextResponse = searchResponse;
}

void setNextResponseToError(SearchPhaseExecutionException ex) {
this.ex = ex;
}
}

@Before
Expand Down Expand Up @@ -246,29 +255,12 @@ public void testExtractionGivenCancelHalfWay() throws IOException {
assertThat(capturedSearchRequests.size(), equalTo(1));
}

public void testExtractionGivenSearchResponseHasError() throws IOException {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createErrorResponse());

assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
}

public void testExtractionGivenSearchResponseHasShardFailures() {
public void testExtractionGivenSearchResponseHasError() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithShardFailures());
extractor.setNextResponseToError(new SearchPhaseExecutionException("phase 1", "boom", ShardSearchFailure.EMPTY_ARRAY));

assertThat(extractor.hasNext(), is(true));
expectThrows(IOException.class, extractor::next);
}

public void testExtractionGivenInitSearchResponseEncounteredUnavailableShards() {
TestDataExtractor extractor = new TestDataExtractor(1000L, 2000L);
extractor.setNextResponse(createResponseWithUnavailableShards(2));

assertThat(extractor.hasNext(), is(true));
IOException e = expectThrows(IOException.class, extractor::next);
assertThat(e.getMessage(), equalTo("[" + jobId + "] Search request encountered [2] unavailable shards"));
expectThrows(SearchPhaseExecutionException.class, extractor::next);
}

private AggregationDataExtractorContext createContext(long start, long end) {
Expand All @@ -295,29 +287,6 @@ private SearchResponse createSearchResponse(Aggregations aggregations) {
return searchResponse;
}

private SearchResponse createErrorResponse() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.INTERNAL_SERVER_ERROR);
return searchResponse;
}

private SearchResponse createResponseWithShardFailures() {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getShardFailures()).thenReturn(
new ShardSearchFailure[] { new ShardSearchFailure(new RuntimeException("shard failed"))});
return searchResponse;
}

private SearchResponse createResponseWithUnavailableShards(int unavailableShards) {
SearchResponse searchResponse = mock(SearchResponse.class);
when(searchResponse.status()).thenReturn(RestStatus.OK);
when(searchResponse.getSuccessfulShards()).thenReturn(3);
when(searchResponse.getTotalShards()).thenReturn(3 + unavailableShards);
when(searchResponse.getTook()).thenReturn(TimeValue.timeValueMillis(randomNonNegativeLong()));
return searchResponse;
}

private static String asString(InputStream inputStream) throws IOException {
try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
return reader.lines().collect(Collectors.joining("\n"));
Expand Down
Loading

0 comments on commit 5074f2c

Please sign in to comment.