Skip to content

Commit

Permalink
feat: support polymorphic invoke (#384)(#1777)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Feb 3, 2023
1 parent 4d00fed commit 540c0a8
Show file tree
Hide file tree
Showing 13 changed files with 342 additions and 10 deletions.
8 changes: 8 additions & 0 deletions jadx-core/src/main/java/jadx/core/codegen/InsnGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,19 @@ private void makeInvoke(InvokeNode insn, ICodeWriter code) throws CodegenExcepti
MethodInfo callMth = insn.getCallMth();
MethodNode callMthNode = mth.root().resolveMethod(callMth);

if (insn.isPolymorphicCall()) {
// add missing cast
code.add('(');
useType(code, callMth.getReturnType());
code.add(") ");
}

int k = 0;
switch (type) {
case DIRECT:
case VIRTUAL:
case INTERFACE:
case POLYMORPHIC:
InsnArg arg = insn.getArg(0);
if (needInvokeArg(arg)) {
addArgDot(code, arg);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package jadx.core.dex.instructions;

import java.util.List;
import java.util.Objects;

import org.jetbrains.annotations.NotNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jadx.api.plugins.input.data.ICodeReader;
import jadx.api.plugins.input.data.IMethodProto;
import jadx.api.plugins.input.data.IMethodRef;
import jadx.api.plugins.input.insns.InsnData;
import jadx.api.plugins.input.insns.custom.IArrayPayload;
Expand All @@ -25,6 +27,7 @@
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.DecodeException;
import jadx.core.utils.exceptions.JadxRuntimeException;
import jadx.core.utils.input.InsnDataUtils;
Expand Down Expand Up @@ -440,6 +443,8 @@ protected InsnNode decode(InsnData insn) throws DecodeException {
return invokeCustom(insn, false);
case INVOKE_SPECIAL:
return invokeSpecial(insn);
case INVOKE_POLYMORPHIC:
return invokePolymorphic(insn, false);

case INVOKE_DIRECT_RANGE:
return invoke(insn, InvokeType.DIRECT, true);
Expand All @@ -451,6 +456,8 @@ protected InsnNode decode(InsnData insn) throws DecodeException {
return invoke(insn, InvokeType.VIRTUAL, true);
case INVOKE_CUSTOM_RANGE:
return invokeCustom(insn, true);
case INVOKE_POLYMORPHIC_RANGE:
return invokePolymorphic(insn, true);

case NEW_INSTANCE:
ArgType clsType = ArgType.parse(insn.getIndexAsType());
Expand Down Expand Up @@ -581,6 +588,22 @@ private InsnNode invokeCustom(InsnData insn, boolean isRange) {
return InvokeCustomBuilder.build(method, insn, isRange);
}

private InsnNode invokePolymorphic(InsnData insn, boolean isRange) {
IMethodRef mthRef = InsnDataUtils.getMethodRef(insn);
if (mthRef == null) {
throw new JadxRuntimeException("Failed to load method reference for insn: " + insn);
}
MethodInfo callMth = MethodInfo.fromRef(root, mthRef);
IMethodProto proto = insn.getIndexAsProto(insn.getTarget());

// expand call args
List<ArgType> args = Utils.collectionMap(proto.getArgTypes(), ArgType::parse);
ArgType returnType = ArgType.parse(proto.getReturnType());
MethodInfo effectiveCallMth = MethodInfo.fromDetails(root, callMth.getDeclClass(),
callMth.getName(), args, returnType);
return new InvokePolymorphicNode(effectiveCallMth, insn, proto, callMth, isRange);
}

private InsnNode invokeSpecial(InsnData insn) {
IMethodRef mthRef = InsnDataUtils.getMethodRef(insn);
if (mthRef == null) {
Expand Down
13 changes: 13 additions & 0 deletions jadx-core/src/main/java/jadx/core/dex/instructions/InvokeNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ public boolean isStaticCall() {
return type == InvokeType.STATIC;
}

public boolean isPolymorphicCall() {
if (type == InvokeType.POLYMORPHIC) {
return true;
}
// java bytecode uses virtual call with modified method info
if (type == InvokeType.VIRTUAL
&& mth.getDeclClass().getFullName().equals("java.lang.invoke.MethodHandle")
&& (mth.getName().equals("invoke") || mth.getName().equals("invokeExact"))) {
return true;
}
return false;
}

public int getFirstArgOffset() {
return type == InvokeType.STATIC ? 0 : 1;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package jadx.core.dex.instructions;

import jadx.api.plugins.input.data.IMethodProto;
import jadx.api.plugins.input.insns.InsnData;
import jadx.core.dex.info.MethodInfo;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils;

public class InvokePolymorphicNode extends InvokeNode {
private final IMethodProto proto;
private final MethodInfo baseCallRef;

public InvokePolymorphicNode(MethodInfo callMth, InsnData insn, IMethodProto proto, MethodInfo baseRef, boolean isRange) {
super(callMth, insn, InvokeType.POLYMORPHIC, true, isRange);
this.proto = proto;
this.baseCallRef = baseRef;
}

public InvokePolymorphicNode(MethodInfo callMth, int argsCount, IMethodProto proto, MethodInfo baseRef) {
super(callMth, InvokeType.POLYMORPHIC, argsCount);
this.proto = proto;
this.baseCallRef = baseRef;
}

public IMethodProto getProto() {
return proto;
}

public MethodInfo getBaseCallRef() {
return baseCallRef;
}

@Override
public InsnNode copy() {
InvokePolymorphicNode copy = new InvokePolymorphicNode(getCallMth(), getArgsCount(), proto, baseCallRef);
copyCommonParams(copy);
return copy;
}

@Override
public boolean isSame(InsnNode obj) {
if (this == obj) {
return true;
}
if (!(obj instanceof InvokePolymorphicNode) || !super.isSame(obj)) {
return false;
}
InvokePolymorphicNode other = (InvokePolymorphicNode) obj;
return proto.equals(other.proto);
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(InsnUtils.formatOffset(offset)).append(": INVOKE_POLYMORPHIC ");
if (getResult() != null) {
sb.append(getResult()).append(" = ");
}
if (!appendArgs(sb)) {
sb.append('\n');
}
sb.append(" base: ").append(baseCallRef).append('\n');
sb.append(" proto: ").append(proto).append('\n');
return sb.toString();
}
}
20 changes: 13 additions & 7 deletions jadx-core/src/main/java/jadx/core/dex/nodes/InsnNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -539,19 +539,25 @@ public final boolean equals(Object obj) {
return super.equals(obj);
}

protected void appendArgs(StringBuilder sb) {
/**
* Append arguments type, wrap line if too long
*
* @return true if args wrapped
*/
protected boolean appendArgs(StringBuilder sb) {
if (arguments.isEmpty()) {
return;
return false;
}
String argsStr = Utils.listToString(arguments);
if (argsStr.length() < 120) {
sb.append(argsStr);
} else {
// wrap args
String separator = ICodeWriter.NL + " ";
sb.append(separator).append(Utils.listToString(arguments, separator));
sb.append(ICodeWriter.NL);
return false;
}
// wrap args
String separator = ICodeWriter.NL + " ";
sb.append(separator).append(Utils.listToString(arguments, separator));
sb.append(ICodeWriter.NL);
return true;
}

@Override
Expand Down
3 changes: 2 additions & 1 deletion jadx-core/src/test/java/jadx/tests/api/IntegrationTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.slf4j.LoggerFactory;

import jadx.api.CommentsLevel;
import jadx.api.DecompilationMode;
import jadx.api.ICodeInfo;
import jadx.api.ICodeWriter;
import jadx.api.JadxArgs;
Expand Down Expand Up @@ -546,7 +547,7 @@ public void useTargetJavaVersion(int version) {

protected void setFallback() {
disableCompilation();
this.args.setFallbackMode(true);
this.args.setDecompilationMode(DecompilationMode.FALLBACK);
}

protected void disableCompilation() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ public JadxCodeAssertions code() {
return new JadxCodeAssertions(codeStr);
}

public JadxCodeAssertions disasmCode() {
isNotNull();
String disasmCode = actual.getDisassembledCode();
assertThat(disasmCode).isNotNull().isNotBlank();
return new JadxCodeAssertions(disasmCode);
}

public JadxCodeAssertions reloadCode(IntegrationTest testInstance) {
isNotNull();
ICodeInfo code = actual.reloadCode();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package jadx.tests.integration.invoke;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

import org.junit.jupiter.api.Test;

import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import jadx.tests.api.extensions.profiles.TestProfile;
import jadx.tests.api.extensions.profiles.TestWithProfiles;

import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;

public class TestPolymorphicInvoke extends SmaliTest {

public static class TestCls {
public String func(int a, int c) {
return String.valueOf(a + c);
}

public String test() {
try {
MethodType methodType = MethodType.methodType(String.class, Integer.TYPE, Integer.TYPE);
MethodHandle methodHandle = MethodHandles.lookup().findVirtual(TestCls.class, "func", methodType);
return (String) methodHandle.invoke(this, 1, 2);
} catch (Throwable e) {
fail(e);
return null;
}
}

public void check() {
assertThat(test()).isEqualTo("3");
}
}

@TestWithProfiles({ TestProfile.DX_J8, TestProfile.D8_J11 })
public void test() {
ClassNode cls = getClassNode(TestCls.class);
assertThat(cls).code()
.containsOne("return (String) methodHandle.invoke(this, 1, 2);");
assertThat(cls).disasmCode()
.containsOne("invoke-polymorphic");
}

@TestWithProfiles({ TestProfile.JAVA8, TestProfile.JAVA11 })
public void testJava() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("return (String) methodHandle.invoke(this, 1, 2);");
// java uses 'invokevirtual'
}

@Test
public void testSmali() {
assertThat(getClassNodeFromSmali())
.code()
.containsOne("String ret = (String) methodHandle.invoke(this, 10, 20);");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package jadx.tests.integration.invoke;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;

import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.extensions.profiles.TestProfile;
import jadx.tests.api.extensions.profiles.TestWithProfiles;

import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
import static org.junit.jupiter.api.Assertions.fail;

public class TestPolymorphicRangeInvoke extends IntegrationTest {

public static class TestCls {
public String func2(int a, int b, int c, int d, int e, int f) {
return String.valueOf(a + b + c + d + e + f);
}

public String test() {
try {
MethodHandles.Lookup lookup = MethodHandles.lookup();
MethodType methodType = MethodType.methodType(String.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, Integer.TYPE,
Integer.TYPE, Integer.TYPE);
MethodHandle methodHandle = lookup.findVirtual(TestCls.class, "func2", methodType);
return (String) methodHandle.invoke(this, 10, 20, 30, 40, 50, 60);
} catch (Throwable e) {
fail(e);
return null;
}
}

public void check() {
assertThat(test()).isEqualTo("210");
}
}

@TestWithProfiles({ TestProfile.DX_J8 })
public void test() {
ClassNode cls = getClassNode(TestCls.class);
assertThat(cls).code()
.containsOne("return (String) methodHandle.invoke(this, 10, 20, 30, 40, 50, 60);");
assertThat(cls).disasmCode()
.containsOne("invoke-polymorphic/range");
}
}
Loading

0 comments on commit 540c0a8

Please sign in to comment.