Skip to content

Commit

Permalink
Scripting: Fix painless compiler loader to know about context classes (
Browse files Browse the repository at this point in the history
…#32385)

This commit fixes the painless compiler classloader to know about the
classes from the script context. This fixes an issue when a custom
context is used from a plugin which caused a ClassNotFoundException for
the script class and its factory classes.
  • Loading branch information
rjernst authored Jul 31, 2018
1 parent 2245957 commit 2ed9782
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,14 @@ final class Compiler {
/**
* A secure class loader used to define Painless scripts.
*/
static final class Loader extends SecureClassLoader {
final class Loader extends SecureClassLoader {
private final AtomicInteger lambdaCounter = new AtomicInteger(0);
private final PainlessLookup painlessLookup;

/**
* @param parent The parent ClassLoader.
*/
Loader(ClassLoader parent, PainlessLookup painlessLookup) {
Loader(ClassLoader parent) {
super(parent);

this.painlessLookup = painlessLookup;
}

/**
Expand All @@ -90,6 +87,15 @@ static final class Loader extends SecureClassLoader {
*/
@Override
public Class<?> findClass(String name) throws ClassNotFoundException {
if (scriptClass.getName().equals(name)) {
return scriptClass;
}
if (factoryClass != null && factoryClass.getName().equals(name)) {
return factoryClass;
}
if (statefulFactoryClass != null && statefulFactoryClass.getName().equals(name)) {
return statefulFactoryClass;
}
Class<?> found = painlessLookup.getClassFromBinaryName(name);

return found != null ? found : super.findClass(name);
Expand Down Expand Up @@ -139,13 +145,23 @@ int newLambdaIdentifier() {
* {@link Compiler}'s specified {@link PainlessLookup}.
*/
public Loader createLoader(ClassLoader parent) {
return new Loader(parent, painlessLookup);
return new Loader(parent);
}

/**
* The class/interface the script is guaranteed to derive/implement.
* The class/interface the script will implement.
*/
private final Class<?> scriptClass;

/**
* The class/interface to create the {@code scriptClass} instance.
*/
private final Class<?> factoryClass;

/**
* An optional class/interface to create the {@code factoryClass} instance.
*/
private final Class<?> base;
private final Class<?> statefulFactoryClass;

/**
* The whitelist the script will use.
Expand All @@ -154,11 +170,15 @@ public Loader createLoader(ClassLoader parent) {

/**
* Standard constructor.
* @param base The class/interface the script is guaranteed to derive/implement.
* @param scriptClass The class/interface the script will implement.
* @param factoryClass An optional class/interface to create the {@code scriptClass} instance.
* @param statefulFactoryClass An optional class/interface to create the {@code factoryClass} instance.
* @param painlessLookup The whitelist the script will use.
*/
Compiler(Class<?> base, PainlessLookup painlessLookup) {
this.base = base;
Compiler(Class<?> scriptClass, Class<?> factoryClass, Class<?> statefulFactoryClass, PainlessLookup painlessLookup) {
this.scriptClass = scriptClass;
this.factoryClass = factoryClass;
this.statefulFactoryClass = statefulFactoryClass;
this.painlessLookup = painlessLookup;
}

Expand All @@ -177,7 +197,7 @@ Constructor<?> compile(Loader loader, MainMethodReserved reserved, String name,
" plugin if a script longer than this length is a requirement.");
}

ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base);
ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass);
SSource root = Walker.buildPainlessTree(scriptClassInfo, reserved, name, source, settings, painlessLookup,
null);
root.analyze(painlessLookup);
Expand Down Expand Up @@ -209,7 +229,7 @@ byte[] compile(String name, String source, CompilerSettings settings, Printer de
" plugin if a script longer than this length is a requirement.");
}

ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, base);
ScriptClassInfo scriptClassInfo = new ScriptClassInfo(painlessLookup, scriptClass);
SSource root = Walker.buildPainlessTree(scriptClassInfo, new MainMethodReserved(), name, source, settings, painlessLookup,
debugStream);
root.analyze(painlessLookup);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import java.security.PrivilegedAction;

import static java.lang.invoke.MethodHandles.Lookup;
import static org.elasticsearch.painless.Compiler.Loader;
import static org.elasticsearch.painless.WriterConstants.CLASS_VERSION;
import static org.elasticsearch.painless.WriterConstants.CTOR_METHOD_NAME;
import static org.elasticsearch.painless.WriterConstants.DELEGATE_BOOTSTRAP_HANDLE;
Expand Down Expand Up @@ -207,7 +206,7 @@ public static CallSite lambdaBootstrap(
MethodType delegateMethodType,
int isDelegateInterface)
throws LambdaConversionException {
Loader loader = (Loader)lookup.lookupClass().getClassLoader();
Compiler.Loader loader = (Compiler.Loader)lookup.lookupClass().getClassLoader();
String lambdaClassName = Type.getInternalName(lookup.lookupClass()) + "$$Lambda" + loader.newLambdaIdentifier();
Type lambdaClassType = Type.getObjectType(lambdaClassName);
Type delegateClassType = Type.getObjectType(delegateClassName.replace('.', '/'));
Expand Down Expand Up @@ -457,11 +456,11 @@ private static void endLambdaClass(ClassWriter cw) {
}

/**
* Defines the {@link Class} for the lambda class using the same {@link Loader}
* Defines the {@link Class} for the lambda class using the same {@link Compiler.Loader}
* that originally defined the class for the Painless script.
*/
private static Class<?> createLambdaClass(
Loader loader,
Compiler.Loader loader,
ClassWriter cw,
Type lambdaClassType) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ public PainlessScriptEngine(Settings settings, Map<ScriptContext<?>, List<Whitel
for (Map.Entry<ScriptContext<?>, List<Whitelist>> entry : contexts.entrySet()) {
ScriptContext<?> context = entry.getKey();
if (context.instanceClazz.equals(SearchScript.class) || context.instanceClazz.equals(ExecutableScript.class)) {
contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class,
contextsToCompilers.put(context, new Compiler(GenericElasticsearchScript.class, null, null,
PainlessLookupBuilder.buildFromWhitelists(entry.getValue())));
} else {
contextsToCompilers.put(context, new Compiler(context.instanceClazz,
contextsToCompilers.put(context, new Compiler(context.instanceClazz, context.factoryClazz, context.statefulFactoryClazz,
PainlessLookupBuilder.buildFromWhitelists(entry.getValue())));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public Map<String, Object> getTestMap() {
}

public void testGets() {
Compiler compiler = new Compiler(Gets.class, painlessLookup);
Compiler compiler = new Compiler(Gets.class, null, null, painlessLookup);
Map<String, Object> map = new HashMap<>();
map.put("s", 1);

Expand All @@ -87,7 +87,7 @@ public abstract static class NoArgs {
public abstract Object execute();
}
public void testNoArgs() {
Compiler compiler = new Compiler(NoArgs.class, painlessLookup);
Compiler compiler = new Compiler(NoArgs.class, null, null, painlessLookup);
assertEquals(1, ((NoArgs)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
assertEquals("foo", ((NoArgs)scriptEngine.compile(compiler, null, "'foo'", emptyMap())).execute());

Expand All @@ -111,13 +111,13 @@ public abstract static class OneArg {
public abstract Object execute(Object arg);
}
public void testOneArg() {
Compiler compiler = new Compiler(OneArg.class, painlessLookup);
Compiler compiler = new Compiler(OneArg.class, null, null, painlessLookup);
Object rando = randomInt();
assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));
rando = randomAlphaOfLength(5);
assertEquals(rando, ((OneArg)scriptEngine.compile(compiler, null, "arg", emptyMap())).execute(rando));

Compiler noargs = new Compiler(NoArgs.class, painlessLookup);
Compiler noargs = new Compiler(NoArgs.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, () ->
scriptEngine.compile(noargs, null, "doc", emptyMap()));
assertEquals("Variable [doc] is not defined.", e.getMessage());
Expand All @@ -132,7 +132,7 @@ public abstract static class ArrayArg {
public abstract Object execute(String[] arg);
}
public void testArrayArg() {
Compiler compiler = new Compiler(ArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(ArrayArg.class, null, null, painlessLookup);
String rando = randomAlphaOfLength(5);
assertEquals(rando, ((ArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new String[] {rando, "foo"}));
}
Expand All @@ -142,7 +142,7 @@ public abstract static class PrimitiveArrayArg {
public abstract Object execute(int[] arg);
}
public void testPrimitiveArrayArg() {
Compiler compiler = new Compiler(PrimitiveArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(PrimitiveArrayArg.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((PrimitiveArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new int[] {rando, 10}));
}
Expand All @@ -152,7 +152,7 @@ public abstract static class DefArrayArg {
public abstract Object execute(Object[] arg);
}
public void testDefArrayArg() {
Compiler compiler = new Compiler(DefArrayArg.class, painlessLookup);
Compiler compiler = new Compiler(DefArrayArg.class, null, null, painlessLookup);
Object rando = randomInt();
assertEquals(rando, ((DefArrayArg)scriptEngine.compile(compiler, null, "arg[0]", emptyMap())).execute(new Object[] {rando, 10}));
rando = randomAlphaOfLength(5);
Expand All @@ -170,7 +170,7 @@ public abstract static class ManyArgs {
public abstract boolean needsD();
}
public void testManyArgs() {
Compiler compiler = new Compiler(ManyArgs.class, painlessLookup);
Compiler compiler = new Compiler(ManyArgs.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((ManyArgs)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
assertEquals(10, ((ManyArgs)scriptEngine.compile(compiler, null, "a + b + c + d", emptyMap())).execute(1, 2, 3, 4));
Expand Down Expand Up @@ -198,7 +198,7 @@ public abstract static class VarargTest {
public abstract Object execute(String... arg);
}
public void testVararg() {
Compiler compiler = new Compiler(VarargTest.class, painlessLookup);
Compiler compiler = new Compiler(VarargTest.class, null, null, painlessLookup);
assertEquals("foo bar baz", ((VarargTest)scriptEngine.compile(compiler, null, "String.join(' ', Arrays.asList(arg))", emptyMap()))
.execute("foo", "bar", "baz"));
}
Expand All @@ -214,7 +214,7 @@ public Object executeWithASingleOne(int a, int b, int c) {
}
}
public void testDefaultMethods() {
Compiler compiler = new Compiler(DefaultMethods.class, painlessLookup);
Compiler compiler = new Compiler(DefaultMethods.class, null, null, painlessLookup);
int rando = randomInt();
assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).execute(rando, 0, 0, 0));
assertEquals(rando, ((DefaultMethods)scriptEngine.compile(compiler, null, "a", emptyMap())).executeWithASingleOne(rando, 0, 0));
Expand All @@ -228,7 +228,7 @@ public abstract static class ReturnsVoid {
public abstract void execute(Map<String, Object> map);
}
public void testReturnsVoid() {
Compiler compiler = new Compiler(ReturnsVoid.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsVoid.class, null, null, painlessLookup);
Map<String, Object> map = new HashMap<>();
((ReturnsVoid)scriptEngine.compile(compiler, null, "map.a = 'foo'", emptyMap())).execute(map);
assertEquals(singletonMap("a", "foo"), map);
Expand All @@ -247,7 +247,7 @@ public abstract static class ReturnsPrimitiveBoolean {
public abstract boolean execute();
}
public void testReturnsPrimitiveBoolean() {
Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveBoolean.class, null, null, painlessLookup);

assertEquals(true, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "true", emptyMap())).execute());
assertEquals(false, ((ReturnsPrimitiveBoolean)scriptEngine.compile(compiler, null, "false", emptyMap())).execute());
Expand Down Expand Up @@ -289,7 +289,7 @@ public abstract static class ReturnsPrimitiveInt {
public abstract int execute();
}
public void testReturnsPrimitiveInt() {
Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveInt.class, null, null, painlessLookup);

assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "1", emptyMap())).execute());
assertEquals(1, ((ReturnsPrimitiveInt)scriptEngine.compile(compiler, null, "(int) 1L", emptyMap())).execute());
Expand Down Expand Up @@ -331,7 +331,7 @@ public abstract static class ReturnsPrimitiveFloat {
public abstract float execute();
}
public void testReturnsPrimitiveFloat() {
Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveFloat.class, null, null, painlessLookup);

assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "1.1f", emptyMap())).execute(), 0);
assertEquals(1.1f, ((ReturnsPrimitiveFloat)scriptEngine.compile(compiler, null, "(float) 1.1d", emptyMap())).execute(), 0);
Expand Down Expand Up @@ -362,7 +362,7 @@ public abstract static class ReturnsPrimitiveDouble {
public abstract double execute();
}
public void testReturnsPrimitiveDouble() {
Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, painlessLookup);
Compiler compiler = new Compiler(ReturnsPrimitiveDouble.class, null, null, painlessLookup);

assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1", emptyMap())).execute(), 0);
assertEquals(1.0, ((ReturnsPrimitiveDouble)scriptEngine.compile(compiler, null, "1L", emptyMap())).execute(), 0);
Expand Down Expand Up @@ -396,7 +396,7 @@ public abstract static class NoArgumentsConstant {
public abstract Object execute(String foo);
}
public void testNoArgumentsConstant() {
Compiler compiler = new Compiler(NoArgumentsConstant.class, painlessLookup);
Compiler compiler = new Compiler(NoArgumentsConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith(
Expand All @@ -409,7 +409,7 @@ public abstract static class WrongArgumentsConstant {
public abstract Object execute(String foo);
}
public void testWrongArgumentsConstant() {
Compiler compiler = new Compiler(WrongArgumentsConstant.class, painlessLookup);
Compiler compiler = new Compiler(WrongArgumentsConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith(
Expand All @@ -422,7 +422,7 @@ public abstract static class WrongLengthOfArgumentConstant {
public abstract Object execute(String foo);
}
public void testWrongLengthOfArgumentConstant() {
Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, painlessLookup);
Compiler compiler = new Compiler(WrongLengthOfArgumentConstant.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertThat(e.getMessage(), startsWith("[" + WrongLengthOfArgumentConstant.class.getName() + "#ARGUMENTS] has length [2] but ["
Expand All @@ -434,7 +434,7 @@ public abstract static class UnknownArgType {
public abstract Object execute(UnknownArgType foo);
}
public void testUnknownArgType() {
Compiler compiler = new Compiler(UnknownArgType.class, painlessLookup);
Compiler compiler = new Compiler(UnknownArgType.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("[foo] is of unknown type [" + UnknownArgType.class.getName() + ". Painless interfaces can only accept arguments "
Expand All @@ -446,7 +446,7 @@ public abstract static class UnknownReturnType {
public abstract UnknownReturnType execute(String foo);
}
public void testUnknownReturnType() {
Compiler compiler = new Compiler(UnknownReturnType.class, painlessLookup);
Compiler compiler = new Compiler(UnknownReturnType.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("Painless can only implement execute methods returning a whitelisted type but [" + UnknownReturnType.class.getName()
Expand All @@ -458,7 +458,7 @@ public abstract static class UnknownArgTypeInArray {
public abstract Object execute(UnknownArgTypeInArray[] foo);
}
public void testUnknownArgTypeInArray() {
Compiler compiler = new Compiler(UnknownArgTypeInArray.class, painlessLookup);
Compiler compiler = new Compiler(UnknownArgTypeInArray.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "1", emptyMap()));
assertEquals("[foo] is of unknown type [" + UnknownArgTypeInArray.class.getName() + ". Painless interfaces can only accept "
Expand All @@ -470,7 +470,7 @@ public abstract static class TwoExecuteMethods {
public abstract Object execute(boolean foo);
}
public void testTwoExecuteMethods() {
Compiler compiler = new Compiler(TwoExecuteMethods.class, painlessLookup);
Compiler compiler = new Compiler(TwoExecuteMethods.class, null, null, painlessLookup);
Exception e = expectScriptThrows(IllegalArgumentException.class, false, () ->
scriptEngine.compile(compiler, null, "null", emptyMap()));
assertEquals("Painless can only implement interfaces that have a single method named [execute] but ["
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static String toString(Class<?> iface, String source, CompilerSettings settings)
PrintWriter outputWriter = new PrintWriter(output);
Textifier textifier = new Textifier();
try {
new Compiler(iface, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS))
new Compiler(iface, null, null, PainlessLookupBuilder.buildFromWhitelists(Whitelist.BASE_WHITELISTS))
.compile("<debugging>", source, settings, textifier);
} catch (RuntimeException e) {
textifier.print(outputWriter);
Expand Down

0 comments on commit 2ed9782

Please sign in to comment.