From e26dd213e7ea149352c2da91687cd3ea67d4c8f5 Mon Sep 17 00:00:00 2001 From: Ivan Brusic Date: Mon, 25 Sep 2023 15:38:14 -0700 Subject: [PATCH] Add support for show_only_intersecting Signed-off-by: Ivan Brusic --- .../70_adjacency_matrix.yml | 32 +++++++++ .../AdjacencyMatrixAggregationBuilder.java | 66 +++++++++++++++++-- .../adjacency/AdjacencyMatrixAggregator.java | 17 +++-- .../AdjacencyMatrixAggregatorFactory.java | 16 ++++- .../metrics/AdjacencyMatrixTests.java | 9 +++ 5 files changed, 127 insertions(+), 13 deletions(-) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/70_adjacency_matrix.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/70_adjacency_matrix.yml index f8fa537ed91bf..c3e83f20d58e8 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/70_adjacency_matrix.yml +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/70_adjacency_matrix.yml @@ -125,3 +125,35 @@ setup: - match: { aggregations.conns.buckets.3.doc_count: 1 } - match: { aggregations.conns.buckets.3.key: "4" } + + +--- +"Show only intersections": + + - do: + search: + index: test + rest_total_hits_as_int: true + body: + size: 0 + aggs: + conns: + adjacency_matrix: + show_only_intersecting: true + filters: + 1: + term: + num: 1 + 2: + term: + num: 2 + 4: + term: + num: 4 + + - match: { hits.total: 3 } + + - length: { aggregations.conns.buckets: 1 } + + - match: { aggregations.conns.buckets.0.doc_count: 1 } + - match: { aggregations.conns.buckets.0.key: "1&2" } diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregationBuilder.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregationBuilder.java index 743d0023364fa..b3d14f3763dca 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregationBuilder.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregationBuilder.java @@ -71,7 +71,11 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde private static final ParseField SEPARATOR_FIELD = new ParseField("separator"); private static final ParseField FILTERS_FIELD = new ParseField("filters"); + + public static final ParseField SHOW_ONLY_INTERSECTING = new ParseField("show_only_intersecting"); + private List filters; + private boolean showOnlyIntersecting = false; private String separator = DEFAULT_SEPARATOR; private static final ObjectParser PARSER = ObjectParser.fromBuilder( @@ -81,6 +85,10 @@ public class AdjacencyMatrixAggregationBuilder extends AbstractAggregationBuilde static { PARSER.declareString(AdjacencyMatrixAggregationBuilder::separator, SEPARATOR_FIELD); PARSER.declareNamedObjects(AdjacencyMatrixAggregationBuilder::setFiltersAsList, KeyedFilter.PARSER, FILTERS_FIELD); + PARSER.declareBoolean( + AdjacencyMatrixAggregationBuilder::showOnlyIntersecting, + AdjacencyMatrixAggregationBuilder.SHOW_ONLY_INTERSECTING + ); } public static AggregationBuilder parse(XContentParser parser, String name) throws IOException { @@ -115,6 +123,7 @@ protected AdjacencyMatrixAggregationBuilder( super(clone, factoriesBuilder, metadata); this.filters = new ArrayList<>(clone.filters); this.separator = clone.separator; + this.showOnlyIntersecting = clone.showOnlyIntersecting; } @Override @@ -138,6 +147,28 @@ public AdjacencyMatrixAggregationBuilder(String name, String separator, Map filters, + boolean showOnlyIntersecting + ) { + this(name, separator, filters); + this.showOnlyIntersecting = showOnlyIntersecting; + } + /** * Read from a stream. */ @@ -145,6 +176,7 @@ public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException { super(in); int filtersSize = in.readVInt(); separator = in.readString(); + showOnlyIntersecting = in.readBoolean(); filters = new ArrayList<>(filtersSize); for (int i = 0; i < filtersSize; i++) { filters.add(new KeyedFilter(in)); @@ -155,6 +187,7 @@ public AdjacencyMatrixAggregationBuilder(StreamInput in) throws IOException { protected void doWriteTo(StreamOutput out) throws IOException { out.writeVInt(filters.size()); out.writeString(separator); + out.writeBoolean(showOnlyIntersecting); for (KeyedFilter keyedFilter : filters) { keyedFilter.writeTo(out); } @@ -185,6 +218,11 @@ private AdjacencyMatrixAggregationBuilder setFiltersAsList(List fil return this; } + public AdjacencyMatrixAggregationBuilder showOnlyIntersecting(boolean showOnlyIntersecting) { + this.showOnlyIntersecting = showOnlyIntersecting; + return this; + } + /** * Set the separator used to join pairs of bucket keys */ @@ -214,6 +252,10 @@ public Map filters() { return result; } + public boolean isShowOnlyIntersecting() { + return showOnlyIntersecting; + } + @Override protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryShardContext) throws IOException { boolean modified = false; @@ -224,7 +266,9 @@ protected AdjacencyMatrixAggregationBuilder doRewrite(QueryRewriteContext queryS rewrittenFilters.add(new KeyedFilter(kf.key(), rewritten)); } if (modified) { - return new AdjacencyMatrixAggregationBuilder(name).separator(separator).setFiltersAsList(rewrittenFilters); + return new AdjacencyMatrixAggregationBuilder(name).separator(separator) + .setFiltersAsList(rewrittenFilters) + .showOnlyIntersecting(showOnlyIntersecting); } return this; } @@ -245,7 +289,16 @@ protected AggregatorFactory doBuild(QueryShardContext queryShardContext, Aggrega + "] index level setting." ); } - return new AdjacencyMatrixAggregatorFactory(name, filters, separator, queryShardContext, parent, subFactoriesBuilder, metadata); + return new AdjacencyMatrixAggregatorFactory( + name, + filters, + showOnlyIntersecting, + separator, + queryShardContext, + parent, + subFactoriesBuilder, + metadata + ); } @Override @@ -257,7 +310,8 @@ public BucketCardinality bucketCardinality() { protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(SEPARATOR_FIELD.getPreferredName(), separator); - builder.startObject(AdjacencyMatrixAggregator.FILTERS_FIELD.getPreferredName()); + builder.field(SHOW_ONLY_INTERSECTING.getPreferredName(), showOnlyIntersecting); + builder.startObject(FILTERS_FIELD.getPreferredName()); for (KeyedFilter keyedFilter : filters) { builder.field(keyedFilter.key(), keyedFilter.filter()); } @@ -268,7 +322,7 @@ protected XContentBuilder internalXContent(XContentBuilder builder, Params param @Override public int hashCode() { - return Objects.hash(super.hashCode(), filters, separator); + return Objects.hash(super.hashCode(), filters, showOnlyIntersecting, separator); } @Override @@ -277,7 +331,9 @@ public boolean equals(Object obj) { if (obj == null || getClass() != obj.getClass()) return false; if (super.equals(obj) == false) return false; AdjacencyMatrixAggregationBuilder other = (AdjacencyMatrixAggregationBuilder) obj; - return Objects.equals(filters, other.filters) && Objects.equals(separator, other.separator); + return Objects.equals(filters, other.filters) + && Objects.equals(separator, other.separator) + && Objects.equals(showOnlyIntersecting, other.showOnlyIntersecting); } @Override diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java index ef1795f425240..323c1cd4ec241 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregator.java @@ -36,7 +36,6 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; import org.opensearch.common.lucene.Lucene; -import org.opensearch.core.ParseField; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; @@ -70,8 +69,6 @@ */ public class AdjacencyMatrixAggregator extends BucketsAggregator { - public static final ParseField FILTERS_FIELD = new ParseField("filters"); - /** * A keyed filter * @@ -145,6 +142,8 @@ public boolean equals(Object obj) { private final String[] keys; private final Weight[] filters; + + private final boolean showOnlyIntersecting; private final int totalNumKeys; private final int totalNumIntersections; private final String separator; @@ -155,6 +154,7 @@ public AdjacencyMatrixAggregator( String separator, String[] keys, Weight[] filters, + boolean showOnlyIntersecting, SearchContext context, Aggregator parent, Map metadata @@ -163,6 +163,7 @@ public AdjacencyMatrixAggregator( this.separator = separator; this.keys = keys; this.filters = filters; + this.showOnlyIntersecting = showOnlyIntersecting; this.totalNumIntersections = ((keys.length * keys.length) - keys.length) / 2; this.totalNumKeys = keys.length + totalNumIntersections; } @@ -177,10 +178,12 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, final LeafBuc return new LeafBucketCollectorBase(sub, null) { @Override public void collect(int doc, long bucket) throws IOException { - // Check each of the provided filters - for (int i = 0; i < bits.length; i++) { - if (bits[i].get(doc)) { - collectBucket(sub, doc, bucketOrd(bucket, i)); + if (!showOnlyIntersecting) { + // Check each of the provided filters + for (int i = 0; i < bits.length; i++) { + if (bits[i].get(doc)) { + collectBucket(sub, doc, bucketOrd(bucket, i)); + } } } // Check all the possible intersections of the provided filters diff --git a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregatorFactory.java b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregatorFactory.java index 99ffb563ba2a8..bae86f3fcdfc1 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregatorFactory.java +++ b/server/src/main/java/org/opensearch/search/aggregations/bucket/adjacency/AdjacencyMatrixAggregatorFactory.java @@ -57,11 +57,14 @@ public class AdjacencyMatrixAggregatorFactory extends AggregatorFactory { private final String[] keys; private final Weight[] weights; + + private final boolean showOnlyIntersecting; private final String separator; public AdjacencyMatrixAggregatorFactory( String name, List filters, + boolean showOnlyIntersecting, String separator, QueryShardContext queryShardContext, AggregatorFactory parent, @@ -79,6 +82,7 @@ public AdjacencyMatrixAggregatorFactory( Query filter = keyedFilter.filter().toQuery(queryShardContext); weights[i] = contextSearcher.createWeight(contextSearcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1f); } + this.showOnlyIntersecting = showOnlyIntersecting; } @Override @@ -88,7 +92,17 @@ public Aggregator createInternal( CardinalityUpperBound cardinality, Map metadata ) throws IOException { - return new AdjacencyMatrixAggregator(name, factories, separator, keys, weights, searchContext, parent, metadata); + return new AdjacencyMatrixAggregator( + name, + factories, + separator, + keys, + weights, + showOnlyIntersecting, + searchContext, + parent, + metadata + ); } @Override diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/AdjacencyMatrixTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/AdjacencyMatrixTests.java index c5cf56f6caff7..84ab691a91e76 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/AdjacencyMatrixTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/AdjacencyMatrixTests.java @@ -68,4 +68,13 @@ public void testFiltersSameMap() { assertEquals(original, builder.filters()); assert original != builder.filters(); } + + public void testShowOnlyIntersecting() { + Map original = new HashMap<>(); + original.put("bbb", new MatchNoneQueryBuilder()); + original.put("aaa", new MatchNoneQueryBuilder()); + AdjacencyMatrixAggregationBuilder builder; + builder = new AdjacencyMatrixAggregationBuilder("my-agg", "&", original, true); + assertEquals(true, builder.isShowOnlyIntersecting()); + } }