Skip to content

Commit

Permalink
Call ensureNoSelfReferences() on _agg state variable after scripted m…
Browse files Browse the repository at this point in the history
…etric agg script executions (#31044)

Previously this was called for the combine script only. This change checks for self references for
init, map, and reduce scripts as well, and adds unit test coverage for the init, map, and combine cases.
  • Loading branch information
rationull authored and s1monw committed Jun 11, 2018
1 parent bd5c1a4 commit 85c26d6
Show file tree
Hide file tree
Showing 12 changed files with 118 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public Iterator<Object> iterator() {

@Override
public String stringify(Object object) {
CollectionUtils.ensureNoSelfReferences(object);
CollectionUtils.ensureNoSelfReferences(object, "CustomReflectionObjectHandler stringify");
return super.stringify(object);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.IdentityHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.RandomAccess;
Expand All @@ -40,6 +41,7 @@
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.IntroSorter;
import org.elasticsearch.common.Strings;

/** Collections-related utility methods. */
public class CollectionUtils {
Expand Down Expand Up @@ -225,10 +227,17 @@ public static int[] toArray(Collection<Integer> ints) {
return ints.stream().mapToInt(s -> s).toArray();
}

public static void ensureNoSelfReferences(Object value) {
/**
* Deeply inspects a Map, Iterable, or Object array looking for references back to itself.
* @throws IllegalArgumentException if a self-reference is found
* @param value The object to evaluate looking for self references
* @param messageHint A string to be included in the exception message if the call fails, to provide
* more context to the handler of the exception
*/
public static void ensureNoSelfReferences(Object value, String messageHint) {
Iterable<?> it = convert(value);
if (it != null) {
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()));
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()), messageHint);
}
}

Expand All @@ -247,13 +256,15 @@ private static Iterable<?> convert(Object value) {
}
}

