Skip to content

Commit

Permalink
fix: improve usage search, refactor java nodes creation (#1489)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed May 27, 2022
1 parent 1df217c commit cb741db
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 162 deletions.
120 changes: 21 additions & 99 deletions jadx-core/src/main/java/jadx/api/JadxDecompiler.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public final class JadxDecompiler implements Closeable {
private final Map<MethodNode, JavaMethod> methodsMap = new ConcurrentHashMap<>();
private final Map<FieldNode, JavaField> fieldsMap = new ConcurrentHashMap<>();

private final IDecompileScheduler decompileScheduler = new DecompilerScheduler(this);
private final IDecompileScheduler decompileScheduler = new DecompilerScheduler();

private final List<ILoadResult> customLoads = new ArrayList<>();

Expand Down Expand Up @@ -202,6 +202,7 @@ private void loadPlugins(JadxArgs args) {
}
}

@SuppressWarnings("unused")
public void registerPlugin(JadxPlugin plugin) {
pluginManager.register(plugin);
}
Expand Down Expand Up @@ -467,23 +468,10 @@ synchronized ProtoXMLParser getProtoXmlParser() {
return protoXmlParser;
}

private void loadJavaClass(JavaClass javaClass) {
javaClass.getMethods().forEach(mth -> methodsMap.put(mth.getMethodNode(), mth));
javaClass.getFields().forEach(fld -> fieldsMap.put(fld.getFieldNode(), fld));

for (JavaClass innerCls : javaClass.getInnerClasses()) {
classesMap.put(innerCls.getClassNode(), innerCls);
loadJavaClass(innerCls);
}
for (JavaClass inlinedCls : javaClass.getInlinedClasses()) {
classesMap.put(inlinedCls.getClassNode(), inlinedCls);
loadJavaClass(inlinedCls);
}
}

/**
* Get JavaClass by ClassNode without loading and decompilation
*/
@ApiStatus.Internal
JavaClass convertClassNode(ClassNode cls) {
return classesMap.compute(cls, (node, prevJavaCls) -> {
if (prevJavaCls != null && prevJavaCls.getClassNode() == cls) {
Expand All @@ -497,66 +485,23 @@ JavaClass convertClassNode(ClassNode cls) {
});
}

@Nullable("For not generated classes")
@ApiStatus.Internal
public JavaClass getJavaClassByNode(ClassNode cls) {
JavaClass javaClass = classesMap.get(cls);
if (javaClass != null && javaClass.getClassNode() == cls) {
return javaClass;
}
// load parent class if inner
ClassNode parentClass = cls.getTopParentClass();
if (parentClass.contains(AFlag.DONT_GENERATE)) {
return null;
}
JavaClass parentJavaClass = classesMap.get(parentClass);
if (parentJavaClass == null) {
getClasses();
parentJavaClass = classesMap.get(parentClass);
}
if (parentJavaClass != null) {
loadJavaClass(parentJavaClass);
javaClass = classesMap.get(cls);
if (javaClass != null) {
return javaClass;
}
}
// class or parent classes can be excluded from generation
if (cls.hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaClass not found by ClassNode: " + cls);
JavaField convertFieldNode(FieldNode field) {
return fieldsMap.computeIfAbsent(field, fldNode -> {
JavaClass parentCls = convertClassNode(fldNode.getParentClass());
return new JavaField(parentCls, fldNode);
});
}

@ApiStatus.Internal
@Nullable
public JavaMethod getJavaMethodByNode(MethodNode mth) {
JavaMethod javaMethod = methodsMap.get(mth);
if (javaMethod != null && javaMethod.getMethodNode() == mth) {
return javaMethod;
}
if (mth.contains(AFlag.DONT_GENERATE)) {
return null;
}
// parent class not loaded yet
ClassNode parentClass = mth.getParentClass();
ClassNode codeCls = getCodeParentClass(parentClass);
JavaClass javaClass = getJavaClassByNode(codeCls);
if (javaClass == null) {
return null;
}
loadJavaClass(javaClass);
javaMethod = methodsMap.get(mth);
if (javaMethod != null) {
return javaMethod;
}
if (parentClass.hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaMethod not found by MethodNode: " + mth);
JavaMethod convertMethodNode(MethodNode method) {
return methodsMap.computeIfAbsent(method, mthNode -> {
ClassNode codeCls = getCodeParentClass(mthNode.getParentClass());
return new JavaMethod(convertClassNode(codeCls), mthNode);
});
}

private ClassNode getCodeParentClass(ClassNode cls) {
private static ClassNode getCodeParentClass(ClassNode cls) {
ClassNode codeCls;
InlinedAttr inlinedAttr = cls.get(AType.INLINED);
if (inlinedAttr != null) {
Expand All @@ -570,35 +515,12 @@ private ClassNode getCodeParentClass(ClassNode cls) {
return getCodeParentClass(codeCls);
}

@ApiStatus.Internal
@Nullable
public JavaField getJavaFieldByNode(FieldNode fld) {
JavaField javaField = fieldsMap.get(fld);
if (javaField != null && javaField.getFieldNode() == fld) {
return javaField;
}
// parent class not loaded yet
JavaClass javaClass = getJavaClassByNode(fld.getParentClass().getTopParentClass());
if (javaClass == null) {
return null;
}
loadJavaClass(javaClass);
javaField = fieldsMap.get(fld);
if (javaField != null) {
return javaField;
}
if (fld.getParentClass().hasNotGeneratedParent()) {
return null;
}
throw new JadxRuntimeException("JavaField not found by FieldNode: " + fld);
}

@Nullable
public JavaClass searchJavaClassByOrigFullName(String fullName) {
return getRoot().getClasses().stream()
.filter(cls -> cls.getClassInfo().getFullName().equals(fullName))
.findFirst()
.map(this::getJavaClassByNode)
.map(this::convertClassNode)
.orElse(null);
}

Expand All @@ -619,9 +541,9 @@ public JavaClass searchJavaClassOrItsParentByOrigFullName(String fullName) {
.orElse(null);
if (node != null) {
if (node.contains(AFlag.DONT_GENERATE)) {
return getJavaClassByNode(node.getTopParentClass());
return convertClassNode(node.getTopParentClass());
} else {
return getJavaClassByNode(node);
return convertClassNode(node);
}
}
return null;
Expand All @@ -632,7 +554,7 @@ public JavaClass searchJavaClassByAliasFullName(String fullName) {
return getRoot().getClasses().stream()
.filter(cls -> cls.getClassInfo().getAliasFullName().equals(fullName))
.findFirst()
.map(this::getJavaClassByNode)
.map(this::convertClassNode)
.orElse(null);
}

Expand All @@ -650,9 +572,9 @@ public JavaNode getJavaNodeByCodeAnnotation(@Nullable ICodeInfo codeInfo, @Nulla
case CLASS:
return convertClassNode((ClassNode) ann);
case METHOD:
return getJavaMethodByNode((MethodNode) ann);
return convertMethodNode((MethodNode) ann);
case FIELD:
return getJavaFieldByNode((FieldNode) ann);
return convertFieldNode((FieldNode) ann);
case DECLARATION:
return getJavaNodeByCodeAnnotation(codeInfo, ((NodeDeclareRef) ann).getNode());
case VAR:
Expand All @@ -670,7 +592,7 @@ public JavaNode getJavaNodeByCodeAnnotation(@Nullable ICodeInfo codeInfo, @Nulla
@Nullable
private JavaVariable resolveVarNode(VarNode varNode) {
MethodNode mthNode = varNode.getMth();
JavaMethod mth = getJavaMethodByNode(mthNode);
JavaMethod mth = convertMethodNode(mthNode);
if (mth == null) {
return null;
}
Expand Down
58 changes: 30 additions & 28 deletions jadx-core/src/main/java/jadx/api/JavaClass.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
Expand All @@ -15,7 +14,6 @@
import org.slf4j.LoggerFactory;

import jadx.api.metadata.ICodeAnnotation;
import jadx.api.metadata.ICodeAnnotation.AnnType;
import jadx.api.metadata.ICodeNodeRef;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.attributes.AType;
Expand All @@ -24,6 +22,7 @@
import jadx.core.dex.nodes.ClassNode;
import jadx.core.dex.nodes.FieldNode;
import jadx.core.dex.nodes.MethodNode;
import jadx.core.utils.ListUtils;

public final class JavaClass implements JavaNode {
private static final Logger LOG = LoggerFactory.getLogger(JavaClass.class);
Expand Down Expand Up @@ -88,6 +87,14 @@ public synchronized String getSmali() {
return cls.getDisassembledCode();
}

@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.CLASS) {
return ann.equals(cls);
}
return false;
}

/**
* Internal API. Not Stable!
*/
Expand All @@ -99,7 +106,6 @@ public ClassNode getClassNode() {
/**
* Decompile class and loads internal lists of fields, methods, etc.
* Do nothing if already loaded.
* Return not null on first call only (for actual loading)
*/
@Nullable
private synchronized void load() {
Expand Down Expand Up @@ -140,10 +146,9 @@ private synchronized void load() {
if (fieldsCount != 0) {
List<JavaField> flds = new ArrayList<>(fieldsCount);
for (FieldNode f : cls.getFields()) {
// if (!f.contains(AFlag.DONT_GENERATE)) {
JavaField javaField = new JavaField(this, f);
flds.add(javaField);
// }
if (!f.contains(AFlag.DONT_GENERATE)) {
flds.add(rootDecompiler.convertFieldNode(f));
}
}
this.fields = Collections.unmodifiableList(flds);
}
Expand All @@ -153,16 +158,15 @@ private synchronized void load() {
List<JavaMethod> mths = new ArrayList<>(methodsCount);
for (MethodNode m : cls.getMethods()) {
if (!m.contains(AFlag.DONT_GENERATE)) {
JavaMethod javaMethod = new JavaMethod(this, m);
mths.add(javaMethod);
mths.add(rootDecompiler.convertMethodNode(m));
}
}
mths.sort(Comparator.comparing(JavaMethod::getName));
this.methods = Collections.unmodifiableList(mths);
}
}

protected JadxDecompiler getRootDecompiler() {
JadxDecompiler getRootDecompiler() {
if (parent != null) {
return parent.getRootDecompiler();
}
Expand Down Expand Up @@ -193,27 +197,16 @@ public Map<Integer, JavaNode> getUsageMap() {
}

public List<Integer> getUsePlacesFor(ICodeInfo codeInfo, JavaNode javaNode) {
Map<Integer, ICodeAnnotation> map = codeInfo.getCodeMetadata().getAsMap();
if (map.isEmpty() || decompiler == null) {
if (!codeInfo.hasMetadata()) {
return Collections.emptyList();
}
JadxDecompiler rootDec = getRootDecompiler();
List<Integer> result = new ArrayList<>();
for (Map.Entry<Integer, ICodeAnnotation> entry : map.entrySet()) {
ICodeAnnotation ann = entry.getValue();
AnnType annType = ann.getAnnType();
if (annType == AnnType.DECLARATION || annType == AnnType.OFFSET) {
// ignore declarations and offset annotations
continue;
}
JavaNode annNode = rootDec.getJavaNodeByCodeAnnotation(codeInfo, ann);
if (annNode == null && LOG.isDebugEnabled()) {
LOG.debug("Failed to resolve code annotation, cls: {}, pos: {}, ann: {}", this, entry.getKey(), ann);
codeInfo.getCodeMetadata().searchDown(0, (pos, ann) -> {
if (javaNode.isOwnCodeAnnotation(ann)) {
result.add(pos);
}
if (Objects.equals(annNode, javaNode)) {
result.add(entry.getKey());
}
}
return null;
});
return result;
}

Expand Down Expand Up @@ -294,7 +287,16 @@ public JavaMethod searchMethodByShortId(String shortId) {
if (methodNode == null) {
return null;
}
return new JavaMethod(this, methodNode);
return getRootDecompiler().convertMethodNode(methodNode);
}

public List<JavaClass> getDependencies() {
JadxDecompiler d = getRootDecompiler();
return ListUtils.map(cls.getDependencies(), d::convertClassNode);
}

public int getTotalDepsCount() {
return cls.getTotalDepsCount();
}

@Override
Expand Down
9 changes: 9 additions & 0 deletions jadx-core/src/main/java/jadx/api/JavaField.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import org.jetbrains.annotations.ApiStatus;

import jadx.api.metadata.ICodeAnnotation;
import jadx.core.dex.info.AccessInfo;
import jadx.core.dex.instructions.args.ArgType;
import jadx.core.dex.nodes.FieldNode;
Expand Down Expand Up @@ -65,6 +66,14 @@ public void removeAlias() {
this.field.getFieldInfo().removeAlias();
}

@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.FIELD) {
return ann.equals(field);
}
return false;
}

/**
* Internal API. Not Stable!
*/
Expand Down
12 changes: 11 additions & 1 deletion jadx-core/src/main/java/jadx/api/JavaMethod.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import jadx.api.metadata.ICodeAnnotation;
import jadx.core.dex.attributes.AType;
import jadx.core.dex.attributes.nodes.MethodOverrideAttr;
import jadx.core.dex.info.AccessInfo;
Expand All @@ -18,6 +19,7 @@

public final class JavaMethod implements JavaNode {
private static final Logger LOG = LoggerFactory.getLogger(JavaMethod.class);

private final MethodNode mth;
private final JavaClass parent;

Expand Down Expand Up @@ -78,7 +80,7 @@ public List<JavaMethod> getOverrideRelatedMethods() {
JadxDecompiler decompiler = getDeclaringClass().getRootDecompiler();
return ovrdAttr.getRelatedMthNodes().stream()
.map(m -> {
JavaMethod javaMth = decompiler.getJavaMethodByNode(m);
JavaMethod javaMth = decompiler.convertMethodNode(m);
if (javaMth == null) {
LOG.warn("Failed convert to java method: {}", m);
}
Expand Down Expand Up @@ -106,6 +108,14 @@ public void removeAlias() {
this.mth.getMethodInfo().removeAlias();
}

@Override
public boolean isOwnCodeAnnotation(ICodeAnnotation ann) {
if (ann.getAnnType() == ICodeAnnotation.AnnType.METHOD) {
return ann.equals(mth);
}
return false;
}

/**
* Internal API. Not Stable!
*/
Expand Down
Loading

0 comments on commit cb741db

Please sign in to comment.