Skip to content

Commit

Permalink
Pass script level params into scripted metric aggs (elastic#28819)
Browse files Browse the repository at this point in the history
Now params that are passed at the script level and at the aggregation level
are merged and can both be used in the aggregation scripts. If there are
any conflicts, aggregation level params will win. This may be followed
by another change detecting that case and throwing an exception to
disallow such conflicts.
  • Loading branch information
rationull committed Mar 19, 2018
1 parent 0abf51a commit ccf2417
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.elasticsearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -198,20 +199,34 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega
Builder subfactoriesBuilder) throws IOException {

QueryShardContext queryShardContext = context.getQueryShardContext();

// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
// access to them for the scripts it's given precompiled.

ExecutableScript.Factory executableInitScript;
Map<String, Object> initScriptParams;
if (initScript != null) {
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
initScriptParams = initScript.getParams();
} else {
executableInitScript = p -> null;
initScriptParams = Collections.emptyMap();
}

SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
Map<String, Object> mapScriptParams = mapScript.getParams();

ExecutableScript.Factory executableCombineScript;
Map<String, Object> combineScriptParams;
if (combineScript != null) {
executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
combineScriptParams = combineScript.getParams();
} else {
executableCombineScript = p -> null;
combineScriptParams = Collections.emptyMap();
}
return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript,
return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
executableCombineScript, combineScriptParams, reduceScript,
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,35 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {

private final SearchScript.Factory mapScript;
private final Map<String, Object> mapScriptParams;
private final ExecutableScript.Factory combineScript;
private final Map<String, Object> combineScriptParams;
private final Script reduceScript;
private final Map<String, Object> params;
private final Map<String, Object> aggParams;
private final SearchLookup lookup;
private final ExecutableScript.Factory initScript;
private final Map<String, Object> initScriptParams;

public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript,
ExecutableScript.Factory combineScript, Script reduceScript, Map<String, Object> params,
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
Script reduceScript, Map<String, Object> aggParams,
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactories, metaData);
this.mapScript = mapScript;
this.mapScriptParams = mapScriptParams;
this.initScript = initScript;
this.initScriptParams = initScriptParams;
this.combineScript = combineScript;
this.combineScriptParams = combineScriptParams;
this.reduceScript = reduceScript;
this.lookup = lookup;
this.params = params;
this.aggParams = aggParams;
}

@Override
Expand All @@ -65,26 +72,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
if (collectsFromSingleBucket == false) {
return asMultiBucketAggregator(this, context, parent);
}
Map<String, Object> params = this.params;
if (params != null) {
params = deepCopyParams(params, context);
Map<String, Object> aggParams = this.aggParams;
if (aggParams != null) {
aggParams = deepCopyParams(aggParams, context);
} else {
params = new HashMap<>();
aggParams = new HashMap<>();
}
if (params.containsKey("_agg") == false) {
params.put("_agg", new HashMap<String, Object>());
if (aggParams.containsKey("_agg") == false) {
aggParams.put("_agg", new HashMap<String, Object>());
}

final ExecutableScript initScript = this.initScript.newInstance(params);
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(params);
final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));

final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) {
initScript.run();
}
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, params, context, parent,
combineScript, reduceScript, aggParams, context, parent,
pipelineAggregators, metaData);
}

Expand Down Expand Up @@ -128,5 +135,16 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
return clone;
}

private static Map<String, Object> mergeParams(Map<String, Object> agg, Map<String, Object> script) {
// TODO Should we throw an exception when param names conflict between aggregation and script? Need to add test coverage
// for error or override behavior depending on the decision. Should this check be added at call time or at
// construction?

// Aggregation level commands need to win in case of conflict so that params can keep the same identity and
// content across all the scripts that are run in the aggregation.
Map<String, Object> combined = new HashMap<>();
combined.putAll(script);
combined.putAll(agg);
return combined;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,17 @@ public void testMapWithParams() {
}

public void testMapWithParamsAndImplicitAggMap() {
Map<String, Object> params = new HashMap<>();
// don't put any _agg map in params
params.put("param1", "12");
params.put("param2", 1);
// Split the params up between the script and the aggregation.
// Don't put any _agg map in params.
Map<String, Object> scriptParams = Collections.singletonMap("param1", "12");
Map<String, Object> aggregationParams = Collections.singletonMap("param2", 1);

// The _agg hashmap will be available even if not declared in the params map
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params);
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams);

SearchResponse response = client().prepareSearch("idx")
.setQuery(matchAllQuery())
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript))
.addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript))
.get();
assertSearchResponse(response);
assertThat(response.getHits().getTotalHits(), equalTo(numDocs));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,15 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Collections.emptyMap());
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
Collections.emptyMap());
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams",
Collections.singletonMap("initialValue", 24));
private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
Collections.singletonMap("itemValue", 12));
private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
Collections.singletonMap("divisor", 4));

private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

@BeforeClass
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -99,6 +106,26 @@ public static void initMockScripts() {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
});

SCRIPTS.put("initScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer initialValue = (Integer)params.get("initialValue");
ArrayList<Integer> collector = new ArrayList();
collector.add(initialValue);
agg.put("collector", collector);
return agg;
});
SCRIPTS.put("mapScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
Integer itemValue = (Integer) params.get("itemValue");
((List<Integer>) agg.get("collector")).add(itemValue);
return agg;
});
SCRIPTS.put("combineScriptParams", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
int divisor = ((Integer) params.get("divisor"));
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
});
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -187,6 +214,25 @@ public void testScriptedMetricWithCombineAccessesScores() throws IOException {
}
}

public void testScriptParamsPassedThrough() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < 100; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}

try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS);
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);

// The result value depends on the script params.
assertEquals(306, scriptedMetric.aggregation());
}
}
}

/**
* We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
Expand Down

0 comments on commit ccf2417

Please sign in to comment.