private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors) {
private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors,
String messageHint) {
if (value != null) {
if (ancestors.add(originalReference) == false) {
throw new IllegalArgumentException("Iterable object is self-referencing itself");
String suffix = Strings.isNullOrEmpty(messageHint) ? "" : String.format(Locale.ROOT, " (%s)", messageHint);
throw new IllegalArgumentException("Iterable object is self-referencing itself" + suffix);
}
for (Object o : value) {
ensureNoSelfReferences(convert(o), o, ancestors);
ensureNoSelfReferences(convert(o), o, ancestors, messageHint);
}
ancestors.remove(originalReference);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.Script;
Expand Down Expand Up @@ -97,7 +98,11 @@ public InternalAggregation doReduce(List<InternalAggregation> aggregations, Redu
ExecutableScript.Factory factory = reduceContext.scriptService().compile(
firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT);
ExecutableScript script = factory.newInstance(vars);
aggregation = Collections.singletonList(script.run());

Object scriptResult = script.run();
CollectionUtils.ensureNoSelfReferences(scriptResult, "reduce script");

aggregation = Collections.singletonList(scriptResult);
} else if (reduceContext.isFinalReduce()) {
aggregation = Collections.singletonList(aggregationObjects);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public void collect(int doc, long bucket) throws IOException {
assert bucket == 0 : bucket;
leafMapScript.setDocument(doc);
leafMapScript.run();
CollectionUtils.ensureNoSelfReferences(params, "Scripted metric aggs map script");
}
};
}
Expand All @@ -78,7 +79,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object aggregation;
if (combineScript != null) {
aggregation = combineScript.run();
CollectionUtils.ensureNoSelfReferences(aggregation);
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
} else {
aggregation = params.get("_agg");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search.aggregations.metrics.scripted;

import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
Expand Down Expand Up @@ -89,6 +90,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) {
initScript.run();
CollectionUtils.ensureNoSelfReferences(aggParams.get("_agg"), "Scripted metric aggs init script");
}
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, aggParams, context, parent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ public boolean advanceExact(int doc) throws IOException {
final BytesRef value = bytesValues.nextValue();
script.setNextAggregationValue(value.utf8ToString());
Object run = script.run();
CollectionUtils.ensureNoSelfReferences(run);
CollectionUtils.ensureNoSelfReferences(run, "ValuesSource.BytesValues script");
values[i].copyChars(run.toString());
}
sort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private void set(int i, Object o) {
if (o == null) {
values[i].clear();
} else {
CollectionUtils.ensureNoSelfReferences(o);
CollectionUtils.ensureNoSelfReferences(o, "ScriptBytesValues value");
values[i].copyChars(o.toString());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOExcept
final Object value;
try {
value = leafScripts[i].run();
CollectionUtils.ensureNoSelfReferences(value);
CollectionUtils.ensureNoSelfReferences(value, "ScriptFieldsFetchSubPhase leaf script " + i);
} catch (RuntimeException e) {
if (scriptFields.get(i).ignoreException()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ public boolean advanceExact(int doc) throws IOException {
@Override
public BytesRef binaryValue() {
final Object run = leafScript.run();
CollectionUtils.ensureNoSelfReferences(run);
CollectionUtils.ensureNoSelfReferences(run, "ScriptSortBuilder leaf script");
spare.copyChars(run.toString());
return spare.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,14 @@ public void testPerfectPartition() {
}

public void testEnsureNoSelfReferences() {
CollectionUtils.ensureNoSelfReferences(emptyMap());
CollectionUtils.ensureNoSelfReferences(null);
CollectionUtils.ensureNoSelfReferences(emptyMap(), "test with empty map");
CollectionUtils.ensureNoSelfReferences(null, "test with null");

Map<String, Object> map = new HashMap<>();
map.put("field", map);

IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> CollectionUtils.ensureNoSelfReferences(map));
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself"));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> CollectionUtils.ensureNoSelfReferences(map, "test with self ref"));
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself (test with self ref)"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,8 @@ public void testEnsureNotNull() {
}

public void testEnsureNoSelfReferences() throws IOException {
CollectionUtils.ensureNoSelfReferences(emptyMap());
CollectionUtils.ensureNoSelfReferences(null);
builder().map(emptyMap());
builder().map(null);

Map<String, Object> map = new HashMap<>();
map.put("field", map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Collections.singletonMap("divisor", 4));
private static final String CONFLICTING_PARAM_NAME = "initialValue";

private static final Script INIT_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptSelfRef",
Collections.emptyMap());
private static final Script MAP_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptSelfRef",
Collections.emptyMap());
private static final Script COMBINE_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptSelfRef",
Collections.emptyMap());

private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

@BeforeClass
Expand Down Expand Up @@ -127,6 +134,25 @@ public static void initMockScripts() {
int divisor = ((Integer) params.get("divisor"));
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
});

SCRIPTS.put("initScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("collector", new ArrayList<Integer>());
agg.put("selfRef", agg);
return agg;
});

SCRIPTS.put("mapScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("selfRef", agg);
return agg;
});

SCRIPTS.put("combineScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("selfRef", agg);
return agg;
});
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -257,6 +283,60 @@ public void testConflictingAggAndScriptParams() throws IOException {
}
}

public void testSelfReferencingAggStateAfterInit() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
// No need to add docs for this test
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT_SELF_REF).mapScript(MAP_SCRIPT);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs init script)", ex.getMessage());
}
}
}

public void testSelfReferencingAggStateAfterMap() throws IOException {
try (Directory directory = newDirectory()) {
Integer numDocs = randomInt(100);
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < numDocs; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT_SELF_REF);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs map script)", ex.getMessage());
}
}
}

public void testSelfReferencingAggStateAfterCombine() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
// No need to add docs for this test
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT).combineScript(COMBINE_SCRIPT_SELF_REF);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs combine script)", ex.getMessage());
}
}
}

/**
* We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
Expand Down

0 comments on commit 85c26d6

Please sign in to comment.