diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java index 59f48bd7fbaba..07edf487af670 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregator.java @@ -33,6 +33,7 @@ import org.opensearch.search.aggregations.InternalAggregation; import org.opensearch.search.aggregations.InternalOrder; import org.opensearch.search.aggregations.LeafBucketCollector; +import org.opensearch.search.aggregations.bucket.BucketsAggregator; import org.opensearch.search.aggregations.bucket.DeferableBucketAggregator; import org.opensearch.search.aggregations.bucket.LocalBucketCountThresholds; import org.opensearch.search.aggregations.support.AggregationPath; @@ -215,19 +216,11 @@ public InternalAggregation buildEmptyAggregation() { @Override protected LeafBucketCollector getLeafCollector(LeafReaderContext ctx, LeafBucketCollector sub) throws IOException { - MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, this, sub); return new LeafBucketCollector() { @Override public void collect(int doc, long owningBucketOrd) throws IOException { - for (BytesRef compositeKey : collector.apply(doc)) { - long bucketOrd = bucketOrds.add(owningBucketOrd, compositeKey); - if (bucketOrd < 0) { - bucketOrd = -1 - bucketOrd; - collectExistingBucket(sub, doc, bucketOrd); - } else { - collectBucket(sub, doc, bucketOrd); - } - } + collector.apply(doc, owningBucketOrd); } }; } @@ -268,12 +261,10 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept } // we need to fill-in the blanks for (LeafReaderContext ctx : context.searcher().getTopReaderContext().leaves()) { - MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx); // brute force + MultiTermsValuesSourceCollector collector = multiTermsValue.getValues(ctx, bucketOrds, null, null); for (int docId = 0; docId < ctx.reader().maxDoc(); ++docId) { - for (BytesRef compositeKey : collector.apply(docId)) { - bucketOrds.add(owningBucketOrd, compositeKey); - } + collector.apply(docId, owningBucketOrd); } } } @@ -284,10 +275,11 @@ private void collectZeroDocEntriesIfNeeded(long owningBucketOrd) throws IOExcept @FunctionalInterface interface MultiTermsValuesSourceCollector { /** - * Collect a list values of multi_terms on each doc. - * Each terms could have multi_values, so the result is the cartesian product of each term's values. + * Generates the cartesian product of all fields used in aggregation and + * collects them in buckets using the composite key of their field values. */ - List apply(int doc) throws IOException; + void apply(int doc, long owningBucketOrd) throws IOException; + } @FunctionalInterface @@ -361,47 +353,72 @@ public MultiTermsValuesSource(List valuesSources) { this.valuesSources = valuesSources; } - public MultiTermsValuesSourceCollector getValues(LeafReaderContext ctx) throws IOException { + public MultiTermsValuesSourceCollector getValues( + LeafReaderContext ctx, + BytesKeyedBucketOrds bucketOrds, + BucketsAggregator aggregator, + LeafBucketCollector sub + ) throws IOException { List collectors = new ArrayList<>(); for (InternalValuesSource valuesSource : valuesSources) { collectors.add(valuesSource.apply(ctx)); } + boolean collectBucketOrds = aggregator != null && sub != null; return new MultiTermsValuesSourceCollector() { + + /** + * This method does the following :
+ *
  • Fetches the values of every field present in the doc List>> via @{@link InternalValuesSourceCollector}
  • + *
  • Generates Composite keys from the fetched values for all fields present in the aggregation.
  • + *
  • Adds every composite key to the @{@link BytesKeyedBucketOrds} and Optionally collects them via @{@link BucketsAggregator#collectBucket(LeafBucketCollector, int, long)}
  • + */ @Override - public List apply(int doc) throws IOException { + public void apply(int doc, long owningBucketOrd) throws IOException { + // TODO A new list creation can be avoided for every doc. List>> collectedValues = new ArrayList<>(); for (InternalValuesSourceCollector collector : collectors) { collectedValues.add(collector.apply(doc)); } - List result = new ArrayList<>(); scratch.seek(0); scratch.writeVInt(collectors.size()); // number of fields per composite key - cartesianProduct(result, scratch, collectedValues, 0); - return result; + generateAndCollectCompositeKeys(collectedValues, 0, owningBucketOrd, doc); } /** - * Cartesian product using depth first search. - * - *

    - * Composite keys are encoded to a {@link BytesRef} in a format compatible with {@link StreamOutput::writeGenericValue}, - * but reuses the encoding of the shared prefixes from the previous levels to avoid wasteful work. + * This generates and collects all Composite keys in their buckets by performing a cartesian product
    + * of all the values in all the fields ( used in agg ) for the given doc recursively. + * @param collectedValues : Values of all fields present in the aggregation for the @doc + * @param index : Points to the field being added to generate the composite key */ - private void cartesianProduct( - List compositeKeys, - BytesStreamOutput scratch, + private void generateAndCollectCompositeKeys( List>> collectedValues, - int index + int index, + long owningBucketOrd, + int doc ) throws IOException { if (collectedValues.size() == index) { - compositeKeys.add(BytesRef.deepCopyOf(scratch.bytes().toBytesRef())); + // Avoid performing a deep copy of the composite key by inlining. + long bucketOrd = bucketOrds.add(owningBucketOrd, scratch.bytes().toBytesRef()); + if (collectBucketOrds) { + if (bucketOrd < 0) { + bucketOrd = -1 - bucketOrd; + aggregator.collectExistingBucket(sub, doc, bucketOrd); + } else { + aggregator.collectBucket(sub, doc, bucketOrd); + } + } return; } long position = scratch.position(); - for (TermValue value : collectedValues.get(index)) { + List> values = collectedValues.get(index); + int numIterations = values.size(); + // For each loop is not done to reduce the allocations done for Iterator objects + // once for every field in every doc. + for (int i = 0; i < numIterations; i++) { + TermValue value = values.get(i); value.writeTo(scratch); // encode the value - cartesianProduct(compositeKeys, scratch, collectedValues, index + 1); // dfs + generateAndCollectCompositeKeys(collectedValues, index + 1, owningBucketOrd, doc); // dfs scratch.seek(position); // backtrack } } @@ -441,9 +458,14 @@ static InternalValuesSource bytesValuesSource(ValuesSource valuesSource, Include if (i > 0 && bytes.equals(previous)) { continue; } - BytesRef copy = BytesRef.deepCopyOf(bytes); - termValues.add(TermValue.of(copy)); - previous = copy; + // Performing a deep copy is not required for field containing only one value. + if (valuesCount > 1) { + BytesRef copy = BytesRef.deepCopyOf(bytes); + termValues.add(TermValue.of(copy)); + previous = copy; + } else { + termValues.add(TermValue.of(bytes)); + } } return termValues; }; diff --git a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java index d550c4c354c0f..bb46c5607a4a7 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/bucket/terms/MultiTermsAggregatorTests.java @@ -126,6 +126,19 @@ public class MultiTermsAggregatorTests extends AggregatorTestCase { private static final Consumer NONE_DECORATOR = null; + private static final Consumer IP_AND_KEYWORD_DESC_ORDER_VERIFY = h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("192.168.0.0"))); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|192.168.0.0")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo("192.168.0.1"))); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("b|192.168.0.1")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("192.168.0.2"))); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("c|192.168.0.2")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); + }; + @Override protected List getSupportedValuesSourceTypes() { return Collections.unmodifiableList( @@ -672,8 +685,48 @@ public void testDatesFieldFormat() throws IOException { ); } - public void testIpAndKeyword() throws IOException { - testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, IP_FIELD)), NONE_DECORATOR, iw -> { + public void testIpAndKeywordDefaultDescOrder() throws IOException { + ipAndKeywordTest(NONE_DECORATOR, IP_AND_KEYWORD_DESC_ORDER_VERIFY); + } + + public void testIpAndKeywordWithBucketCountSameAsSize() throws IOException { + ipAndKeywordTest(multiTermsAggregationBuilder -> { + multiTermsAggregationBuilder.minDocCount(0); + multiTermsAggregationBuilder.size(3); + multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(false))); + }, IP_AND_KEYWORD_DESC_ORDER_VERIFY); + } + + public void testIpAndKeywordWithBucketCountGreaterThanSize() throws IOException { + ipAndKeywordTest(multiTermsAggregationBuilder -> { + multiTermsAggregationBuilder.minDocCount(0); + multiTermsAggregationBuilder.size(10); + multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(false))); + }, IP_AND_KEYWORD_DESC_ORDER_VERIFY); + } + + public void testIpAndKeywordAscOrder() throws IOException { + ipAndKeywordTest(multiTermsAggregationBuilder -> { + multiTermsAggregationBuilder.minDocCount(0); + multiTermsAggregationBuilder.size(3); + multiTermsAggregationBuilder.order(BucketOrder.compound(BucketOrder.count(true))); + }, h -> { + MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); + MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("b"), equalTo("192.168.0.1"))); + MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("b|192.168.0.1")); + MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("c"), equalTo("192.168.0.2"))); + MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("c|192.168.0.2")); + MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); + MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("a"), equalTo("192.168.0.0"))); + MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("a|192.168.0.0")); + MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(2L)); + }); + } + + private void ipAndKeywordTest(Consumer builderDecorator, Consumer verify) + throws IOException { + testAggregation(new MatchAllDocsQuery(), fieldConfigs(asList(KEYWORD_FIELD, IP_FIELD)), builderDecorator, iw -> { iw.addDocument( asList( new SortedDocValuesField(KEYWORD_FIELD, new BytesRef("a")), @@ -698,18 +751,7 @@ public void testIpAndKeyword() throws IOException { new SortedDocValuesField(IP_FIELD, new BytesRef(InetAddressPoint.encode(InetAddresses.forString("192.168.0.0")))) ) ); - }, h -> { - MatcherAssert.assertThat(h.getBuckets(), hasSize(3)); - MatcherAssert.assertThat(h.getBuckets().get(0).getKey(), contains(equalTo("a"), equalTo("192.168.0.0"))); - MatcherAssert.assertThat(h.getBuckets().get(0).getKeyAsString(), equalTo("a|192.168.0.0")); - MatcherAssert.assertThat(h.getBuckets().get(0).getDocCount(), equalTo(2L)); - MatcherAssert.assertThat(h.getBuckets().get(1).getKey(), contains(equalTo("b"), equalTo("192.168.0.1"))); - MatcherAssert.assertThat(h.getBuckets().get(1).getKeyAsString(), equalTo("b|192.168.0.1")); - MatcherAssert.assertThat(h.getBuckets().get(1).getDocCount(), equalTo(1L)); - MatcherAssert.assertThat(h.getBuckets().get(2).getKey(), contains(equalTo("c"), equalTo("192.168.0.2"))); - MatcherAssert.assertThat(h.getBuckets().get(2).getKeyAsString(), equalTo("c|192.168.0.2")); - MatcherAssert.assertThat(h.getBuckets().get(2).getDocCount(), equalTo(1L)); - }); + }, verify); } public void testEmpty() throws IOException {