Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix getAllInvolvedRawTypes() recursing infinitely #1276

Merged
merged 4 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NOTICE
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
ArchUnit
Copyright 2016 and onwards Peter Gafert <peter.gafert@tngtech.com>
Copyright 2016 and onwards Peter Gafert <peter.gafert@archunit.org>

This product includes software developed at
TNG Technology Consulting GmbH (https://www.tngtech.com/).
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.collect.Sets.immutableEnumSet;
import static com.google.common.collect.Sets.union;
import static com.tngtech.archunit.PublicAPI.Usage.ACCESS;
import static com.tngtech.archunit.base.ClassLoaders.getCurrentClassLoader;
Expand Down Expand Up @@ -138,7 +139,7 @@ public final class JavaClass
isRecord = builder.isRecord();
isAnonymousClass = builder.isAnonymousClass();
isMemberClass = builder.isMemberClass();
modifiers = checkNotNull(builder.getModifiers());
modifiers = immutableEnumSet(builder.getModifiers());
reflectSupplier = Suppliers.memoize(new ReflectClassSupplier());
sourceCodeLocation = SourceCodeLocation.of(this);
javaPackage = JavaPackage.simple(this);
Expand Down Expand Up @@ -657,8 +658,8 @@ public JavaClass toErasure() {
}

@Override
public Set<JavaClass> getAllInvolvedRawTypes() {
return ImmutableSet.of(getBaseComponentType());
public void traverseSignature(SignatureVisitor visitor) {
SignatureTraversal.from(visitor).visitClass(this);
}

@PublicAPI(usage = ACCESS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package com.tngtech.archunit.core.domain;

import java.util.Set;

import com.tngtech.archunit.PublicAPI;

import static com.google.common.base.Preconditions.checkNotNull;
Expand Down Expand Up @@ -70,8 +68,8 @@ public JavaClass toErasure() {
}

@Override
public Set<JavaClass> getAllInvolvedRawTypes() {
return this.componentType.getAllInvolvedRawTypes();
public void traverseSignature(SignatureVisitor visitor) {
SignatureTraversal.from(visitor).visitGenericArrayType(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ public interface JavaParameterizedType extends JavaType {
*/
@PublicAPI(usage = ACCESS)
List<JavaType> getActualTypeArguments();

@Override
default void traverseSignature(SignatureVisitor visitor) {
SignatureTraversal.from(visitor).visitParameterizedType(this);
}
}
194 changes: 193 additions & 1 deletion archunit/src/main/java/com/tngtech/archunit/core/domain/JavaType.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,23 @@
package com.tngtech.archunit.core.domain;

import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.tngtech.archunit.PublicAPI;
import com.tngtech.archunit.base.ChainableFunction;
import com.tngtech.archunit.core.domain.properties.HasName;

import static com.tngtech.archunit.PublicAPI.Usage.ACCESS;
import static com.tngtech.archunit.PublicAPI.Usage.INHERITANCE;
import static com.tngtech.archunit.core.domain.JavaType.SignatureVisitor.Result.CONTINUE;
import static com.tngtech.archunit.core.domain.JavaType.SignatureVisitor.Result.STOP;
import static java.util.Collections.singleton;

/**
* Represents a general Java type. This can e.g. be a class like {@code java.lang.String}, a parameterized type
Expand Down Expand Up @@ -82,7 +92,111 @@ public interface JavaType extends HasName {
* @return All raw types involved in this {@link JavaType}
*/
@PublicAPI(usage = ACCESS)
Set<JavaClass> getAllInvolvedRawTypes();
default Set<JavaClass> getAllInvolvedRawTypes() {
codecholeric marked this conversation as resolved.
Show resolved Hide resolved
ImmutableSet.Builder<JavaClass> result = ImmutableSet.builder();
traverseSignature(new SignatureVisitor() {
@Override
public Result visitClass(JavaClass type) {
result.add(type.getBaseComponentType());
return CONTINUE;
}

@Override
public Result visitParameterizedType(JavaParameterizedType type) {
result.add(type.toErasure());
return CONTINUE;
}
});
return result.build();
}

/**
* Traverses through the signature of this {@link JavaType}.<br>
* This method considers the type signature as a tree,
* where e.g. a {@link JavaClass} is a simple leaf,
* but a {@link JavaParameterizedType} has the type as root and then
* branches out into its actual type arguments, which in turn can have type arguments
* or upper/lower bounds in case of {@link JavaTypeVariable} or {@link JavaWildcardType}.<br>
* The following is a simple visualization of such a signature tree:
* <pre><code>
* List&lt;Map&lt;? extends Serializable, String[]&gt;&gt;
* |
* Map&lt;? extends Serializable, String[]&gt;
* / \
* ? extends Serializable String[]
* |
* Serializable
* </code></pre>
* For every node visited the respective method of the provided {@code visitor}
hankem marked this conversation as resolved.
Show resolved Hide resolved
* will be invoked. The traversal happens depth first, i.e. in this case the {@code visitor}
* would be invoked for all types down to {@code Serializable} before visiting the {@code String[]}
* array type of the second branch. At every step it is possible to continue the traversal
* by returning {@link SignatureVisitor.Result#CONTINUE CONTINUE} or stop at that point by
* returning {@link SignatureVisitor.Result#STOP STOP}.<br><br>
* Note that the traversal will continue to traverse bounds of type variables,
* even if that type variable isn't declared in this signature itself.<br>
* E.g. take the following scenario
* <pre><code>
* class Example&lt;T extends String&gt; {
* T field;
* }</code></pre>
* Traversing the {@link JavaField#getType() field type} of {@code field} will continue
* down to the upper bounds of the type variable {@code T} and thus end at the type {@code String}.<br><br>
* Also, note that the traversal will not continue down the type parameters of a raw type
* declared in a signature.<br>
* E.g. given the signature {@code class Example<T extends Map>} the traversal would stop at
* {@code Map} and not traverse down the type parameters {@code K} and {@code V} of {@code Map}.
*
* @param visitor A {@link SignatureVisitor} to invoke for every encountered {@link JavaType}
* while traversing this signature.
*/
@PublicAPI(usage = ACCESS)
void traverseSignature(SignatureVisitor visitor);

/**
* @see #traverseSignature(SignatureVisitor)
*/
@PublicAPI(usage = INHERITANCE)
interface SignatureVisitor {
default Result visitClass(JavaClass type) {
return CONTINUE;
}

default Result visitParameterizedType(JavaParameterizedType type) {
return CONTINUE;
}

default Result visitTypeVariable(JavaTypeVariable<?> type) {
return CONTINUE;
}

default Result visitGenericArrayType(JavaGenericArrayType type) {
return CONTINUE;
}

default Result visitWildcardType(JavaWildcardType type) {
return CONTINUE;
}

/**
* Result of a single step {@link #traverseSignature(SignatureVisitor) traversing a signature}.
* After each step it's possible to either {@link #STOP stop} or {@link #CONTINUE continue}
* the traversal.
*/
@PublicAPI(usage = ACCESS)
enum Result {
/**
* Causes the traversal to continue
*/
@PublicAPI(usage = ACCESS)
CONTINUE,
/**
* Causes the traversal to stop
*/
@PublicAPI(usage = ACCESS)
STOP
}
}

/**
* Predefined {@link ChainableFunction functions} to transform {@link JavaType}.
Expand All @@ -101,3 +215,81 @@ public JavaClass apply(JavaType input) {
};
}
}

class SignatureTraversal implements JavaType.SignatureVisitor {
private final Set<JavaType> visited = new HashSet<>();
private final JavaType.SignatureVisitor delegate;
private Result lastResult;

private SignatureTraversal(JavaType.SignatureVisitor delegate) {
this.delegate = delegate;
}

@Override
public Result visitClass(JavaClass type) {
// We only traverse type parameters of a JavaClass if the traversal was started *at the JavaClass* itself.
// Otherwise, we can only encounter a regular class as a raw type in a type signature.
// In these cases we don't want to traverse further down, as that would be surprising behavior
// (consider `class MyClass<T extends Map>`, traversing into the type variables `K` and `V` of `Map` would be surprising).
Supplier<Iterable<JavaTypeVariable<JavaClass>>> getFurtherTypesToTraverse = visited.isEmpty() ? type::getTypeParameters : Collections::emptyList;
return visit(type, delegate::visitClass, getFurtherTypesToTraverse);
}

@Override
public Result visitParameterizedType(JavaParameterizedType type) {
return visit(type, delegate::visitParameterizedType, type::getActualTypeArguments);
}

@Override
public Result visitTypeVariable(JavaTypeVariable<?> type) {
return visit(type, delegate::visitTypeVariable, type::getUpperBounds);
}

@Override
public Result visitGenericArrayType(JavaGenericArrayType type) {
return visit(type, delegate::visitGenericArrayType, () -> singleton(type.getComponentType()));
}

@Override
public Result visitWildcardType(JavaWildcardType type) {
return visit(type, delegate::visitWildcardType, () -> Iterables.concat(type.getUpperBounds(), type.getLowerBounds()));
}

private <CURRENT extends JavaType, NEXT extends JavaType> Result visit(
CURRENT type,
Function<CURRENT, Result> visitCurrent,
Supplier<Iterable<NEXT>> nextTypes
) {
if (visited.contains(type)) {
// if we've encountered this type already we continue traversing the siblings,
// but we won't descend further into this type signature
return setLast(CONTINUE);
}
visited.add(type);
if (visitCurrent.apply(type) == CONTINUE) {
Result result = visit(nextTypes.get());
return setLast(result);
} else {
return setLast(STOP);
hankem marked this conversation as resolved.
Show resolved Hide resolved
}
}

private Result visit(Iterable<? extends JavaType> types) {
for (JavaType nextType : types) {
nextType.traverseSignature(this);
if (lastResult == STOP) {
return STOP;
}
}
return CONTINUE;
}

private Result setLast(Result result) {
lastResult = result;
return result;
}

static SignatureTraversal from(JavaType.SignatureVisitor visitor) {
return visitor instanceof SignatureTraversal ? (SignatureTraversal) visitor : new SignatureTraversal(visitor);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import java.lang.reflect.TypeVariable;
import java.util.List;
import java.util.Set;

import com.tngtech.archunit.PublicAPI;
import com.tngtech.archunit.base.HasDescription;
Expand All @@ -29,7 +28,6 @@
import static com.tngtech.archunit.core.domain.properties.HasName.Functions.GET_NAME;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;

/**
* Represents a type variable used by generic types and members.<br>
Expand Down Expand Up @@ -122,11 +120,8 @@ public JavaClass toErasure() {
}

@Override
public Set<JavaClass> getAllInvolvedRawTypes() {
return this.upperBounds.stream()
.map(JavaType::getAllInvolvedRawTypes)
.flatMap(Set::stream)
.collect(toSet());
public void traverseSignature(SignatureVisitor visitor) {
SignatureTraversal.from(visitor).visitTypeVariable(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import java.lang.reflect.WildcardType;
import java.util.List;
import java.util.Set;
import java.util.stream.Stream;

import com.tngtech.archunit.PublicAPI;
import com.tngtech.archunit.core.domain.properties.HasUpperBounds;
Expand All @@ -27,7 +25,6 @@
import static com.tngtech.archunit.PublicAPI.Usage.ACCESS;
import static com.tngtech.archunit.core.domain.Formatters.ensureCanonicalArrayTypeName;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;

/**
* Represents a wildcard type in a type signature (compare the JLS).
Expand Down Expand Up @@ -99,11 +96,8 @@ public JavaClass toErasure() {
}

@Override
public Set<JavaClass> getAllInvolvedRawTypes() {
return Stream.concat(upperBounds.stream(), lowerBounds.stream())
.map(JavaType::getAllInvolvedRawTypes)
.flatMap(Set::stream)
.collect(toSet());
public void traverseSignature(SignatureVisitor visitor) {
SignatureTraversal.from(visitor).visitWildcardType(this);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@
import static com.tngtech.archunit.core.domain.properties.HasName.Utils.namesOf;
import static java.util.Collections.emptyList;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toSet;

@Internal
@SuppressWarnings("UnusedReturnValue")
Expand Down Expand Up @@ -1217,14 +1216,6 @@ public JavaClass toErasure() {
return type.toErasure();
}

@Override
public Set<JavaClass> getAllInvolvedRawTypes() {
return Stream.concat(
type.getAllInvolvedRawTypes().stream(),
typeArguments.stream().map(JavaType::getAllInvolvedRawTypes).flatMap(Set::stream)
).collect(toSet());
}

@Override
public List<JavaType> getActualTypeArguments() {
return typeArguments;
Expand Down
Loading