Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Call ensureNoSelfReferences() on _agg state variable after scripted metric agg script executions #31044

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public Iterator<Object> iterator() {

@Override
public String stringify(Object object) {
CollectionUtils.ensureNoSelfReferences(object);
CollectionUtils.ensureNoSelfReferences(object, "CustomReflectionObjectHandler stringify");
return super.stringify(object);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.IdentityHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.RandomAccess;
Expand All @@ -40,6 +41,7 @@
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.IntroSorter;
import org.elasticsearch.common.Strings;

/** Collections-related utility methods. */
public class CollectionUtils {
Expand Down Expand Up @@ -225,10 +227,17 @@ public static int[] toArray(Collection<Integer> ints) {
return ints.stream().mapToInt(s -> s).toArray();
}

public static void ensureNoSelfReferences(Object value) {
/**
* Deeply inspects a Map, Iterable, or Object array looking for references back to itself.
* @throws IllegalArgumentException if a self-reference is found
* @param value The object to evaluate looking for self references
* @param messageHint A string to be included in the exception message if the call fails, to provide
* more context to the handler of the exception
*/
public static void ensureNoSelfReferences(Object value, String messageHint) {
Iterable<?> it = convert(value);
if (it != null) {
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()));
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()), messageHint);
}
}

Expand All @@ -247,13 +256,15 @@ private static Iterable<?> convert(Object value) {
}
}

