Skip to content

Commit

Permalink
fix: merge const block before return (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Oct 31, 2019
1 parent 11db454 commit bae36f9
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ public boolean sameRegAndSVar(InsnArg arg) {
&& Objects.equals(sVar, reg.getSVar());
}

public boolean sameReg(InsnArg arg) {
if (!arg.isRegister()) {
return false;
}
return regNum == ((RegisterArg) arg).getRegNum();
}

public boolean sameCodeVar(RegisterArg arg) {
return this.getSVar().getCodeVar() == arg.getSVar().getCodeVar();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.LoopInfo;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
Expand All @@ -29,6 +30,7 @@
import jadx.core.dex.trycatch.TryCatchBlock;
import jadx.core.dex.visitors.AbstractVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxRuntimeException;

import static jadx.core.dex.visitors.blocksmaker.BlockSplitter.connect;
Expand Down Expand Up @@ -413,7 +415,48 @@ private static boolean modifyBlocksTree(MethodNode mth) {
return true;
}
}
return splitReturn(mth);
if (mergeConstReturn(mth)) {
return true;
}
return splitReturnBlocks(mth);
}

private static boolean mergeConstReturn(MethodNode mth) {
if (mth.getReturnType() == ArgType.VOID) {
return false;
}

boolean changed = false;
for (BlockNode exitBlock : new ArrayList<>(mth.getExitBlocks())) {
BlockNode pred = Utils.getOne(exitBlock.getPredecessors());
if (pred != null) {
InsnNode constInsn = Utils.getOne(pred.getInstructions());
if (constInsn != null && constInsn.isConstInsn()) {
RegisterArg constArg = constInsn.getResult();
InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock);
if (returnInsn != null) {
InsnArg retArg = returnInsn.getArg(0);
if (constArg.sameReg(retArg)) {
mergeConstAndReturnBlocks(mth, exitBlock, pred);
changed = true;
}
}
}
}
}
if (changed) {
removeMarkedBlocks(mth);
cleanExitNodes(mth);
}
return changed;
}

private static void mergeConstAndReturnBlocks(MethodNode mth, BlockNode exitBlock, BlockNode pred) {
pred.getInstructions().addAll(exitBlock.getInstructions());
pred.copyAttributesFrom(exitBlock);
BlockSplitter.removeConnection(pred, exitBlock);
exitBlock.getInstructions().clear();
exitBlock.add(AFlag.REMOVE);
}

