From 0028563aace400e1af380cbf44d30bfb6aee4171 Mon Sep 17 00:00:00 2001 From: rationull Date: Tue, 3 Apr 2018 01:57:49 -0700 Subject: [PATCH] Pass through script params in scripted metric agg (#29154) * Pass script level params into scripted metric aggs (#28819) Now params that are passed at the script level and at the aggregation level are merged and can both be used in the aggregation scripts. If there are any conflicts, aggregation level params will win. This may be followed by another change detecting that case and throwing an exception to disallow such conflicts. * Disallow duplicate parameter names between scripted agg and script (#28819) If a scripted metric aggregation has aggregation params and script params which have the same name, throw an IllegalArgumentException when merging the parameter lists. --- .../ScriptedMetricAggregationBuilder.java | 19 ++++- .../ScriptedMetricAggregatorFactory.java | 50 +++++++++---- .../metrics/ScriptedMetricIT.java | 31 +++++--- .../ScriptedMetricAggregatorTests.java | 72 ++++++++++++++++++- 4 files changed, 146 insertions(+), 26 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java index 69ac175c419c8..c11c68f9b2524 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java @@ -37,6 +37,7 @@ import org.elasticsearch.search.internal.SearchContext; import java.io.IOException; +import java.util.Collections; import java.util.Map; import java.util.Objects; @@ -198,20 +199,34 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega Builder subfactoriesBuilder) throws IOException { QueryShardContext queryShardContext = context.getQueryShardContext(); + + // Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have + // access to them for the scripts it's given precompiled. + ExecutableScript.Factory executableInitScript; + Map initScriptParams; if (initScript != null) { executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT); + initScriptParams = initScript.getParams(); } else { executableInitScript = p -> null; + initScriptParams = Collections.emptyMap(); } + SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT); + Map mapScriptParams = mapScript.getParams(); + ExecutableScript.Factory executableCombineScript; + Map combineScriptParams; if (combineScript != null) { - executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT); + combineScriptParams = combineScript.getParams(); } else { executableCombineScript = p -> null; + combineScriptParams = Collections.emptyMap(); } - return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript, + return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams, + executableCombineScript, combineScriptParams, reduceScript, params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData); } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java index aa7de3e1ab6e1..0bc6a614e541f 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java @@ -35,28 +35,35 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; public class ScriptedMetricAggregatorFactory extends AggregatorFactory { private final SearchScript.Factory mapScript; + private final Map mapScriptParams; private final ExecutableScript.Factory combineScript; + private final Map combineScriptParams; private final Script reduceScript; - private final Map params; + private final Map aggParams; private final SearchLookup lookup; private final ExecutableScript.Factory initScript; + private final Map initScriptParams; - public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript, - ExecutableScript.Factory combineScript, Script reduceScript, Map params, + public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map mapScriptParams, + ExecutableScript.Factory initScript, Map initScriptParams, + ExecutableScript.Factory combineScript, Map combineScriptParams, + Script reduceScript, Map aggParams, SearchLookup lookup, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactories, Map metaData) throws IOException { super(name, context, parent, subFactories, metaData); this.mapScript = mapScript; + this.mapScriptParams = mapScriptParams; this.initScript = initScript; + this.initScriptParams = initScriptParams; this.combineScript = combineScript; + this.combineScriptParams = combineScriptParams; this.reduceScript = reduceScript; this.lookup = lookup; - this.params = params; + this.aggParams = aggParams; } @Override @@ -65,26 +72,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu if (collectsFromSingleBucket == false) { return asMultiBucketAggregator(this, context, parent); } - Map params = this.params; - if (params != null) { - params = deepCopyParams(params, context); + Map aggParams = this.aggParams; + if (aggParams != null) { + aggParams = deepCopyParams(aggParams, context); } else { - params = new HashMap<>(); + aggParams = new HashMap<>(); } - if (params.containsKey("_agg") == false) { - params.put("_agg", new HashMap()); + if (aggParams.containsKey("_agg") == false) { + aggParams.put("_agg", new HashMap()); } - final ExecutableScript initScript = this.initScript.newInstance(params); - final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup); - final ExecutableScript combineScript = this.combineScript.newInstance(params); + final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams)); + final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup); + final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams)); final Script reduceScript = deepCopyScript(this.reduceScript, context); if (initScript != null) { initScript.run(); } return new ScriptedMetricAggregator(name, mapScript, - combineScript, reduceScript, params, context, parent, + combineScript, reduceScript, aggParams, context, parent, pipelineAggregators, metaData); } @@ -128,5 +135,18 @@ private static T deepCopyParams(T original, SearchContext context) { return clone; } + private static Map mergeParams(Map agg, Map script) { + // Start with script params + Map combined = new HashMap<>(script); + // Add in agg params, throwing an exception if any conflicts are detected + for (Map.Entry aggEntry : agg.entrySet()) { + if (combined.putIfAbsent(aggEntry.getKey(), aggEntry.getValue()) != null) { + throw new IllegalArgumentException("Parameter name \"" + aggEntry.getKey() + + "\" used in both aggregation and script parameters"); + } + } + + return combined; + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java index 24d94d5a4643c..9db5b237a858c 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java @@ -20,6 +20,8 @@ package org.elasticsearch.search.aggregations.metrics; import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.search.SearchPhaseExecutionException; +import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; @@ -62,6 +64,7 @@ import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse; import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; @@ -322,11 +325,11 @@ public void testMap() { assertThat(numShardsRun, greaterThan(0)); } - public void testMapWithParams() { + public void testExplicitAggParam() { Map params = new HashMap<>(); params.put("_agg", new ArrayList<>()); - Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", Collections.emptyMap()); SearchResponse response = client().prepareSearch("idx") .setQuery(matchAllQuery()) @@ -361,17 +364,17 @@ public void testMapWithParams() { } public void testMapWithParamsAndImplicitAggMap() { - Map params = new HashMap<>(); - // don't put any _agg map in params - params.put("param1", "12"); - params.put("param2", 1); + // Split the params up between the script and the aggregation. + // Don't put any _agg map in params. + Map scriptParams = Collections.singletonMap("param1", "12"); + Map aggregationParams = Collections.singletonMap("param2", 1); // The _agg hashmap will be available even if not declared in the params map - Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams); SearchResponse response = client().prepareSearch("idx") .setQuery(matchAllQuery()) - .addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript)) + .addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript)) .get(); assertSearchResponse(response); assertThat(response.getHits().getTotalHits(), equalTo(numDocs)); @@ -1001,4 +1004,16 @@ public void testDontCacheScripts() throws Exception { assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache() .getMissCount(), equalTo(0L)); } + + public void testConflictingAggAndScriptParams() { + Map params = Collections.singletonMap("param1", "12"); + Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params); + + SearchRequestBuilder builder = client().prepareSearch("idx") + .setQuery(matchAllQuery()) + .addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript)); + + SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get); + assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters")); + } } diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java index db2feafe6c4a3..0989b1ce6a3fa 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java @@ -64,8 +64,16 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase { Collections.emptyMap()); private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore", Collections.emptyMap()); - private static final Map, Object>> SCRIPTS = new HashMap<>(); + private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams", + Collections.singletonMap("initialValue", 24)); + private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams", + Collections.singletonMap("itemValue", 12)); + private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams", + Collections.singletonMap("divisor", 4)); + private static final String CONFLICTING_PARAM_NAME = "initialValue"; + + private static final Map, Object>> SCRIPTS = new HashMap<>(); @BeforeClass @SuppressWarnings("unchecked") @@ -99,6 +107,26 @@ public static void initMockScripts() { Map agg = (Map) params.get("_agg"); return ((List) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum(); }); + + SCRIPTS.put("initScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + Integer initialValue = (Integer)params.get("initialValue"); + ArrayList collector = new ArrayList(); + collector.add(initialValue); + agg.put("collector", collector); + return agg; + }); + SCRIPTS.put("mapScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + Integer itemValue = (Integer) params.get("itemValue"); + ((List) agg.get("collector")).add(itemValue); + return agg; + }); + SCRIPTS.put("combineScriptParams", params -> { + Map agg = (Map) params.get("_agg"); + int divisor = ((Integer) params.get("divisor")); + return ((List) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum(); + }); } @SuppressWarnings("unchecked") @@ -187,6 +215,48 @@ public void testScriptedMetricWithCombineAccessesScores() throws IOException { } } + public void testScriptParamsPassedThrough() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < 100; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS); + ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder); + + // The result value depends on the script params. + assertEquals(306, scriptedMetric.aggregation()); + } + } + } + + public void testConflictingAggAndScriptParams() throws IOException { + try (Directory directory = newDirectory()) { + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + for (int i = 0; i < 100; i++) { + indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i))); + } + } + + try (IndexReader indexReader = DirectoryReader.open(directory)) { + ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME); + Map aggParams = Collections.singletonMap(CONFLICTING_PARAM_NAME, "blah"); + aggregationBuilder.params(aggParams).initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS). + combineScript(COMBINE_SCRIPT_PARAMS); + + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> + search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder) + ); + assertEquals("Parameter name \"" + CONFLICTING_PARAM_NAME + "\" used in both aggregation and script parameters", + ex.getMessage()); + } + } + } + /** * We cannot use Mockito for mocking QueryShardContext in this case because * script-related methods (e.g. QueryShardContext#getLazyExecutableScript)