From 21ace7a17b7abd98ab6aebefa3d4857b5f0e5a0d Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 22 May 2018 20:08:16 +0000 Subject: [PATCH 01/11] Add WeightedAvg aggregation --- .../stats/MatrixStatsAggregationBuilder.java | 4 +- .../matrix/stats/MatrixStatsAggregator.java | 6 +- .../stats/MatrixStatsAggregatorFactory.java | 4 +- .../matrix/stats/MatrixStatsParser.java | 4 +- ...luesSource.java => ArrayValuesSource.java} | 14 +- ... ArrayValuesSourceAggregationBuilder.java} | 32 +- ...> ArrayValuesSourceAggregatorFactory.java} | 12 +- ...rser.java => ArrayValuesSourceParser.java} | 38 +-- .../search.aggregation/260_weighted_avg.yml | 71 +++++ .../elasticsearch/search/MultiValueMode.java | 52 ++++ .../elasticsearch/search/SearchModule.java | 4 + .../aggregations/AggregationBuilders.java | 8 + .../weighted_avg/InternalWeightedAvg.java | 133 +++++++++ .../weighted_avg/ParsedWeightedAvg.java | 64 ++++ .../metrics/weighted_avg/WeightedAvg.java | 32 ++ .../WeightedAvgAggregationBuilder.java | 128 ++++++++ .../weighted_avg/WeightedAvgAggregator.java | 140 +++++++++ .../WeightedAvgAggregatorFactory.java | 63 ++++ .../support/MultiValuesSource.java | 125 ++++++++ .../MultiValuesSourceAggregationBuilder.java | 278 ++++++++++++++++++ .../MultiValuesSourceAggregatorFactory.java | 65 ++++ .../support/MultiValuesSourceConfig.java | 37 +++ .../support/MultiValuesSourceFieldConfig.java | 144 +++++++++ .../support/MultiValuesSourceParseHelper.java | 40 +++ .../aggregations/support/ValueType.java | 5 +- .../support/ValuesSourceParserHelper.java | 11 +- .../support/ValuesSourceType.java | 30 +- .../WeightedAvgAggregatorTests.java | 276 +++++++++++++++++ .../xpack/rollup/RollupRequestTranslator.java | 2 +- 29 files changed, 1756 insertions(+), 66 deletions(-) rename modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/{MultiValuesSource.java => ArrayValuesSource.java} (87%) rename modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/{MultiValuesSourceAggregationBuilder.java => ArrayValuesSourceAggregationBuilder.java} (92%) rename modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/{MultiValuesSourceAggregatorFactory.java => ArrayValuesSourceAggregatorFactory.java} (81%) rename modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/{MultiValuesSourceParser.java => ArrayValuesSourceParser.java} (87%) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java create mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java create mode 100644 server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java index ad8ae1a681191..75f24eb78037a 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java @@ -26,7 +26,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; @@ -38,7 +38,7 @@ import java.util.Map; public class MatrixStatsAggregationBuilder - extends MultiValuesSourceAggregationBuilder.LeafOnly { + extends ArrayValuesSourceAggregationBuilder.LeafOnly { public static final String NAME = "matrix_stats"; private MultiValueMode multiValueMode = MultiValueMode.AVG; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java index 578116d7b5eb2..aa19f62fedc4f 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java @@ -30,7 +30,7 @@ import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.metrics.MetricsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; -import org.elasticsearch.search.aggregations.support.MultiValuesSource.NumericMultiValuesSource; +import org.elasticsearch.search.aggregations.support.ArrayValuesSource.NumericArrayValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.internal.SearchContext; @@ -43,7 +43,7 @@ **/ final class MatrixStatsAggregator extends MetricsAggregator { /** Multiple ValuesSource with field names */ - private final NumericMultiValuesSource valuesSources; + private final NumericArrayValuesSource valuesSources; /** array of descriptive stats, per shard, needed to compute the correlation */ ObjectArray stats; @@ -53,7 +53,7 @@ final class MatrixStatsAggregator extends MetricsAggregator { Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); if (valuesSources != null && !valuesSources.isEmpty()) { - this.valuesSources = new NumericMultiValuesSource(valuesSources, multiValueMode); + this.valuesSources = new NumericArrayValuesSource(valuesSources, multiValueMode); stats = context.bigArrays().newObjectArray(1); } else { this.valuesSources = null; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java index 2c3ac82a0c1a8..fb456d75bb78b 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java @@ -23,7 +23,7 @@ import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.internal.SearchContext; @@ -33,7 +33,7 @@ import java.util.Map; final class MatrixStatsAggregatorFactory - extends MultiValuesSourceAggregatorFactory { + extends ArrayValuesSourceAggregatorFactory { private final MultiValueMode multiValueMode; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java index fd13037e8f922..0f48d1855ae3e 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java @@ -21,14 +21,14 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.MultiValueMode; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceParser.NumericValuesSourceParser; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceParser.NumericValuesSourceParser; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.aggregations.support.ValuesSourceType; import java.io.IOException; import java.util.Map; -import static org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD; +import static org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD; public class MatrixStatsParser extends NumericValuesSourceParser { diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java similarity index 87% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java index 0274c1748dde5..65542ba500554 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java @@ -28,13 +28,13 @@ /** * Class to encapsulate a set of ValuesSource objects labeled by field name */ -public abstract class MultiValuesSource { +public abstract class ArrayValuesSource { protected MultiValueMode multiValueMode; protected String[] names; protected VS[] values; - public static class NumericMultiValuesSource extends MultiValuesSource { - public NumericMultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + public static class NumericArrayValuesSource extends ArrayValuesSource { + public NumericArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); if (valuesSources != null) { this.values = valuesSources.values().toArray(new ValuesSource.Numeric[0]); @@ -51,8 +51,8 @@ public NumericDoubleValues getField(final int ordinal, LeafReaderContext ctx) th } } - public static class BytesMultiValuesSource extends MultiValuesSource { - public BytesMultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + public static class BytesArrayValuesSource extends ArrayValuesSource { + public BytesArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); this.values = valuesSources.values().toArray(new ValuesSource.Bytes[0]); } @@ -62,14 +62,14 @@ public Object getField(final int ordinal, LeafReaderContext ctx) throws IOExcept } } - public static class GeoPointValuesSource extends MultiValuesSource { + public static class GeoPointValuesSource extends ArrayValuesSource { public GeoPointValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); this.values = valuesSources.values().toArray(new ValuesSource.GeoPoint[0]); } } - private MultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + private ArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { if (valuesSources != null) { this.names = valuesSources.keySet().toArray(new String[0]); } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java similarity index 92% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java index 4cf497c9c02a5..39f5885e7c79c 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java @@ -44,13 +44,13 @@ import java.util.Map; import java.util.Objects; -public abstract class MultiValuesSourceAggregationBuilder> - extends AbstractAggregationBuilder { +public abstract class ArrayValuesSourceAggregationBuilder> + extends AbstractAggregationBuilder { public static final ParseField MULTIVALUE_MODE_FIELD = new ParseField("mode"); - public abstract static class LeafOnly> - extends MultiValuesSourceAggregationBuilder { + public abstract static class LeafOnly> + extends ArrayValuesSourceAggregationBuilder { protected LeafOnly(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { super(name, valuesSourceType, targetValueType); @@ -94,7 +94,7 @@ public AB subAggregations(Builder subFactories) { private Object missing = null; private Map missingMap = Collections.emptyMap(); - protected MultiValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { + protected ArrayValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { super(name); if (valuesSourceType == null) { throw new IllegalArgumentException("[valuesSourceType] must not be null: [" + name + "]"); @@ -103,7 +103,7 @@ protected MultiValuesSourceAggregationBuilder(String name, ValuesSourceType valu this.targetValueType = targetValueType; } - protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder clone, + protected ArrayValuesSourceAggregationBuilder(ArrayValuesSourceAggregationBuilder clone, Builder factoriesBuilder, Map metaData) { super(clone, factoriesBuilder, metaData); this.valuesSourceType = clone.valuesSourceType; @@ -115,7 +115,7 @@ protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilde this.missing = clone.missing; } - protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType) + protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType) throws IOException { super(in); assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType"; @@ -124,7 +124,7 @@ protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType v read(in); } - protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { + protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { super(in); assert serializeTargetValueType() : "Wrong read constructor called for subclass that serializes its targetValueType"; this.valuesSourceType = valuesSourceType; @@ -239,10 +239,10 @@ public Map missingMap() { } @Override - protected final MultiValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + protected final ArrayValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { Map> configs = resolveConfig(context); - MultiValuesSourceAggregatorFactory factory = innerBuild(context, configs, parent, subFactoriesBuilder); + ArrayValuesSourceAggregatorFactory factory = innerBuild(context, configs, parent, subFactoriesBuilder); return factory; } @@ -255,9 +255,9 @@ protected Map> resolveConfig(SearchContext contex return configs; } - protected abstract MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, - Map> configs, AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder) throws IOException; + protected abstract ArrayValuesSourceAggregatorFactory innerBuild(SearchContext context, + Map> configs, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException; public ValuesSourceConfig config(SearchContext context, String field, Script script) { @@ -355,14 +355,14 @@ public final XContentBuilder internalXContent(XContentBuilder builder, Params pa @Override protected final int doHashCode() { return Objects.hash(fields, format, missing, targetValueType, valueType, valuesSourceType, - innerHashCode()); + innerHashCode()); } protected abstract int innerHashCode(); @Override protected final boolean doEquals(Object obj) { - MultiValuesSourceAggregationBuilder other = (MultiValuesSourceAggregationBuilder) obj; + ArrayValuesSourceAggregationBuilder other = (ArrayValuesSourceAggregationBuilder) obj; if (!Objects.equals(fields, other.fields)) return false; if (!Objects.equals(format, other.format)) diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java similarity index 81% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java index 7d5c56a571bbe..cd6f0eb2b06bb 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java @@ -30,14 +30,14 @@ import java.util.List; import java.util.Map; -public abstract class MultiValuesSourceAggregatorFactory> - extends AggregatorFactory { +public abstract class ArrayValuesSourceAggregatorFactory> + extends AggregatorFactory { protected Map> configs; - public MultiValuesSourceAggregatorFactory(String name, Map> configs, - SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, - Map metaData) throws IOException { + public ArrayValuesSourceAggregatorFactory(String name, Map> configs, + SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { super(name, context, parent, subFactoriesBuilder, metaData); this.configs = configs; } @@ -63,6 +63,6 @@ protected abstract Aggregator createUnmapped(Aggregator parent, List metaData) throws IOException; protected abstract Aggregator doCreateInternal(Map valuesSources, Aggregator parent, boolean collectsFromSingleBucket, - List pipelineAggregators, Map metaData) throws IOException; + List pipelineAggregators, Map metaData) throws IOException; } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java similarity index 87% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java index 22a90b552d920..c2857411c0b39 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java @@ -33,30 +33,30 @@ import java.util.List; import java.util.Map; -public abstract class MultiValuesSourceParser implements Aggregator.Parser { +public abstract class ArrayValuesSourceParser implements Aggregator.Parser { - public abstract static class AnyValuesSourceParser extends MultiValuesSourceParser { + public abstract static class AnyValuesSourceParser extends ArrayValuesSourceParser { protected AnyValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.ANY, null); } } - public abstract static class NumericValuesSourceParser extends MultiValuesSourceParser { + public abstract static class NumericValuesSourceParser extends ArrayValuesSourceParser { protected NumericValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.NUMERIC, ValueType.NUMERIC); } } - public abstract static class BytesValuesSourceParser extends MultiValuesSourceParser { + public abstract static class BytesValuesSourceParser extends ArrayValuesSourceParser { protected BytesValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.BYTES, ValueType.STRING); } } - public abstract static class GeoPointValuesSourceParser extends MultiValuesSourceParser { + public abstract static class GeoPointValuesSourceParser extends ArrayValuesSourceParser { protected GeoPointValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.GEOPOINT, ValueType.GEOPOINT); @@ -67,15 +67,15 @@ protected GeoPointValuesSourceParser(boolean formattable) { private ValuesSourceType valuesSourceType = null; private ValueType targetValueType = null; - private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) { + private ArrayValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) { this.valuesSourceType = valuesSourceType; this.targetValueType = targetValueType; this.formattable = formattable; } @Override - public final MultiValuesSourceAggregationBuilder parse(String aggregationName, XContentParser parser) - throws IOException { + public final ArrayValuesSourceAggregationBuilder parse(String aggregationName, XContentParser parser) + throws IOException { List fields = null; ValueType valueType = null; @@ -98,7 +98,7 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour "Multi-field aggregations do not support scripts."); } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (token == XContentParser.Token.START_OBJECT) { if (CommonFields.MISSING.match(currentFieldName, parser.getDeprecationHandler())) { @@ -113,7 +113,7 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (token == XContentParser.Token.START_ARRAY) { if (Script.SCRIPT_PARSE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -127,21 +127,21 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour fields.add(parser.text()); } else { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } - MultiValuesSourceAggregationBuilder factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType, - otherOptions); + ArrayValuesSourceAggregationBuilder factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType, + otherOptions); if (fields != null) { factory.fields(fields); } @@ -182,7 +182,7 @@ private void parseMissingAndAdd(final String aggregationName, final String curre /** * Creates a {@link ValuesSourceAggregationBuilder} from the information * gathered by the subclass. Options parsed in - * {@link MultiValuesSourceParser} itself will be added to the factory + * {@link ArrayValuesSourceParser} itself will be added to the factory * after it has been returned by this method. * * @param aggregationName @@ -197,11 +197,11 @@ private void parseMissingAndAdd(final String aggregationName, final String curre * method * @return the created factory */ - protected abstract MultiValuesSourceAggregationBuilder createFactory(String aggregationName, ValuesSourceType valuesSourceType, - ValueType targetValueType, Map otherOptions); + protected abstract ArrayValuesSourceAggregationBuilder createFactory(String aggregationName, ValuesSourceType valuesSourceType, + ValueType targetValueType, Map otherOptions); /** - * Allows subclasses of {@link MultiValuesSourceParser} to parse extra + * Allows subclasses of {@link ArrayValuesSourceParser} to parse extra * parameters and store them in a {@link Map} which will later be passed to * {@link #createFactory(String, ValuesSourceType, ValueType, Map)}. * diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml new file mode 100644 index 0000000000000..bccc961923394 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml @@ -0,0 +1,71 @@ +setup: + - do: + indices.create: + index: test_1 + body: + settings: + number_of_replicas: 0 + mappings: + doc: + properties: + int_field: + type : integer + double_field: + type : double + string_field: + type: keyword + + - do: + bulk: + refresh: true + body: + - index: + _index: test_1 + _type: doc + _id: 1 + - int_field: 1 + double_field: 1.0 + - index: + _index: test_1 + _type: doc + _id: 2 + - int_field: 2 + double_field: 2.0 + - index: + _index: test_1 + _type: doc + _id: 3 + - int_field: 3 + double_field: 3.0 + - index: + _index: test_1 + _type: doc + _id: 4 + - int_field: 4 + double_field: 4.0 + +--- +"Basic test": + + - do: + search: + body: + aggs: + the_int_avg: + weighted_avg: + value: + field: "int_field" + weight: + field: "int_field" + the_double_avg: + weighted_avg: + value: + field: "double_field" + weight: + field: "double_field" + + - match: { hits.total: 4 } + - length: { hits.hits: 4 } + - match: { aggregations.the_int_avg.value: 3.0 } + - match: { aggregations.the_double_avg.value: 3.0 } + diff --git a/server/src/main/java/org/elasticsearch/search/MultiValueMode.java b/server/src/main/java/org/elasticsearch/search/MultiValueMode.java index b2ee4b8ffbd5f..7c2a192588c32 100644 --- a/server/src/main/java/org/elasticsearch/search/MultiValueMode.java +++ b/server/src/main/java/org/elasticsearch/search/MultiValueMode.java @@ -565,6 +565,58 @@ public double doubleValue() throws IOException { } } + /** + * Return a {@link NumericDoubleValues} instance that can be used to sort documents + * with this mode and the provided values. When a document has no value, + * advanceExact returns false and doubleValue() will return Double.NaN (although the + * double value should not be used). + * + * Allowed Modes: SUM, AVG, MEDIAN, MIN, MAX + */ + public NumericDoubleValues select(final SortedNumericDoubleValues values) { + final NumericDoubleValues singleton = FieldData.unwrapSingleton(values); + if (singleton != null) { + return new NumericDoubleValues() { + private double value; + + @Override + public boolean advanceExact(int doc) throws IOException { + if (singleton.advanceExact(doc)) { + this.value = singleton.doubleValue(); + return true; + } + this.value = Double.NaN; + return false; + } + + @Override + public double doubleValue() throws IOException { + return this.value; + } + }; + } else { + return new NumericDoubleValues() { + + private double value; + + @Override + public boolean advanceExact(int target) throws IOException { + if (values.advanceExact(target)) { + this.value = pick(values); + return true; + } + this.value = Double.NaN; + return false; + } + + @Override + public double doubleValue() throws IOException { + return this.value; + } + }; + } + } + protected double pick(SortedNumericDoubleValues values) throws IOException { throw new IllegalArgumentException("Unsupported sort mode: " + this); } diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 66ea407f42afd..d5c55e4128d70 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -180,6 +180,8 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.valuecount.InternalValueCount; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.InternalWeightedAvg; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValue; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.InternalBucketMetricValue; @@ -335,6 +337,8 @@ public ParseFieldRegistry getMovingAverageModel private void registerAggregations(List plugins) { registerAggregation(new AggregationSpec(AvgAggregationBuilder.NAME, AvgAggregationBuilder::new, AvgAggregationBuilder::parse) .addResultReader(InternalAvg::new)); + registerAggregation(new AggregationSpec(WeightedAvgAggregationBuilder.NAME, WeightedAvgAggregationBuilder::new, + WeightedAvgAggregationBuilder::parse).addResultReader(InternalWeightedAvg::new)); registerAggregation(new AggregationSpec(SumAggregationBuilder.NAME, SumAggregationBuilder::new, SumAggregationBuilder::parse) .addResultReader(InternalSum::new)); registerAggregation(new AggregationSpec(MinAggregationBuilder.NAME, MinAggregationBuilder::new, MinAggregationBuilder::parse) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java index 26d8bb1a1bdf5..b4e416f4d7789 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java @@ -82,6 +82,7 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCount; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder; import java.util.Map; @@ -107,6 +108,13 @@ public static AvgAggregationBuilder avg(String name) { return new AvgAggregationBuilder(name); } + /** + * Create a new {@link Avg} aggregation with the given name. + */ + public static WeightedAvgAggregationBuilder weightedAvg(String name) { + return new WeightedAvgAggregationBuilder(name); + } + /** * Create a new {@link Max} aggregation with the given name. */ diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java new file mode 100644 index 0000000000000..12931e4fed5fe --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java @@ -0,0 +1,133 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InternalWeightedAvg extends InternalNumericMetricsAggregation.SingleValue implements WeightedAvg { + private final double sum; + private final double weight; + + public InternalWeightedAvg(String name, double sum, double weight, DocValueFormat format, List pipelineAggregators, + Map metaData) { + super(name, pipelineAggregators, metaData); + this.sum = sum; + this.weight = weight; + this.format = format; + } + + /** + * Read from a stream. + */ + public InternalWeightedAvg(StreamInput in) throws IOException { + super(in); + format = in.readNamedWriteable(DocValueFormat.class); + sum = in.readDouble(); + weight = in.readDouble(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(format); + out.writeDouble(sum); + out.writeDouble(weight); + } + + @Override + public double value() { + return getValue(); + } + + @Override + public double getValue() { + return sum / weight; + } + + double getSum() { + return sum; + } + + double getWeight() { + return weight; + } + + DocValueFormat getFormatter() { + return format; + } + + @Override + public String getWriteableName() { + return WeightedAvgAggregationBuilder.NAME; + } + + @Override + public InternalWeightedAvg doReduce(List aggregations, ReduceContext reduceContext) { + double weight = 0; + double sum = 0; + double compensation = 0; + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + for (InternalAggregation aggregation : aggregations) { + InternalWeightedAvg avg = (InternalWeightedAvg) aggregation; + weight += avg.weight; + if (Double.isFinite(avg.sum) == false) { + sum += avg.sum; + } else if (Double.isFinite(sum)) { + double corrected = avg.sum - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; + } + } + return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData()); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null); + if (weight != 0 && format != DocValueFormat.RAW) { + builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(getValue())); + } + return builder; + } + + @Override + protected int doHashCode() { + return Objects.hash(sum, weight, format.getWriteableName()); + } + + @Override + protected boolean doEquals(Object obj) { + InternalWeightedAvg other = (InternalWeightedAvg) obj; + return Objects.equals(sum, other.sum) && + Objects.equals(weight, other.weight) && + Objects.equals(format.getWriteableName(), other.format.getWriteableName()); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java new file mode 100644 index 0000000000000..e558c8d6488e4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.metrics.ParsedSingleValueNumericMetricsAggregation; + +import java.io.IOException; + +public class ParsedWeightedAvg extends ParsedSingleValueNumericMetricsAggregation implements WeightedAvg { + + @Override + public double getValue() { + return value(); + } + + @Override + public String getType() { + return WeightedAvgAggregationBuilder.NAME; + } + + @Override + protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + // InternalWeightedAvg renders value only if the avg normalizer (count) is not 0. + // We parse back `null` as Double.POSITIVE_INFINITY so we check for that value here to get the same xContent output + boolean hasValue = value != Double.POSITIVE_INFINITY; + builder.field(CommonFields.VALUE.getPreferredName(), hasValue ? value : null); + if (hasValue && valueAsString != null) { + builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), valueAsString); + } + return builder; + } + + private static final ObjectParser PARSER = new ObjectParser<>(ParsedWeightedAvg.class.getSimpleName(), true, ParsedWeightedAvg::new); + + static { + declareSingleValueFields(PARSER, Double.POSITIVE_INFINITY); + } + + public static ParsedWeightedAvg fromXContent(XContentParser parser, final String name) { + ParsedWeightedAvg avg = PARSER.apply(parser, null); + avg.setName(name); + return avg; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java new file mode 100644 index 0000000000000..7af48f677c1f6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; + +/** + * An aggregation that computes the average of the values in the current bucket. + */ +public interface WeightedAvg extends NumericMetricsAggregation.SingleValue { + + /** + * The average value. + */ + double getValue(); +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java new file mode 100644 index 0000000000000..cedfcb0faeb24 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java @@ -0,0 +1,128 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper; +import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder { + public static final String NAME = "weighted_avg"; + public static final ParseField VALUE_FIELD = new ParseField("value"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + + private static final ObjectParser PARSER; + static { + PARSER = new ObjectParser<>(WeightedAvgAggregationBuilder.NAME); + MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); + MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false); + MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false); + } + + public static AggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { + return PARSER.parse(parser, new WeightedAvgAggregationBuilder(aggregationName), null); + } + + public WeightedAvgAggregationBuilder(String name) { + super(name, ValueType.NUMERIC); + } + + public WeightedAvgAggregationBuilder(WeightedAvgAggregationBuilder clone, Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + } + + public WeightedAvgAggregationBuilder value(MultiValuesSourceFieldConfig valueConfig) { + valueConfig = Objects.requireNonNull(valueConfig, "Configuration for field [" + VALUE_FIELD + "] cannot be null"); + field(VALUE_FIELD.getPreferredName(), valueConfig); + return this; + } + + public WeightedAvgAggregationBuilder weight(MultiValuesSourceFieldConfig weightConfig) { + weightConfig = Objects.requireNonNull(weightConfig, "Configuration for field [" + WEIGHT_FIELD + "] cannot be null"); + field(WEIGHT_FIELD.getPreferredName(), weightConfig); + return this; + } + + /** + * Read from a stream. + */ + public WeightedAvgAggregationBuilder(StreamInput in) throws IOException { + super(in); + } + + @Override + protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map metaData) { + return new WeightedAvgAggregationBuilder(this, factoriesBuilder, metaData); + } + + @Override + protected void innerWriteTo(StreamOutput out) { + // Do nothing, no extra state to write to stream + } + + @Override + protected MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, + MultiValuesSourceConfig configs, + DocValueFormat format, + AggregatorFactory parent, + Builder subFactoriesBuilder) throws IOException { + return new WeightedAvgAggregatorFactory(name, configs, format, context, parent, subFactoriesBuilder, metaData); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder; + } + + @Override + protected int innerHashCode() { + return 0; + } + + @Override + protected boolean innerEquals(Object obj) { + return true; + } + + @Override + public String getType() { + return NAME; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java new file mode 100644 index 0000000000000..43568386296a1 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -0,0 +1,140 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.index.fielddata.NumericDoubleValues; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.aggregations.support.MultiValuesSource; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.VALUE_FIELD; +import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.WEIGHT_FIELD; + +public class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue { + + final MultiValuesSource.NumericMultiValuesSource valuesSources; + + DoubleArray weights; + DoubleArray sums; + DoubleArray compensations; + DocValueFormat format; + + public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, DocValueFormat format, + SearchContext context, Aggregator parent, List pipelineAggregators, + Map metaData) throws IOException { + super(name, context, parent, pipelineAggregators, metaData); + this.valuesSources = valuesSources; + this.format = format; + if (valuesSources != null) { + final BigArrays bigArrays = context.bigArrays(); + weights = bigArrays.newDoubleArray(1, true); + sums = bigArrays.newDoubleArray(1, true); + compensations = bigArrays.newDoubleArray(1, true); + } + } + + @Override + public boolean needsScores() { + return valuesSources != null && valuesSources.needsScores(); + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, + final LeafBucketCollector sub) throws IOException { + if (valuesSources == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + final BigArrays bigArrays = context.bigArrays(); + final NumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx); + final NumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), 1.0, ctx); + + return new LeafBucketCollectorBase(sub, docValues) { + @Override + public void collect(int doc, long bucket) throws IOException { + weights = bigArrays.grow(weights, bucket + 1); + sums = bigArrays.grow(sums, bucket + 1); + compensations = bigArrays.grow(compensations, bucket + 1); + + if (docValues.advanceExact(doc)) { + docWeights.advanceExact(doc); + final double weight = docWeights.doubleValue(); + + weights.increment(bucket, weight); + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + double sum = sums.get(bucket); + double compensation = compensations.get(bucket); + + final double value = docValues.doubleValue() * weight; + if (Double.isFinite(value) == false) { + sum += value; + } else if (Double.isFinite(sum)) { + double corrected = value - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; + } + sums.set(bucket, sum); + compensations.set(bucket, compensation); + } + } + }; + } + + @Override + public double metric(long owningBucketOrd) { + if (valuesSources == null || owningBucketOrd >= sums.size()) { + return Double.NaN; + } + return sums.get(owningBucketOrd) / weights.get(owningBucketOrd); + } + + @Override + public InternalAggregation buildAggregation(long bucket) { + if (valuesSources == null || bucket >= sums.size()) { + return buildEmptyAggregation(); + } + return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData()); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalWeightedAvg(name, 0.0, 0L, format, pipelineAggregators(), metaData()); + } + + @Override + public void doClose() { + Releasables.close(weights, sums, compensations); + } + +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java new file mode 100644 index 0000000000000..53e7d3e164dc7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java @@ -0,0 +1,63 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.aggregations.support.MultiValuesSource; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig; +import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory { + + public WeightedAvgAggregatorFactory(String name, MultiValuesSourceConfig configs, + DocValueFormat format, + SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { + super(name, configs, format, context, parent, subFactoriesBuilder, metaData); + } + + @Override + protected Aggregator createUnmapped(Aggregator parent, List pipelineAggregators, Map metaData) + throws IOException { + return new WeightedAvgAggregator(name, null, format, context, parent, pipelineAggregators, metaData); + } + + @Override + protected Aggregator doCreateInternal(MultiValuesSourceConfig configs, DocValueFormat format, Aggregator parent, + boolean collectsFromSingleBucket, List pipelineAggregators, + Map metaData) throws IOException { + MultiValuesSource.NumericMultiValuesSource numericMultiVS + = new MultiValuesSource.NumericMultiValuesSource(configs, context.getQueryShardContext()); + if (numericMultiVS.areValuesSourcesEmpty()) { + return createUnmapped(parent, pipelineAggregators, metaData); + } + return new WeightedAvgAggregator(name, numericMultiVS, format, context, parent, pipelineAggregators, metaData); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java new file mode 100644 index 0000000000000..46f934948bfc7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java @@ -0,0 +1,125 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.support; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.index.fielddata.NumericDoubleValues; +import org.elasticsearch.index.query.QueryShardContext; +import org.elasticsearch.search.MultiValueMode; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Predicate; + +/** + * Class to encapsulate a set of ValuesSource objects labeled by field name + */ +public abstract class MultiValuesSource { + + public static class Wrapper { + private MultiValueMode multiValueMode; + private VS valueSource; + + public Wrapper(MultiValueMode multiValueMode, VS value) { + this.multiValueMode = multiValueMode; + this.valueSource = value; + } + + public MultiValueMode getMultiValueMode() { + return multiValueMode; + } + + public VS getValueSource() { + return valueSource; + } + } + + protected Map> values; + + public static class NumericMultiValuesSource extends MultiValuesSource { + public NumericMultiValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.getMap().size()); + for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { + values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), + entry.getValue().getConfig().toValuesSource(context))); + } + } + + public NumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException { + Wrapper wrapper = values.get(fieldName); + if (wrapper == null) { + throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); + } + return wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx)); + } + + public NumericDoubleValues getField(String fieldName, double defaultValue, LeafReaderContext ctx) throws IOException { + Wrapper wrapper = values.get(fieldName); + if (wrapper == null) { + throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); + } + return wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx), defaultValue); + } + } + + public static class BytesMultiValuesSource extends MultiValuesSource { + public BytesMultiValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.getMap().size()); + for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { + values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), + entry.getValue().getConfig().toValuesSource(context))); + } + } + + public Object getField(String fieldName, LeafReaderContext ctx) throws IOException { + Wrapper wrapper = values.get(fieldName); + if (wrapper == null) { + throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); + } + return wrapper.getValueSource().bytesValues(ctx); + } + } + + public static class GeoPointValuesSource extends MultiValuesSource { + public GeoPointValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.getMap().size()); + for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { + values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), + entry.getValue().getConfig().toValuesSource(context))); + } + } + } + + + public boolean needsScores() { + return values.values().stream().anyMatch(vsWrapper -> vsWrapper.getValueSource().needsScores()); + } + + public String[] fieldNames() { + return values.keySet().toArray(new String[0]); + } + + public boolean areValuesSourcesEmpty() { + return values.values().stream().allMatch(vsWrapper -> vsWrapper.getValueSource() == null); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java new file mode 100644 index 0000000000000..e3ac1455eb541 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -0,0 +1,278 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationInitializationException; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public abstract class MultiValuesSourceAggregationBuilder> + extends AbstractAggregationBuilder { + + + public abstract static class LeafOnly> + extends MultiValuesSourceAggregationBuilder { + + protected LeafOnly(String name, ValueType targetValueType) { + super(name, targetValueType); + } + + protected LeafOnly(LeafOnly clone, Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + if (factoriesBuilder.count() > 0) { + throw new AggregationInitializationException("Aggregator [" + name + "] of type [" + + getType() + "] cannot accept sub-aggregations"); + } + } + + /** + * Read from a stream that does not serialize its targetValueType. This should be used by most subclasses. + */ + protected LeafOnly(StreamInput in, ValueType targetValueType) throws IOException { + super(in, targetValueType); + } + + /** + * Read an aggregation from a stream that serializes its targetValueType. This should only be used by subclasses that override + * {@link #serializeTargetValueType()} to return true. + */ + /* + protected LeafOnly(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { + super(in, valuesSourceType); + } + */ + @Override + public AB subAggregations(Builder subFactories) { + throw new AggregationInitializationException("Aggregator [" + name + "] of type [" + + getType() + "] cannot accept sub-aggregations"); + } + } + + + + private Map fields = new HashMap<>(); + private final ValueType targetValueType; + private ValueType valueType = null; + private String format = null; + + protected MultiValuesSourceAggregationBuilder(String name, ValueType targetValueType) { + super(name); + this.targetValueType = targetValueType; + } + + protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder clone, + Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + + this.fields = new HashMap<>(clone.fields); + this.targetValueType = clone.targetValueType; + this.valueType = clone.valueType; + this.format = clone.format; + } + + protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetValueType) + throws IOException { + super(in); + assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType"; + this.targetValueType = targetValueType; + read(in); + } + + protected MultiValuesSourceAggregationBuilder(StreamInput in) throws IOException { + super(in); + assert serializeTargetValueType() : "Wrong read constructor called for subclass that serializes its targetValueType"; + this.targetValueType = in.readOptionalWriteable(ValueType::readFromStream); + read(in); + } + + /** + * Read from a stream. + */ + @SuppressWarnings("unchecked") + private void read(StreamInput in) throws IOException { + fields = in.readMap(StreamInput::readString, MultiValuesSourceFieldConfig::new); + valueType = in.readOptionalWriteable(ValueType::readFromStream); + format = in.readOptionalString(); + } + + @Override + protected final void doWriteTo(StreamOutput out) throws IOException { + if (serializeTargetValueType()) { + out.writeOptionalWriteable(targetValueType); + } + out.writeMap(fields, StreamOutput::writeString, (o, value) -> value.writeTo(o)); + out.writeOptionalWriteable(valueType); + out.writeOptionalString(format); + innerWriteTo(out); + } + + /** + * Write subclass' state to the stream + */ + protected abstract void innerWriteTo(StreamOutput out) throws IOException; + + @SuppressWarnings("unchecked") + protected AB field(String propertyName, MultiValuesSourceFieldConfig config) { + if (config == null) { + throw new IllegalArgumentException("[config] must not be null: [" + name + "]"); + } + this.fields.put(propertyName, config); + return (AB) this; + } + + public Map fields() { + return fields; + } + + /** + * Sets the {@link ValueType} for the value produced by this aggregation + */ + @SuppressWarnings("unchecked") + public AB valueType(ValueType valueType) { + if (valueType == null) { + throw new IllegalArgumentException("[valueType] must not be null: [" + name + "]"); + } + this.valueType = valueType; + return (AB) this; + } + + /** + * Gets the {@link ValueType} for the value produced by this aggregation + */ + public ValueType valueType() { + return valueType; + } + + /** + * Sets the format to use for the output of the aggregation. + */ + @SuppressWarnings("unchecked") + public AB format(String format) { + if (format == null) { + throw new IllegalArgumentException("[format] must not be null: [" + name + "]"); + } + this.format = format; + return (AB) this; + } + + /** + * Gets the format to use for the output of the aggregation. + */ + public String format() { + return format; + } + + @Override + protected final MultiValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType; + + MultiValuesSourceConfig configs = new MultiValuesSourceConfig<>(); + fields.forEach((key, value) -> { + ValuesSourceConfig config = ValuesSourceConfig.resolve(context.getQueryShardContext(), finalValueType, + value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format); + configs.addField(key, config, value.getMulti()); + }); + DocValueFormat docValueFormat = resolveFormat(format, finalValueType); + return innerBuild(context, configs, docValueFormat, parent, subFactoriesBuilder); + } + + + private static DocValueFormat resolveFormat(@Nullable String format, @Nullable ValueType valueType) { + if (valueType == null) { + return DocValueFormat.RAW; // we can't figure it out + } + DocValueFormat valueFormat = valueType.defaultFormat; + if (valueFormat instanceof DocValueFormat.Decimal && format != null) { + valueFormat = new DocValueFormat.Decimal(format); + } + return valueFormat; + } + + protected abstract MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, + MultiValuesSourceConfig configs, DocValueFormat format, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException; + + + /** + * Should this builder serialize its targetValueType? Defaults to false. All subclasses that override this to true + * should use the three argument read constructor rather than the four argument version. + */ + protected boolean serializeTargetValueType() { + return false; + } + + @Override + public final XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fields != null) { + builder.field(CommonFields.FIELDS.getPreferredName(), fields); + } + if (format != null) { + builder.field(CommonFields.FORMAT.getPreferredName(), format); + } + if (valueType != null) { + builder.field(CommonFields.VALUE_TYPE.getPreferredName(), valueType.getPreferredName()); + } + doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + protected abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException; + + @Override + protected final int doHashCode() { + return Objects.hash(fields, format, targetValueType, valueType, innerHashCode()); + } + + protected abstract int innerHashCode(); + + @Override + protected final boolean doEquals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + MultiValuesSourceAggregationBuilder that = (MultiValuesSourceAggregationBuilder) other; + + return Objects.equals(this.fields, that.fields) + && Objects.equals(this.format, that.format) + && Objects.equals(this.valueType, that.valueType); + } + + protected abstract boolean innerEquals(Object obj); +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java new file mode 100644 index 0000000000000..7f37e64e6ba17 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public abstract class MultiValuesSourceAggregatorFactory> + extends AggregatorFactory { + + protected final MultiValuesSourceConfig configs; + protected final DocValueFormat format; + + public MultiValuesSourceAggregatorFactory(String name, MultiValuesSourceConfig configs, + DocValueFormat format, SearchContext context, + AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { + super(name, context, parent, subFactoriesBuilder, metaData); + this.configs = configs; + this.format = format; + } + + @Override + public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBucket, List pipelineAggregators, + Map metaData) throws IOException { + + return doCreateInternal(configs, format, parent, collectsFromSingleBucket, + pipelineAggregators, metaData); + } + + protected abstract Aggregator createUnmapped(Aggregator parent, List pipelineAggregators, + Map metaData) throws IOException; + + protected abstract Aggregator doCreateInternal(MultiValuesSourceConfig configs, + DocValueFormat format, Aggregator parent, boolean collectsFromSingleBucket, + List pipelineAggregators, + Map metaData) throws IOException; + +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java new file mode 100644 index 0000000000000..60e6f1a1b6131 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java @@ -0,0 +1,37 @@ +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.search.MultiValueMode; + +import java.util.HashMap; +import java.util.Map; + +public class MultiValuesSourceConfig { + private Map> map = new HashMap<>(); + + public static class Wrapper { + private MultiValueMode multi; + private ValuesSourceConfig config; + + public Wrapper(MultiValueMode multi, ValuesSourceConfig config) { + this.multi = multi; + this.config = config; + } + + public MultiValueMode getMulti() { + return multi; + } + + public ValuesSourceConfig getConfig() { + return config; + } + } + + public void addField(String fieldName, ValuesSourceConfig config, MultiValueMode multiValueMode) { + map.put(fieldName, new Wrapper<>(multiValueMode, config)); + } + + public Map> getMap() { + return map; + } + +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java new file mode 100644 index 0000000000000..8936b834ee90e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -0,0 +1,144 @@ +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.elasticsearch.search.MultiValueMode; +import org.joda.time.DateTimeZone; + +import java.io.IOException; +import java.util.function.BiFunction; + +public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragment { + private String fieldName; + private Object missing = null; + private Script script = null; + private DateTimeZone timeZone = null; + private MultiValueMode multi = MultiValueMode.AVG; + + private static final String NAME = "field_config"; + private static final ParseField MULTI = new ParseField("multi"); + + public static final BiFunction> PARSER + = (scriptable, timezoneAware) -> { + + ConstructingObjectParser parser + = new ConstructingObjectParser<>(MultiValuesSourceFieldConfig.NAME, false, o -> new MultiValuesSourceFieldConfig((String)o[0])); + + parser.declareString(ConstructingObjectParser.constructorArg(), ParseField.CommonFields.FIELD); + parser.declareField(MultiValuesSourceFieldConfig::setMissing, XContentParser::objectText, + ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); + parser.declareField(MultiValuesSourceFieldConfig::setMulti, p -> MultiValueMode.fromString(p.text()), MULTI, + ObjectParser.ValueType.STRING); + + if (scriptable) { + parser.declareField(MultiValuesSourceFieldConfig::setScript, + (p, context) -> Script.parse(p), + Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); + } + + if (timezoneAware) { + parser.declareField(MultiValuesSourceFieldConfig::setTimeZone, p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DateTimeZone.forID(p.text()); + } else { + return DateTimeZone.forOffsetHours(p.intValue()); + } + }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); + } + return parser; + }; + + + public MultiValuesSourceFieldConfig(String fieldName) { + this.fieldName = fieldName; + } + + public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { + this.fieldName = in.readString(); + this.missing = in.readGenericValue(); + this.script = in.readOptionalWriteable(Script::new); + this.timeZone = in.readOptionalTimeZone(); + this.multi = MultiValueMode.readMultiValueModeFrom(in); + } + + public Object getMissing() { + return missing; + } + + public MultiValuesSourceFieldConfig setMissing(Object missing) { + this.missing = missing; + return this; + } + + public Script getScript() { + return script; + } + + public MultiValuesSourceFieldConfig setScript(Script script) { + this.script = script; + return this; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + public MultiValuesSourceFieldConfig setTimeZone(DateTimeZone timeZone) { + this.timeZone = timeZone; + return this; + } + + public String getFieldName() { + return fieldName; + } + + public MultiValuesSourceFieldConfig setFieldName(String fieldName) { + this.fieldName = fieldName; + return this; + } + + public MultiValueMode getMulti() { + return multi; + } + + public void setMulti(MultiValueMode multi) { + this.multi = multi; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeGenericValue(missing); + out.writeOptionalWriteable(script); + out.writeOptionalTimeZone(timeZone); + multi.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (missing != null) { + builder.field(ParseField.CommonFields.MISSING.getPreferredName(), missing); + } + if (script != null) { + builder.field(Script.SCRIPT_PARSE_FIELD.getPreferredName(), script); + } + if (fieldName != null) { + builder.field(ParseField.CommonFields.FIELD.getPreferredName(), fieldName); + } + if (timeZone != null) { + builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone); + } + if (multi != null) { + builder.field(MULTI.getPreferredName(), multi); + } + return builder; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java new file mode 100644 index 0000000000000..d693190741b5d --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java @@ -0,0 +1,40 @@ +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.xcontent.AbstractObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +public final class MultiValuesSourceParseHelper { + + public static void declareCommon( + AbstractObjectParser, T> objectParser, boolean formattable, + ValueType targetValueType) { + + objectParser.declareField(MultiValuesSourceAggregationBuilder::valueType, p -> { + ValueType valueType = ValueType.resolveForScript(p.text()); + if (targetValueType != null && valueType.isNotA(targetValueType)) { + throw new ParsingException(p.getTokenLocation(), + "Aggregation [" + objectParser.getName() + "] was configured with an incompatible value type [" + + valueType + "]. It can only work on value of type [" + + targetValueType + "]"); + } + return valueType; + }, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING); + + if (formattable) { + objectParser.declareField(MultiValuesSourceAggregationBuilder::format, XContentParser::text, + ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING); + } + } + + public static void declareField(String fieldName, + AbstractObjectParser, T> objectParser, + boolean scriptable, boolean timezoneAware) { + + objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig), + (p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null), + new ParseField(fieldName), ObjectParser.ValueType.OBJECT); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java index 318540e3e5806..5c4e65d2bd4df 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.support; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -95,6 +96,8 @@ public boolean isNumeric() { private final byte id; private String preferredName; + public final static ParseField VALUE_TYPE = new ParseField("value_type", "valueType"); + ValueType(byte id, String description, String preferredName, ValuesSourceType valuesSourceType, Class fieldDataType, DocValueFormat defaultFormat) { this.id = id; @@ -112,7 +115,7 @@ public String description() { public String getPreferredName() { return preferredName; } - + public ValuesSourceType getValuesSourceType() { return valuesSourceType; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java index f3f8fa3056898..2a009f8d2002d 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java @@ -29,7 +29,6 @@ import org.joda.time.DateTimeZone; public final class ValuesSourceParserHelper { - static final ParseField TIME_ZONE = new ParseField("time_zone"); private ValuesSourceParserHelper() {} // utility class, no instantiation @@ -63,10 +62,10 @@ private static void declareFields( objectParser.declareField(ValuesSourceAggregationBuilder::field, XContentParser::text, - new ParseField("field"), ObjectParser.ValueType.STRING); + ParseField.CommonFields.FIELD, ObjectParser.ValueType.STRING); objectParser.declareField(ValuesSourceAggregationBuilder::missing, XContentParser::objectText, - new ParseField("missing"), ObjectParser.ValueType.VALUE); + ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); objectParser.declareField(ValuesSourceAggregationBuilder::valueType, p -> { ValueType valueType = ValueType.resolveForScript(p.text()); @@ -77,11 +76,11 @@ private static void declareFields( + targetValueType + "]"); } return valueType; - }, new ParseField("value_type", "valueType"), ObjectParser.ValueType.STRING); + }, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING); if (formattable) { objectParser.declareField(ValuesSourceAggregationBuilder::format, XContentParser::text, - new ParseField("format"), ObjectParser.ValueType.STRING); + ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING); } if (scriptable) { @@ -97,7 +96,7 @@ private static void declareFields( } else { return DateTimeZone.forOffsetHours(p.intValue()); } - }, TIME_ZONE, ObjectParser.ValueType.LONG); + }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java index a6b252d6903d9..a68e24da81bf6 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java @@ -19,9 +19,37 @@ package org.elasticsearch.search.aggregations.support; -public enum ValuesSourceType { +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum ValuesSourceType implements Writeable { ANY, NUMERIC, BYTES, GEOPOINT; + + public final static ParseField VALUE_SOURCE_TYPE = new ParseField("value_source_type"); + + public static ValuesSourceType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static ValuesSourceType fromStream(StreamInput in) throws IOException { + return in.readEnum(ValuesSourceType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + ValuesSourceType state = this; + out.writeEnum(state); + } + + public String value() { + return name().toLowerCase(Locale.ROOT); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java new file mode 100644 index 0000000000000..956acd3d9023d --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java @@ -0,0 +1,276 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.NumericUtils; +import org.elasticsearch.common.CheckedConsumer; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; + +import java.io.IOException; +import java.util.Arrays; +import java.util.function.Consumer; + +import static java.util.Collections.singleton; + +public class WeightedAvgAggregatorTests extends AggregatorTestCase { + + public void testNoDocs() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + // Intentionally not writing any docs + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testNoMatchingField() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 7))); + iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 3))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 7))); + iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 3))); + }, avg -> { + assertEquals(4, avg.getValue(), 0); + }); + } + + public void testSomeMatchesSortedNumericDocValuesWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3))); + + }, avg -> { + // (7*2 + 2*3 + 3*3) / (2+3+3) == 3.625 + assertEquals(3.625, avg.getValue(), 0); + }); + } + + public void testSomeMatchesNumericDocValues() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new DocValuesFieldExistsQuery("value_field"), aggregationBuilder, iw -> { + iw.addDocument(singleton(new NumericDocValuesField("value_field", 7))); + iw.addDocument(singleton(new NumericDocValuesField("value_field", 2))); + iw.addDocument(singleton(new NumericDocValuesField("value_field", 3))); + }, avg -> { + assertEquals(4, avg.getValue(), 0); + }); + } + + public void testQueryFiltering() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", 0, 3), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 3))); + }, avg -> { + assertEquals(2.5, avg.getValue(), 0); + }); + } + + public void testQueryFilteringWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("filter_field", 0, 3), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + double value = (2.0*3.0 + 3.0*4.0) / (3.0+4.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testQueryFiltersAll() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 7))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testQueryFiltersAllWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testSummationAccuracy() throws IOException { + // Summing up a normal array and expect an accurate value + double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; + verifyAvgOfDoubles(values, 0.9, 0d); + + // Summing up an array which contains NaN and infinities and expect a result same as naive summation + int n = randomIntBetween(5, 10); + values = new double[n]; + double sum = 0; + for (int i = 0; i < n; i++) { + values[i] = frequently() + ? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) + : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); + sum += values[i]; + } + verifyAvgOfDoubles(values, sum / n, 1e-10); + + // Summing up some big double values and expect infinity result + n = randomIntBetween(5, 10); + double[] largeValues = new double[n]; + for (int i = 0; i < n; i++) { + largeValues[i] = Double.MAX_VALUE; + } + verifyAvgOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d); + + for (int i = 0; i < n; i++) { + largeValues[i] = -Double.MAX_VALUE; + } + verifyAvgOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d); + } + + private void verifyAvgOfDoubles(double[] values, double expected, double delta) throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, + iw -> { + for (double value : values) { + iw.addDocument(singleton(new NumericDocValuesField("value_field", NumericUtils.doubleToSortableLong(value)))); + } + }, + avg -> assertEquals(expected, avg.getValue(), delta), + NumberFieldMapper.NumberType.DOUBLE + ); + } + + private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder, + CheckedConsumer buildIndex, + Consumer verify) throws IOException { + testCase(query, aggregationBuilder, buildIndex, verify, NumberFieldMapper.NumberType.LONG); + } + + private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder, + CheckedConsumer buildIndex, + Consumer verify, + NumberFieldMapper.NumberType fieldNumberType) throws IOException { + Directory directory = newDirectory(); + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + buildIndex.accept(indexWriter); + indexWriter.close(); + + IndexReader indexReader = DirectoryReader.open(directory); + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType.setName("value_field"); + fieldType.setHasDocValues(true); + + MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType2.setName("weight_field"); + fieldType2.setHasDocValues(true); + + WeightedAvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType, fieldType2); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + verify.accept((InternalWeightedAvg) aggregator.buildAggregation(0L)); + + indexReader.close(); + directory.close(); + } +} diff --git a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java index dc2fac776c6c0..538babf4fbced 100644 --- a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java +++ b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java @@ -208,7 +208,7 @@ public static List translateAggregation(AggregationBuilder s private static List translateDateHistogram(DateHistogramAggregationBuilder source, List filterConditions, NamedWriteableRegistry registry) { - + return translateVSAggBuilder(source, filterConditions, registry, () -> { DateHistogramAggregationBuilder rolledDateHisto = new DateHistogramAggregationBuilder(source.getName()); From 2d89f4c639fb60cd462717cb6b15efc63e287df0 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 5 Jun 2018 17:09:53 +0000 Subject: [PATCH 02/11] Use Kahan Summation for weights, use new multivalue mode, assert --- .../elasticsearch/search/MultiValueMode.java | 52 ------------------- .../weighted_avg/InternalWeightedAvg.java | 17 ++++-- .../weighted_avg/WeightedAvgAggregator.java | 51 ++++++++++-------- .../support/MultiValuesSource.java | 3 +- 4 files changed, 43 insertions(+), 80 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/MultiValueMode.java b/server/src/main/java/org/elasticsearch/search/MultiValueMode.java index 207b76174d188..eaaa5f74fa4d5 100644 --- a/server/src/main/java/org/elasticsearch/search/MultiValueMode.java +++ b/server/src/main/java/org/elasticsearch/search/MultiValueMode.java @@ -539,58 +539,6 @@ public double doubleValue() throws IOException { } } - /** - * Return a {@link NumericDoubleValues} instance that can be used to sort documents - * with this mode and the provided values. When a document has no value, - * advanceExact returns false and doubleValue() will return Double.NaN (although the - * double value should not be used). - * - * Allowed Modes: SUM, AVG, MEDIAN, MIN, MAX - */ - public NumericDoubleValues select(final SortedNumericDoubleValues values) { - final NumericDoubleValues singleton = FieldData.unwrapSingleton(values); - if (singleton != null) { - return new NumericDoubleValues() { - private double value; - - @Override - public boolean advanceExact(int doc) throws IOException { - if (singleton.advanceExact(doc)) { - this.value = singleton.doubleValue(); - return true; - } - this.value = Double.NaN; - return false; - } - - @Override - public double doubleValue() throws IOException { - return this.value; - } - }; - } else { - return new NumericDoubleValues() { - - private double value; - - @Override - public boolean advanceExact(int target) throws IOException { - if (values.advanceExact(target)) { - this.value = pick(values); - return true; - } - this.value = Double.NaN; - return false; - } - - @Override - public double doubleValue() throws IOException { - return this.value; - } - }; - } - } - protected double pick(SortedNumericDoubleValues values) throws IOException { throw new IllegalArgumentException("Unsupported sort mode: " + this); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java index 12931e4fed5fe..3a19de46325a9 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java @@ -91,24 +91,31 @@ public String getWriteableName() { public InternalWeightedAvg doReduce(List aggregations, ReduceContext reduceContext) { double weight = 0; double sum = 0; - double compensation = 0; + double sumCompensation = 0; + double weightCompensation = 0; // Compute the sum of double values with Kahan summation algorithm which is more // accurate than naive summation. for (InternalAggregation aggregation : aggregations) { InternalWeightedAvg avg = (InternalWeightedAvg) aggregation; - weight += avg.weight; + if (Double.isFinite(avg.weight) == false) { + weight += avg.weight; + } else if (Double.isFinite(weight)) { + double corrected = avg.weight - weightCompensation; + double newWeight = weight + corrected; + weightCompensation = (newWeight - weight) - corrected; + weight = newWeight; + } if (Double.isFinite(avg.sum) == false) { sum += avg.sum; } else if (Double.isFinite(sum)) { - double corrected = avg.sum - compensation; + double corrected = avg.sum - sumCompensation; double newSum = sum + corrected; - compensation = (newSum - sum) - corrected; + sumCompensation = (newSum - sum) - corrected; sum = newSum; } } return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData()); } - @Override public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java index 43568386296a1..40ea390cb1b6a 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -46,7 +46,8 @@ public class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue DoubleArray weights; DoubleArray sums; - DoubleArray compensations; + DoubleArray sumCompensations; + DoubleArray weightCompensations; DocValueFormat format; public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, DocValueFormat format, @@ -59,7 +60,8 @@ public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSo final BigArrays bigArrays = context.bigArrays(); weights = bigArrays.newDoubleArray(1, true); sums = bigArrays.newDoubleArray(1, true); - compensations = bigArrays.newDoubleArray(1, true); + sumCompensations = bigArrays.newDoubleArray(1, true); + weightCompensations = bigArrays.newDoubleArray(1, true); } } @@ -83,34 +85,39 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, public void collect(int doc, long bucket) throws IOException { weights = bigArrays.grow(weights, bucket + 1); sums = bigArrays.grow(sums, bucket + 1); - compensations = bigArrays.grow(compensations, bucket + 1); + sumCompensations = bigArrays.grow(sumCompensations, bucket + 1); + weightCompensations = bigArrays.grow(weightCompensations, bucket + 1); if (docValues.advanceExact(doc)) { - docWeights.advanceExact(doc); + boolean advanced = docWeights.advanceExact(doc); + assert advanced; final double weight = docWeights.doubleValue(); - weights.increment(bucket, weight); - // Compute the sum of double values with Kahan summation algorithm which is more - // accurate than naive summation. - double sum = sums.get(bucket); - double compensation = compensations.get(bucket); - - final double value = docValues.doubleValue() * weight; - if (Double.isFinite(value) == false) { - sum += value; - } else if (Double.isFinite(sum)) { - double corrected = value - compensation; - double newSum = sum + corrected; - compensation = (newSum - sum) - corrected; - sum = newSum; - } - sums.set(bucket, sum); - compensations.set(bucket, compensation); + kahanSum(docValues.doubleValue() * weight, sums, sumCompensations, bucket); + kahanSum(weight, weights, weightCompensations, bucket); } } }; } + private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) { + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + double sum = values.get(bucket); + double compensation = compensations.get(bucket); + + if (Double.isFinite(value) == false) { + sum += value; + } else if (Double.isFinite(sum)) { + double corrected = value - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; + } + values.set(bucket, sum); + compensations.set(bucket, compensation); + } + @Override public double metric(long owningBucketOrd) { if (valuesSources == null || owningBucketOrd >= sums.size()) { @@ -134,7 +141,7 @@ public InternalAggregation buildEmptyAggregation() { @Override public void doClose() { - Releasables.close(weights, sums, compensations); + Releasables.close(weights, sums, sumCompensations, weightCompensations); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java index 46f934948bfc7..71117879e3a82 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.support; import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.index.fielddata.FieldData; import org.elasticsearch.index.fielddata.NumericDoubleValues; import org.elasticsearch.index.query.QueryShardContext; import org.elasticsearch.search.MultiValueMode; @@ -76,7 +77,7 @@ public NumericDoubleValues getField(String fieldName, double defaultValue, LeafR if (wrapper == null) { throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); } - return wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx), defaultValue); + return FieldData.replaceMissing(wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx)), defaultValue); } } From f25249321dd4bad7c3cb05b2bbec853757b407e2 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Mon, 18 Jun 2018 20:08:58 +0000 Subject: [PATCH 03/11] More tests --- .../support/MultiValuesSourceFieldConfig.java | 3 +- .../WeightedAvgAggregatorTests.java | 279 +++++++++++++++++- 2 files changed, 265 insertions(+), 17 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java index 8936b834ee90e..7b5867508157a 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -109,8 +109,9 @@ public MultiValueMode getMulti() { return multi; } - public void setMulti(MultiValueMode multi) { + public MultiValuesSourceFieldConfig setMulti(MultiValueMode multi) { this.multi = multi; + return this; } @Override diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java index 956acd3d9023d..e264a879f0da8 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.metrics.weighted_avg; +import org.apache.lucene.document.Document; import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; @@ -34,14 +35,18 @@ import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.MultiValueMode; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; +import org.joda.time.DateTimeZone; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.function.Consumer; import static java.util.Collections.singleton; +import static org.hamcrest.Matchers.equalTo; public class WeightedAvgAggregatorTests extends AggregatorTestCase { @@ -189,6 +194,247 @@ public void testQueryFiltersAllWeights() throws IOException { }); } + public void testValueSetMissing() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMissing(2); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + double value = (2.0*2.0 + 2.0*3.0 + 2.0*4.0) / (2.0+3.0+4.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMissing() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMissing(2); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + }, avg -> { + double value = (2.0*2.0 + 3.0*2.0 + 4.0*2.0) / (2.0+2.0+2.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetTimezone() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setTimeZone(DateTimeZone.UTC); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + }, avg -> { + fail("Should not have executed test case"); + })); + assertThat(e.getMessage(), equalTo("Field [weight_field] of type [long] does not support custom time zones")); + } + + public void testValueSetTimezone() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setTimeZone(DateTimeZone.UTC); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + }, avg -> { + fail("Should not have executed test case"); + })); + assertThat(e.getMessage(), equalTo("Field [value_field] of type [long] does not support custom time zones")); + } + + public void testValueSetMultiAvg() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.AVG); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("value_field", 5))); + }, avg -> { + double value = (((2.0+3.0)/2.0) + ((3.0+4.0)/2.0) + ((4.0+5.0)/2.0)) / (1.0+1.0+1.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testValueSetMultiMax() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.MAX); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("value_field", 5))); + }, avg -> { + double value = (3.0 + 4.0 + 5.0) / (1.0+1.0+1.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testValueSetMultiMin() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.MIN); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("value_field", 5))); + }, avg -> { + double value = (2.0 + 3.0 + 4.0) / (1.0+1.0+1.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testValueSetMultiSum() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.SUM); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("value_field", 5))); + }, avg -> { + double value = (5.0 + 7.0 + 9.0) / (1.0+1.0+1.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMultiAvg() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.AVG); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + double value = ((2.0 * (2.0+3.0)/2.0) + (3.0 * (3.0+4.0)/2.0) + (4.0 * (4.0+5.0)/2.0)) + / ((2.0+3.0)/2.0 + (3.0+4.0)/2.0 + (4.0+5.0)/2.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMultiMax() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.MAX); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + double value = ((2.0 * 3.0) + (3.0 * 4.0) + (4.0 * 5.0)) / (3.0+4.0+5.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMultiMin() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.MIN); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + double value = ((2.0 * 2.0) + (3.0 * 3.0) + (4.0 * 4.0)) / (2.0+3.0+4.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMultiSum() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.SUM); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + double value = ((2.0 * 5.0) + (3.0 * 7.0) + (4.0 * 9.0)) / (5.0+7.0+9.0); + assertEquals(value, avg.getValue(), 0); + }); + } + public void testSummationAccuracy() throws IOException { // Summing up a normal array and expect an accurate value double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; @@ -247,30 +493,31 @@ private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuil CheckedConsumer buildIndex, Consumer verify, NumberFieldMapper.NumberType fieldNumberType) throws IOException { + Directory directory = newDirectory(); RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); buildIndex.accept(indexWriter); indexWriter.close(); - IndexReader indexReader = DirectoryReader.open(directory); IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + try { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType.setName("value_field"); + fieldType.setHasDocValues(true); - MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType); - fieldType.setName("value_field"); - fieldType.setHasDocValues(true); - - MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(fieldNumberType); - fieldType2.setName("weight_field"); - fieldType2.setHasDocValues(true); + MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType2.setName("weight_field"); + fieldType2.setHasDocValues(true); - WeightedAvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType, fieldType2); - aggregator.preCollection(); - indexSearcher.search(query, aggregator); - aggregator.postCollection(); - verify.accept((InternalWeightedAvg) aggregator.buildAggregation(0L)); - - indexReader.close(); - directory.close(); + WeightedAvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType, fieldType2); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + verify.accept((InternalWeightedAvg) aggregator.buildAggregation(0L)); + } finally { + indexReader.close(); + directory.close(); + } } } From fe9dd45ad105605f5c6e1687f076cff3d253cdcb Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 19 Jun 2018 16:29:50 +0000 Subject: [PATCH 04/11] Add Builder for MVSourceFieldConfig, tests --- .../support/MultiValuesSourceConfig.java | 19 +++ .../support/MultiValuesSourceFieldConfig.java | 149 +++++++++++++----- .../support/MultiValuesSourceParseHelper.java | 21 ++- .../WeightedAvgAggregatorTests.java | 125 +++++++++------ .../MultiValuesSourceFieldConfigTests.java | 44 ++++++ 5 files changed, 273 insertions(+), 85 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java index 60e6f1a1b6131..ceee40ddde4f9 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java @@ -1,3 +1,22 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.elasticsearch.search.aggregations.support; import org.elasticsearch.search.MultiValueMode; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java index 7b5867508157a..5cc350f8c623e 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -1,10 +1,29 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.elasticsearch.search.aggregations.support; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentFragment; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -18,34 +37,34 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragment { private String fieldName; - private Object missing = null; - private Script script = null; - private DateTimeZone timeZone = null; - private MultiValueMode multi = MultiValueMode.AVG; + private Object missing; + private Script script; + private DateTimeZone timeZone; + private MultiValueMode multi; private static final String NAME = "field_config"; private static final ParseField MULTI = new ParseField("multi"); - public static final BiFunction> PARSER + public static final BiFunction> PARSER = (scriptable, timezoneAware) -> { - ConstructingObjectParser parser - = new ConstructingObjectParser<>(MultiValuesSourceFieldConfig.NAME, false, o -> new MultiValuesSourceFieldConfig((String)o[0])); + ObjectParser parser + = new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new); - parser.declareString(ConstructingObjectParser.constructorArg(), ParseField.CommonFields.FIELD); - parser.declareField(MultiValuesSourceFieldConfig::setMissing, XContentParser::objectText, + parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); + parser.declareField(MultiValuesSourceFieldConfig.Builder::setMissing, XContentParser::objectText, ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); - parser.declareField(MultiValuesSourceFieldConfig::setMulti, p -> MultiValueMode.fromString(p.text()), MULTI, + parser.declareField(MultiValuesSourceFieldConfig.Builder::setMulti, p -> MultiValueMode.fromString(p.text()), MULTI, ObjectParser.ValueType.STRING); if (scriptable) { - parser.declareField(MultiValuesSourceFieldConfig::setScript, + parser.declareField(MultiValuesSourceFieldConfig.Builder::setScript, (p, context) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); } if (timezoneAware) { - parser.declareField(MultiValuesSourceFieldConfig::setTimeZone, p -> { + parser.declareField(MultiValuesSourceFieldConfig.Builder::setTimeZone, p -> { if (p.currentToken() == XContentParser.Token.VALUE_STRING) { return DateTimeZone.forID(p.text()); } else { @@ -56,9 +75,12 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragme return parser; }; - - public MultiValuesSourceFieldConfig(String fieldName) { + private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, DateTimeZone timeZone, MultiValueMode multi) { this.fieldName = fieldName; + this.missing = missing; + this.script = script; + this.timeZone = timeZone; + this.multi = multi; } public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { @@ -73,47 +95,23 @@ public Object getMissing() { return missing; } - public MultiValuesSourceFieldConfig setMissing(Object missing) { - this.missing = missing; - return this; - } - public Script getScript() { return script; } - public MultiValuesSourceFieldConfig setScript(Script script) { - this.script = script; - return this; - } - public DateTimeZone getTimeZone() { return timeZone; } - public MultiValuesSourceFieldConfig setTimeZone(DateTimeZone timeZone) { - this.timeZone = timeZone; - return this; - } - public String getFieldName() { return fieldName; } - public MultiValuesSourceFieldConfig setFieldName(String fieldName) { - this.fieldName = fieldName; - return this; - } public MultiValueMode getMulti() { return multi; } - public MultiValuesSourceFieldConfig setMulti(MultiValueMode multi) { - this.multi = multi; - return this; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(fieldName); @@ -142,4 +140,77 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } return builder; } + + public static class Builder { + private String fieldName; + private Object missing = null; + private Script script = null; + private DateTimeZone timeZone = null; + private MultiValueMode multi = MultiValueMode.AVG; + + public String getFieldName() { + return fieldName; + } + + public Builder setFieldName(String fieldName) { + this.fieldName = fieldName; + return this; + } + + public Object getMissing() { + return missing; + } + + public Builder setMissing(Object missing) { + this.missing = missing; + return this; + } + + public Script getScript() { + return script; + } + + public Builder setScript(Script script) { + this.script = script; + return this; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + public Builder setTimeZone(DateTimeZone timeZone) { + this.timeZone = timeZone; + return this; + } + + public MultiValueMode getMulti() { + return multi; + } + + public Builder setMulti(MultiValueMode multi) { + this.multi = multi; + return this; + } + + public MultiValuesSourceFieldConfig build() { + if (Strings.isNullOrEmpty(fieldName) && script == null) { + throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() + + "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be null. " + + "Please specify one or the other."); + } + + if (Strings.isNullOrEmpty(fieldName) == false && script != null) { + throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() + + "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be configured. " + + "Please specify one or the other."); + } + + if (multi == null) { + throw new IllegalArgumentException("[" + MULTI.getPreferredName() + "] cannot be null"); + } + + return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone, multi); + } + } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java index d693190741b5d..4888495f9d8da 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java @@ -1,3 +1,22 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.elasticsearch.search.aggregations.support; import org.elasticsearch.common.ParseField; @@ -33,7 +52,7 @@ public static void declareField(String fieldName, AbstractObjectParser, T> objectParser, boolean scriptable, boolean timezoneAware) { - objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig), + objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()), (p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null), new ParseField(fieldName), ObjectParser.ValueType.OBJECT); } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java index e264a879f0da8..539f760372a67 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.search.aggregations.metrics.weighted_avg; -import org.apache.lucene.document.Document; import org.apache.lucene.document.IntPoint; import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.SortedNumericDocValuesField; @@ -51,8 +50,8 @@ public class WeightedAvgAggregatorTests extends AggregatorTestCase { public void testNoDocs() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -64,8 +63,8 @@ public void testNoDocs() throws IOException { } public void testNoMatchingField() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -78,8 +77,8 @@ public void testNoMatchingField() throws IOException { } public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -93,8 +92,8 @@ public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException { } public void testSomeMatchesSortedNumericDocValuesWeights() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -113,8 +112,8 @@ public void testSomeMatchesSortedNumericDocValuesWeights() throws IOException { } public void testSomeMatchesNumericDocValues() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -128,8 +127,8 @@ public void testSomeMatchesNumericDocValues() throws IOException { } public void testQueryFiltering() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -143,8 +142,8 @@ public void testQueryFiltering() throws IOException { } public void testQueryFilteringWeights() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -162,8 +161,8 @@ public void testQueryFilteringWeights() throws IOException { } public void testQueryFiltersAll() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -177,8 +176,8 @@ public void testQueryFiltersAll() throws IOException { } public void testQueryFiltersAllWeights() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -195,8 +194,11 @@ public void testQueryFiltersAllWeights() throws IOException { } public void testValueSetMissing() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMissing(2); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMissing(2) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -211,8 +213,11 @@ public void testValueSetMissing() throws IOException { } public void testWeightSetMissing() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMissing(2); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMissing(2) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -227,8 +232,11 @@ public void testWeightSetMissing() throws IOException { } public void testWeightSetTimezone() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setTimeZone(DateTimeZone.UTC); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setTimeZone(DateTimeZone.UTC) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -245,8 +253,11 @@ public void testWeightSetTimezone() throws IOException { } public void testValueSetTimezone() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setTimeZone(DateTimeZone.UTC); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setTimeZone(DateTimeZone.UTC) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -263,8 +274,11 @@ public void testValueSetTimezone() throws IOException { } public void testValueSetMultiAvg() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.AVG); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMulti(MultiValueMode.AVG) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -283,8 +297,11 @@ public void testValueSetMultiAvg() throws IOException { } public void testValueSetMultiMax() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.MAX); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMulti(MultiValueMode.MAX) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -303,8 +320,11 @@ public void testValueSetMultiMax() throws IOException { } public void testValueSetMultiMin() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.MIN); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMulti(MultiValueMode.MIN) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -323,8 +343,11 @@ public void testValueSetMultiMin() throws IOException { } public void testValueSetMultiSum() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field").setMulti(MultiValueMode.SUM); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMulti(MultiValueMode.SUM) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -343,8 +366,11 @@ public void testValueSetMultiSum() throws IOException { } public void testWeightSetMultiAvg() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.AVG); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMulti(MultiValueMode.AVG) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -367,8 +393,11 @@ public void testWeightSetMultiAvg() throws IOException { } public void testWeightSetMultiMax() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.MAX); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMulti(MultiValueMode.MAX) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -390,8 +419,11 @@ public void testWeightSetMultiMax() throws IOException { } public void testWeightSetMultiMin() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.MIN); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMulti(MultiValueMode.MIN) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -413,8 +445,11 @@ public void testWeightSetMultiMin() throws IOException { } public void testWeightSetMultiSum() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field").setMulti(MultiValueMode.SUM); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMulti(MultiValueMode.SUM) + .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); @@ -467,8 +502,8 @@ public void testSummationAccuracy() throws IOException { } private void verifyAvgOfDoubles(double[] values, double expected, double delta) throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig("value_field"); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig("weight_field"); + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java new file mode 100644 index 0000000000000..3e64175ac2295 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java @@ -0,0 +1,44 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.script.Script; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class MultiValuesSourceFieldConfigTests extends ESTestCase { + public void testMissingFieldScript() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new MultiValuesSourceFieldConfig.Builder().build()); + assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be null. Please specify one or the other.")); + } + + public void testBothFieldScript() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build()); + assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other.")); + } + + public void testNullMulti() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setMulti(null).build()); + assertThat(e.getMessage(), equalTo("[multi] cannot be null")); + } +} From b24507100a19473f07018c989f24be635fad9e05 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 19 Jun 2018 16:30:05 +0000 Subject: [PATCH 05/11] Add Documentation --- docs/build.gradle | 4 +- docs/reference/aggregations/metrics.asciidoc | 2 + .../metrics/weighted-avg-aggregation.asciidoc | 185 ++++++++++++++++++ 3 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc diff --git a/docs/build.gradle b/docs/build.gradle index f1d1324192b16..5913156f5b2f7 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -393,9 +393,9 @@ buildRestTests.setups['exams'] = ''' refresh: true body: | {"index":{}} - {"grade": 100} + {"grade": 100, "weight": 2} {"index":{}} - {"grade": 50}''' + {"grade": 50, "weight": 3}''' buildRestTests.setups['stored_example_script'] = ''' # Simple script to load a field. Not really a good example, but a simple one. diff --git a/docs/reference/aggregations/metrics.asciidoc b/docs/reference/aggregations/metrics.asciidoc index ae6bee2eb7d17..96597564dac2d 100644 --- a/docs/reference/aggregations/metrics.asciidoc +++ b/docs/reference/aggregations/metrics.asciidoc @@ -13,6 +13,8 @@ bucket aggregations (some bucket aggregations enable you to sort the returned bu include::metrics/avg-aggregation.asciidoc[] +include::metrics/weighted-avg-aggregation.asciidoc[] + include::metrics/cardinality-aggregation.asciidoc[] include::metrics/extendedstats-aggregation.asciidoc[] diff --git a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc new file mode 100644 index 0000000000000..cec387365da38 --- /dev/null +++ b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc @@ -0,0 +1,185 @@ +[[search-aggregations-metrics-weight-avg-aggregation]] +=== Weighted Avg Aggregation + +A `single-value` metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents. +These values can be extracted either from specific numeric fields in the documents. + +When calculating a regular average, each datapoint has an equal "weight" ... it contributes equally to the final value. Weighted averages, +on the other hand, weight each datapoint differently. The amount that each datapoint contributes to the final value is extracted from the +document, or provided by a script. + +As a formula, a weighted average is the `∑(value * weight) / ∑(weight)` + +A regular average can be thought of as a weighted average where every value has an implicit weight of `1`. + + +.`weighted_avg` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`value` | The configuration for the field or script that provides the values |Required | +|`weight` | The configuration for the field or script that provides the weights |Required | +|`format` | The numeric response formatter |Optional | +|`value_type` | A hint about the values for pure scripts or unmapped fields |Optional | +|=== + +The `value` and `weight` objects have per-field specific configuration: + +.`value` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`field` | The field that values should be extracted from |Required | +|`missing` | A value to use if the field is missing entirely |Optional | +|`multi` | If a document has multiple values for the field, how should the values be combined |Optional | `avg` +|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional +|=== + +.`weight` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`field` | The field that weights should be extracted from |Required | +|`missing` | A weight to use if the field is missing entirely |Optional | +|`multi` | If a document has multiple values for the field, how should the values be combined |Optional | `avg` +|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional +|=== + + +==== Examples + +If our documents have a `"grade"` field that holds a 0-100 numeric score, and a `"weight"` field which holds an arbitrary numeric weight, +we can calculate the weighted average using: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade" + }, + "weight": { + "field": "weight" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + +Which yields a response like: + +[source,js] +-------------------------------------------------- +{ + ... + "aggregations": { + "weighted_grade": { + "value": 70.0 + } + } +} +-------------------------------------------------- +// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] + + +==== Script + +Both the value and the weight can be derived from a script, instead of a field. As a simple example, the following +will add one to the grade and weight in the document using a script: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "script": "doc.grade.value + 1" + }, + "weight": { + "script": "doc.weight.value + 1" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + + +==== Missing values + +The `missing` parameter defines how documents that are missing a value should be treated. +The default behavior is different for `value` and `weight`: + +By default, if the `value` field is missing the document is ignored and the aggregation moves on to the next document. +If the `weight` field is missing, it is assumed to have a weight of `1` (like a normal average). + +Both of these defaults can be overridden with the `missing` parameter: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade", + "missing": 2 + }, + "weight": { + "field": "weight", + "missing": 3 + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + +==== Multi-value mode + +If a document has multiple values, you can configure the `multi` mode of both `value` and `weight`. This controls +how the multiple values should be combined when calculating the average. Acceptable values are: + +- `avg`: average the multiple values together +- `min`: use the minimum value +- `max`: use the maximum value +- `sum`: sum all the values together + +The default if unspecified is `avg`. + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade", + "multi": "avg" + }, + "weight": { + "field": "weight", + "multi": "min" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] From b7acb878639ef29f159bf0350dd4142807e0f7e2 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 19 Jun 2018 18:42:47 +0000 Subject: [PATCH 06/11] checkstyle --- .../support/ArrayValuesSourceAggregationBuilder.java | 5 +++-- .../support/ArrayValuesSourceAggregatorFactory.java | 6 ++++-- .../aggregations/support/ArrayValuesSourceParser.java | 6 ++++-- .../metrics/weighted_avg/ParsedWeightedAvg.java | 3 ++- .../search/aggregations/support/MultiValuesSource.java | 3 +-- .../support/MultiValuesSourceAggregatorFactory.java | 1 - .../search/aggregations/support/ValueType.java | 2 +- .../search/aggregations/support/ValuesSourceType.java | 2 +- 8 files changed, 16 insertions(+), 12 deletions(-) diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java index 39f5885e7c79c..eb8152e0fe0b8 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java @@ -256,8 +256,9 @@ protected Map> resolveConfig(SearchContext contex } protected abstract ArrayValuesSourceAggregatorFactory innerBuild(SearchContext context, - Map> configs, AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder) throws IOException; + Map> configs, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException; public ValuesSourceConfig config(SearchContext context, String field, Script script) { diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java index cd6f0eb2b06bb..ce8eeecd19036 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java @@ -36,7 +36,8 @@ public abstract class ArrayValuesSourceAggregatorFactory> configs; public ArrayValuesSourceAggregatorFactory(String name, Map> configs, - SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, Map metaData) throws IOException { super(name, context, parent, subFactoriesBuilder, metaData); this.configs = configs; @@ -63,6 +64,7 @@ protected abstract Aggregator createUnmapped(Aggregator parent, List metaData) throws IOException; protected abstract Aggregator doCreateInternal(Map valuesSources, Aggregator parent, boolean collectsFromSingleBucket, - List pipelineAggregators, Map metaData) throws IOException; + List pipelineAggregators, + Map metaData) throws IOException; } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java index c2857411c0b39..1100884cf8ace 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java @@ -197,8 +197,10 @@ private void parseMissingAndAdd(final String aggregationName, final String curre * method * @return the created factory */ - protected abstract ArrayValuesSourceAggregationBuilder createFactory(String aggregationName, ValuesSourceType valuesSourceType, - ValueType targetValueType, Map otherOptions); + protected abstract ArrayValuesSourceAggregationBuilder createFactory(String aggregationName, + ValuesSourceType valuesSourceType, + ValueType targetValueType, + Map otherOptions); /** * Allows subclasses of {@link ArrayValuesSourceParser} to parse extra diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java index e558c8d6488e4..dcda79ce33e92 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java @@ -50,7 +50,8 @@ protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) return builder; } - private static final ObjectParser PARSER = new ObjectParser<>(ParsedWeightedAvg.class.getSimpleName(), true, ParsedWeightedAvg::new); + private static final ObjectParser PARSER + = new ObjectParser<>(ParsedWeightedAvg.class.getSimpleName(), true, ParsedWeightedAvg::new); static { declareSingleValueFields(PARSER, Double.POSITIVE_INFINITY); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java index 71117879e3a82..60654cc4955ad 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java @@ -27,7 +27,6 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; -import java.util.function.Predicate; /** * Class to encapsulate a set of ValuesSource objects labeled by field name @@ -104,7 +103,7 @@ public static class GeoPointValuesSource extends MultiValuesSource valuesSourceConfigs, QueryShardContext context) throws IOException { values = new HashMap<>(valuesSourceConfigs.getMap().size()); - for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { + for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()){ values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), entry.getValue().getConfig().toValuesSource(context))); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java index 7f37e64e6ba17..a940fabe4f026 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java @@ -27,7 +27,6 @@ import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.Map; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java index 5c4e65d2bd4df..7f6e76a6611a8 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java @@ -96,7 +96,7 @@ public boolean isNumeric() { private final byte id; private String preferredName; - public final static ParseField VALUE_TYPE = new ParseField("value_type", "valueType"); + public static final ParseField VALUE_TYPE = new ParseField("value_type", "valueType"); ValueType(byte id, String description, String preferredName, ValuesSourceType valuesSourceType, Class fieldDataType, DocValueFormat defaultFormat) { diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java index a68e24da81bf6..387e807ba861e 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java @@ -33,7 +33,7 @@ public enum ValuesSourceType implements Writeable { BYTES, GEOPOINT; - public final static ParseField VALUE_SOURCE_TYPE = new ParseField("value_source_type"); + public static final ParseField VALUE_SOURCE_TYPE = new ParseField("value_source_type"); public static ValuesSourceType fromString(String name) { return valueOf(name.trim().toUpperCase(Locale.ROOT)); From 7883d3977c1b8a8b21e6d6d846509ec549ad2577 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Wed, 20 Jun 2018 17:30:20 +0000 Subject: [PATCH 07/11] Review cleanup --- .../metrics/weighted-avg-aggregation.asciidoc | 4 ++-- .../weighted_avg/WeightedAvgAggregator.java | 12 ++++++------ .../MultiValuesSourceAggregationBuilder.java | 15 ++++++--------- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc index cec387365da38..78d9e80bf537a 100644 --- a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc @@ -38,8 +38,8 @@ The `value` and `weight` objects have per-field specific configuration: |Parameter Name |Description |Required |Default Value |`field` | The field that weights should be extracted from |Required | |`missing` | A weight to use if the field is missing entirely |Optional | -|`multi` | If a document has multiple values for the field, how should the values be combined |Optional | `avg` -|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional +|`multi` | If a document has multiple weights for the field, how should the weights be combined |Optional | `avg` +|`script` | A script which provides the weights for the document. This is mutually exclusive with `field` |Optional |=== diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java index 40ea390cb1b6a..d5b7e991e931b 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -42,13 +42,13 @@ public class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue { - final MultiValuesSource.NumericMultiValuesSource valuesSources; + private final MultiValuesSource.NumericMultiValuesSource valuesSources; - DoubleArray weights; - DoubleArray sums; - DoubleArray sumCompensations; - DoubleArray weightCompensations; - DocValueFormat format; + private DoubleArray weights; + private DoubleArray sums; + private DoubleArray sumCompensations; + private DoubleArray weightCompensations; + private DocValueFormat format; public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, DocValueFormat format, SearchContext context, Aggregator parent, List pipelineAggregators, diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java index e3ac1455eb541..6d00b0c378a85 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -35,6 +35,12 @@ import java.util.Map; import java.util.Objects; +/** + * Similar to {@link ValuesSourceAggregationBuilder}, except it references multiple ValuesSources (e.g. so that an aggregation + * can pull values from multiple fields). + * + * A limitation of this class is that all the ValuesSource's being refereenced must be of the same type. + */ public abstract class MultiValuesSourceAggregationBuilder> extends AbstractAggregationBuilder { @@ -61,15 +67,6 @@ protected LeafOnly(StreamInput in, ValueType targetValueType) throws IOException super(in, targetValueType); } - /** - * Read an aggregation from a stream that serializes its targetValueType. This should only be used by subclasses that override - * {@link #serializeTargetValueType()} to return true. - */ - /* - protected LeafOnly(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { - super(in, valuesSourceType); - } - */ @Override public AB subAggregations(Builder subFactories) { throw new AggregationInitializationException("Aggregator [" + name + "] of type [" + From 3baf06fb1941f1998920ecae023c7bbd334bb80b Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Fri, 6 Jul 2018 13:46:08 -0400 Subject: [PATCH 08/11] Remove unecessary stream ctor, extend LeafOnly Fixes an issue where assertions were being tripped on REST tests due to using the wrong stream ctor --- .../weighted_avg/WeightedAvgAggregationBuilder.java | 4 ++-- .../support/MultiValuesSourceAggregationBuilder.java | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java index cedfcb0faeb24..5e3bd31c938cf 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java @@ -43,7 +43,7 @@ import java.util.Map; import java.util.Objects; -public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder { +public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder.LeafOnly { public static final String NAME = "weighted_avg"; public static final ParseField VALUE_FIELD = new ParseField("value"); public static final ParseField WEIGHT_FIELD = new ParseField("weight"); @@ -84,7 +84,7 @@ public WeightedAvgAggregationBuilder weight(MultiValuesSourceFieldConfig weightC * Read from a stream. */ public WeightedAvgAggregationBuilder(StreamInput in) throws IOException { - super(in); + super(in, ValueType.NUMERIC); } @Override diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java index 6d00b0c378a85..840fff80bd70b 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -104,13 +104,6 @@ protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetVa read(in); } - protected MultiValuesSourceAggregationBuilder(StreamInput in) throws IOException { - super(in); - assert serializeTargetValueType() : "Wrong read constructor called for subclass that serializes its targetValueType"; - this.targetValueType = in.readOptionalWriteable(ValueType::readFromStream); - read(in); - } - /** * Read from a stream. */ From 3959bb45537ff9c9876a3d56c3725101f40da5e0 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Mon, 9 Jul 2018 11:08:01 -0400 Subject: [PATCH 09/11] Add clarification comment --- .../metrics/weighted_avg/InternalWeightedAvg.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java index 3a19de46325a9..9ad1a1df78aec 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java @@ -97,6 +97,8 @@ public InternalWeightedAvg doReduce(List aggregations, Redu // accurate than naive summation. for (InternalAggregation aggregation : aggregations) { InternalWeightedAvg avg = (InternalWeightedAvg) aggregation; + // If the weight is Inf or NaN, just add it to the running tally to "convert" to + // Inf/NaN. This keeps the behavior bwc from before kahan summing if (Double.isFinite(avg.weight) == false) { weight += avg.weight; } else if (Double.isFinite(weight)) { @@ -105,6 +107,8 @@ public InternalWeightedAvg doReduce(List aggregations, Redu weightCompensation = (newWeight - weight) - corrected; weight = newWeight; } + // If the avg is Inf or NaN, just add it to the running tally to "convert" to + // Inf/NaN. This keeps the behavior bwc from before kahan summing if (Double.isFinite(avg.sum) == false) { sum += avg.sum; } else if (Double.isFinite(sum)) { From afbfee700985aadb56eadd7ac2eb5133b4d70f60 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Tue, 17 Jul 2018 15:57:26 -0400 Subject: [PATCH 10/11] Remove MultiValueMode options and MultiValuesSourceConfig --- .../metrics/weighted-avg-aggregation.asciidoc | 47 +--- .../WeightedAvgAggregationBuilder.java | 10 +- .../weighted_avg/WeightedAvgAggregator.java | 31 ++- .../WeightedAvgAggregatorFactory.java | 13 +- .../support/MultiValuesSource.java | 80 ++---- .../MultiValuesSourceAggregationBuilder.java | 6 +- .../MultiValuesSourceAggregatorFactory.java | 6 +- .../support/MultiValuesSourceConfig.java | 56 ---- .../support/MultiValuesSourceFieldConfig.java | 34 +-- .../WeightedAvgAggregatorTests.java | 240 ++++-------------- .../MultiValuesSourceFieldConfigTests.java | 6 - 11 files changed, 130 insertions(+), 399 deletions(-) delete mode 100644 server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java diff --git a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc index 78d9e80bf537a..4bc8b5f234c94 100644 --- a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc @@ -12,6 +12,16 @@ As a formula, a weighted average is the `∑(value * weight) / ∑(weight)` A regular average can be thought of as a weighted average where every value has an implicit weight of `1`. +[NOTE] +====== +While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters +a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception. +If you have this situation, you will need to specify a `script` for the weight field, and use the script +to combine the multiple values into a single value to be used. + +This single weight will be applied independently to each value extracted from the `value` field. +====== + .`weighted_avg` Parameters |=== @@ -29,7 +39,6 @@ The `value` and `weight` objects have per-field specific configuration: |Parameter Name |Description |Required |Default Value |`field` | The field that values should be extracted from |Required | |`missing` | A value to use if the field is missing entirely |Optional | -|`multi` | If a document has multiple values for the field, how should the values be combined |Optional | `avg` |`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional |=== @@ -38,7 +47,6 @@ The `value` and `weight` objects have per-field specific configuration: |Parameter Name |Description |Required |Default Value |`field` | The field that weights should be extracted from |Required | |`missing` | A weight to use if the field is missing entirely |Optional | -|`multi` | If a document has multiple weights for the field, how should the weights be combined |Optional | `avg` |`script` | A script which provides the weights for the document. This is mutually exclusive with `field` |Optional |=== @@ -148,38 +156,3 @@ POST /exams/_search // CONSOLE // TEST[setup:exams] -==== Multi-value mode - -If a document has multiple values, you can configure the `multi` mode of both `value` and `weight`. This controls -how the multiple values should be combined when calculating the average. Acceptable values are: - -- `avg`: average the multiple values together -- `min`: use the minimum value -- `max`: use the maximum value -- `sum`: sum all the values together - -The default if unspecified is `avg`. - -[source,js] --------------------------------------------------- -POST /exams/_search -{ - "size": 0, - "aggs" : { - "weighted_grade": { - "weighted_avg": { - "value": { - "field": "grade", - "multi": "avg" - }, - "weight": { - "field": "weight", - "multi": "min" - } - } - } - } -} --------------------------------------------------- -// CONSOLE -// TEST[setup:exams] diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java index 5e3bd31c938cf..be06f792a5e89 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java @@ -32,11 +32,11 @@ import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -99,10 +99,10 @@ protected void innerWriteTo(StreamOutput out) { @Override protected MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, - MultiValuesSourceConfig configs, - DocValueFormat format, - AggregatorFactory parent, - Builder subFactoriesBuilder) throws IOException { + Map> configs, + DocValueFormat format, + AggregatorFactory parent, + Builder subFactoriesBuilder) throws IOException { return new WeightedAvgAggregatorFactory(name, configs, format, context, parent, subFactoriesBuilder, metaData); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java index d5b7e991e931b..d9f683548817b 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -22,8 +22,9 @@ import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.DoubleArray; -import org.elasticsearch.index.fielddata.NumericDoubleValues; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.InternalAggregation; import org.elasticsearch.search.aggregations.LeafBucketCollector; @@ -77,8 +78,8 @@ public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, return LeafBucketCollector.NO_OP_COLLECTOR; } final BigArrays bigArrays = context.bigArrays(); - final NumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx); - final NumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), 1.0, ctx); + final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx); + final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx); return new LeafBucketCollectorBase(sub, docValues) { @Override @@ -88,13 +89,23 @@ public void collect(int doc, long bucket) throws IOException { sumCompensations = bigArrays.grow(sumCompensations, bucket + 1); weightCompensations = bigArrays.grow(weightCompensations, bucket + 1); - if (docValues.advanceExact(doc)) { - boolean advanced = docWeights.advanceExact(doc); - assert advanced; - final double weight = docWeights.doubleValue(); - - kahanSum(docValues.doubleValue() * weight, sums, sumCompensations, bucket); - kahanSum(weight, weights, weightCompensations, bucket); + if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) { + if (docWeights.docValueCount() > 1) { + throw new AggregationExecutionException("Encountered more than one weight for a " + + "single document. Use a script to combine multiple weights-per-doc into a single value."); + } + // There should always be one weight if advanceExact lands us here, either + // a real weight or a `missing` value + assert docWeights.docValueCount() == 1; + final double weight = docWeights.nextValue(); + + final int numValues = docValues.docValueCount(); + assert numValues > 0; + + for (int i = 0; i < numValues; i++) { + kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket); + kahanSum(weight, weights, weightCompensations, bucket); + } } } }; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java index 53e7d3e164dc7..c7aab73af2867 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java @@ -26,8 +26,8 @@ import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.MultiValuesSource; import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceConfig; import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; @@ -36,9 +36,9 @@ public class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory { - public WeightedAvgAggregatorFactory(String name, MultiValuesSourceConfig configs, - DocValueFormat format, - SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + public WeightedAvgAggregatorFactory(String name, Map> configs, + DocValueFormat format, SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, Map metaData) throws IOException { super(name, configs, format, context, parent, subFactoriesBuilder, metaData); } @@ -50,8 +50,9 @@ protected Aggregator createUnmapped(Aggregator parent, List } @Override - protected Aggregator doCreateInternal(MultiValuesSourceConfig configs, DocValueFormat format, Aggregator parent, - boolean collectsFromSingleBucket, List pipelineAggregators, + protected Aggregator doCreateInternal(Map> configs, DocValueFormat format, + Aggregator parent, boolean collectsFromSingleBucket, + List pipelineAggregators, Map metaData) throws IOException { MultiValuesSource.NumericMultiValuesSource numericMultiVS = new MultiValuesSource.NumericMultiValuesSource(configs, context.getQueryShardContext()); diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java index 60654cc4955ad..9ceecd75deaf0 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java @@ -19,100 +19,68 @@ package org.elasticsearch.search.aggregations.support; import org.apache.lucene.index.LeafReaderContext; -import org.elasticsearch.index.fielddata.FieldData; -import org.elasticsearch.index.fielddata.NumericDoubleValues; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; import org.elasticsearch.index.query.QueryShardContext; -import org.elasticsearch.search.MultiValueMode; import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Objects; /** * Class to encapsulate a set of ValuesSource objects labeled by field name */ public abstract class MultiValuesSource { - - public static class Wrapper { - private MultiValueMode multiValueMode; - private VS valueSource; - - public Wrapper(MultiValueMode multiValueMode, VS value) { - this.multiValueMode = multiValueMode; - this.valueSource = value; - } - - public MultiValueMode getMultiValueMode() { - return multiValueMode; - } - - public VS getValueSource() { - return valueSource; - } - } - - protected Map> values; + protected Map values; public static class NumericMultiValuesSource extends MultiValuesSource { - public NumericMultiValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + public NumericMultiValuesSource(Map> valuesSourceConfigs, QueryShardContext context) throws IOException { - values = new HashMap<>(valuesSourceConfigs.getMap().size()); - for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { - values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), - entry.getValue().getConfig().toValuesSource(context))); + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); } } - public NumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException { - Wrapper wrapper = values.get(fieldName); - if (wrapper == null) { + public SortedNumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException { + ValuesSource.Numeric value = values.get(fieldName); + if (value == null) { throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); } - return wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx)); - } - - public NumericDoubleValues getField(String fieldName, double defaultValue, LeafReaderContext ctx) throws IOException { - Wrapper wrapper = values.get(fieldName); - if (wrapper == null) { - throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); - } - return FieldData.replaceMissing(wrapper.getMultiValueMode().select(wrapper.getValueSource().doubleValues(ctx)), defaultValue); + return value.doubleValues(ctx); } } public static class BytesMultiValuesSource extends MultiValuesSource { - public BytesMultiValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + public BytesMultiValuesSource(Map> valuesSourceConfigs, QueryShardContext context) throws IOException { - values = new HashMap<>(valuesSourceConfigs.getMap().size()); - for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()) { - values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), - entry.getValue().getConfig().toValuesSource(context))); + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); } } public Object getField(String fieldName, LeafReaderContext ctx) throws IOException { - Wrapper wrapper = values.get(fieldName); - if (wrapper == null) { + ValuesSource.Bytes value = values.get(fieldName); + if (value == null) { throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); } - return wrapper.getValueSource().bytesValues(ctx); + return value.bytesValues(ctx); } } public static class GeoPointValuesSource extends MultiValuesSource { - public GeoPointValuesSource(MultiValuesSourceConfig valuesSourceConfigs, + public GeoPointValuesSource(Map> valuesSourceConfigs, QueryShardContext context) throws IOException { - values = new HashMap<>(valuesSourceConfigs.getMap().size()); - for (Map.Entry> entry : valuesSourceConfigs.getMap().entrySet()){ - values.put(entry.getKey(), new Wrapper<>(entry.getValue().getMulti(), - entry.getValue().getConfig().toValuesSource(context))); + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); } } } - public boolean needsScores() { - return values.values().stream().anyMatch(vsWrapper -> vsWrapper.getValueSource().needsScores()); + return values.values().stream().anyMatch(ValuesSource::needsScores); } public String[] fieldNames() { @@ -120,6 +88,6 @@ public String[] fieldNames() { } public boolean areValuesSourcesEmpty() { - return values.values().stream().allMatch(vsWrapper -> vsWrapper.getValueSource() == null); + return values.values().stream().allMatch(Objects::isNull); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java index 840fff80bd70b..fee685346ec98 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -186,11 +186,11 @@ public String format() { AggregatorFactories.Builder subFactoriesBuilder) throws IOException { ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType; - MultiValuesSourceConfig configs = new MultiValuesSourceConfig<>(); + Map> configs = new HashMap<>(fields.size()); fields.forEach((key, value) -> { ValuesSourceConfig config = ValuesSourceConfig.resolve(context.getQueryShardContext(), finalValueType, value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format); - configs.addField(key, config, value.getMulti()); + configs.put(key, config); }); DocValueFormat docValueFormat = resolveFormat(format, finalValueType); return innerBuild(context, configs, docValueFormat, parent, subFactoriesBuilder); @@ -209,7 +209,7 @@ private static DocValueFormat resolveFormat(@Nullable String format, @Nullable V } protected abstract MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, - MultiValuesSourceConfig configs, DocValueFormat format, AggregatorFactory parent, + Map> configs, DocValueFormat format, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java index a940fabe4f026..5de8fbd7561dc 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java @@ -33,10 +33,10 @@ public abstract class MultiValuesSourceAggregatorFactory> extends AggregatorFactory { - protected final MultiValuesSourceConfig configs; + protected final Map> configs; protected final DocValueFormat format; - public MultiValuesSourceAggregatorFactory(String name, MultiValuesSourceConfig configs, + public MultiValuesSourceAggregatorFactory(String name, Map> configs, DocValueFormat format, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, Map metaData) throws IOException { @@ -56,7 +56,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu protected abstract Aggregator createUnmapped(Aggregator parent, List pipelineAggregators, Map metaData) throws IOException; - protected abstract Aggregator doCreateInternal(MultiValuesSourceConfig configs, + protected abstract Aggregator doCreateInternal(Map> configs, DocValueFormat format, Aggregator parent, boolean collectsFromSingleBucket, List pipelineAggregators, Map metaData) throws IOException; diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java deleted file mode 100644 index ceee40ddde4f9..0000000000000 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceConfig.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.search.aggregations.support; - -import org.elasticsearch.search.MultiValueMode; - -import java.util.HashMap; -import java.util.Map; - -public class MultiValuesSourceConfig { - private Map> map = new HashMap<>(); - - public static class Wrapper { - private MultiValueMode multi; - private ValuesSourceConfig config; - - public Wrapper(MultiValueMode multi, ValuesSourceConfig config) { - this.multi = multi; - this.config = config; - } - - public MultiValueMode getMulti() { - return multi; - } - - public ValuesSourceConfig getConfig() { - return config; - } - } - - public void addField(String fieldName, ValuesSourceConfig config, MultiValueMode multiValueMode) { - map.put(fieldName, new Wrapper<>(multiValueMode, config)); - } - - public Map> getMap() { - return map; - } - -} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java index 5cc350f8c623e..56ceae69ff78e 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -29,7 +29,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.script.Script; -import org.elasticsearch.search.MultiValueMode; import org.joda.time.DateTimeZone; import java.io.IOException; @@ -40,10 +39,8 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragme private Object missing; private Script script; private DateTimeZone timeZone; - private MultiValueMode multi; private static final String NAME = "field_config"; - private static final ParseField MULTI = new ParseField("multi"); public static final BiFunction> PARSER = (scriptable, timezoneAware) -> { @@ -54,8 +51,6 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragme parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); parser.declareField(MultiValuesSourceFieldConfig.Builder::setMissing, XContentParser::objectText, ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); - parser.declareField(MultiValuesSourceFieldConfig.Builder::setMulti, p -> MultiValueMode.fromString(p.text()), MULTI, - ObjectParser.ValueType.STRING); if (scriptable) { parser.declareField(MultiValuesSourceFieldConfig.Builder::setScript, @@ -75,12 +70,11 @@ public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragme return parser; }; - private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, DateTimeZone timeZone, MultiValueMode multi) { + private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, DateTimeZone timeZone) { this.fieldName = fieldName; this.missing = missing; this.script = script; this.timeZone = timeZone; - this.multi = multi; } public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { @@ -88,7 +82,6 @@ public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { this.missing = in.readGenericValue(); this.script = in.readOptionalWriteable(Script::new); this.timeZone = in.readOptionalTimeZone(); - this.multi = MultiValueMode.readMultiValueModeFrom(in); } public Object getMissing() { @@ -107,18 +100,12 @@ public String getFieldName() { return fieldName; } - - public MultiValueMode getMulti() { - return multi; - } - @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeGenericValue(missing); out.writeOptionalWriteable(script); out.writeOptionalTimeZone(timeZone); - multi.writeTo(out); } @Override @@ -135,9 +122,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (timeZone != null) { builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone); } - if (multi != null) { - builder.field(MULTI.getPreferredName(), multi); - } return builder; } @@ -146,7 +130,6 @@ public static class Builder { private Object missing = null; private Script script = null; private DateTimeZone timeZone = null; - private MultiValueMode multi = MultiValueMode.AVG; public String getFieldName() { return fieldName; @@ -184,15 +167,6 @@ public Builder setTimeZone(DateTimeZone timeZone) { return this; } - public MultiValueMode getMulti() { - return multi; - } - - public Builder setMulti(MultiValueMode multi) { - this.multi = multi; - return this; - } - public MultiValuesSourceFieldConfig build() { if (Strings.isNullOrEmpty(fieldName) && script == null) { throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() @@ -206,11 +180,7 @@ public MultiValuesSourceFieldConfig build() { "Please specify one or the other."); } - if (multi == null) { - throw new IllegalArgumentException("[" + MULTI.getPreferredName() + "] cannot be null"); - } - - return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone, multi); + return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone); } } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java index 539f760372a67..70b1b651723e0 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java @@ -34,7 +34,7 @@ import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.NumberFieldMapper; -import org.elasticsearch.search.MultiValueMode; +import org.elasticsearch.search.aggregations.AggregationExecutionException; import org.elasticsearch.search.aggregations.AggregatorTestCase; import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; import org.joda.time.DateTimeZone; @@ -45,6 +45,7 @@ import java.util.function.Consumer; import static java.util.Collections.singleton; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class WeightedAvgAggregatorTests extends AggregatorTestCase { @@ -83,9 +84,12 @@ public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException { .value(valueConfig) .weight(weightConfig); testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 7))); - iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 2))); - iw.addDocument(singleton(new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { assertEquals(4, avg.getValue(), 0); }); @@ -118,9 +122,12 @@ public void testSomeMatchesNumericDocValues() throws IOException { .value(valueConfig) .weight(weightConfig); testCase(new DocValuesFieldExistsQuery("value_field"), aggregationBuilder, iw -> { - iw.addDocument(singleton(new NumericDocValuesField("value_field", 7))); - iw.addDocument(singleton(new NumericDocValuesField("value_field", 2))); - iw.addDocument(singleton(new NumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { assertEquals(4, avg.getValue(), 0); }); @@ -133,9 +140,12 @@ public void testQueryFiltering() throws IOException { .value(valueConfig) .weight(weightConfig); testCase(IntPoint.newRangeQuery("value_field", 0, 3), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7))); - iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2))); - iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { assertEquals(2.5, avg.getValue(), 0); }); @@ -243,9 +253,12 @@ public void testWeightSetTimezone() throws IOException { IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { fail("Should not have executed test case"); })); @@ -264,19 +277,21 @@ public void testValueSetTimezone() throws IOException { IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); - iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { fail("Should not have executed test case"); })); assertThat(e.getMessage(), equalTo("Field [value_field] of type [long] does not support custom time zones")); } - public void testValueSetMultiAvg() throws IOException { + public void testMultiValues() throws IOException { MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() .setFieldName("value_field") - .setMulti(MultiValueMode.AVG) .build(); MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") @@ -285,190 +300,44 @@ public void testValueSetMultiAvg() throws IOException { testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("value_field", 3))); + new SortedNumericDocValuesField("value_field", 3), new SortedNumericDocValuesField("weight_field", 1))); iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("value_field", 4))); + new SortedNumericDocValuesField("value_field", 4), new SortedNumericDocValuesField("weight_field", 1))); iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("value_field", 5))); + new SortedNumericDocValuesField("value_field", 5), new SortedNumericDocValuesField("weight_field", 1))); }, avg -> { double value = (((2.0+3.0)/2.0) + ((3.0+4.0)/2.0) + ((4.0+5.0)/2.0)) / (1.0+1.0+1.0); assertEquals(value, avg.getValue(), 0); }); } - public void testValueSetMultiMax() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("value_field") - .setMulti(MultiValueMode.MAX) - .build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("value_field", 3))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("value_field", 4))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("value_field", 5))); - }, avg -> { - double value = (3.0 + 4.0 + 5.0) / (1.0+1.0+1.0); - assertEquals(value, avg.getValue(), 0); - }); - } - - public void testValueSetMultiMin() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("value_field") - .setMulti(MultiValueMode.MIN) - .build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("value_field", 3))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("value_field", 4))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("value_field", 5))); - }, avg -> { - double value = (2.0 + 3.0 + 4.0) / (1.0+1.0+1.0); - assertEquals(value, avg.getValue(), 0); - }); - } - - public void testValueSetMultiSum() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("value_field") - .setMulti(MultiValueMode.SUM) - .build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("value_field", 3))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("value_field", 4))); - iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("value_field", 5))); - }, avg -> { - double value = (5.0 + 7.0 + 9.0) / (1.0+1.0+1.0); - assertEquals(value, avg.getValue(), 0); - }); - } - - public void testWeightSetMultiAvg() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("weight_field") - .setMulti(MultiValueMode.AVG) - .build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); - }, avg -> { - double value = ((2.0 * (2.0+3.0)/2.0) + (3.0 * (3.0+4.0)/2.0) + (4.0 * (4.0+5.0)/2.0)) - / ((2.0+3.0)/2.0 + (3.0+4.0)/2.0 + (4.0+5.0)/2.0); - assertEquals(value, avg.getValue(), 0); - }); - } - - public void testWeightSetMultiMax() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("weight_field") - .setMulti(MultiValueMode.MAX) - .build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); - }, avg -> { - double value = ((2.0 * 3.0) + (3.0 * 4.0) + (4.0 * 5.0)) / (3.0+4.0+5.0); - assertEquals(value, avg.getValue(), 0); - }); - } - - public void testWeightSetMultiMin() throws IOException { + public void testMultiWeight() throws IOException { MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() .setFieldName("weight_field") - .setMulti(MultiValueMode.MIN) .build(); WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") .value(valueConfig) .weight(weightConfig); - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); - }, avg -> { - double value = ((2.0 * 2.0) + (3.0 * 3.0) + (4.0 * 4.0)) / (2.0+3.0+4.0); - assertEquals(value, avg.getValue(), 0); - }); + AggregationExecutionException e = expectThrows(AggregationExecutionException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + fail("Should have thrown exception"); + })); + assertThat(e.getMessage(), containsString("Encountered more than one weight for a single document. " + + "Use a script to combine multiple weights-per-doc into a single value.")); } - public void testWeightSetMultiSum() throws IOException { - MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); - MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() - .setFieldName("weight_field") - .setMulti(MultiValueMode.SUM) - .build(); - WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") - .value(valueConfig) - .weight(weightConfig); - - testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 2), - new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 3), - new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); - iw.addDocument(Arrays.asList( - new SortedNumericDocValuesField("value_field", 4), - new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); - }, avg -> { - double value = ((2.0 * 5.0) + (3.0 * 7.0) + (4.0 * 9.0)) / (5.0+7.0+9.0); - assertEquals(value, avg.getValue(), 0); - }); - } public void testSummationAccuracy() throws IOException { // Summing up a normal array and expect an accurate value @@ -510,7 +379,8 @@ private void verifyAvgOfDoubles(double[] values, double expected, double delta) testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { for (double value : values) { - iw.addDocument(singleton(new NumericDocValuesField("value_field", NumericUtils.doubleToSortableLong(value)))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", NumericUtils.doubleToSortableLong(value)), + new SortedNumericDocValuesField("weight_field", NumericUtils.doubleToSortableLong(1.0)))); } }, avg -> assertEquals(expected, avg.getValue(), delta), diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java index 3e64175ac2295..ac1c07a40490e 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java @@ -35,10 +35,4 @@ public void testBothFieldScript() { () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build()); assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other.")); } - - public void testNullMulti() { - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, - () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setMulti(null).build()); - assertThat(e.getMessage(), equalTo("[multi] cannot be null")); - } } From 177ff7e04c978f65466c790bc85cdb27ffcb4d13 Mon Sep 17 00:00:00 2001 From: Zachary Tong Date: Mon, 23 Jul 2018 16:50:22 -0400 Subject: [PATCH 11/11] Add example to docs for multi-valued fields, fix comment typo --- .../metrics/weighted-avg-aggregation.asciidoc | 66 +++++++++++++++---- .../weighted_avg/WeightedAvgAggregator.java | 2 +- 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc index 4bc8b5f234c94..252728a6db367 100644 --- a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc +++ b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc @@ -12,17 +12,6 @@ As a formula, a weighted average is the `∑(value * weight) / ∑(weight)` A regular average can be thought of as a weighted average where every value has an implicit weight of `1`. -[NOTE] -====== -While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters -a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception. -If you have this situation, you will need to specify a `script` for the weight field, and use the script -to combine the multiple values into a single value to be used. - -This single weight will be applied independently to each value extracted from the `value` field. -====== - - .`weighted_avg` Parameters |=== |Parameter Name |Description |Required |Default Value @@ -94,6 +83,61 @@ Which yields a response like: // TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] +While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters +a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception. +If you have this situation, you will need to specify a `script` for the weight field, and use the script +to combine the multiple values into a single value to be used. + +This single weight will be applied independently to each value extracted from the `value` field. + +This example show how a single document with multiple values will be averaged with a single weight: + +[source,js] +-------------------------------------------------- +POST /exams/_doc?refresh +{ + "grade": [1, 2, 3], + "weight": 2 +} + +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade" + }, + "weight": { + "field": "weight" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST + +The three values (`1`, `2`, and `3`) will be included as independent values, all with the weight of `2`: + +[source,js] +-------------------------------------------------- +{ + ... + "aggregations": { + "weighted_grade": { + "value": 2.0 + } + } +} +-------------------------------------------------- +// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] + +The aggregation returns `2.0` as the result, which matches what we would expect when calculating by hand: +`((1*2) + (2*2) + (3*2)) / (2+2+2) == 2` + ==== Script Both the value and the weight can be derived from a script, instead of a field. As a simple example, the following diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java index d9f683548817b..7a34fe6df4a68 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -95,7 +95,7 @@ public void collect(int doc, long bucket) throws IOException { "single document. Use a script to combine multiple weights-per-doc into a single value."); } // There should always be one weight if advanceExact lands us here, either - // a real weight or a `missing` value + // a real weight or a `missing` weight assert docWeights.docValueCount() == 1; final double weight = docWeights.nextValue();