Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass through script params in scripted metric agg #29154

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -198,20 +199,34 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega
Builder subfactoriesBuilder) throws IOException {

QueryShardContext queryShardContext = context.getQueryShardContext();

// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
// access to them for the scripts it's given precompiled.

ExecutableScript.Factory executableInitScript;
Map<String, Object> initScriptParams;
if (initScript != null) {
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
initScriptParams = initScript.getParams();
} else {
executableInitScript = p -> null;
initScriptParams = Collections.emptyMap();
}

SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
Map<String, Object> mapScriptParams = mapScript.getParams();

ExecutableScript.Factory executableCombineScript;
Map<String, Object> combineScriptParams;
if (combineScript != null) {
executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
combineScriptParams = combineScript.getParams();
} else {
executableCombineScript = p -> null;
combineScriptParams = Collections.emptyMap();
}
return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript,
return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
executableCombineScript, combineScriptParams, reduceScript,
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,35 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {

private final SearchScript.Factory mapScript;
private final Map<String, Object> mapScriptParams;
private final ExecutableScript.Factory combineScript;
private final Map<String, Object> combineScriptParams;
private final Script reduceScript;
private final Map<String, Object> params;
private final Map<String, Object> aggParams;
private final SearchLookup lookup;
private final ExecutableScript.Factory initScript;
private final Map<String, Object> initScriptParams;

public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript,
ExecutableScript.Factory combineScript, Script reduceScript, Map<String, Object> params,
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to combine together the compiled script factories and the params into some kind of a tuple object rather than just adding more constructor parameters? I opted not to because they're different types and AFAIK passing them around paired with parameters isn't a common case.

ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
Script reduceScript, Map<String, Object> aggParams,
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactories, metaData);
this.mapScript = mapScript;
this.mapScriptParams = mapScriptParams;
this.initScript = initScript;
this.initScriptParams = initScriptParams;
this.combineScript = combineScript;
this.combineScriptParams = combineScriptParams;
this.reduceScript = reduceScript;
this.lookup = lookup;
this.params = params;
this.aggParams = aggParams;
}

@Override
Expand All @@ -65,26 +72,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
if (collectsFromSingleBucket == false) {
return asMultiBucketAggregator(this, context, parent);
}
Map<String, Object> params = this.params;
if (params != null) {
params = deepCopyParams(params, context);
Map<String, Object> aggParams = this.aggParams;
if (aggParams != null) {
aggParams = deepCopyParams(aggParams, context);
} else {
params = new HashMap<>();
aggParams = new HashMap<>();
}
if (params.containsKey("_agg") == false) {
params.put("_agg", new HashMap<String, Object>());
if (aggParams.containsKey("_agg") == false) {
aggParams.put("_agg", new HashMap<String, Object>());
}

final ExecutableScript initScript = this.initScript.newInstance(params);
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(params);
final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I considered merging the parameter lists at construction time (which would also move conflict detection to construction time) but hit some test failures in that case. I did not rigorously track down the problem because I wasn't sure that it was a good idea to evaluate the params list before the last minute anyway, in case they could be modified in any way between AggregatorFactory construction and calling createInternal(). This is a smaller behavior change.

final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));

final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) {
initScript.run();
}
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, params, context, parent,
combineScript, reduceScript, aggParams, context, parent,
pipelineAggregators, metaData);
}

Expand Down Expand Up @@ -128,5 +135,18 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
return clone;
}

private static Map<String, Object> mergeParams(Map<String, Object> agg, Map<String, Object> script) {
// Start with script params
Map<String, Object> combined = new HashMap<>(script);

// Add in agg params, throwing an exception if any conflicts are detected
for (Map.Entry<String, Object> aggEntry : agg.entrySet()) {
if (combined.putIfAbsent(aggEntry.getKey(), aggEntry.getValue()) != null) {
throw new IllegalArgumentException("Parameter name \"" + aggEntry.getKey() +
"\" used in both aggregation and script parameters");
}
}

return combined;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.elasticsearch.search.aggregations.metrics;

import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchPhaseExecutionException;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -62,6 +64,7 @@
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
import static org.hamcrest.Matchers.allOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
Expand Down Expand Up @@ -322,11 +325,11 @@ public void testMap() {
assertThat(numShardsRun, greaterThan(0));
}

public void testMapWithParams() {
public void testExplicitAggParam() {
Copy link
Contributor Author

@rationull rationull Mar 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testMapWithParams was previously actually working because the params were on the aggregation, not because the params were on the map script. This test may be extraneous now but I renamed it instead of deleting it since it does provide specific coverage for an explicit _agg param, even if that may be kind of a weird case. I wouldn't presume to delete it without an opinion from a dev more familiar with this code..

Map<String, Object> params = new HashMap<>();
params.put("_agg", new ArrayList<>());

Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params);
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", Collections.emptyMap());

SearchResponse response = client().prepareSearch("idx")
.setQuery(matchAllQuery())
Expand Down Expand Up @@ -361,17 +364,17 @@ public void testMapWithParams() {
}

public void testMapWithParamsAndImplicitAggMap() {
Map<String, Object> params = new HashMap<>();
// don't put any _agg map in params
params.put("param1", "12");
params.put("param2", 1);
// Split the params up between the script and the aggregation.
// Don't put any _agg map in params.
Map<String, Object> scriptParams = Collections.singletonMap("param1", "12");
Map<String, Object> aggregationParams = Collections.singletonMap("param2", 1);

// The _agg hashmap will be available even if not declared in the params map
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params);
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams);

SearchResponse response = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript))
.addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript))
.get();
assertSearchResponse(response);
assertThat(response.getHits().getTotalHits(), equalTo(numDocs));
Expand Down Expand Up @@ -1001,4 +1004,16 @@ public void testDontCacheScripts() throws Exception {
assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache()
.getMissCount(), equalTo(0L));
}

public void testConflictingAggAndScriptParams() {
Map<String, Object> params = Collections.singletonMap("param1", "12");
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params);

SearchRequestBuilder builder = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript));

SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get);
assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,16 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Collections.emptyMap());
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
Collections.emptyMap());
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams",
Collections.singletonMap("initialValue", 24));
private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
Collections.singletonMap("itemValue", 12));
private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
Collections.singletonMap("divisor", 4));
private static final String CONFLICTING_PARAM_NAME = "initialValue";

private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

@BeforeClass
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -99,6 +107,26 @@ public static void initMockScripts() {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
});

SCRIPTS.put("initScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer initialValue = (Integer)params.get("initialValue");
ArrayList<Integer> collector = new ArrayList();
collector.add(initialValue);
agg.put("collector", collector);
return agg;
});
SCRIPTS.put("mapScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer itemValue = (Integer) params.get("itemValue");
((List<Integer>) agg.get("collector")).add(itemValue);
return agg;
});
SCRIPTS.put("combineScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
int divisor = ((Integer) params.get("divisor"));
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
});
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -187,6 +215,48 @@ public void testScriptedMetricWithCombineAccessesScores() throws IOException {
}
}

public void testScriptParamsPassedThrough() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < 100; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}

try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS);
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);

// The result value depends on the script params.
assertEquals(306, scriptedMetric.aggregation());
}
}
}

public void testConflictingAggAndScriptParams() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < 100; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}

try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
Map<String, Object> aggParams = Collections.singletonMap(CONFLICTING_PARAM_NAME, "blah");
aggregationBuilder.params(aggParams).initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).
combineScript(COMBINE_SCRIPT_PARAMS);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Parameter name \"" + CONFLICTING_PARAM_NAME + "\" used in both aggregation and script parameters",
ex.getMessage());
}
}
}

/**
* We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
Expand Down