Skip to content

Commit

Permalink
fix: inline assign in complex conditions (#699)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Nov 30, 2019
1 parent 600842a commit d1a6841
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 71 deletions.
20 changes: 14 additions & 6 deletions jadx-core/src/main/java/jadx/core/codegen/InsnGen.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,15 +104,26 @@ public void addArg(CodeWriter code, InsnArg arg, boolean wrap) throws CodegenExc
} else if (arg.isLiteral()) {
code.add(lit((LiteralArg) arg));
} else if (arg.isInsnWrap()) {
Flags flag = wrap ? Flags.BODY_ONLY : Flags.BODY_ONLY_NOWRAP;
makeInsn(((InsnWrapArg) arg).getWrapInsn(), code, flag);
addWrappedArg(code, (InsnWrapArg) arg, wrap);
} else if (arg.isNamed()) {
code.add(((Named) arg).getName());
} else {
throw new CodegenException("Unknown arg type " + arg);
}
}

private void addWrappedArg(CodeWriter code, InsnWrapArg arg, boolean wrap) throws CodegenException {
InsnNode wrapInsn = arg.getWrapInsn();
if (wrapInsn.contains(AFlag.FORCE_ASSIGN_INLINE)) {
code.add('(');
makeInsn(wrapInsn, code, Flags.INLINE);
code.add(')');
} else {
Flags flags = wrap ? Flags.BODY_ONLY : Flags.BODY_ONLY_NOWRAP;
makeInsn(wrapInsn, code, flags);
}
}