private static boolean independentBlockTreeMod(MethodNode mth) {
Expand Down Expand Up @@ -604,16 +647,25 @@ private static boolean mergeHandlers(MethodNode mth, List<BlockNode> blocksForMe
return true;
}

private static boolean splitReturnBlocks(MethodNode mth) {
boolean changed = false;
for (BlockNode exitBlock : mth.getExitBlocks()) {
if (splitReturn(mth, exitBlock)) {
changed = true;
}
}
if (changed) {
cleanExitNodes(mth);
}
return changed;
}

/**
* Splice return block if several predecessors presents
*/
private static boolean splitReturn(MethodNode mth) {
if (mth.getExitBlocks().size() != 1) {
return false;
}
BlockNode exitBlock = mth.getExitBlocks().get(0);
if (exitBlock.getInstructions().size() != 1
|| exitBlock.contains(AFlag.SYNTHETIC)
private static boolean splitReturn(MethodNode mth, BlockNode exitBlock) {
if (exitBlock.contains(AFlag.SYNTHETIC)
|| exitBlock.contains(AFlag.ORIG_RETURN)
|| exitBlock.contains(AType.SPLITTER_BLOCK)) {
return false;
}
Expand All @@ -625,37 +677,45 @@ private static boolean splitReturn(MethodNode mth) {
if (preds.size() < 2) {
return false;
}
InsnNode returnInsn = exitBlock.getInstructions().get(0);
if (returnInsn.getArgsCount() != 0 && !isReturnArgAssignInPred(preds, returnInsn)) {
InsnNode returnInsn = BlockUtils.getLastInsn(exitBlock);
if (returnInsn == null) {
return false;
}
if (returnInsn.getArgsCount() == 1
&& exitBlock.getInstructions().size() == 1
&& !isReturnArgAssignInPred(preds, returnInsn)) {
return false;
}

boolean first = true;
for (BlockNode pred : preds) {
BlockNode newRetBlock = BlockSplitter.startNewBlock(mth, -1);
newRetBlock.add(AFlag.SYNTHETIC);
InsnNode newRetInsn;
if (first) {
newRetInsn = returnInsn;
newRetBlock.add(AFlag.ORIG_RETURN);
newRetBlock.getInstructions().addAll(exitBlock.getInstructions());
first = false;
} else {
newRetInsn = duplicateReturnInsn(returnInsn);
for (InsnNode oldInsn : exitBlock.getInstructions()) {
newRetBlock.getInstructions().add(oldInsn.copy());
}
}
newRetBlock.getInstructions().add(newRetInsn);
BlockSplitter.replaceConnection(pred, exitBlock, newRetBlock);
}
cleanExitNodes(mth);
return true;
}

private static boolean isReturnArgAssignInPred(List<BlockNode> preds, InsnNode returnInsn) {
RegisterArg arg = (RegisterArg) returnInsn.getArg(0);
int regNum = arg.getRegNum();
for (BlockNode pred : preds) {
for (InsnNode insnNode : pred.getInstructions()) {
RegisterArg result = insnNode.getResult();
if (result != null && result.getRegNum() == regNum) {
return true;
InsnArg retArg = returnInsn.getArg(0);
if (retArg.isRegister()) {
RegisterArg arg = (RegisterArg) retArg;
int regNum = arg.getRegNum();
for (BlockNode pred : preds) {
for (InsnNode insnNode : pred.getInstructions()) {
RegisterArg result = insnNode.getResult();
if (result != null && result.getRegNum() == regNum) {
return true;
}
}
}
}
Expand All @@ -673,18 +733,6 @@ private static void cleanExitNodes(MethodNode mth) {
}
}

private static InsnNode duplicateReturnInsn(InsnNode returnInsn) {
InsnNode insn = new InsnNode(returnInsn.getType(), returnInsn.getArgsCount());
if (returnInsn.getArgsCount() == 1) {
RegisterArg arg = (RegisterArg) returnInsn.getArg(0);
insn.addArg(arg.duplicate());
}
insn.copyAttributesFrom(returnInsn);
insn.setOffset(returnInsn.getOffset());
insn.setSourceLine(returnInsn.getSourceLine());
return insn;
}

private static void removeMarkedBlocks(MethodNode mth) {
mth.getBasicBlocks().removeIf(block -> {
if (block.contains(AFlag.REMOVE)) {
Expand Down
8 changes: 8 additions & 0 deletions jadx-core/src/main/java/jadx/core/utils/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,14 @@ public static Map<String, String> newConstStringMap(String... parameters) {
return Collections.unmodifiableMap(result);
}

@Nullable
public static <T> T getOne(@Nullable List<T> list) {
if (list == null || list.size() != 1) {
return null;
}
return list.get(0);
}

@Nullable
public static <T> T last(List<T> list) {
if (list.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import org.junit.jupiter.api.Test;

import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;

import static jadx.tests.api.utils.JadxMatchers.containsLines;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;

Expand All @@ -31,7 +33,20 @@ public void test() {
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();

assertThat(code, containsOne("return this == obj"
+ " || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));"));
assertThat(code, containsLines(2,
"if (this != obj) {",
indent() + "return (obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map);",
"}",
"return true;"));
}

@Test
@NotYetImplemented
public void testNYI() {
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();

assertThat(code,
containsOne("return this == obj || ((obj instanceof TestConditions18) && st(this.map, ((TestConditions18) obj).map));"));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package jadx.tests.integration.conditions;

import org.junit.jupiter.api.Test;

import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;

import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;

public class TestConditions21 extends SmaliTest {

// @formatter:off
/*
public boolean check(Object obj) {
if (this == obj) {
return true;
}
if (obj instanceof List) {
List list = (List) obj;
if (!list.isEmpty() && list.contains(this)) {
return true;
}
}
return false;
}
*/
// @formatter:on

@Test
public void test() {
ClassNode cls = getClassNodeFromSmali();
String code = cls.getCode().toString();

assertThat(code, containsOne("!list.isEmpty() && list.contains(this)"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@

import org.junit.jupiter.api.Test;

import jadx.NotYetImplemented;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;

import static jadx.tests.api.utils.JadxMatchers.containsLines;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;

public class TestTernaryInIf2 extends SmaliTest {

public static class TestCls {
private String a;
private String b;
private String a = "a";
private String b = "b";

public boolean equals(TestCls other) {
if (this.a == null ? other.a == null : this.a.equals(other.a)) {
Expand All @@ -22,6 +24,22 @@ public boolean equals(TestCls other) {
}
return false;
}

public void check() {
TestCls other = new TestCls();
other.a = "a";
other.b = "b";
assertThat(this.equals(other), is(true));

other.b = "not-b";
assertThat(this.equals(other), is(false));

other.b = null;
assertThat(this.equals(other), is(false));

this.b = null;
assertThat(this.equals(other), is(true));
}
}

@Test
Expand All @@ -30,9 +48,20 @@ public void test() {
String code = cls.getCode().toString();

assertThat(code, containsLines(2, "if (this.a != null ? this.a.equals(other.a) : other.a == null) {"));
assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null) {"));
assertThat(code, containsLines(4, "return true;"));
assertThat(code, containsLines(2, "return false;"));
// assertThat(code, containsLines(3, "if (this.b != null ? this.b.equals(other.b) : other.b == null)
// {"));
// assertThat(code, containsLines(4, "return true;"));
// assertThat(code, containsLines(2, "return false;"));
}

@Test
@NotYetImplemented
public void testNYI() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();

assertThat(code, containsLines(2, "return (this.a != null ? this.a.equals(other.a) : other.a == null) "
+ "&& (this.b == null ? other.b == null : this.b.equals(other.b));"));
}

@Test
Expand Down
33 changes: 33 additions & 0 deletions jadx-core/src/test/smali/conditions/TestConditions21.smali
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
.class public final Lconditions/TestConditions21;
.super Ljava/lang/Object;

.method public check(Ljava/lang/Object;)Z
.locals 2

if-eq p0, p1, :ret_true

instance-of v0, p1, Ljava/util/List;
if-eqz v0, :ret_false

check-cast p1, Ljava/util/List;

invoke-interface {p1}, Ljava/util/List;->isEmpty()Z
move-result v0

if-nez v0, :ret_false

invoke-interface {p1, p0}, Ljava/util/List;->contains(Ljava/lang/Object;)Z
move-result v0

if-eqz v0, :ret_false

goto :ret_true

:ret_false
const/4 p1, 0x0
return p1

:ret_true
const/4 p1, 0x1
return p1
.end method

0 comments on commit bae36f9

Please sign in to comment.