private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors) {
private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors,
String messageHint) {
if (value != null) {
if (ancestors.add(originalReference) == false) {
throw new IllegalArgumentException("Iterable object is self-referencing itself");
String suffix = Strings.isNullOrEmpty(messageHint) ? "" : String.format(Locale.ROOT, " (%s)", messageHint);
throw new IllegalArgumentException("Iterable object is self-referencing itself" + suffix);
}
for (Object o : value) {
ensureNoSelfReferences(convert(o), o, ancestors);
ensureNoSelfReferences(convert(o), o, ancestors, messageHint);
}
ancestors.remove(originalReference);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.Script;
Expand Down Expand Up @@ -97,7 +98,11 @@ public InternalAggregation doReduce(List<InternalAggregation> aggregations, Redu
ExecutableScript.Factory factory = reduceContext.scriptService().compile(
firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT);
ExecutableScript script = factory.newInstance(vars);
aggregation = Collections.singletonList(script.run());

Object scriptResult = script.run();
CollectionUtils.ensureNoSelfReferences(scriptResult, "reduce script");

aggregation = Collections.singletonList(scriptResult);
} else if (reduceContext.isFinalReduce()) {
aggregation = Collections.singletonList(aggregationObjects);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ public void collect(int doc, long bucket) throws IOException {
assert bucket == 0 : bucket;
leafMapScript.setDocument(doc);
leafMapScript.run();
CollectionUtils.ensureNoSelfReferences(params, "Scripted metric aggs map script");
}
};
}
Expand All @@ -78,7 +79,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
Object aggregation;
if (combineScript != null) {
aggregation = combineScript.run();
CollectionUtils.ensureNoSelfReferences(aggregation);
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
} else {
aggregation = params.get("_agg");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.search.aggregations.metrics.scripted;

import org.elasticsearch.common.util.CollectionUtils;
import org.elasticsearch.script.ExecutableScript;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.SearchScript;
Expand Down Expand Up @@ -89,6 +90,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
final Script reduceScript = deepCopyScript(this.reduceScript, context);
if (initScript != null) {
initScript.run();
CollectionUtils.ensureNoSelfReferences(aggParams.get("_agg"), "Scripted metric aggs init script");
}
return new ScriptedMetricAggregator(name, mapScript,
combineScript, reduceScript, aggParams, context, parent,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ public boolean advanceExact(int doc) throws IOException {
final BytesRef value = bytesValues.nextValue();
script.setNextAggregationValue(value.utf8ToString());
Object run = script.run();
CollectionUtils.ensureNoSelfReferences(run);
CollectionUtils.ensureNoSelfReferences(run, "ValuesSource.BytesValues script");
values[i].copyChars(run.toString());
}
sort();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private void set(int i, Object o) {
if (o == null) {
values[i].clear();
} else {
CollectionUtils.ensureNoSelfReferences(o);
CollectionUtils.ensureNoSelfReferences(o, "ScriptBytesValues value");
values[i].copyChars(o.toString());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOExcept
final Object value;
try {
value = leafScripts[i].run();
CollectionUtils.ensureNoSelfReferences(value);
CollectionUtils.ensureNoSelfReferences(value, "ScriptFieldsFetchSubPhase leaf script " + i);
} catch (RuntimeException e) {
if (scriptFields.get(i).ignoreException()) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ public boolean advanceExact(int doc) throws IOException {
@Override
public BytesRef binaryValue() {
final Object run = leafScript.run();
CollectionUtils.ensureNoSelfReferences(run);
CollectionUtils.ensureNoSelfReferences(run, "ScriptSortBuilder leaf script");
spare.copyChars(run.toString());
return spare.get();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,14 @@ public void testPerfectPartition() {
}

public void testEnsureNoSelfReferences() {
CollectionUtils.ensureNoSelfReferences(emptyMap());
CollectionUtils.ensureNoSelfReferences(null);
CollectionUtils.ensureNoSelfReferences(emptyMap(), "test with empty map");
CollectionUtils.ensureNoSelfReferences(null, "test with null");

Map<String, Object> map = new HashMap<>();
map.put("field", map);

IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> CollectionUtils.ensureNoSelfReferences(map));
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself"));
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
() -> CollectionUtils.ensureNoSelfReferences(map, "test with self ref"));
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself (test with self ref)"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -843,8 +843,8 @@ public void testEnsureNotNull() {
}

public void testEnsureNoSelfReferences() throws IOException {
CollectionUtils.ensureNoSelfReferences(emptyMap());
CollectionUtils.ensureNoSelfReferences(null);
builder().map(emptyMap());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed these to call the functions that are supposed to be under test here, but I did not go so far as to introduce a message hint to the XContentBuilder version of ensureNoSelfReferences(), just because that's where I arbitrarily drew the line on scope. Happy to revisit if you'd like.

builder().map(null);

Map<String, Object> map = new HashMap<>();
map.put("field", map);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
Collections.singletonMap("divisor", 4));
private static final String CONFLICTING_PARAM_NAME = "initialValue";

private static final Script INIT_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptSelfRef",
Collections.emptyMap());
private static final Script MAP_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptSelfRef",
Collections.emptyMap());
private static final Script COMBINE_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptSelfRef",
Collections.emptyMap());

private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();

@BeforeClass
Expand Down Expand Up @@ -127,6 +134,25 @@ public static void initMockScripts() {
int divisor = ((Integer) params.get("divisor"));
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
});

SCRIPTS.put("initScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("collector", new ArrayList<Integer>());
agg.put("selfRef", agg);
return agg;
});

SCRIPTS.put("mapScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("selfRef", agg);
return agg;
});

SCRIPTS.put("combineScriptSelfRef", params -> {
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
agg.put("selfRef", agg);
return agg;
});
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -257,6 +283,60 @@ public void testConflictingAggAndScriptParams() throws IOException {
}
}

public void testSelfReferencingAggStateAfterInit() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
// No need to add docs for this test
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT_SELF_REF).mapScript(MAP_SCRIPT);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs init script)", ex.getMessage());
}
}
}

public void testSelfReferencingAggStateAfterMap() throws IOException {
try (Directory directory = newDirectory()) {
Integer numDocs = randomInt(100);
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
for (int i = 0; i < numDocs; i++) {
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
}
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT_SELF_REF);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs map script)", ex.getMessage());
}
}
}

public void testSelfReferencingAggStateAfterCombine() throws IOException {
try (Directory directory = newDirectory()) {
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
// No need to add docs for this test
}
try (IndexReader indexReader = DirectoryReader.open(directory)) {
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT).combineScript(COMBINE_SCRIPT_SELF_REF);

IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
);
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs combine script)", ex.getMessage());
}
}
}

/**
* We cannot use Mockito for mocking QueryShardContext in this case because
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)
Expand Down