Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip final reduction if SearchRequest holds a cluster alias #37000

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -407,17 +407,18 @@ private SearchHits getHits(ReducedQueryPhase reducedQueryPhase, boolean ignoreFr
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
return reducedQueryPhase(queryResults, true, true);
ReducedQueryPhase reducedScrollQueryPhase(Collection<? extends SearchPhaseResult> queryResults) {
return reducedQueryPhase(queryResults, true, true, true);
}

/**
* Reduces the given query results and consumes all aggregations and profile results.
* @param queryResults a list of non-null query shard results
*/
public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, boolean trackTotalHits) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest);
ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
boolean isScrollRequest, boolean trackTotalHits, boolean performFinalReduce) {
return reducedQueryPhase(queryResults, null, new ArrayList<>(), new TopDocsStats(trackTotalHits), 0, isScrollRequest,
performFinalReduce);
}

/**
Expand All @@ -433,7 +434,8 @@ public ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResul
*/
private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResult> queryResults,
List<InternalAggregations> bufferedAggs, List<TopDocs> bufferedTopDocs,
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest) {
TopDocsStats topDocsStats, int numReducePhases, boolean isScrollRequest,
boolean performFinalReduce) {
assert numReducePhases >= 0 : "num reduce phases must be >= 0 but was: " + numReducePhases;
numReducePhases++; // increment for this phase
boolean timedOut = false;
Expand Down Expand Up @@ -499,15 +501,15 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
}
}
final Suggest suggest = groupedSuggestions.isEmpty() ? null : new Suggest(Suggest.reduce(groupedSuggestions));
ReduceContext reduceContext = reduceContextFunction.apply(true);
ReduceContext reduceContext = reduceContextFunction.apply(performFinalReduce);
final InternalAggregations aggregations = aggregationsList.isEmpty() ? null : reduceAggs(aggregationsList,
firstResult.pipelineAggregators(), reduceContext);
final SearchProfileShardResults shardResults = profileResults.isEmpty() ? null : new SearchProfileShardResults(profileResults);
final SortedTopDocs sortedTopDocs = sortDocs(isScrollRequest, queryResults, bufferedTopDocs, topDocsStats, from, size);
final TotalHits totalHits = topDocsStats.getTotalHits();
return new ReducedQueryPhase(totalHits, topDocsStats.fetchHits, topDocsStats.maxScore,
timedOut, terminatedEarly, suggest, aggregations, shardResults, sortedTopDocs,
firstResult.sortValueFormats(), numReducePhases, size, from, firstResult == null);
firstResult.sortValueFormats(), numReducePhases, size, from, false);
}

