diff --git a/docs/build.gradle b/docs/build.gradle index 829db4381b046..a67c0217490b3 100644 --- a/docs/build.gradle +++ b/docs/build.gradle @@ -379,9 +379,9 @@ buildRestTests.setups['exams'] = ''' refresh: true body: | {"index":{}} - {"grade": 100} + {"grade": 100, "weight": 2} {"index":{}} - {"grade": 50}''' + {"grade": 50, "weight": 3}''' buildRestTests.setups['stored_example_script'] = ''' # Simple script to load a field. Not really a good example, but a simple one. diff --git a/docs/reference/aggregations/metrics.asciidoc b/docs/reference/aggregations/metrics.asciidoc index ae6bee2eb7d17..96597564dac2d 100644 --- a/docs/reference/aggregations/metrics.asciidoc +++ b/docs/reference/aggregations/metrics.asciidoc @@ -13,6 +13,8 @@ bucket aggregations (some bucket aggregations enable you to sort the returned bu include::metrics/avg-aggregation.asciidoc[] +include::metrics/weighted-avg-aggregation.asciidoc[] + include::metrics/cardinality-aggregation.asciidoc[] include::metrics/extendedstats-aggregation.asciidoc[] diff --git a/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc new file mode 100644 index 0000000000000..252728a6db367 --- /dev/null +++ b/docs/reference/aggregations/metrics/weighted-avg-aggregation.asciidoc @@ -0,0 +1,202 @@ +[[search-aggregations-metrics-weight-avg-aggregation]] +=== Weighted Avg Aggregation + +A `single-value` metrics aggregation that computes the weighted average of numeric values that are extracted from the aggregated documents. +These values can be extracted either from specific numeric fields in the documents. + +When calculating a regular average, each datapoint has an equal "weight" ... it contributes equally to the final value. Weighted averages, +on the other hand, weight each datapoint differently. The amount that each datapoint contributes to the final value is extracted from the +document, or provided by a script. + +As a formula, a weighted average is the `∑(value * weight) / ∑(weight)` + +A regular average can be thought of as a weighted average where every value has an implicit weight of `1`. + +.`weighted_avg` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`value` | The configuration for the field or script that provides the values |Required | +|`weight` | The configuration for the field or script that provides the weights |Required | +|`format` | The numeric response formatter |Optional | +|`value_type` | A hint about the values for pure scripts or unmapped fields |Optional | +|=== + +The `value` and `weight` objects have per-field specific configuration: + +.`value` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`field` | The field that values should be extracted from |Required | +|`missing` | A value to use if the field is missing entirely |Optional | +|`script` | A script which provides the values for the document. This is mutually exclusive with `field` |Optional +|=== + +.`weight` Parameters +|=== +|Parameter Name |Description |Required |Default Value +|`field` | The field that weights should be extracted from |Required | +|`missing` | A weight to use if the field is missing entirely |Optional | +|`script` | A script which provides the weights for the document. This is mutually exclusive with `field` |Optional +|=== + + +==== Examples + +If our documents have a `"grade"` field that holds a 0-100 numeric score, and a `"weight"` field which holds an arbitrary numeric weight, +we can calculate the weighted average using: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade" + }, + "weight": { + "field": "weight" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + +Which yields a response like: + +[source,js] +-------------------------------------------------- +{ + ... + "aggregations": { + "weighted_grade": { + "value": 70.0 + } + } +} +-------------------------------------------------- +// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] + + +While multiple values-per-field are allowed, only one weight is allowed. If the aggregation encounters +a document that has more than one weight (e.g. the weight field is a multi-valued field) it will throw an exception. +If you have this situation, you will need to specify a `script` for the weight field, and use the script +to combine the multiple values into a single value to be used. + +This single weight will be applied independently to each value extracted from the `value` field. + +This example show how a single document with multiple values will be averaged with a single weight: + +[source,js] +-------------------------------------------------- +POST /exams/_doc?refresh +{ + "grade": [1, 2, 3], + "weight": 2 +} + +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade" + }, + "weight": { + "field": "weight" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST + +The three values (`1`, `2`, and `3`) will be included as independent values, all with the weight of `2`: + +[source,js] +-------------------------------------------------- +{ + ... + "aggregations": { + "weighted_grade": { + "value": 2.0 + } + } +} +-------------------------------------------------- +// TESTRESPONSE[s/\.\.\./"took": $body.took,"timed_out": false,"_shards": $body._shards,"hits": $body.hits,/] + +The aggregation returns `2.0` as the result, which matches what we would expect when calculating by hand: +`((1*2) + (2*2) + (3*2)) / (2+2+2) == 2` + +==== Script + +Both the value and the weight can be derived from a script, instead of a field. As a simple example, the following +will add one to the grade and weight in the document using a script: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "script": "doc.grade.value + 1" + }, + "weight": { + "script": "doc.weight.value + 1" + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + + +==== Missing values + +The `missing` parameter defines how documents that are missing a value should be treated. +The default behavior is different for `value` and `weight`: + +By default, if the `value` field is missing the document is ignored and the aggregation moves on to the next document. +If the `weight` field is missing, it is assumed to have a weight of `1` (like a normal average). + +Both of these defaults can be overridden with the `missing` parameter: + +[source,js] +-------------------------------------------------- +POST /exams/_search +{ + "size": 0, + "aggs" : { + "weighted_grade": { + "weighted_avg": { + "value": { + "field": "grade", + "missing": 2 + }, + "weight": { + "field": "weight", + "missing": 3 + } + } + } + } +} +-------------------------------------------------- +// CONSOLE +// TEST[setup:exams] + diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java index ad8ae1a681191..75f24eb78037a 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregationBuilder.java @@ -26,7 +26,7 @@ import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; @@ -38,7 +38,7 @@ import java.util.Map; public class MatrixStatsAggregationBuilder - extends MultiValuesSourceAggregationBuilder.LeafOnly { + extends ArrayValuesSourceAggregationBuilder.LeafOnly { public static final String NAME = "matrix_stats"; private MultiValueMode multiValueMode = MultiValueMode.AVG; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java index 578116d7b5eb2..aa19f62fedc4f 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java @@ -30,7 +30,7 @@ import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; import org.elasticsearch.search.aggregations.metrics.MetricsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; -import org.elasticsearch.search.aggregations.support.MultiValuesSource.NumericMultiValuesSource; +import org.elasticsearch.search.aggregations.support.ArrayValuesSource.NumericArrayValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.internal.SearchContext; @@ -43,7 +43,7 @@ **/ final class MatrixStatsAggregator extends MetricsAggregator { /** Multiple ValuesSource with field names */ - private final NumericMultiValuesSource valuesSources; + private final NumericArrayValuesSource valuesSources; /** array of descriptive stats, per shard, needed to compute the correlation */ ObjectArray stats; @@ -53,7 +53,7 @@ final class MatrixStatsAggregator extends MetricsAggregator { Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); if (valuesSources != null && !valuesSources.isEmpty()) { - this.valuesSources = new NumericMultiValuesSource(valuesSources, multiValueMode); + this.valuesSources = new NumericArrayValuesSource(valuesSources, multiValueMode); stats = context.bigArrays().newObjectArray(1); } else { this.valuesSources = null; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java index 2c3ac82a0c1a8..fb456d75bb78b 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java @@ -23,7 +23,7 @@ import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.AggregatorFactory; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSource; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.internal.SearchContext; @@ -33,7 +33,7 @@ import java.util.Map; final class MatrixStatsAggregatorFactory - extends MultiValuesSourceAggregatorFactory { + extends ArrayValuesSourceAggregatorFactory { private final MultiValueMode multiValueMode; diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java index fd13037e8f922..0f48d1855ae3e 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsParser.java @@ -21,14 +21,14 @@ import org.elasticsearch.common.ParseField; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.MultiValueMode; -import org.elasticsearch.search.aggregations.support.MultiValuesSourceParser.NumericValuesSourceParser; +import org.elasticsearch.search.aggregations.support.ArrayValuesSourceParser.NumericValuesSourceParser; import org.elasticsearch.search.aggregations.support.ValueType; import org.elasticsearch.search.aggregations.support.ValuesSourceType; import java.io.IOException; import java.util.Map; -import static org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD; +import static org.elasticsearch.search.aggregations.support.ArrayValuesSourceAggregationBuilder.MULTIVALUE_MODE_FIELD; public class MatrixStatsParser extends NumericValuesSourceParser { diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java similarity index 87% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java index 86d1836721f10..94bf68c7ae489 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSource.java @@ -28,13 +28,13 @@ /** * Class to encapsulate a set of ValuesSource objects labeled by field name */ -public abstract class MultiValuesSource { +public abstract class ArrayValuesSource { protected MultiValueMode multiValueMode; protected String[] names; protected VS[] values; - public static class NumericMultiValuesSource extends MultiValuesSource { - public NumericMultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + public static class NumericArrayValuesSource extends ArrayValuesSource { + public NumericArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); if (valuesSources != null) { this.values = valuesSources.values().toArray(new ValuesSource.Numeric[0]); @@ -51,8 +51,8 @@ public NumericDoubleValues getField(final int ordinal, LeafReaderContext ctx) th } } - public static class BytesMultiValuesSource extends MultiValuesSource { - public BytesMultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + public static class BytesArrayValuesSource extends ArrayValuesSource { + public BytesArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); this.values = valuesSources.values().toArray(new ValuesSource.Bytes[0]); } @@ -62,14 +62,14 @@ public Object getField(final int ordinal, LeafReaderContext ctx) throws IOExcept } } - public static class GeoPointValuesSource extends MultiValuesSource { + public static class GeoPointValuesSource extends ArrayValuesSource { public GeoPointValuesSource(Map valuesSources, MultiValueMode multiValueMode) { super(valuesSources, multiValueMode); this.values = valuesSources.values().toArray(new ValuesSource.GeoPoint[0]); } } - private MultiValuesSource(Map valuesSources, MultiValueMode multiValueMode) { + private ArrayValuesSource(Map valuesSources, MultiValueMode multiValueMode) { if (valuesSources != null) { this.names = valuesSources.keySet().toArray(new String[0]); } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java similarity index 91% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java index 4cf497c9c02a5..eb8152e0fe0b8 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregationBuilder.java @@ -44,13 +44,13 @@ import java.util.Map; import java.util.Objects; -public abstract class MultiValuesSourceAggregationBuilder> - extends AbstractAggregationBuilder { +public abstract class ArrayValuesSourceAggregationBuilder> + extends AbstractAggregationBuilder { public static final ParseField MULTIVALUE_MODE_FIELD = new ParseField("mode"); - public abstract static class LeafOnly> - extends MultiValuesSourceAggregationBuilder { + public abstract static class LeafOnly> + extends ArrayValuesSourceAggregationBuilder { protected LeafOnly(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { super(name, valuesSourceType, targetValueType); @@ -94,7 +94,7 @@ public AB subAggregations(Builder subFactories) { private Object missing = null; private Map missingMap = Collections.emptyMap(); - protected MultiValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { + protected ArrayValuesSourceAggregationBuilder(String name, ValuesSourceType valuesSourceType, ValueType targetValueType) { super(name); if (valuesSourceType == null) { throw new IllegalArgumentException("[valuesSourceType] must not be null: [" + name + "]"); @@ -103,7 +103,7 @@ protected MultiValuesSourceAggregationBuilder(String name, ValuesSourceType valu this.targetValueType = targetValueType; } - protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder clone, + protected ArrayValuesSourceAggregationBuilder(ArrayValuesSourceAggregationBuilder clone, Builder factoriesBuilder, Map metaData) { super(clone, factoriesBuilder, metaData); this.valuesSourceType = clone.valuesSourceType; @@ -115,7 +115,7 @@ protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilde this.missing = clone.missing; } - protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType) + protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType, ValueType targetValueType) throws IOException { super(in); assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType"; @@ -124,7 +124,7 @@ protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType v read(in); } - protected MultiValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { + protected ArrayValuesSourceAggregationBuilder(StreamInput in, ValuesSourceType valuesSourceType) throws IOException { super(in); assert serializeTargetValueType() : "Wrong read constructor called for subclass that serializes its targetValueType"; this.valuesSourceType = valuesSourceType; @@ -239,10 +239,10 @@ public Map missingMap() { } @Override - protected final MultiValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + protected final ArrayValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { Map> configs = resolveConfig(context); - MultiValuesSourceAggregatorFactory factory = innerBuild(context, configs, parent, subFactoriesBuilder); + ArrayValuesSourceAggregatorFactory factory = innerBuild(context, configs, parent, subFactoriesBuilder); return factory; } @@ -255,9 +255,10 @@ protected Map> resolveConfig(SearchContext contex return configs; } - protected abstract MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, - Map> configs, AggregatorFactory parent, - AggregatorFactories.Builder subFactoriesBuilder) throws IOException; + protected abstract ArrayValuesSourceAggregatorFactory innerBuild(SearchContext context, + Map> configs, + AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException; public ValuesSourceConfig config(SearchContext context, String field, Script script) { @@ -355,14 +356,14 @@ public final XContentBuilder internalXContent(XContentBuilder builder, Params pa @Override protected final int doHashCode() { return Objects.hash(fields, format, missing, targetValueType, valueType, valuesSourceType, - innerHashCode()); + innerHashCode()); } protected abstract int innerHashCode(); @Override protected final boolean doEquals(Object obj) { - MultiValuesSourceAggregationBuilder other = (MultiValuesSourceAggregationBuilder) obj; + ArrayValuesSourceAggregationBuilder other = (ArrayValuesSourceAggregationBuilder) obj; if (!Objects.equals(fields, other.fields)) return false; if (!Objects.equals(format, other.format)) diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java similarity index 78% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java index 7d5c56a571bbe..ce8eeecd19036 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceAggregatorFactory.java @@ -30,14 +30,15 @@ import java.util.List; import java.util.Map; -public abstract class MultiValuesSourceAggregatorFactory> - extends AggregatorFactory { +public abstract class ArrayValuesSourceAggregatorFactory> + extends AggregatorFactory { protected Map> configs; - public MultiValuesSourceAggregatorFactory(String name, Map> configs, - SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, - Map metaData) throws IOException { + public ArrayValuesSourceAggregatorFactory(String name, Map> configs, + SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { super(name, context, parent, subFactoriesBuilder, metaData); this.configs = configs; } @@ -63,6 +64,7 @@ protected abstract Aggregator createUnmapped(Aggregator parent, List metaData) throws IOException; protected abstract Aggregator doCreateInternal(Map valuesSources, Aggregator parent, boolean collectsFromSingleBucket, - List pipelineAggregators, Map metaData) throws IOException; + List pipelineAggregators, + Map metaData) throws IOException; } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java similarity index 86% rename from modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java rename to modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java index 22a90b552d920..1100884cf8ace 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParser.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/support/ArrayValuesSourceParser.java @@ -33,30 +33,30 @@ import java.util.List; import java.util.Map; -public abstract class MultiValuesSourceParser implements Aggregator.Parser { +public abstract class ArrayValuesSourceParser implements Aggregator.Parser { - public abstract static class AnyValuesSourceParser extends MultiValuesSourceParser { + public abstract static class AnyValuesSourceParser extends ArrayValuesSourceParser { protected AnyValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.ANY, null); } } - public abstract static class NumericValuesSourceParser extends MultiValuesSourceParser { + public abstract static class NumericValuesSourceParser extends ArrayValuesSourceParser { protected NumericValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.NUMERIC, ValueType.NUMERIC); } } - public abstract static class BytesValuesSourceParser extends MultiValuesSourceParser { + public abstract static class BytesValuesSourceParser extends ArrayValuesSourceParser { protected BytesValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.BYTES, ValueType.STRING); } } - public abstract static class GeoPointValuesSourceParser extends MultiValuesSourceParser { + public abstract static class GeoPointValuesSourceParser extends ArrayValuesSourceParser { protected GeoPointValuesSourceParser(boolean formattable) { super(formattable, ValuesSourceType.GEOPOINT, ValueType.GEOPOINT); @@ -67,15 +67,15 @@ protected GeoPointValuesSourceParser(boolean formattable) { private ValuesSourceType valuesSourceType = null; private ValueType targetValueType = null; - private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) { + private ArrayValuesSourceParser(boolean formattable, ValuesSourceType valuesSourceType, ValueType targetValueType) { this.valuesSourceType = valuesSourceType; this.targetValueType = targetValueType; this.formattable = formattable; } @Override - public final MultiValuesSourceAggregationBuilder parse(String aggregationName, XContentParser parser) - throws IOException { + public final ArrayValuesSourceAggregationBuilder parse(String aggregationName, XContentParser parser) + throws IOException { List fields = null; ValueType valueType = null; @@ -98,7 +98,7 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour "Multi-field aggregations do not support scripts."); } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (token == XContentParser.Token.START_OBJECT) { if (CommonFields.MISSING.match(currentFieldName, parser.getDeprecationHandler())) { @@ -113,7 +113,7 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (token == XContentParser.Token.START_ARRAY) { if (Script.SCRIPT_PARSE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -127,21 +127,21 @@ private MultiValuesSourceParser(boolean formattable, ValuesSourceType valuesSour fields.add(parser.text()); } else { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } else if (!token(aggregationName, currentFieldName, token, parser, otherOptions)) { throw new ParsingException(parser.getTokenLocation(), - "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); + "Unexpected token " + token + " [" + currentFieldName + "] in [" + aggregationName + "]."); } } - MultiValuesSourceAggregationBuilder factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType, - otherOptions); + ArrayValuesSourceAggregationBuilder factory = createFactory(aggregationName, this.valuesSourceType, this.targetValueType, + otherOptions); if (fields != null) { factory.fields(fields); } @@ -182,7 +182,7 @@ private void parseMissingAndAdd(final String aggregationName, final String curre /** * Creates a {@link ValuesSourceAggregationBuilder} from the information * gathered by the subclass. Options parsed in - * {@link MultiValuesSourceParser} itself will be added to the factory + * {@link ArrayValuesSourceParser} itself will be added to the factory * after it has been returned by this method. * * @param aggregationName @@ -197,11 +197,13 @@ private void parseMissingAndAdd(final String aggregationName, final String curre * method * @return the created factory */ - protected abstract MultiValuesSourceAggregationBuilder createFactory(String aggregationName, ValuesSourceType valuesSourceType, - ValueType targetValueType, Map otherOptions); + protected abstract ArrayValuesSourceAggregationBuilder createFactory(String aggregationName, + ValuesSourceType valuesSourceType, + ValueType targetValueType, + Map otherOptions); /** - * Allows subclasses of {@link MultiValuesSourceParser} to parse extra + * Allows subclasses of {@link ArrayValuesSourceParser} to parse extra * parameters and store them in a {@link Map} which will later be passed to * {@link #createFactory(String, ValuesSourceType, ValueType, Map)}. * diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml new file mode 100644 index 0000000000000..bccc961923394 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search.aggregation/260_weighted_avg.yml @@ -0,0 +1,71 @@ +setup: + - do: + indices.create: + index: test_1 + body: + settings: + number_of_replicas: 0 + mappings: + doc: + properties: + int_field: + type : integer + double_field: + type : double + string_field: + type: keyword + + - do: + bulk: + refresh: true + body: + - index: + _index: test_1 + _type: doc + _id: 1 + - int_field: 1 + double_field: 1.0 + - index: + _index: test_1 + _type: doc + _id: 2 + - int_field: 2 + double_field: 2.0 + - index: + _index: test_1 + _type: doc + _id: 3 + - int_field: 3 + double_field: 3.0 + - index: + _index: test_1 + _type: doc + _id: 4 + - int_field: 4 + double_field: 4.0 + +--- +"Basic test": + + - do: + search: + body: + aggs: + the_int_avg: + weighted_avg: + value: + field: "int_field" + weight: + field: "int_field" + the_double_avg: + weighted_avg: + value: + field: "double_field" + weight: + field: "double_field" + + - match: { hits.total: 4 } + - length: { hits.hits: 4 } + - match: { aggregations.the_int_avg.value: 3.0 } + - match: { aggregations.the_double_avg.value: 3.0 } + diff --git a/server/src/main/java/org/elasticsearch/search/SearchModule.java b/server/src/main/java/org/elasticsearch/search/SearchModule.java index 199d2278bf76b..87494a71c8d4d 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchModule.java +++ b/server/src/main/java/org/elasticsearch/search/SearchModule.java @@ -179,6 +179,8 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.valuecount.InternalValueCount; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.InternalWeightedAvg; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder; import org.elasticsearch.search.aggregations.pipeline.InternalSimpleValue; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.pipeline.bucketmetrics.InternalBucketMetricValue; @@ -333,6 +335,8 @@ public ParseFieldRegistry getMovingAverageModel private void registerAggregations(List plugins) { registerAggregation(new AggregationSpec(AvgAggregationBuilder.NAME, AvgAggregationBuilder::new, AvgAggregationBuilder::parse) .addResultReader(InternalAvg::new)); + registerAggregation(new AggregationSpec(WeightedAvgAggregationBuilder.NAME, WeightedAvgAggregationBuilder::new, + WeightedAvgAggregationBuilder::parse).addResultReader(InternalWeightedAvg::new)); registerAggregation(new AggregationSpec(SumAggregationBuilder.NAME, SumAggregationBuilder::new, SumAggregationBuilder::parse) .addResultReader(InternalSum::new)); registerAggregation(new AggregationSpec(MinAggregationBuilder.NAME, MinAggregationBuilder::new, MinAggregationBuilder::parse) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java index 26d8bb1a1bdf5..b4e416f4d7789 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/AggregationBuilders.java @@ -82,6 +82,7 @@ import org.elasticsearch.search.aggregations.metrics.tophits.TopHitsAggregationBuilder; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCount; import org.elasticsearch.search.aggregations.metrics.valuecount.ValueCountAggregationBuilder; +import org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder; import java.util.Map; @@ -107,6 +108,13 @@ public static AvgAggregationBuilder avg(String name) { return new AvgAggregationBuilder(name); } + /** + * Create a new {@link Avg} aggregation with the given name. + */ + public static WeightedAvgAggregationBuilder weightedAvg(String name) { + return new WeightedAvgAggregationBuilder(name); + } + /** * Create a new {@link Max} aggregation with the given name. */ diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java new file mode 100644 index 0000000000000..9ad1a1df78aec --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/InternalWeightedAvg.java @@ -0,0 +1,144 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class InternalWeightedAvg extends InternalNumericMetricsAggregation.SingleValue implements WeightedAvg { + private final double sum; + private final double weight; + + public InternalWeightedAvg(String name, double sum, double weight, DocValueFormat format, List pipelineAggregators, + Map metaData) { + super(name, pipelineAggregators, metaData); + this.sum = sum; + this.weight = weight; + this.format = format; + } + + /** + * Read from a stream. + */ + public InternalWeightedAvg(StreamInput in) throws IOException { + super(in); + format = in.readNamedWriteable(DocValueFormat.class); + sum = in.readDouble(); + weight = in.readDouble(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(format); + out.writeDouble(sum); + out.writeDouble(weight); + } + + @Override + public double value() { + return getValue(); + } + + @Override + public double getValue() { + return sum / weight; + } + + double getSum() { + return sum; + } + + double getWeight() { + return weight; + } + + DocValueFormat getFormatter() { + return format; + } + + @Override + public String getWriteableName() { + return WeightedAvgAggregationBuilder.NAME; + } + + @Override + public InternalWeightedAvg doReduce(List aggregations, ReduceContext reduceContext) { + double weight = 0; + double sum = 0; + double sumCompensation = 0; + double weightCompensation = 0; + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + for (InternalAggregation aggregation : aggregations) { + InternalWeightedAvg avg = (InternalWeightedAvg) aggregation; + // If the weight is Inf or NaN, just add it to the running tally to "convert" to + // Inf/NaN. This keeps the behavior bwc from before kahan summing + if (Double.isFinite(avg.weight) == false) { + weight += avg.weight; + } else if (Double.isFinite(weight)) { + double corrected = avg.weight - weightCompensation; + double newWeight = weight + corrected; + weightCompensation = (newWeight - weight) - corrected; + weight = newWeight; + } + // If the avg is Inf or NaN, just add it to the running tally to "convert" to + // Inf/NaN. This keeps the behavior bwc from before kahan summing + if (Double.isFinite(avg.sum) == false) { + sum += avg.sum; + } else if (Double.isFinite(sum)) { + double corrected = avg.sum - sumCompensation; + double newSum = sum + corrected; + sumCompensation = (newSum - sum) - corrected; + sum = newSum; + } + } + return new InternalWeightedAvg(getName(), sum, weight, format, pipelineAggregators(), getMetaData()); + } + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + builder.field(CommonFields.VALUE.getPreferredName(), weight != 0 ? getValue() : null); + if (weight != 0 && format != DocValueFormat.RAW) { + builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), format.format(getValue())); + } + return builder; + } + + @Override + protected int doHashCode() { + return Objects.hash(sum, weight, format.getWriteableName()); + } + + @Override + protected boolean doEquals(Object obj) { + InternalWeightedAvg other = (InternalWeightedAvg) obj; + return Objects.equals(sum, other.sum) && + Objects.equals(weight, other.weight) && + Objects.equals(format.getWriteableName(), other.format.getWriteableName()); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java new file mode 100644 index 0000000000000..dcda79ce33e92 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/ParsedWeightedAvg.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.aggregations.metrics.ParsedSingleValueNumericMetricsAggregation; + +import java.io.IOException; + +public class ParsedWeightedAvg extends ParsedSingleValueNumericMetricsAggregation implements WeightedAvg { + + @Override + public double getValue() { + return value(); + } + + @Override + public String getType() { + return WeightedAvgAggregationBuilder.NAME; + } + + @Override + protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { + // InternalWeightedAvg renders value only if the avg normalizer (count) is not 0. + // We parse back `null` as Double.POSITIVE_INFINITY so we check for that value here to get the same xContent output + boolean hasValue = value != Double.POSITIVE_INFINITY; + builder.field(CommonFields.VALUE.getPreferredName(), hasValue ? value : null); + if (hasValue && valueAsString != null) { + builder.field(CommonFields.VALUE_AS_STRING.getPreferredName(), valueAsString); + } + return builder; + } + + private static final ObjectParser PARSER + = new ObjectParser<>(ParsedWeightedAvg.class.getSimpleName(), true, ParsedWeightedAvg::new); + + static { + declareSingleValueFields(PARSER, Double.POSITIVE_INFINITY); + } + + public static ParsedWeightedAvg fromXContent(XContentParser parser, final String name) { + ParsedWeightedAvg avg = PARSER.apply(parser, null); + avg.setName(name); + return avg; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java new file mode 100644 index 0000000000000..7af48f677c1f6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvg.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregation; + +/** + * An aggregation that computes the average of the values in the current bucket. + */ +public interface WeightedAvg extends NumericMetricsAggregation.SingleValue { + + /** + * The average value. + */ + double getValue(); +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java new file mode 100644 index 0000000000000..be06f792a5e89 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregationBuilder.java @@ -0,0 +1,128 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AggregationBuilder; +import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregationBuilder; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceParseHelper; +import org.elasticsearch.search.aggregations.support.ValueType; +import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class WeightedAvgAggregationBuilder extends MultiValuesSourceAggregationBuilder.LeafOnly { + public static final String NAME = "weighted_avg"; + public static final ParseField VALUE_FIELD = new ParseField("value"); + public static final ParseField WEIGHT_FIELD = new ParseField("weight"); + + private static final ObjectParser PARSER; + static { + PARSER = new ObjectParser<>(WeightedAvgAggregationBuilder.NAME); + MultiValuesSourceParseHelper.declareCommon(PARSER, true, ValueType.NUMERIC); + MultiValuesSourceParseHelper.declareField(VALUE_FIELD.getPreferredName(), PARSER, true, false); + MultiValuesSourceParseHelper.declareField(WEIGHT_FIELD.getPreferredName(), PARSER, true, false); + } + + public static AggregationBuilder parse(String aggregationName, XContentParser parser) throws IOException { + return PARSER.parse(parser, new WeightedAvgAggregationBuilder(aggregationName), null); + } + + public WeightedAvgAggregationBuilder(String name) { + super(name, ValueType.NUMERIC); + } + + public WeightedAvgAggregationBuilder(WeightedAvgAggregationBuilder clone, Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + } + + public WeightedAvgAggregationBuilder value(MultiValuesSourceFieldConfig valueConfig) { + valueConfig = Objects.requireNonNull(valueConfig, "Configuration for field [" + VALUE_FIELD + "] cannot be null"); + field(VALUE_FIELD.getPreferredName(), valueConfig); + return this; + } + + public WeightedAvgAggregationBuilder weight(MultiValuesSourceFieldConfig weightConfig) { + weightConfig = Objects.requireNonNull(weightConfig, "Configuration for field [" + WEIGHT_FIELD + "] cannot be null"); + field(WEIGHT_FIELD.getPreferredName(), weightConfig); + return this; + } + + /** + * Read from a stream. + */ + public WeightedAvgAggregationBuilder(StreamInput in) throws IOException { + super(in, ValueType.NUMERIC); + } + + @Override + protected AggregationBuilder shallowCopy(Builder factoriesBuilder, Map metaData) { + return new WeightedAvgAggregationBuilder(this, factoriesBuilder, metaData); + } + + @Override + protected void innerWriteTo(StreamOutput out) { + // Do nothing, no extra state to write to stream + } + + @Override + protected MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, + Map> configs, + DocValueFormat format, + AggregatorFactory parent, + Builder subFactoriesBuilder) throws IOException { + return new WeightedAvgAggregatorFactory(name, configs, format, context, parent, subFactoriesBuilder, metaData); + } + + @Override + public XContentBuilder doXContentBody(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder; + } + + @Override + protected int innerHashCode() { + return 0; + } + + @Override + protected boolean innerEquals(Object obj) { + return true; + } + + @Override + public String getType() { + return NAME; + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java new file mode 100644 index 0000000000000..7a34fe6df4a68 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregator.java @@ -0,0 +1,158 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.DoubleArray; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AggregationExecutionException; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.InternalAggregation; +import org.elasticsearch.search.aggregations.LeafBucketCollector; +import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.aggregations.support.MultiValuesSource; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.VALUE_FIELD; +import static org.elasticsearch.search.aggregations.metrics.weighted_avg.WeightedAvgAggregationBuilder.WEIGHT_FIELD; + +public class WeightedAvgAggregator extends NumericMetricsAggregator.SingleValue { + + private final MultiValuesSource.NumericMultiValuesSource valuesSources; + + private DoubleArray weights; + private DoubleArray sums; + private DoubleArray sumCompensations; + private DoubleArray weightCompensations; + private DocValueFormat format; + + public WeightedAvgAggregator(String name, MultiValuesSource.NumericMultiValuesSource valuesSources, DocValueFormat format, + SearchContext context, Aggregator parent, List pipelineAggregators, + Map metaData) throws IOException { + super(name, context, parent, pipelineAggregators, metaData); + this.valuesSources = valuesSources; + this.format = format; + if (valuesSources != null) { + final BigArrays bigArrays = context.bigArrays(); + weights = bigArrays.newDoubleArray(1, true); + sums = bigArrays.newDoubleArray(1, true); + sumCompensations = bigArrays.newDoubleArray(1, true); + weightCompensations = bigArrays.newDoubleArray(1, true); + } + } + + @Override + public boolean needsScores() { + return valuesSources != null && valuesSources.needsScores(); + } + + @Override + public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, + final LeafBucketCollector sub) throws IOException { + if (valuesSources == null) { + return LeafBucketCollector.NO_OP_COLLECTOR; + } + final BigArrays bigArrays = context.bigArrays(); + final SortedNumericDoubleValues docValues = valuesSources.getField(VALUE_FIELD.getPreferredName(), ctx); + final SortedNumericDoubleValues docWeights = valuesSources.getField(WEIGHT_FIELD.getPreferredName(), ctx); + + return new LeafBucketCollectorBase(sub, docValues) { + @Override + public void collect(int doc, long bucket) throws IOException { + weights = bigArrays.grow(weights, bucket + 1); + sums = bigArrays.grow(sums, bucket + 1); + sumCompensations = bigArrays.grow(sumCompensations, bucket + 1); + weightCompensations = bigArrays.grow(weightCompensations, bucket + 1); + + if (docValues.advanceExact(doc) && docWeights.advanceExact(doc)) { + if (docWeights.docValueCount() > 1) { + throw new AggregationExecutionException("Encountered more than one weight for a " + + "single document. Use a script to combine multiple weights-per-doc into a single value."); + } + // There should always be one weight if advanceExact lands us here, either + // a real weight or a `missing` weight + assert docWeights.docValueCount() == 1; + final double weight = docWeights.nextValue(); + + final int numValues = docValues.docValueCount(); + assert numValues > 0; + + for (int i = 0; i < numValues; i++) { + kahanSum(docValues.nextValue() * weight, sums, sumCompensations, bucket); + kahanSum(weight, weights, weightCompensations, bucket); + } + } + } + }; + } + + private static void kahanSum(double value, DoubleArray values, DoubleArray compensations, long bucket) { + // Compute the sum of double values with Kahan summation algorithm which is more + // accurate than naive summation. + double sum = values.get(bucket); + double compensation = compensations.get(bucket); + + if (Double.isFinite(value) == false) { + sum += value; + } else if (Double.isFinite(sum)) { + double corrected = value - compensation; + double newSum = sum + corrected; + compensation = (newSum - sum) - corrected; + sum = newSum; + } + values.set(bucket, sum); + compensations.set(bucket, compensation); + } + + @Override + public double metric(long owningBucketOrd) { + if (valuesSources == null || owningBucketOrd >= sums.size()) { + return Double.NaN; + } + return sums.get(owningBucketOrd) / weights.get(owningBucketOrd); + } + + @Override + public InternalAggregation buildAggregation(long bucket) { + if (valuesSources == null || bucket >= sums.size()) { + return buildEmptyAggregation(); + } + return new InternalWeightedAvg(name, sums.get(bucket), weights.get(bucket), format, pipelineAggregators(), metaData()); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return new InternalWeightedAvg(name, 0.0, 0L, format, pipelineAggregators(), metaData()); + } + + @Override + public void doClose() { + Releasables.close(weights, sums, sumCompensations, weightCompensations); + } + +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java new file mode 100644 index 0000000000000..c7aab73af2867 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.aggregations.support.MultiValuesSource; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory; +import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric; +import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public class WeightedAvgAggregatorFactory extends MultiValuesSourceAggregatorFactory { + + public WeightedAvgAggregatorFactory(String name, Map> configs, + DocValueFormat format, SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { + super(name, configs, format, context, parent, subFactoriesBuilder, metaData); + } + + @Override + protected Aggregator createUnmapped(Aggregator parent, List pipelineAggregators, Map metaData) + throws IOException { + return new WeightedAvgAggregator(name, null, format, context, parent, pipelineAggregators, metaData); + } + + @Override + protected Aggregator doCreateInternal(Map> configs, DocValueFormat format, + Aggregator parent, boolean collectsFromSingleBucket, + List pipelineAggregators, + Map metaData) throws IOException { + MultiValuesSource.NumericMultiValuesSource numericMultiVS + = new MultiValuesSource.NumericMultiValuesSource(configs, context.getQueryShardContext()); + if (numericMultiVS.areValuesSourcesEmpty()) { + return createUnmapped(parent, pipelineAggregators, metaData); + } + return new WeightedAvgAggregator(name, numericMultiVS, format, context, parent, pipelineAggregators, metaData); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java new file mode 100644 index 0000000000000..9ceecd75deaf0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSource.java @@ -0,0 +1,93 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.support; + +import org.apache.lucene.index.LeafReaderContext; +import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; +import org.elasticsearch.index.query.QueryShardContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Class to encapsulate a set of ValuesSource objects labeled by field name + */ +public abstract class MultiValuesSource { + protected Map values; + + public static class NumericMultiValuesSource extends MultiValuesSource { + public NumericMultiValuesSource(Map> valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); + } + } + + public SortedNumericDoubleValues getField(String fieldName, LeafReaderContext ctx) throws IOException { + ValuesSource.Numeric value = values.get(fieldName); + if (value == null) { + throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); + } + return value.doubleValues(ctx); + } + } + + public static class BytesMultiValuesSource extends MultiValuesSource { + public BytesMultiValuesSource(Map> valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); + } + } + + public Object getField(String fieldName, LeafReaderContext ctx) throws IOException { + ValuesSource.Bytes value = values.get(fieldName); + if (value == null) { + throw new IllegalArgumentException("Could not find field name [" + fieldName + "] in multiValuesSource"); + } + return value.bytesValues(ctx); + } + } + + public static class GeoPointValuesSource extends MultiValuesSource { + public GeoPointValuesSource(Map> valuesSourceConfigs, + QueryShardContext context) throws IOException { + values = new HashMap<>(valuesSourceConfigs.size()); + for (Map.Entry> entry : valuesSourceConfigs.entrySet()) { + values.put(entry.getKey(), entry.getValue().toValuesSource(context)); + } + } + } + + public boolean needsScores() { + return values.values().stream().anyMatch(ValuesSource::needsScores); + } + + public String[] fieldNames() { + return values.keySet().toArray(new String[0]); + } + + public boolean areValuesSourcesEmpty() { + return values.values().stream().allMatch(Objects::isNull); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java new file mode 100644 index 0000000000000..fee685346ec98 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregationBuilder.java @@ -0,0 +1,268 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.AggregationInitializationException; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactories.Builder; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Similar to {@link ValuesSourceAggregationBuilder}, except it references multiple ValuesSources (e.g. so that an aggregation + * can pull values from multiple fields). + * + * A limitation of this class is that all the ValuesSource's being refereenced must be of the same type. + */ +public abstract class MultiValuesSourceAggregationBuilder> + extends AbstractAggregationBuilder { + + + public abstract static class LeafOnly> + extends MultiValuesSourceAggregationBuilder { + + protected LeafOnly(String name, ValueType targetValueType) { + super(name, targetValueType); + } + + protected LeafOnly(LeafOnly clone, Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + if (factoriesBuilder.count() > 0) { + throw new AggregationInitializationException("Aggregator [" + name + "] of type [" + + getType() + "] cannot accept sub-aggregations"); + } + } + + /** + * Read from a stream that does not serialize its targetValueType. This should be used by most subclasses. + */ + protected LeafOnly(StreamInput in, ValueType targetValueType) throws IOException { + super(in, targetValueType); + } + + @Override + public AB subAggregations(Builder subFactories) { + throw new AggregationInitializationException("Aggregator [" + name + "] of type [" + + getType() + "] cannot accept sub-aggregations"); + } + } + + + + private Map fields = new HashMap<>(); + private final ValueType targetValueType; + private ValueType valueType = null; + private String format = null; + + protected MultiValuesSourceAggregationBuilder(String name, ValueType targetValueType) { + super(name); + this.targetValueType = targetValueType; + } + + protected MultiValuesSourceAggregationBuilder(MultiValuesSourceAggregationBuilder clone, + Builder factoriesBuilder, Map metaData) { + super(clone, factoriesBuilder, metaData); + + this.fields = new HashMap<>(clone.fields); + this.targetValueType = clone.targetValueType; + this.valueType = clone.valueType; + this.format = clone.format; + } + + protected MultiValuesSourceAggregationBuilder(StreamInput in, ValueType targetValueType) + throws IOException { + super(in); + assert false == serializeTargetValueType() : "Wrong read constructor called for subclass that provides its targetValueType"; + this.targetValueType = targetValueType; + read(in); + } + + /** + * Read from a stream. + */ + @SuppressWarnings("unchecked") + private void read(StreamInput in) throws IOException { + fields = in.readMap(StreamInput::readString, MultiValuesSourceFieldConfig::new); + valueType = in.readOptionalWriteable(ValueType::readFromStream); + format = in.readOptionalString(); + } + + @Override + protected final void doWriteTo(StreamOutput out) throws IOException { + if (serializeTargetValueType()) { + out.writeOptionalWriteable(targetValueType); + } + out.writeMap(fields, StreamOutput::writeString, (o, value) -> value.writeTo(o)); + out.writeOptionalWriteable(valueType); + out.writeOptionalString(format); + innerWriteTo(out); + } + + /** + * Write subclass' state to the stream + */ + protected abstract void innerWriteTo(StreamOutput out) throws IOException; + + @SuppressWarnings("unchecked") + protected AB field(String propertyName, MultiValuesSourceFieldConfig config) { + if (config == null) { + throw new IllegalArgumentException("[config] must not be null: [" + name + "]"); + } + this.fields.put(propertyName, config); + return (AB) this; + } + + public Map fields() { + return fields; + } + + /** + * Sets the {@link ValueType} for the value produced by this aggregation + */ + @SuppressWarnings("unchecked") + public AB valueType(ValueType valueType) { + if (valueType == null) { + throw new IllegalArgumentException("[valueType] must not be null: [" + name + "]"); + } + this.valueType = valueType; + return (AB) this; + } + + /** + * Gets the {@link ValueType} for the value produced by this aggregation + */ + public ValueType valueType() { + return valueType; + } + + /** + * Sets the format to use for the output of the aggregation. + */ + @SuppressWarnings("unchecked") + public AB format(String format) { + if (format == null) { + throw new IllegalArgumentException("[format] must not be null: [" + name + "]"); + } + this.format = format; + return (AB) this; + } + + /** + * Gets the format to use for the output of the aggregation. + */ + public String format() { + return format; + } + + @Override + protected final MultiValuesSourceAggregatorFactory doBuild(SearchContext context, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException { + ValueType finalValueType = this.valueType != null ? this.valueType : targetValueType; + + Map> configs = new HashMap<>(fields.size()); + fields.forEach((key, value) -> { + ValuesSourceConfig config = ValuesSourceConfig.resolve(context.getQueryShardContext(), finalValueType, + value.getFieldName(), value.getScript(), value.getMissing(), value.getTimeZone(), format); + configs.put(key, config); + }); + DocValueFormat docValueFormat = resolveFormat(format, finalValueType); + return innerBuild(context, configs, docValueFormat, parent, subFactoriesBuilder); + } + + + private static DocValueFormat resolveFormat(@Nullable String format, @Nullable ValueType valueType) { + if (valueType == null) { + return DocValueFormat.RAW; // we can't figure it out + } + DocValueFormat valueFormat = valueType.defaultFormat; + if (valueFormat instanceof DocValueFormat.Decimal && format != null) { + valueFormat = new DocValueFormat.Decimal(format); + } + return valueFormat; + } + + protected abstract MultiValuesSourceAggregatorFactory innerBuild(SearchContext context, + Map> configs, DocValueFormat format, AggregatorFactory parent, + AggregatorFactories.Builder subFactoriesBuilder) throws IOException; + + + /** + * Should this builder serialize its targetValueType? Defaults to false. All subclasses that override this to true + * should use the three argument read constructor rather than the four argument version. + */ + protected boolean serializeTargetValueType() { + return false; + } + + @Override + public final XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fields != null) { + builder.field(CommonFields.FIELDS.getPreferredName(), fields); + } + if (format != null) { + builder.field(CommonFields.FORMAT.getPreferredName(), format); + } + if (valueType != null) { + builder.field(CommonFields.VALUE_TYPE.getPreferredName(), valueType.getPreferredName()); + } + doXContentBody(builder, params); + builder.endObject(); + return builder; + } + + protected abstract XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException; + + @Override + protected final int doHashCode() { + return Objects.hash(fields, format, targetValueType, valueType, innerHashCode()); + } + + protected abstract int innerHashCode(); + + @Override + protected final boolean doEquals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + MultiValuesSourceAggregationBuilder that = (MultiValuesSourceAggregationBuilder) other; + + return Objects.equals(this.fields, that.fields) + && Objects.equals(this.format, that.format) + && Objects.equals(this.valueType, that.valueType); + } + + protected abstract boolean innerEquals(Object obj); +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java new file mode 100644 index 0000000000000..5de8fbd7561dc --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceAggregatorFactory.java @@ -0,0 +1,64 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.search.DocValueFormat; +import org.elasticsearch.search.aggregations.Aggregator; +import org.elasticsearch.search.aggregations.AggregatorFactories; +import org.elasticsearch.search.aggregations.AggregatorFactory; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.search.internal.SearchContext; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +public abstract class MultiValuesSourceAggregatorFactory> + extends AggregatorFactory { + + protected final Map> configs; + protected final DocValueFormat format; + + public MultiValuesSourceAggregatorFactory(String name, Map> configs, + DocValueFormat format, SearchContext context, + AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, + Map metaData) throws IOException { + super(name, context, parent, subFactoriesBuilder, metaData); + this.configs = configs; + this.format = format; + } + + @Override + public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBucket, List pipelineAggregators, + Map metaData) throws IOException { + + return doCreateInternal(configs, format, parent, collectsFromSingleBucket, + pipelineAggregators, metaData); + } + + protected abstract Aggregator createUnmapped(Aggregator parent, List pipelineAggregators, + Map metaData) throws IOException; + + protected abstract Aggregator doCreateInternal(Map> configs, + DocValueFormat format, Aggregator parent, boolean collectsFromSingleBucket, + List pipelineAggregators, + Map metaData) throws IOException; + +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java new file mode 100644 index 0000000000000..56ceae69ff78e --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfig.java @@ -0,0 +1,186 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.ToXContentFragment; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.script.Script; +import org.joda.time.DateTimeZone; + +import java.io.IOException; +import java.util.function.BiFunction; + +public class MultiValuesSourceFieldConfig implements Writeable, ToXContentFragment { + private String fieldName; + private Object missing; + private Script script; + private DateTimeZone timeZone; + + private static final String NAME = "field_config"; + + public static final BiFunction> PARSER + = (scriptable, timezoneAware) -> { + + ObjectParser parser + = new ObjectParser<>(MultiValuesSourceFieldConfig.NAME, MultiValuesSourceFieldConfig.Builder::new); + + parser.declareString(MultiValuesSourceFieldConfig.Builder::setFieldName, ParseField.CommonFields.FIELD); + parser.declareField(MultiValuesSourceFieldConfig.Builder::setMissing, XContentParser::objectText, + ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); + + if (scriptable) { + parser.declareField(MultiValuesSourceFieldConfig.Builder::setScript, + (p, context) -> Script.parse(p), + Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING); + } + + if (timezoneAware) { + parser.declareField(MultiValuesSourceFieldConfig.Builder::setTimeZone, p -> { + if (p.currentToken() == XContentParser.Token.VALUE_STRING) { + return DateTimeZone.forID(p.text()); + } else { + return DateTimeZone.forOffsetHours(p.intValue()); + } + }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); + } + return parser; + }; + + private MultiValuesSourceFieldConfig(String fieldName, Object missing, Script script, DateTimeZone timeZone) { + this.fieldName = fieldName; + this.missing = missing; + this.script = script; + this.timeZone = timeZone; + } + + public MultiValuesSourceFieldConfig(StreamInput in) throws IOException { + this.fieldName = in.readString(); + this.missing = in.readGenericValue(); + this.script = in.readOptionalWriteable(Script::new); + this.timeZone = in.readOptionalTimeZone(); + } + + public Object getMissing() { + return missing; + } + + public Script getScript() { + return script; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + public String getFieldName() { + return fieldName; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeGenericValue(missing); + out.writeOptionalWriteable(script); + out.writeOptionalTimeZone(timeZone); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (missing != null) { + builder.field(ParseField.CommonFields.MISSING.getPreferredName(), missing); + } + if (script != null) { + builder.field(Script.SCRIPT_PARSE_FIELD.getPreferredName(), script); + } + if (fieldName != null) { + builder.field(ParseField.CommonFields.FIELD.getPreferredName(), fieldName); + } + if (timeZone != null) { + builder.field(ParseField.CommonFields.TIME_ZONE.getPreferredName(), timeZone); + } + return builder; + } + + public static class Builder { + private String fieldName; + private Object missing = null; + private Script script = null; + private DateTimeZone timeZone = null; + + public String getFieldName() { + return fieldName; + } + + public Builder setFieldName(String fieldName) { + this.fieldName = fieldName; + return this; + } + + public Object getMissing() { + return missing; + } + + public Builder setMissing(Object missing) { + this.missing = missing; + return this; + } + + public Script getScript() { + return script; + } + + public Builder setScript(Script script) { + this.script = script; + return this; + } + + public DateTimeZone getTimeZone() { + return timeZone; + } + + public Builder setTimeZone(DateTimeZone timeZone) { + this.timeZone = timeZone; + return this; + } + + public MultiValuesSourceFieldConfig build() { + if (Strings.isNullOrEmpty(fieldName) && script == null) { + throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() + + "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be null. " + + "Please specify one or the other."); + } + + if (Strings.isNullOrEmpty(fieldName) == false && script != null) { + throw new IllegalArgumentException("[" + ParseField.CommonFields.FIELD.getPreferredName() + + "] and [" + Script.SCRIPT_PARSE_FIELD.getPreferredName() + "] cannot both be configured. " + + "Please specify one or the other."); + } + + return new MultiValuesSourceFieldConfig(fieldName, missing, script, timeZone); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java new file mode 100644 index 0000000000000..4888495f9d8da --- /dev/null +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceParseHelper.java @@ -0,0 +1,59 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.ParsingException; +import org.elasticsearch.common.xcontent.AbstractObjectParser; +import org.elasticsearch.common.xcontent.ObjectParser; +import org.elasticsearch.common.xcontent.XContentParser; + +public final class MultiValuesSourceParseHelper { + + public static void declareCommon( + AbstractObjectParser, T> objectParser, boolean formattable, + ValueType targetValueType) { + + objectParser.declareField(MultiValuesSourceAggregationBuilder::valueType, p -> { + ValueType valueType = ValueType.resolveForScript(p.text()); + if (targetValueType != null && valueType.isNotA(targetValueType)) { + throw new ParsingException(p.getTokenLocation(), + "Aggregation [" + objectParser.getName() + "] was configured with an incompatible value type [" + + valueType + "]. It can only work on value of type [" + + targetValueType + "]"); + } + return valueType; + }, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING); + + if (formattable) { + objectParser.declareField(MultiValuesSourceAggregationBuilder::format, XContentParser::text, + ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING); + } + } + + public static void declareField(String fieldName, + AbstractObjectParser, T> objectParser, + boolean scriptable, boolean timezoneAware) { + + objectParser.declareField((o, fieldConfig) -> o.field(fieldName, fieldConfig.build()), + (p, c) -> MultiValuesSourceFieldConfig.PARSER.apply(scriptable, timezoneAware).parse(p, null), + new ParseField(fieldName), ObjectParser.ValueType.OBJECT); + } +} diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java index 318540e3e5806..7f6e76a6611a8 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValueType.java @@ -19,6 +19,7 @@ package org.elasticsearch.search.aggregations.support; +import org.elasticsearch.common.ParseField; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -95,6 +96,8 @@ public boolean isNumeric() { private final byte id; private String preferredName; + public static final ParseField VALUE_TYPE = new ParseField("value_type", "valueType"); + ValueType(byte id, String description, String preferredName, ValuesSourceType valuesSourceType, Class fieldDataType, DocValueFormat defaultFormat) { this.id = id; @@ -112,7 +115,7 @@ public String description() { public String getPreferredName() { return preferredName; } - + public ValuesSourceType getValuesSourceType() { return valuesSourceType; } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java index 365233122c43e..fc0a2f3a9fefe 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceParserHelper.java @@ -28,7 +28,6 @@ import org.joda.time.DateTimeZone; public final class ValuesSourceParserHelper { - static final ParseField TIME_ZONE = new ParseField("time_zone"); private ValuesSourceParserHelper() {} // utility class, no instantiation @@ -62,10 +61,10 @@ private static void declareFields( objectParser.declareField(ValuesSourceAggregationBuilder::field, XContentParser::text, - new ParseField("field"), ObjectParser.ValueType.STRING); + ParseField.CommonFields.FIELD, ObjectParser.ValueType.STRING); objectParser.declareField(ValuesSourceAggregationBuilder::missing, XContentParser::objectText, - new ParseField("missing"), ObjectParser.ValueType.VALUE); + ParseField.CommonFields.MISSING, ObjectParser.ValueType.VALUE); objectParser.declareField(ValuesSourceAggregationBuilder::valueType, p -> { ValueType valueType = ValueType.resolveForScript(p.text()); @@ -76,11 +75,11 @@ private static void declareFields( + targetValueType + "]"); } return valueType; - }, new ParseField("value_type", "valueType"), ObjectParser.ValueType.STRING); + }, ValueType.VALUE_TYPE, ObjectParser.ValueType.STRING); if (formattable) { objectParser.declareField(ValuesSourceAggregationBuilder::format, XContentParser::text, - new ParseField("format"), ObjectParser.ValueType.STRING); + ParseField.CommonFields.FORMAT, ObjectParser.ValueType.STRING); } if (scriptable) { @@ -96,7 +95,7 @@ private static void declareFields( } else { return DateTimeZone.forOffsetHours(p.intValue()); } - }, TIME_ZONE, ObjectParser.ValueType.LONG); + }, ParseField.CommonFields.TIME_ZONE, ObjectParser.ValueType.LONG); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java index a6b252d6903d9..387e807ba861e 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSourceType.java @@ -19,9 +19,37 @@ package org.elasticsearch.search.aggregations.support; -public enum ValuesSourceType { +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; +import java.util.Locale; + +public enum ValuesSourceType implements Writeable { ANY, NUMERIC, BYTES, GEOPOINT; + + public static final ParseField VALUE_SOURCE_TYPE = new ParseField("value_source_type"); + + public static ValuesSourceType fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + public static ValuesSourceType fromStream(StreamInput in) throws IOException { + return in.readEnum(ValuesSourceType.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + ValuesSourceType state = this; + out.writeEnum(state); + } + + public String value() { + return name().toLowerCase(Locale.ROOT); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java new file mode 100644 index 0000000000000..70b1b651723e0 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/weighted_avg/WeightedAvgAggregatorTests.java @@ -0,0 +1,428 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.metrics.weighted_avg; + +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.DocValuesFieldExistsQuery; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.NumericUtils; +import org.elasticsearch.common.CheckedConsumer; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.aggregations.AggregationExecutionException; +import org.elasticsearch.search.aggregations.AggregatorTestCase; +import org.elasticsearch.search.aggregations.support.MultiValuesSourceFieldConfig; +import org.joda.time.DateTimeZone; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Consumer; + +import static java.util.Collections.singleton; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class WeightedAvgAggregatorTests extends AggregatorTestCase { + + public void testNoDocs() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + // Intentionally not writing any docs + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testNoMatchingField() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 7))); + iw.addDocument(singleton(new SortedNumericDocValuesField("wrong_number", 3))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testSomeMatchesSortedNumericDocValuesNoWeight() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + assertEquals(4, avg.getValue(), 0); + }); + } + + public void testSomeMatchesSortedNumericDocValuesWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3))); + + }, avg -> { + // (7*2 + 2*3 + 3*3) / (2+3+3) == 3.625 + assertEquals(3.625, avg.getValue(), 0); + }); + } + + public void testSomeMatchesNumericDocValues() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new DocValuesFieldExistsQuery("value_field"), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + assertEquals(4, avg.getValue(), 0); + }); + } + + public void testQueryFiltering() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", 0, 3), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + assertEquals(2.5, avg.getValue(), 0); + }); + } + + public void testQueryFilteringWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("filter_field", 0, 3), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + double value = (2.0*3.0 + 3.0*4.0) / (3.0+4.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testQueryFiltersAll() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("value_field", 7), new SortedNumericDocValuesField("value_field", 7))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 1), new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("value_field", 3), new SortedNumericDocValuesField("value_field", 7))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testQueryFiltersAllWeights() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(IntPoint.newRangeQuery("value_field", -1, 0), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 7), new SortedNumericDocValuesField("value_field", 7), + new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 2), new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList(new IntPoint("filter_field", 3), new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + assertEquals(Double.NaN, avg.getValue(), 0); + }); + } + + public void testValueSetMissing() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setMissing(2) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("weight_field", 4))); + }, avg -> { + double value = (2.0*2.0 + 2.0*3.0 + 2.0*4.0) / (2.0+3.0+4.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetMissing() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setMissing(2) + .build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 2))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 3))); + iw.addDocument(Collections.singletonList(new SortedNumericDocValuesField("value_field", 4))); + }, avg -> { + double value = (2.0*2.0 + 3.0*2.0 + 4.0*2.0) / (2.0+2.0+2.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testWeightSetTimezone() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .setTimeZone(DateTimeZone.UTC) + .build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + fail("Should not have executed test case"); + })); + assertThat(e.getMessage(), equalTo("Field [weight_field] of type [long] does not support custom time zones")); + } + + public void testValueSetTimezone() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .setTimeZone(DateTimeZone.UTC) + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + fail("Should not have executed test case"); + })); + assertThat(e.getMessage(), equalTo("Field [value_field] of type [long] does not support custom time zones")); + } + + public void testMultiValues() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("value_field") + .build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("value_field", 3), new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("value_field", 4), new SortedNumericDocValuesField("weight_field", 1))); + iw.addDocument(Arrays.asList(new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("value_field", 5), new SortedNumericDocValuesField("weight_field", 1))); + }, avg -> { + double value = (((2.0+3.0)/2.0) + ((3.0+4.0)/2.0) + ((4.0+5.0)/2.0)) / (1.0+1.0+1.0); + assertEquals(value, avg.getValue(), 0); + }); + } + + public void testMultiWeight() throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder() + .setFieldName("weight_field") + .build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + + AggregationExecutionException e = expectThrows(AggregationExecutionException.class, + () -> testCase(new MatchAllDocsQuery(), aggregationBuilder, iw -> { + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 2), + new SortedNumericDocValuesField("weight_field", 2), new SortedNumericDocValuesField("weight_field", 3))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 3), + new SortedNumericDocValuesField("weight_field", 3), new SortedNumericDocValuesField("weight_field", 4))); + iw.addDocument(Arrays.asList( + new SortedNumericDocValuesField("value_field", 4), + new SortedNumericDocValuesField("weight_field", 4), new SortedNumericDocValuesField("weight_field", 5))); + }, avg -> { + fail("Should have thrown exception"); + })); + assertThat(e.getMessage(), containsString("Encountered more than one weight for a single document. " + + "Use a script to combine multiple weights-per-doc into a single value.")); + } + + + public void testSummationAccuracy() throws IOException { + // Summing up a normal array and expect an accurate value + double[] values = new double[]{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.9, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7}; + verifyAvgOfDoubles(values, 0.9, 0d); + + // Summing up an array which contains NaN and infinities and expect a result same as naive summation + int n = randomIntBetween(5, 10); + values = new double[n]; + double sum = 0; + for (int i = 0; i < n; i++) { + values[i] = frequently() + ? randomFrom(Double.NaN, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY) + : randomDoubleBetween(Double.MIN_VALUE, Double.MAX_VALUE, true); + sum += values[i]; + } + verifyAvgOfDoubles(values, sum / n, 1e-10); + + // Summing up some big double values and expect infinity result + n = randomIntBetween(5, 10); + double[] largeValues = new double[n]; + for (int i = 0; i < n; i++) { + largeValues[i] = Double.MAX_VALUE; + } + verifyAvgOfDoubles(largeValues, Double.POSITIVE_INFINITY, 0d); + + for (int i = 0; i < n; i++) { + largeValues[i] = -Double.MAX_VALUE; + } + verifyAvgOfDoubles(largeValues, Double.NEGATIVE_INFINITY, 0d); + } + + private void verifyAvgOfDoubles(double[] values, double expected, double delta) throws IOException { + MultiValuesSourceFieldConfig valueConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("value_field").build(); + MultiValuesSourceFieldConfig weightConfig = new MultiValuesSourceFieldConfig.Builder().setFieldName("weight_field").build(); + WeightedAvgAggregationBuilder aggregationBuilder = new WeightedAvgAggregationBuilder("_name") + .value(valueConfig) + .weight(weightConfig); + testCase(new MatchAllDocsQuery(), aggregationBuilder, + iw -> { + for (double value : values) { + iw.addDocument(Arrays.asList(new NumericDocValuesField("value_field", NumericUtils.doubleToSortableLong(value)), + new SortedNumericDocValuesField("weight_field", NumericUtils.doubleToSortableLong(1.0)))); + } + }, + avg -> assertEquals(expected, avg.getValue(), delta), + NumberFieldMapper.NumberType.DOUBLE + ); + } + + private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder, + CheckedConsumer buildIndex, + Consumer verify) throws IOException { + testCase(query, aggregationBuilder, buildIndex, verify, NumberFieldMapper.NumberType.LONG); + } + + private void testCase(Query query, WeightedAvgAggregationBuilder aggregationBuilder, + CheckedConsumer buildIndex, + Consumer verify, + NumberFieldMapper.NumberType fieldNumberType) throws IOException { + + Directory directory = newDirectory(); + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory); + buildIndex.accept(indexWriter); + indexWriter.close(); + IndexReader indexReader = DirectoryReader.open(directory); + IndexSearcher indexSearcher = newSearcher(indexReader, true, true); + + try { + MappedFieldType fieldType = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType.setName("value_field"); + fieldType.setHasDocValues(true); + + MappedFieldType fieldType2 = new NumberFieldMapper.NumberFieldType(fieldNumberType); + fieldType2.setName("weight_field"); + fieldType2.setHasDocValues(true); + + WeightedAvgAggregator aggregator = createAggregator(aggregationBuilder, indexSearcher, fieldType, fieldType2); + aggregator.preCollection(); + indexSearcher.search(query, aggregator); + aggregator.postCollection(); + verify.accept((InternalWeightedAvg) aggregator.buildAggregation(0L)); + } finally { + indexReader.close(); + directory.close(); + } + } +} diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java new file mode 100644 index 0000000000000..ac1c07a40490e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/aggregations/support/MultiValuesSourceFieldConfigTests.java @@ -0,0 +1,38 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.search.aggregations.support; + +import org.elasticsearch.script.Script; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; + +public class MultiValuesSourceFieldConfigTests extends ESTestCase { + public void testMissingFieldScript() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> new MultiValuesSourceFieldConfig.Builder().build()); + assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be null. Please specify one or the other.")); + } + + public void testBothFieldScript() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, + () -> new MultiValuesSourceFieldConfig.Builder().setFieldName("foo").setScript(new Script("foo")).build()); + assertThat(e.getMessage(), equalTo("[field] and [script] cannot both be configured. Please specify one or the other.")); + } +} diff --git a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java index dc2fac776c6c0..538babf4fbced 100644 --- a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java +++ b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/RollupRequestTranslator.java @@ -208,7 +208,7 @@ public static List translateAggregation(AggregationBuilder s private static List translateDateHistogram(DateHistogramAggregationBuilder source, List filterConditions, NamedWriteableRegistry registry) { - + return translateVSAggBuilder(source, filterConditions, registry, () -> { DateHistogramAggregationBuilder rolledDateHisto = new DateHistogramAggregationBuilder(source.getName());