diff --git a/archunit/src/main/java/com/tngtech/archunit/core/importer/ClassFileImportRecord.java b/archunit/src/main/java/com/tngtech/archunit/core/importer/ClassFileImportRecord.java index 601321194b..857ca2d69e 100644 --- a/archunit/src/main/java/com/tngtech/archunit/core/importer/ClassFileImportRecord.java +++ b/archunit/src/main/java/com/tngtech/archunit/core/importer/ClassFileImportRecord.java @@ -50,11 +50,12 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static com.tngtech.archunit.base.Optionals.stream; import static com.tngtech.archunit.core.importer.JavaClassDescriptorImporter.isLambdaMethodName; import static com.tngtech.archunit.core.importer.JavaClassDescriptorImporter.isSyntheticAccessMethodName; import static com.tngtech.archunit.core.importer.JavaClassDescriptorImporter.isSyntheticEnumSwitchMapFieldName; import static java.util.Collections.emptyList; +import static java.util.Collections.singleton; +import static java.util.stream.Collectors.toSet; class ClassFileImportRecord { private static final Logger log = LoggerFactory.getLogger(ClassFileImportRecord.class); @@ -288,7 +289,7 @@ private Stream fixSyntheticOrigins( Stream result = rawAccessRecordsIncludingSyntheticAccesses.stream(); for (SyntheticAccessRecorder syntheticAccessRecorder : syntheticAccessRecorders) { - result = result.flatMap(access -> stream(syntheticAccessRecorder.fixSyntheticAccess(access, createAccessWithNewOrigin))); + result = result.flatMap(access -> syntheticAccessRecorder.fixSyntheticAccess(access, createAccessWithNewOrigin).stream()); } return result; } @@ -374,7 +375,7 @@ Optional getEnclosingCodeUnit(String ownerName) { } private static class SyntheticAccessRecorder { - private final Map rawSyntheticMethodInvocationRecordsByTarget = new HashMap<>(); + private final SetMultimap rawSyntheticMethodInvocationRecordsByTarget = HashMultimap.create(); private final Predicate isSyntheticOrigin; private final BiConsumer, CodeUnit> fixOrigin; @@ -390,41 +391,41 @@ void registerSyntheticMethodInvocation(RawAccessRecord record) { rawSyntheticMethodInvocationRecordsByTarget.put(getMemberKey(record.target), record); } - Optional fixSyntheticAccess( + Set fixSyntheticAccess( ACCESS access, Function> copyAccess ) { return isSyntheticOrigin.test(access.caller) ? replaceOriginByFixedOrigin(access, copyAccess) - : Optional.of(access); + : singleton(access); } - private Optional replaceOriginByFixedOrigin( + private Set replaceOriginByFixedOrigin( ACCESS accessFromSyntheticMethod, Function> copyAccess ) { - RawAccessRecord accessWithCorrectOrigin = findNonSyntheticOriginOf(accessFromSyntheticMethod); - - if (accessWithCorrectOrigin != null) { - RawAccessRecord.BaseBuilder copiedBuilder = copyAccess.apply(accessFromSyntheticMethod); - fixOrigin.accept(copiedBuilder, accessWithCorrectOrigin.caller); - return Optional.of(copiedBuilder.build()); - } else { + Set result = findNonSyntheticOriginOf(accessFromSyntheticMethod) + .map(accessWithCorrectOrigin -> { + RawAccessRecord.BaseBuilder copiedBuilder = copyAccess.apply(accessFromSyntheticMethod); + fixOrigin.accept(copiedBuilder, accessWithCorrectOrigin.caller); + return copiedBuilder.build(); + }) + .collect(toSet()); + + if (result.isEmpty()) { log.warn("Could not find matching origin for synthetic method {}.{}|{}", accessFromSyntheticMethod.target.getDeclaringClassName(), accessFromSyntheticMethod.target.name, accessFromSyntheticMethod.target.getDescriptor()); - return Optional.empty(); } - } - - private RawAccessRecord findNonSyntheticOriginOf(ACCESS accessFromSyntheticMethod) { - RawAccessRecord result = accessFromSyntheticMethod; - do { - result = rawSyntheticMethodInvocationRecordsByTarget.get(getMemberKey(result.caller)); - } while (result != null && isSyntheticOrigin.test(result.caller)); return result; } + + private Stream findNonSyntheticOriginOf(ACCESS access) { + return isSyntheticOrigin.test(access.caller) + ? rawSyntheticMethodInvocationRecordsByTarget.get(getMemberKey(access.caller)).stream().flatMap(this::findNonSyntheticOriginOf) + : Stream.of(access); + } } } diff --git a/archunit/src/test/java/com/tngtech/archunit/core/importer/ClassFileImporterSyntheticPrivateAccessesTest.java b/archunit/src/test/java/com/tngtech/archunit/core/importer/ClassFileImporterSyntheticPrivateAccessesTest.java index 002b5eb1fa..9ff4fc9c53 100644 --- a/archunit/src/test/java/com/tngtech/archunit/core/importer/ClassFileImporterSyntheticPrivateAccessesTest.java +++ b/archunit/src/test/java/com/tngtech/archunit/core/importer/ClassFileImporterSyntheticPrivateAccessesTest.java @@ -4,6 +4,7 @@ import java.util.function.Supplier; import com.tngtech.archunit.core.domain.JavaAccess; +import com.tngtech.archunit.core.domain.JavaClass; import org.junit.Test; import static com.tngtech.archunit.core.domain.JavaConstructor.CONSTRUCTOR_NAME; @@ -334,6 +335,65 @@ Supplier access(Target target) { ); } + @Test + public void imports_multiple_accesses_to_same_private_field() { + @SuppressWarnings("unused") + class Target { + private String field; + } + @SuppressWarnings("unused") + class Origin { + String first(Target target) { + return target.field; + } + + String second(Target target) { + return target.field; + } + } + + JavaClass origin = new ClassFileImporter().importClasses(Target.class, Origin.class).get(Origin.class); + + assertThatAccesses(origin.getAccessesFromSelf()) + .contain(expectedAccess() + .from(Origin.class, "first") + .toField(GET, Target.class, "field") + ) + .contain(expectedAccess() + .from(Origin.class, "second") + .toField(GET, Target.class, "field")); + } + + @Test + public void imports_multiple_calls_to_same_private_method() { + @SuppressWarnings("unused") + class Target { + private void method() { + } + } + @SuppressWarnings("unused") + class Origin { + void first(Target target) { + target.method(); + } + + void second(Target target) { + target.method(); + } + } + + JavaClass origin = new ClassFileImporter().importClasses(Target.class, Origin.class).get(Origin.class); + + assertThatAccesses(origin.getAccessesFromSelf()) + .contain(expectedAccess() + .from(Origin.class, "first") + .to(Target.class, "method") + ) + .contain(expectedAccess() + .from(Origin.class, "second") + .to(Target.class, "method")); + } + private static class Data_of_imports_private_constructor_reference_from_lambda { @SuppressWarnings("unused") static class Target {