Skip to content

Commit

Permalink
fix: restore enum class with custom code in static init (#1699)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Oct 8, 2022
1 parent 683c2df commit 620a177
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 68 deletions.
152 changes: 84 additions & 68 deletions jadx-core/src/main/java/jadx/core/dex/visitors/EnumVisitor.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import jadx.core.dex.nodes.BlockNode;
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.IContainer;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.dex.nodes.RootNode;
Expand All @@ -48,7 +49,6 @@
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.ListUtils;
import jadx.core.utils.Utils;
import jadx.core.utils.exceptions.JadxException;
import jadx.core.utils.exceptions.JadxRuntimeException;
Expand Down Expand Up @@ -94,27 +94,26 @@ public void init(RootNode root) {

@Override
public boolean visit(ClassNode cls) throws JadxException {
boolean converted;
try {
converted = convertToEnum(cls);
} catch (Exception e) {
cls.addWarnComment("Enum visitor error", e);
converted = false;
}
if (!converted) {
AccessInfo accessFlags = cls.getAccessFlags();
if (accessFlags.isEnum()) {
cls.setAccessFlags(accessFlags.remove(AccessFlags.ENUM));
cls.addWarnComment("Failed to restore enum class, 'enum' modifier and super class removed");
if (cls.isEnum()) {
boolean converted;
try {
converted = convertToEnum(cls);
} catch (Exception e) {
cls.addWarnComment("Enum visitor error", e);
converted = false;
}
if (!converted) {
AccessInfo accessFlags = cls.getAccessFlags();
if (accessFlags.isEnum()) {
cls.setAccessFlags(accessFlags.remove(AccessFlags.ENUM));
cls.addWarnComment("Failed to restore enum class, 'enum' modifier and super class removed");
}
}
}
return true;
}

private boolean convertToEnum(ClassNode cls) {
if (!cls.isEnum()) {
return false;
}
ArgType superType = cls.getSuperClass();
if (superType != null && superType.getObject().equals(ArgType.ENUM.getObject())) {
cls.add(AFlag.REMOVE_SUPER_CLASS);
Expand All @@ -128,60 +127,34 @@ private boolean convertToEnum(ClassNode cls) {
if (staticRegion == null || classInitMth.getBasicBlocks().isEmpty()) {
return false;
}
if (!ListUtils.allMatch(staticRegion.getSubBlocks(), BlockNode.class::isInstance)) {
cls.addWarnComment("Unexpected branching instructions in enum static init block");
return false;
}
List<BlockNode> staticBlocks = ListUtils.map(staticRegion.getSubBlocks(), BlockNode.class::cast);
ArgType clsType = cls.getClassInfo().getType();

// search "$VALUES" field (holds all enum values)
List<FieldNode> valuesCandidates = cls.getFields().stream()
.filter(f -> f.getAccessFlags().isStatic())
.filter(f -> f.getType().isArray())
.filter(f -> Objects.equals(f.getType().getArrayRootElement(), clsType))
.collect(Collectors.toList());

if (valuesCandidates.isEmpty()) {
return false;
}
if (valuesCandidates.size() > 1) {
valuesCandidates.removeIf(f -> !f.getAccessFlags().isSynthetic());
}
if (valuesCandidates.size() > 1) {
Optional<FieldNode> valuesOpt = valuesCandidates.stream().filter(f -> f.getName().equals("$VALUES")).findAny();
if (valuesOpt.isPresent()) {
valuesCandidates.clear();
valuesCandidates.add(valuesOpt.get());
// collect blocks on linear part of static method (ignore branching on method end)
List<BlockNode> staticBlocks = new ArrayList<>();
for (IContainer subBlock : staticRegion.getSubBlocks()) {
if (subBlock instanceof BlockNode) {
staticBlocks.add((BlockNode) subBlock);
} else {
break;
}
}
if (valuesCandidates.size() != 1) {
cls.addWarnComment("Found several \"values\" enum fields: " + valuesCandidates);
if (staticBlocks.isEmpty()) {
cls.addWarnComment("Unexpected branching in enum static init block");
return false;
}
FieldNode valuesField = valuesCandidates.get(0);

// search "$VALUES" array init and collect enum fields
BlockInsnPair valuesInitPair = getValuesInitInsn(classInitMth, valuesField);
if (valuesInitPair == null) {
EnumData data = new EnumData(cls, classInitMth, staticBlocks);
if (!searchValuesField(data)) {
return false;
}
BlockNode staticBlock = valuesInitPair.getBlock();
InsnNode valuesInitInsn = valuesInitPair.getInsn();

EnumData enumData = new EnumData(cls, valuesField, staticBlocks);

List<EnumField> enumFields = null;
InsnArg arrArg = valuesInitInsn.getArg(0);
InsnArg arrArg = data.valuesInitInsn.getArg(0);
if (arrArg.isInsnWrap()) {
InsnNode wrappedInsn = ((InsnWrapArg) arrArg).getWrapInsn();
enumFields = extractEnumFieldsFromInsn(enumData, wrappedInsn);
enumFields = extractEnumFieldsFromInsn(data, wrappedInsn);
}
if (enumFields == null) {
cls.addWarnComment("Unknown enum class pattern. Please report as an issue!");
return false;
}
enumData.toRemove.add(valuesInitInsn);
data.toRemove.add(data.valuesInitInsn);

// all checks complete, perform transform
EnumClassAttr attr = new EnumClassAttr(enumFields);
Expand All @@ -201,16 +174,56 @@ private boolean convertToEnum(ClassNode cls) {
fieldNode.getFieldInfo().setAlias(name);
}
fieldNode.add(AFlag.DONT_GENERATE);
processConstructorInsn(enumData, enumField, classInitMth);
processConstructorInsn(data, enumField, classInitMth);
}
valuesField.add(AFlag.DONT_GENERATE);
InsnRemover.removeAllAndUnbind(classInitMth, enumData.toRemove);
data.valuesField.add(AFlag.DONT_GENERATE);
InsnRemover.removeAllAndUnbind(classInitMth, data.toRemove);
if (classInitMth.countInsns() == 0) {
classInitMth.add(AFlag.DONT_GENERATE);
} else if (!enumData.toRemove.isEmpty()) {
} else if (!data.toRemove.isEmpty()) {
CodeShrinkVisitor.shrinkMethod(classInitMth);
}
removeEnumMethods(cls, clsType, valuesField);
removeEnumMethods(cls, data.valuesField);
return true;
}

/**
* Search "$VALUES" field (holds all enum values)
*/
private boolean searchValuesField(EnumData data) {
ArgType clsType = data.cls.getClassInfo().getType();
List<FieldNode> valuesCandidates = data.cls.getFields().stream()
.filter(f -> f.getAccessFlags().isStatic())
.filter(f -> f.getType().isArray())
.filter(f -> Objects.equals(f.getType().getArrayRootElement(), clsType))
.collect(Collectors.toList());

if (valuesCandidates.isEmpty()) {
data.cls.addWarnComment("$VALUES field not found");
return false;
}
if (valuesCandidates.size() > 1) {
valuesCandidates.removeIf(f -> !f.getAccessFlags().isSynthetic());
}
if (valuesCandidates.size() > 1) {
Optional<FieldNode> valuesOpt = valuesCandidates.stream().filter(f -> f.getName().equals("$VALUES")).findAny();
if (valuesOpt.isPresent()) {
valuesCandidates.clear();
valuesCandidates.add(valuesOpt.get());
}
}
if (valuesCandidates.size() != 1) {
data.cls.addWarnComment("Found several \"values\" enum fields: " + valuesCandidates);
return false;
}
data.valuesField = valuesCandidates.get(0);

// search "$VALUES" array init and collect enum fields
BlockInsnPair valuesInitPair = getValuesInitInsn(data);
if (valuesInitPair == null) {
return false;
}
data.valuesInitInsn = valuesInitPair.getInsn();
return true;
}

Expand Down Expand Up @@ -280,9 +293,9 @@ private List<EnumField> extractEnumFieldsFromInvoke(EnumData enumData, InvokeNod
return enumFields;
}

private BlockInsnPair getValuesInitInsn(MethodNode classInitMth, FieldNode valuesField) {
FieldInfo searchField = valuesField.getFieldInfo();
for (BlockNode blockNode : classInitMth.getBasicBlocks()) {
private BlockInsnPair getValuesInitInsn(EnumData data) {
FieldInfo searchField = data.valuesField.getFieldInfo();
for (BlockNode blockNode : data.staticBlocks) {
for (InsnNode insn : blockNode.getInstructions()) {
if (insn.getType() == InsnType.SPUT) {
IndexInsnNode indexInsnNode = (IndexInsnNode) insn;
Expand Down Expand Up @@ -449,7 +462,8 @@ private InsnNode searchFieldPutInsn(EnumData data, FieldNode enumFieldNode) {
return null;
}

private void removeEnumMethods(ClassNode cls, ArgType clsType, FieldNode valuesField) {
private void removeEnumMethods(ClassNode cls, FieldNode valuesField) {
ArgType clsType = cls.getClassInfo().getType();
String valuesMethodShortId = "values()" + TypeGen.signature(ArgType.array(clsType));
MethodNode valuesMethod = null;
// remove compiler generated methods
Expand Down Expand Up @@ -631,13 +645,15 @@ private String getConstString(RootNode root, InsnArg arg) {

private static class EnumData {
final ClassNode cls;
final FieldNode valuesField;
final MethodNode classInitMth;
final List<BlockNode> staticBlocks;
final List<InsnNode> toRemove = new ArrayList<>();
FieldNode valuesField;
InsnNode valuesInitInsn;

public EnumData(ClassNode cls, FieldNode valuesField, List<BlockNode> staticBlocks) {
public EnumData(ClassNode cls, MethodNode classInitMth, List<BlockNode> staticBlocks) {
this.cls = cls;
this.valuesField = valuesField;
this.classInitMth = classInitMth;
this.staticBlocks = staticBlocks;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package jadx.tests.integration.enums;

import java.util.HashMap;
import java.util.Map;

import org.junit.jupiter.api.Test;

import jadx.tests.api.IntegrationTest;

import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;

public class TestEnumsWithCustomInit extends IntegrationTest {

public enum TestCls {
ONE("I"),
TWO("II"),
THREE("III");

public static final Map<String, TestCls> MAP = new HashMap<>();

static {
for (TestCls value : values()) {
MAP.put(value.toString(), value);
}
}

private final String str;

TestCls(String str) {
this.str = str;
}

public String toString() {
return str;
}
}

@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("ONE(\"I\"),")
.doesNotContain("new TestEnumsWithCustomInit$TestCls(");
}
}

0 comments on commit 620a177

Please sign in to comment.