From 60b2353afe9993257d49fb9c1e2449fd4a7a405f Mon Sep 17 00:00:00 2001 From: Skylot Date: Fri, 11 Sep 2020 19:29:55 +0100 Subject: [PATCH] fix: adjust types for arithmetic instructions (#921) --- .../jadx/core/dex/instructions/ArithNode.java | 77 +++++++++++-------- .../jadx/core/dex/instructions/ArithOp.java | 11 +++ .../core/dex/instructions/InsnDecoder.java | 12 +-- .../core/dex/instructions/args/ArgType.java | 1 + .../visitors/typeinference/TypeUpdate.java | 17 +++- .../tests/integration/arith/TestArith4.java | 37 +++++++++ .../jadx/tests/integration/arith/TestXor.java | 63 +++++++++++++++ .../tests/integration/conditions/TestXor.java | 61 --------------- .../smali/{conditions => arith}/TestXor.smali | 18 +---- 9 files changed, 176 insertions(+), 121 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/arith/TestArith4.java create mode 100644 jadx-core/src/test/java/jadx/tests/integration/arith/TestXor.java delete mode 100644 jadx-core/src/test/java/jadx/tests/integration/conditions/TestXor.java rename jadx-core/src/test/smali/{conditions => arith}/TestXor.smali (51%) diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/ArithNode.java b/jadx-core/src/main/java/jadx/core/dex/instructions/ArithNode.java index a360a6678bb..23ee39515ff 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/ArithNode.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/ArithNode.java @@ -1,5 +1,7 @@ package jadx.core.dex.instructions; +import org.jetbrains.annotations.Nullable; + import jadx.api.plugins.input.insns.InsnData; import jadx.core.dex.attributes.AFlag; import jadx.core.dex.instructions.args.ArgType; @@ -8,41 +10,54 @@ import jadx.core.dex.instructions.args.RegisterArg; import jadx.core.dex.nodes.InsnNode; import jadx.core.utils.InsnUtils; +import jadx.core.utils.exceptions.JadxRuntimeException; public class ArithNode extends InsnNode { - private final ArithOp op; + public static ArithNode build(InsnData insn, ArithOp op, ArgType type) { + RegisterArg resArg = InsnArg.reg(insn, 0, fixResultType(op, type)); + ArgType argType = fixArgType(op, type); + switch (insn.getRegsCount()) { + case 2: + return new ArithNode(op, resArg, InsnArg.reg(insn, 0, argType), InsnArg.reg(insn, 1, argType)); + case 3: + return new ArithNode(op, resArg, InsnArg.reg(insn, 1, argType), InsnArg.reg(insn, 2, argType)); + default: + throw new JadxRuntimeException("Unexpected registers count in " + insn); + } + } - public ArithNode(InsnData insn, ArithOp op, ArgType type, boolean literal) { - super(InsnType.ARITH, 2); - this.op = op; - setResult(InsnArg.reg(insn, 0, type)); + public static ArithNode buildLit(InsnData insn, ArithOp op, ArgType type) { + RegisterArg resArg = InsnArg.reg(insn, 0, fixResultType(op, type)); + ArgType argType = fixArgType(op, type); + LiteralArg litArg = InsnArg.lit(insn, argType); + switch (insn.getRegsCount()) { + case 1: + return new ArithNode(op, resArg, InsnArg.reg(insn, 0, argType), litArg); + case 2: + return new ArithNode(op, resArg, InsnArg.reg(insn, 1, argType), litArg); + default: + throw new JadxRuntimeException("Unexpected registers count in " + insn); + } + } - int rc = insn.getRegsCount(); - if (literal) { - if (rc == 1) { - // self - addReg(insn, 0, type); - addLit(insn, type); - } else if (rc == 2) { - // normal - addReg(insn, 1, type); - addLit(insn, type); - } - } else { - if (rc == 2) { - // self - addReg(insn, 0, type); - addReg(insn, 1, type); - } else if (rc == 3) { - // normal - addReg(insn, 1, type); - addReg(insn, 2, type); - } + private static ArgType fixResultType(ArithOp op, ArgType type) { + if (type == ArgType.INT && op.isBitOp()) { + return ArgType.INT_BOOLEAN; } + return type; } - public ArithNode(ArithOp op, RegisterArg res, InsnArg a, InsnArg b) { + private static ArgType fixArgType(ArithOp op, ArgType type) { + if (type == ArgType.INT && op.isBitOp()) { + return ArgType.NARROW_NUMBERS_NO_FLOAT; + } + return type; + } + + private final ArithOp op; + + public ArithNode(ArithOp op, @Nullable RegisterArg res, InsnArg a, InsnArg b) { super(InsnType.ARITH, 2); this.op = op; setResult(res); @@ -50,10 +65,6 @@ public ArithNode(ArithOp op, RegisterArg res, InsnArg a, InsnArg b) { addArg(b); } - public ArithNode(ArithOp op, InsnArg a, InsnArg b) { - this(op, null, a, b); - } - /** * Create one argument arithmetic instructions (a+=2). * Result is not set (null). @@ -61,7 +72,7 @@ public ArithNode(ArithOp op, InsnArg a, InsnArg b) { * @param res argument to change */ public static ArithNode oneArgOp(ArithOp op, InsnArg res, InsnArg a) { - ArithNode insn = new ArithNode(op, res, a); + ArithNode insn = new ArithNode(op, null, res, a); insn.add(AFlag.ARITH_ONEARG); return insn; } @@ -100,7 +111,7 @@ private boolean isSameLiteral(ArithNode other) { @Override public InsnNode copy() { - ArithNode copy = new ArithNode(op, getArg(0).duplicate(), getArg(1).duplicate()); + ArithNode copy = new ArithNode(op, null, getArg(0).duplicate(), getArg(1).duplicate()); return copyCommonParams(copy); } diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/ArithOp.java b/jadx-core/src/main/java/jadx/core/dex/instructions/ArithOp.java index 7e4ce9648ef..f34e6bb22ba 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/ArithOp.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/ArithOp.java @@ -24,4 +24,15 @@ public enum ArithOp { public String getSymbol() { return this.symbol; } + + public boolean isBitOp() { + switch (this) { + case AND: + case OR: + case XOR: + return true; + default: + return false; + } + } } diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java b/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java index 3a0b798f3fd..0c54bf5b7a9 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/InsnDecoder.java @@ -542,19 +542,11 @@ private InsnNode arrayPut(InsnData insn, ArgType argType) { } private InsnNode arith(InsnData insn, ArithOp op, ArgType type) { - return new ArithNode(insn, op, fixTypeForBitOps(op, type), false); + return ArithNode.build(insn, op, type); } private InsnNode arithLit(InsnData insn, ArithOp op, ArgType type) { - return new ArithNode(insn, op, fixTypeForBitOps(op, type), true); - } - - private ArgType fixTypeForBitOps(ArithOp op, ArgType type) { - if (type == ArgType.INT - && (op == ArithOp.AND || op == ArithOp.OR || op == ArithOp.XOR)) { - return ArgType.NARROW_NUMBERS_NO_FLOAT; - } - return type; + return ArithNode.buildLit(insn, op, type); } private InsnNode neg(InsnData insn, ArgType type) { diff --git a/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java b/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java index b68df3463fd..aa3b8c6186a 100644 --- a/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java +++ b/jadx-core/src/main/java/jadx/core/dex/instructions/args/ArgType.java @@ -65,6 +65,7 @@ public abstract class ArgType { public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE); public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT); + public static final ArgType INT_BOOLEAN = unknown(PrimitiveType.INT, PrimitiveType.BOOLEAN); protected int hash; diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java index 074faf84532..b2a9266744f 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/typeinference/TypeUpdate.java @@ -14,6 +14,7 @@ import org.slf4j.LoggerFactory; import jadx.core.Consts; +import jadx.core.dex.instructions.ArithNode; import jadx.core.dex.instructions.BaseInvokeNode; import jadx.core.dex.instructions.IndexInsnNode; import jadx.core.dex.instructions.InsnType; @@ -286,7 +287,7 @@ private Map initListenerRegistry() { registry.put(InsnType.AGET, this::arrayGetListener); registry.put(InsnType.APUT, this::arrayPutListener); registry.put(InsnType.IF, this::ifListener); - registry.put(InsnType.ARITH, this::suggestAllSameListener); + registry.put(InsnType.ARITH, this::arithListener); registry.put(InsnType.NEG, this::suggestAllSameListener); registry.put(InsnType.NOT, this::suggestAllSameListener); registry.put(InsnType.CHECK_CAST, this::checkCastListener); @@ -441,12 +442,24 @@ private TypeUpdateResult allSameListener(TypeUpdateInfo updateInfo, InsnNode ins return allSame ? SAME : CHANGED; } + private TypeUpdateResult arithListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { + ArithNode arithInsn = (ArithNode) insn; + if (candidateType == ArgType.BOOLEAN && arithInsn.getOp().isBitOp()) { + // force all args to boolean + return allSameListener(updateInfo, insn, arg, candidateType); + } + return suggestAllSameListener(updateInfo, insn, arg, candidateType); + } + /** * Try to set candidate type to all args, don't fail on reject */ private TypeUpdateResult suggestAllSameListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) { if (!isAssign(insn, arg)) { - updateTypeChecked(updateInfo, insn.getResult(), candidateType); + RegisterArg resultArg = insn.getResult(); + if (resultArg != null) { + updateTypeChecked(updateInfo, resultArg, candidateType); + } } boolean allSame = true; for (InsnArg insnArg : insn.getArguments()) { diff --git a/jadx-core/src/test/java/jadx/tests/integration/arith/TestArith4.java b/jadx-core/src/test/java/jadx/tests/integration/arith/TestArith4.java new file mode 100644 index 00000000000..430abccb1bc --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/arith/TestArith4.java @@ -0,0 +1,37 @@ +package jadx.tests.integration.arith; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.IntegrationTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestArith4 extends IntegrationTest { + + public static class TestCls { + public static byte test(byte b) { + int k = b & 7; + return (byte) (((b & 255) >>> (8 - k)) | (b << k)); + } + + public static int test2(String str) { + int k = 'a' | str.charAt(0); + return (1 - k) & (1 + k); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .containsOne("int k = b & 7;") + .containsOne("return (1 - k) & (k + 1);"); + } + + @Test + public void testNoDebug() { + noDebugInfo(); + assertThat(getClassNode(TestCls.class)) + .code(); + } +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/arith/TestXor.java b/jadx-core/src/test/java/jadx/tests/integration/arith/TestXor.java new file mode 100644 index 00000000000..03d75c037fc --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/arith/TestXor.java @@ -0,0 +1,63 @@ +package jadx.tests.integration.arith; + +import org.junit.jupiter.api.Test; + +import jadx.tests.api.SmaliTest; + +import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat; + +public class TestXor extends SmaliTest { + + @SuppressWarnings("PointlessBooleanExpression") + public static class TestCls { + public boolean test1() { + return test() ^ true; + } + + public boolean test2(boolean v) { + return v ^ true; + } + + public boolean test() { + return true; + } + + public void check() { + assertThat(test1()).isFalse(); + assertThat(test2(true)).isFalse(); + assertThat(test2(false)).isTrue(); + } + } + + @Test + public void test() { + assertThat(getClassNode(TestCls.class)) + .code() + .containsOne("return !test();") + .containsOne("return !v;"); + } + + @Test + public void smali() { + // @formatter:off + /* + public boolean test1() { + return test() ^ true; + } + + public boolean test2() { + return test() ^ false; + } + + public boolean test() { + return true; + } + */ + // @formatter:on + assertThat(getClassNodeFromSmali()) + .code() + .containsOne("return !test();") + .containsOne("return test();"); + } + +} diff --git a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestXor.java b/jadx-core/src/test/java/jadx/tests/integration/conditions/TestXor.java deleted file mode 100644 index 7375a6a4040..00000000000 --- a/jadx-core/src/test/java/jadx/tests/integration/conditions/TestXor.java +++ /dev/null @@ -1,61 +0,0 @@ -package jadx.tests.integration.conditions; - -import org.junit.jupiter.api.Test; - -import jadx.core.dex.nodes.ClassNode; -import jadx.tests.api.SmaliTest; - -import static jadx.tests.api.utils.JadxMatchers.containsOne; -import static org.hamcrest.MatcherAssert.assertThat; - -public class TestXor extends SmaliTest { - - public static class TestCls { - public boolean test1() { - return test() ^ true; - } - - public boolean test2(boolean v) { - return v ^ true; - } - - public boolean test() { - return true; - } - } - - @Test - public void test() { - ClassNode cls = getClassNode(TestCls.class); - String code = cls.getCode().toString(); - - assertThat(code, containsOne("return !test();")); - assertThat(code, containsOne("return !v;")); - } - - @Test - public void smali() { - // @formatter:off - /* - public boolean test1() { - return test() ^ true; - } - - public boolean test2() { - return test() ^ false; - } - - public boolean test() { - return true; - } - */ - // @formatter:on - - ClassNode cls = getClassNodeFromSmaliWithPath("conditions", "TestXor"); - String code = cls.getCode().toString(); - - assertThat(code, containsOne("return !test();")); - assertThat(code, containsOne("return test();")); - } - -} diff --git a/jadx-core/src/test/smali/conditions/TestXor.smali b/jadx-core/src/test/smali/arith/TestXor.smali similarity index 51% rename from jadx-core/src/test/smali/conditions/TestXor.smali rename to jadx-core/src/test/smali/arith/TestXor.smali index 16e8aee0a95..2cd375526cd 100644 --- a/jadx-core/src/test/smali/conditions/TestXor.smali +++ b/jadx-core/src/test/smali/arith/TestXor.smali @@ -1,19 +1,7 @@ -.class public LTestXor; +.class public Larith/TestXor; .super Ljava/lang/Object; -# direct methods -.method public constructor ()V - .locals 0 - - .line 9 - invoke-direct {p0}, Ljava/lang/Object;->()V - - return-void -.end method - - -# virtual methods .method public test()Z .locals 1 @@ -27,7 +15,7 @@ .locals 1 .line 12 - invoke-virtual {p0}, Lcom/example/myapplication/MainActivity;->test()Z + invoke-virtual {p0}, Larith/TestXor;->test()Z move-result v0 @@ -40,7 +28,7 @@ .locals 1 .line 16 - invoke-virtual {p0}, Lcom/example/myapplication/MainActivity;->test()Z + invoke-virtual {p0}, Larith/TestXor;->test()Z move-result v0