/**
Expand Down Expand Up @@ -617,6 +619,7 @@ static final class QueryPhaseResultConsumer extends InitialSearchPhase.ArraySear
private final SearchPhaseController controller;
private int numReducePhases = 0;
private final TopDocsStats topDocsStats = new TopDocsStats();
private final boolean performFinalReduce;

/**
* Creates a new {@link QueryPhaseResultConsumer}
Expand All @@ -626,7 +629,7 @@ static final class QueryPhaseResultConsumer extends InitialSearchPhase.ArraySear
* the buffer is used to incrementally reduce aggregation results before all shards responded.
*/
private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedResultSize, int bufferSize,
boolean hasTopDocs, boolean hasAggs) {
boolean hasTopDocs, boolean hasAggs, boolean performFinalReduce) {
super(expectedResultSize);
if (expectedResultSize != 1 && bufferSize < 2) {
throw new IllegalArgumentException("buffer size must be >= 2 if there is more than one expected result");
Expand All @@ -644,6 +647,7 @@ private QueryPhaseResultConsumer(SearchPhaseController controller, int expectedR
this.hasTopDocs = hasTopDocs;
this.hasAggs = hasAggs;
this.bufferSize = bufferSize;
this.performFinalReduce = performFinalReduce;
}

@Override
Expand Down Expand Up @@ -693,7 +697,7 @@ private synchronized List<TopDocs> getRemainingTopDocs() {
@Override
public ReducedQueryPhase reduce() {
return controller.reducedQueryPhase(results.asList(), getRemainingAggs(), getRemainingTopDocs(), topDocsStats,
numReducePhases, false);
numReducePhases, false, performFinalReduce);
}

/**
Expand All @@ -715,18 +719,19 @@ InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResu
final boolean hasAggs = source != null && source.aggregations() != null;
final boolean hasTopDocs = source == null || source.size() != 0;
final boolean trackTotalHits = source == null || source.trackTotalHits();
final boolean finalReduce = request.getLocalClusterAlias() == null;

if (isScrollRequest == false && (hasAggs || hasTopDocs)) {
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
if (request.getBatchedReduceSize() < numShards) {
// only use this if there are aggs and if there are more shards than we should reduce at once
return new QueryPhaseResultConsumer(this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs);
return new QueryPhaseResultConsumer(this, numShards, request.getBatchedReduceSize(), hasTopDocs, hasAggs, finalReduce);
}
}
return new InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult>(numShards) {
@Override
ReducedQueryPhase reduce() {
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits);
return reducedQueryPhase(results.asList(), isScrollRequest, trackTotalHits, finalReduce);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,34 @@
import org.junit.Before;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.lessThanOrEqualTo;
import static org.hamcrest.Matchers.not;

public class SearchPhaseControllerTests extends ESTestCase {
private SearchPhaseController searchPhaseController;
private List<Boolean> reductions;

@Before
public void setup() {
reductions = new CopyOnWriteArrayList<>();
searchPhaseController = new SearchPhaseController(
(b) -> new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, b));
(finalReduce) -> {
reductions.add(finalReduce);
return new InternalAggregation.ReduceContext(BigArrays.NON_RECYCLING_INSTANCE, null, finalReduce);
});
}

public void testSort() {
Expand Down Expand Up @@ -158,7 +164,7 @@ public void testMerge() {
AtomicArray<SearchPhaseResult> queryResults = generateQueryResults(nShards, suggestions, queryResultSize, false);
for (boolean trackTotalHits : new boolean[] {true, false}) {
SearchPhaseController.ReducedQueryPhase reducedQueryPhase =
searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits);
searchPhaseController.reducedQueryPhase(queryResults.asList(), false, trackTotalHits, true);
AtomicArray<SearchPhaseResult> fetchResults = generateFetchResults(nShards,
reducedQueryPhase.sortedTopDocs.scoreDocs, reducedQueryPhase.suggest);
InternalSearchResponse mergedResponse = searchPhaseController.merge(false,
Expand Down Expand Up @@ -308,14 +314,15 @@ private static AtomicArray<SearchPhaseResult> generateFetchResults(int nShards,

public void testConsumer() {
int bufferSize = randomIntBetween(2, 3);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer = searchPhaseController.newSearchPhaseResults(request, 3);
assertEquals(0, reductions.size());
QuerySearchResult result = new QuerySearchResult(0, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 1.0D, DocValueFormat.RAW,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 1.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(0);
Expand All @@ -324,7 +331,7 @@ public void testConsumer() {
result = new QuerySearchResult(1, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 3.0D, DocValueFormat.RAW,
aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 3.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(2);
Expand All @@ -333,23 +340,29 @@ public void testConsumer() {
result = new QuerySearchResult(1, new SearchShardTarget("node", new Index("a", "b"), 0, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), Float.NaN),
new DocValueFormat[0]);
aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", 2.0D, DocValueFormat.RAW,
aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", 2.0D, DocValueFormat.RAW,
Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(1);
consumer.consumeResult(result);
int numTotalReducePhases = 1;
final int numTotalReducePhases;
if (bufferSize == 2) {
assertThat(consumer, instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class));
assertEquals(1, ((SearchPhaseController.QueryPhaseResultConsumer)consumer).getNumReducePhases());
assertEquals(2, ((SearchPhaseController.QueryPhaseResultConsumer)consumer).getNumBuffered());
numTotalReducePhases++;
assertEquals(1, reductions.size());
assertEquals(false, reductions.get(0));
numTotalReducePhases = 2;
} else {
assertThat(consumer, not(instanceOf(SearchPhaseController.QueryPhaseResultConsumer.class)));
assertEquals(0, reductions.size());
numTotalReducePhases = 1;
}

SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertEquals(numTotalReducePhases, reduce.numReducePhases);
assertEquals(numTotalReducePhases, reductions.size());
assertFinalReduction(request);
InternalMax max = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(3.0D, max.getValue(), 0.0D);
assertFalse(reduce.sortedTopDocs.isSortedByField);
Expand All @@ -362,7 +375,7 @@ public void testConsumerConcurrently() throws InterruptedException {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);

SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -378,7 +391,7 @@ public void testConsumerConcurrently() throws InterruptedException {
result.topDocs(new TopDocsAndMaxScore(
new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] {new ScoreDoc(0, number)}), number),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", (double) number,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(id);
Expand All @@ -392,6 +405,7 @@ public void testConsumerConcurrently() throws InterruptedException {
threads[i].join();
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(max.get(), internalMax.getValue(), 0.0D);
assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
Expand All @@ -407,7 +421,7 @@ public void testConsumerConcurrently() throws InterruptedException {
public void testConsumerOnlyAggs() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
request.source(new SearchSourceBuilder().aggregation(AggregationBuilders.avg("foo")).size(0));
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -419,14 +433,15 @@ public void testConsumerOnlyAggs() {
QuerySearchResult result = new QuerySearchResult(i, new SearchShardTarget("node", new Index("a", "b"), i, null));
result.topDocs(new TopDocsAndMaxScore(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), number),
new DocValueFormat[0]);
InternalAggregations aggs = new InternalAggregations(Arrays.asList(new InternalMax("test", (double) number,
InternalAggregations aggs = new InternalAggregations(Collections.singletonList(new InternalMax("test", (double) number,
DocValueFormat.RAW, Collections.emptyList(), Collections.emptyMap())));
result.aggregations(aggs);
result.setShardIndex(i);
result.size(1);
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
InternalMax internalMax = (InternalMax) reduce.aggregations.asList().get(0);
assertEquals(max.get(), internalMax.getValue(), 0.0D);
assertEquals(0, reduce.sortedTopDocs.scoreDocs.length);
Expand All @@ -441,7 +456,7 @@ public void testConsumerOnlyAggs() {
public void testConsumerOnlyHits() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
if (randomBoolean()) {
request.source(new SearchSourceBuilder().size(randomIntBetween(1, 10)));
}
Expand All @@ -460,6 +475,7 @@ public void testConsumerOnlyHits() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(1, reduce.sortedTopDocs.scoreDocs.length);
assertEquals(max.get(), reduce.maxScore, 0.0f);
assertEquals(expectedNumResults, reduce.totalHits.value);
Expand All @@ -470,6 +486,12 @@ public void testConsumerOnlyHits() {
assertNull(reduce.sortedTopDocs.collapseValues);
}

private void assertFinalReduction(SearchRequest searchRequest) {
assertThat(reductions.size(), greaterThanOrEqualTo(1));
//the last reduction step was the final one only if no cluster alias was provided with the search request
assertEquals(searchRequest.getLocalClusterAlias() == null, reductions.get(reductions.size() - 1));
}

public void testNewSearchPhaseResults() {
for (int i = 0; i < 10; i++) {
int expectedNumResults = randomIntBetween(1, 10);
Expand Down Expand Up @@ -540,7 +562,7 @@ public void testReduceTopNWithFromOffset() {
public void testConsumerSortByField() {
int expectedNumResults = randomIntBetween(1, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
int size = randomIntBetween(1, 10);
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -560,6 +582,7 @@ public void testConsumerSortByField() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(Math.min(expectedNumResults, size), reduce.sortedTopDocs.scoreDocs.length);
assertEquals(expectedNumResults, reduce.totalHits.value);
assertEquals(max.get(), ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
Expand All @@ -574,7 +597,7 @@ public void testConsumerSortByField() {
public void testConsumerFieldCollapsing() {
int expectedNumResults = randomIntBetween(30, 100);
int bufferSize = randomIntBetween(2, 200);
SearchRequest request = new SearchRequest();
SearchRequest request = randomBoolean() ? new SearchRequest() : new SearchRequest("remote");
int size = randomIntBetween(5, 10);
request.setBatchedReduceSize(bufferSize);
InitialSearchPhase.ArraySearchPhaseResults<SearchPhaseResult> consumer =
Expand All @@ -596,6 +619,7 @@ public void testConsumerFieldCollapsing() {
consumer.consumeResult(result);
}
SearchPhaseController.ReducedQueryPhase reduce = consumer.reduce();
assertFinalReduction(request);
assertEquals(3, reduce.sortedTopDocs.scoreDocs.length);
assertEquals(expectedNumResults, reduce.totalHits.value);
assertEquals(a, ((FieldDoc)reduce.sortedTopDocs.scoreDocs[0]).fields[0]);
Expand Down