Skip to content

Commit

Permalink
fix: additional casts at use place to help type inference (#1002)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Oct 31, 2020
1 parent a22efc2 commit 2b7d7ce
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public enum AFlag {
*/
EXPLICIT_PRIMITIVE_TYPE,
EXPLICIT_CAST,
SOFT_CAST, // synthetic cast to help type inference
SOFT_CAST, // synthetic cast to help type inference (allow unchecked casts for generics)

INCONSISTENT_CODE, // warning about incorrect decompilation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ public static InsnArg wrapArg(InsnNode insn) {
return arg;
}

public boolean isZeroLiteral() {
return isLiteral() && (((LiteralArg) this)).getLiteral() == 0;
}

public boolean isThis() {
return contains(AFlag.THIS);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.InsnWrapArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.instructions.args.SSAVar;
import jadx.core.dex.instructions.mods.ConstructorInsn;
Expand Down Expand Up @@ -209,7 +208,7 @@ private List<EnumField> extractEnumFieldsFromInsn(ClassNode cls, BlockNode stati

case NEW_ARRAY:
InsnArg arg = wrappedInsn.getArg(0);
if (arg.isLiteral() && ((LiteralArg) arg).getLiteral() == 0) {
if (arg.isZeroLiteral()) {
// empty enum
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,7 @@ private static void simplifyIf(IfNode insn) {
if (f.isInsnWrap()) {
InsnNode wi = ((InsnWrapArg) f).getWrapInsn();
if (wi.getType() == InsnType.CMP_L || wi.getType() == InsnType.CMP_G) {
if (insn.getArg(1).isLiteral()
&& ((LiteralArg) insn.getArg(1)).getLiteral() == 0) {
if (insn.getArg(1).isZeroLiteral()) {
insn.changeCondition(insn.getOp(), wi.getArg(0), wi.getArg(1));
} else {
LOG.warn("TODO: cmp {}", insn);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Set;
import java.util.function.Function;

import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -332,6 +333,10 @@ private ITypeBound makeUseBound(RegisterArg regArg) {
return invokeUseBound;
}
}
if (insn.getType() == InsnType.CHECK_CAST && insn.contains(AFlag.SOFT_CAST)) {
// ignore
return null;
}
return new TypeBoundConst(BoundEnum.USE, regArg.getInitType(), regArg);
}

Expand Down Expand Up @@ -499,20 +504,30 @@ private int tryInsertVarCast(MethodNode mth, SSAVar var) {
if (insertAssignCast(mth, var, boundType)) {
return 1;
}
// TODO: check if use casts are needed
return 0;
return insertUseCasts(mth, var);
}
}
return 0;
}

private int insertUseCasts(MethodNode mth, SSAVar var) {
List<RegisterArg> useList = var.getUseList();
if (useList.isEmpty()) {
return 0;
}
int useCasts = 0;
for (RegisterArg useReg : new ArrayList<>(useList)) {
if (insertSoftUseCast(mth, useReg)) {
useCasts++;
}
}
return useCasts;
}

private boolean insertAssignCast(MethodNode mth, SSAVar var, ArgType castType) {
RegisterArg assignArg = var.getAssign();
InsnNode assignInsn = assignArg.getParentInsn();
if (assignInsn == null) {
return false;
}
if (assignInsn.getType() == InsnType.PHI) {
if (assignInsn == null || assignInsn.getType() == InsnType.PHI) {
return false;
}
BlockNode assignBlock = BlockUtils.getBlockByInsn(mth, assignInsn);
Expand All @@ -521,14 +536,38 @@ private boolean insertAssignCast(MethodNode mth, SSAVar var, ArgType castType) {
}
RegisterArg newAssignArg = assignArg.duplicateWithNewSSAVar(mth);
assignInsn.setResult(newAssignArg);
IndexInsnNode castInsn = makeSoftCastInsn(assignArg, newAssignArg, castType);
return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn);
}

private boolean insertSoftUseCast(MethodNode mth, RegisterArg useArg) {
InsnNode useInsn = useArg.getParentInsn();
if (useInsn == null || useInsn.getType() == InsnType.PHI) {
return false;
}
if (useInsn.getType() == InsnType.IF && useInsn.getArg(1).isZeroLiteral()) {
// cast not needed if compare with null
return false;
}
BlockNode useBlock = BlockUtils.getBlockByInsn(mth, useInsn);
if (useBlock == null) {
return false;
}
RegisterArg newUseArg = useArg.duplicateWithNewSSAVar(mth);
useInsn.replaceArg(useArg, newUseArg);

IndexInsnNode castInsn = makeSoftCastInsn(newUseArg, useArg, useArg.getInitType());
return BlockUtils.insertBeforeInsn(useBlock, useInsn, castInsn);
}

@NotNull
private IndexInsnNode makeSoftCastInsn(RegisterArg result, RegisterArg arg, ArgType castType) {
IndexInsnNode castInsn = new IndexInsnNode(InsnType.CHECK_CAST, castType, 1);
castInsn.setResult(assignArg.duplicate());
castInsn.addArg(newAssignArg.duplicate());
castInsn.setResult(result.duplicate());
castInsn.addArg(arg.duplicate());
castInsn.add(AFlag.SOFT_CAST);
castInsn.add(AFlag.SYNTHETIC);

return BlockUtils.insertAfterInsn(assignBlock, assignInsn, castInsn);
return castInsn;
}

private boolean trySplitConstInsns(MethodNode mth) {
Expand Down
26 changes: 21 additions & 5 deletions jadx-core/src/main/java/jadx/core/utils/BlockUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -670,17 +670,33 @@ public static boolean replaceInsn(MethodNode mth, BlockNode block, InsnNode oldI
return false;
}

public static boolean insertBeforeInsn(BlockNode block, InsnNode insn, InsnNode newInsn) {
int index = getInsnIndexInBlock(block, insn);
if (index == -1) {
return false;
}
block.getInstructions().add(index, newInsn);
return true;
}

public static boolean insertAfterInsn(BlockNode block, InsnNode insn, InsnNode newInsn) {
int index = getInsnIndexInBlock(block, insn);
if (index == -1) {
return false;
}
block.getInstructions().add(index + 1, newInsn);
return true;
}

public static int getInsnIndexInBlock(BlockNode block, InsnNode insn) {
List<InsnNode> instructions = block.getInstructions();
int size = instructions.size();
for (int i = 0; i < size; i++) {
InsnNode instruction = instructions.get(i);
if (instruction == insn) {
instructions.add(i + 1, newInsn);
return true;
if (instructions.get(i) == insn) {
return i;
}
}
return false;
return -1;
}

public static boolean replaceInsn(MethodNode mth, InsnNode oldInsn, InsnNode newInsn) {
Expand Down
4 changes: 4 additions & 0 deletions jadx-core/src/main/java/jadx/core/utils/DebugUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -169,4 +169,8 @@ public static void printMap(Map<?, ?> map, String desc) {
LOG.debug(" {}: {}", entry.getKey(), entry.getValue());
}
}

public static void printStackTrace(String label) {
LOG.debug("StackTrace: {}\n{}", label, Utils.getStackTrace(new Exception()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package jadx.tests.integration.types;

import org.junit.jupiter.api.Test;

import jadx.tests.api.SmaliTest;

import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;

/**
* Issue 1002
* Insertion of additional cast (at use place) needed for successful type inference
*/
public class TestTypeResolver16 extends SmaliTest {
// @formatter:off
/*
public final <T, K> List<T> test(List<? extends T> list, Set<? extends T> set, Function<? super T, ? extends K> function) {
checkParameterIsNotNull(function, "distinctBy");
if (set != null) {
List<? extends T> union = list != null ? union(list, set, function) : null;
if (union != null) {
list = union;
}
}
return list != null ? (List<T>) list : emptyList();
}
*/
// @formatter:on

@Test
public void test() {
assertThat(getClassNodeFromSmali())
.code()
.containsOne("(List<T>) list");
}
}
111 changes: 111 additions & 0 deletions jadx-core/src/test/smali/types/TestTypeResolver16.smali
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
.class public Ltypes/TestTypeResolver16;
.super Ljava/lang/Object;

.method public final test(Ljava/util/List;Ljava/util/Set;Ljava/util/function/Function;)Ljava/util/List;
.locals 1
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
"K:",
"Ljava/lang/Object;",
">(",
"Ljava/util/List<",
"+TT;>;",
"Ljava/util/Set<",
"+TT;>;",
"Ljava/util/function/Function<",
"-TT;+TK;>;)",
"Ljava/util/List<",
"TT;>;"
}
.end annotation

const-string v0, "distinctBy"

invoke-static {p3, v0}, Ltypes/TestTypeResolver16;->checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V

if-eqz p2, :cond_1

if-eqz p1, :cond_0

.line 85
move-object v0, p1

check-cast v0, Ljava/util/Collection;

check-cast p2, Ljava/lang/Iterable;

invoke-static {v0, p2, p3}, Ltypes/TestTypeResolver16;->union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List;

move-result-object p2

goto :goto_0

:cond_0
const/4 p2, 0x0

:goto_0
if-eqz p2, :cond_1

move-object p1, p2

:cond_1
if-eqz p1, :cond_2

goto :goto_1

:cond_2
invoke-static {}, Ltypes/TestTypeResolver16;->emptyList()Ljava/util/List;

move-result-object p1

:goto_1
return-object p1
.end method


.method public static final union(Ljava/util/Collection;Ljava/lang/Iterable;Ljava/util/function/Function;)Ljava/util/List;
.locals 4
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
"K:",
"Ljava/lang/Object;",
">(",
"Ljava/util/Collection<",
"+TT;>;",
"Ljava/lang/Iterable<",
"+TT;>;",
"Ljava/util/function/Function<",
"-TT;+TK;>;)",
"Ljava/util/List<",
"TT;>;"
}
.end annotation

const/4 v0, 0x0
return-object v0
.end method

.method public static checkParameterIsNotNull(Ljava/lang/Object;Ljava/lang/String;)V
.locals 0
return-void
.end method

.method public static final emptyList()Ljava/util/List;
.locals 1
.annotation system Ldalvik/annotation/Signature;
value = {
"<T:",
"Ljava/lang/Object;",
">()",
"Ljava/util/List<",
"TT;>;"
}
.end annotation

const/4 v0, 0x0
return-object v0
.end method

0 comments on commit 2b7d7ce

Please sign in to comment.