public void assignVar(CodeWriter code, InsnNode insn) throws CodegenException {
RegisterArg arg = insn.getResult();
if (insn.contains(AFlag.DECLARE_VAR)) {
Expand Down Expand Up @@ -922,10 +933,7 @@ private boolean forceAssign(InsnNode inlineInsn, InvokeNode parentInsn, MethodNo
if (parentInsn.contains(AFlag.WRAPPED)) {
return false;
}
if (callMthNode.getReturnType().equals(ArgType.VOID)) {
return false;
}
return true;
return !callMthNode.getReturnType().equals(ArgType.VOID);
}

private void makeTernary(TernaryInsn insn, CodeWriter code, Set<Flags> state) throws CodegenException {
Expand Down
5 changes: 5 additions & 0 deletions jadx-core/src/main/java/jadx/core/dex/attributes/AFlag.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ public enum AFlag {
*/
IMMUTABLE_TYPE,

/**
* Force inline instruction with inline assign
*/
FORCE_ASSIGN_INLINE,

CUSTOM_DECLARE, // variable for this register don't need declaration
DECLARE_VAR,

Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,48 @@
package jadx.core.dex.regions.conditions;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.InsnNode;

public final class IfInfo {
private final IfCondition condition;
private final Set<BlockNode> mergedBlocks;
private final BlockNode thenBlock;
private final BlockNode elseBlock;
private final Set<BlockNode> skipBlocks;
private final List<InsnNode> forceInlineInsns;
private BlockNode outBlock;
@Deprecated
private BlockNode ifBlock;

public IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock) {
this(condition, thenBlock, elseBlock, new HashSet<>(), new HashSet<>());
this(condition, thenBlock, elseBlock, new HashSet<>(), new HashSet<>(), new ArrayList<>());
}

public IfInfo(IfInfo info, BlockNode thenBlock, BlockNode elseBlock) {
this(info.getCondition(), thenBlock, elseBlock, info.getMergedBlocks(), info.getSkipBlocks());
this(info.getCondition(), thenBlock, elseBlock,
info.getMergedBlocks(), info.getSkipBlocks(), info.getForceInlineInsns());
}

private IfInfo(IfCondition condition, BlockNode thenBlock, BlockNode elseBlock,
Set<BlockNode> mergedBlocks, Set<BlockNode> skipBlocks) {
Set<BlockNode> mergedBlocks, Set<BlockNode> skipBlocks, List<InsnNode> forceInlineInsns) {
this.condition = condition;
this.thenBlock = thenBlock;
this.elseBlock = elseBlock;
this.mergedBlocks = mergedBlocks;
this.skipBlocks = skipBlocks;
this.forceInlineInsns = forceInlineInsns;
}

public static IfInfo invert(IfInfo info) {
IfCondition invertedCondition = IfCondition.invert(info.getCondition());
IfInfo tmpIf = new IfInfo(invertedCondition,
info.getElseBlock(), info.getThenBlock(),
info.getMergedBlocks(), info.getSkipBlocks());
info.getMergedBlocks(), info.getSkipBlocks(), info.getForceInlineInsns());
tmpIf.setIfBlock(info.getIfBlock());
return tmpIf;
}
Expand All @@ -45,6 +51,7 @@ public void merge(IfInfo... arr) {
for (IfInfo info : arr) {
mergedBlocks.addAll(info.getMergedBlocks());
skipBlocks.addAll(info.getSkipBlocks());
addInsnsForForcedInline(info.getForceInlineInsns());
}
}

Expand Down Expand Up @@ -84,6 +91,18 @@ public void setIfBlock(BlockNode ifBlock) {
this.ifBlock = ifBlock;
}

public List<InsnNode> getForceInlineInsns() {
return forceInlineInsns;
}

public void resetForceInlineInsns() {
forceInlineInsns.clear();
}

public void addInsnsForForcedInline(List<InsnNode> insns) {
forceInlineInsns.addAll(insns);
}

@Override
public String toString() {
return "IfInfo: then: " + thenBlock + ", else: " + elseBlock;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package jadx.core.dex.visitors.regions;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Set;

import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -35,11 +37,13 @@ public class IfMakerHelper {
private IfMakerHelper() {
}

@Nullable
static IfInfo makeIfInfo(BlockNode ifBlock) {
IfNode ifNode = (IfNode) BlockUtils.getLastInsn(ifBlock);
if (ifNode == null) {
throw new JadxRuntimeException("Empty IF block: " + ifBlock);
InsnNode lastInsn = BlockUtils.getLastInsn(ifBlock);
if (lastInsn == null || lastInsn.getType() != InsnType.IF) {
return null;
}
IfNode ifNode = (IfNode) lastInsn;
IfCondition condition = IfCondition.fromIfNode(ifNode);
IfInfo info = new IfInfo(condition, ifNode.getThenBlock(), ifNode.getElseBlock());
info.setIfBlock(ifBlock);
Expand All @@ -48,8 +52,11 @@ static IfInfo makeIfInfo(BlockNode ifBlock) {
}

static IfInfo searchNestedIf(IfInfo info) {
IfInfo tmp = mergeNestedIfNodes(info);
return tmp != null ? tmp : info;
IfInfo next = mergeNestedIfNodes(info);
if (next != null) {
return next;
}
return info;
}

static IfInfo restructureIf(MethodNode mth, BlockNode block, IfInfo info) {
Expand Down Expand Up @@ -160,12 +167,24 @@ static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
return null;
}
}

boolean assignInlineNeeded = !nextIf.getForceInlineInsns().isEmpty();
if (assignInlineNeeded) {
for (BlockNode mergedBlock : currentIf.getMergedBlocks()) {
if (mergedBlock.contains(AFlag.LOOP_START)) {
// don't inline assigns into loop condition
return currentIf;
}
}
}

if (isInversionNeeded(currentIf, nextIf)) {
// invert current node for match pattern
nextIf = IfInfo.invert(nextIf);
}
if (!isEqualPaths(curThen, nextIf.getThenBlock())
&& !isEqualPaths(curElse, nextIf.getElseBlock())) {
boolean thenPathSame = isEqualPaths(curThen, nextIf.getThenBlock());
boolean elsePathSame = isEqualPaths(curElse, nextIf.getElseBlock());
if (!thenPathSame && !elsePathSame) {
// complex condition, run additional checks
if (checkConditionBranches(curThen, curElse)
|| checkConditionBranches(curElse, curThen)) {
Expand All @@ -191,6 +210,15 @@ static IfInfo mergeNestedIfNodes(IfInfo currentIf) {
} else {
return currentIf;
}
} else {
if (assignInlineNeeded) {
boolean sameOuts = (thenPathSame && !followThenBranch) || (elsePathSame && followThenBranch);
if (!sameOuts) {
// don't inline assigns inside simple condition
currentIf.resetForceInlineInsns();
return currentIf;
}
}
}

IfInfo result = mergeIfInfo(currentIf, nextIf, followThenBranch);
Expand Down Expand Up @@ -315,36 +343,32 @@ static void confirmMerge(IfInfo info) {
}
info.getSkipBlocks().clear();
}
for (InsnNode forceInlineInsn : info.getForceInlineInsns()) {
forceInlineInsn.add(AFlag.FORCE_ASSIGN_INLINE);
}
}

private static IfInfo getNextIf(IfInfo info, BlockNode block) {
if (!canSelectNext(info, block)) {
return null;
}
BlockNode nestedIfBlock = getNextIfNode(block);
if (nestedIfBlock != null) {
return makeIfInfo(nestedIfBlock);
}
return null;
return getNextIfNodeInfo(info, block);
}

private static boolean canSelectNext(IfInfo info, BlockNode block) {
if (block.getPredecessors().size() == 1) {
return true;
}
if (info.getMergedBlocks().containsAll(block.getPredecessors())) {
return true;
}
return false;
return info.getMergedBlocks().containsAll(block.getPredecessors());
}

private static BlockNode getNextIfNode(BlockNode block) {
private static IfInfo getNextIfNodeInfo(IfInfo info, BlockNode block) {
if (block == null || block.contains(AType.LOOP) || block.contains(AFlag.ADDED_TO_REGION)) {
return null;
}
InsnNode lastInsn = BlockUtils.getLastInsn(block);
if (lastInsn != null && lastInsn.getType() == InsnType.IF) {
return block;
return makeIfInfo(block);
}
// skip this block and search in successors chain
List<BlockNode> successors = block.getSuccessors();
Expand All @@ -358,6 +382,7 @@ private static BlockNode getNextIfNode(BlockNode block) {
}
List<InsnNode> insns = block.getInstructions();
boolean pass = true;
List<InsnNode> forceInlineInsns = new ArrayList<>();
if (!insns.isEmpty()) {
// check that all instructions can be inlined
for (InsnNode insn : insns) {
Expand All @@ -367,7 +392,9 @@ private static BlockNode getNextIfNode(BlockNode block) {
break;
}
List<RegisterArg> useList = res.getSVar().getUseList();
if (useList.size() != 1) {
int useCount = useList.size();
if (useCount == 0) {
// TODO?
pass = false;
break;
}
Expand All @@ -378,12 +405,20 @@ private static BlockNode getNextIfNode(BlockNode block) {
pass = false;
break;
}
if (useCount > 1) {
forceInlineInsns.add(insn);
}
}
}
if (pass) {
return getNextIfNode(next);
if (!pass) {
return null;
}
return null;
IfInfo nextInfo = makeIfInfo(next);
if (nextInfo == null) {
return getNextIfNodeInfo(info, next);
}
nextInfo.addInsnsForForcedInline(forceInlineInsns);
return nextInfo;
}

private static void skipSimplePath(BlockNode block, Set<BlockNode> skipped) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ private BlockNode processIf(IRegion currentRegion, BlockNode block, IfNode ifnod
}

IfInfo currentIf = makeIfInfo(block);
if (currentIf == null) {
return null;
}
IfInfo mergedIf = mergeNestedIfNodes(currentIf);
if (mergedIf != null) {
currentIf = mergedIf;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jadx.core.dex.visitors.ModVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnList;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.exceptions.JadxRuntimeException;

@JadxVisitor(
Expand Down Expand Up @@ -78,12 +79,15 @@ private static void checkInline(MethodNode mth, BlockNode block, InsnList insnLi
if (sVar == null || sVar.getAssign().contains(AFlag.DONT_INLINE)) {
return;
}
// allow inline only one use arg
if (sVar.getVariableUseCount() != 1) {
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null
|| assignInsn.contains(AFlag.DONT_INLINE)
|| assignInsn.contains(AFlag.WRAPPED)) {
return;
}
InsnNode assignInsn = sVar.getAssign().getParentInsn();
if (assignInsn == null || assignInsn.contains(AFlag.DONT_INLINE)) {
// allow inline only one use arg
boolean assignInline = assignInsn.contains(AFlag.FORCE_ASSIGN_INLINE);
if (!assignInline && sVar.getVariableUseCount() != 1) {
return;
}
List<RegisterArg> useList = sVar.getUseList();
Expand All @@ -96,6 +100,10 @@ private static void checkInline(MethodNode mth, BlockNode block, InsnList insnLi

int assignPos = insnList.getIndex(assignInsn);
if (assignPos != -1) {
if (assignInline) {
// TODO?
return;
}
WrapInfo wrapInfo = argsInfo.checkInline(assignPos, arg);
if (wrapInfo != null) {
wrapList.add(wrapInfo);
Expand All @@ -106,11 +114,30 @@ private static void checkInline(MethodNode mth, BlockNode block, InsnList insnLi
if (assignBlock != null
&& assignInsn != arg.getParentInsn()
&& canMoveBetweenBlocks(assignInsn, assignBlock, block, argsInfo.getInsn())) {
inline(mth, arg, assignInsn, assignBlock);
if (assignInline) {
assignInline(mth, arg, assignInsn, assignBlock);
} else {
inline(mth, arg, assignInsn, assignBlock);
}
}
}
}

private static void assignInline(MethodNode mth, RegisterArg arg, InsnNode assignInsn, BlockNode assignBlock) {
RegisterArg useArg = arg.getSVar().getUseList().get(0);
InsnNode useInsn = useArg.getParentInsn();
if (useInsn == null || useInsn.contains(AFlag.DONT_GENERATE)) {
return;
}

InsnArg replaceArg = InsnArg.wrapArg(assignInsn.copy());
useInsn.replaceArg(useArg, replaceArg);

assignInsn.add(AFlag.REMOVE);
assignInsn.add(AFlag.DONT_GENERATE);
InsnRemover.remove(mth, assignBlock, assignInsn);
}

private static boolean inline(MethodNode mth, RegisterArg arg, InsnNode insn, BlockNode block) {
InsnNode parentInsn = arg.getParentInsn();
if (parentInsn != null && parentInsn.getType() == InsnType.RETURN) {
Expand Down
Loading

0 comments on commit d1a6841

Please sign in to comment.