diff --git a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java index c8b207e9b..dbfc0ab33 100644 --- a/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java +++ b/src/main/java/org/openrewrite/staticanalysis/InstanceOfPatternMatch.java @@ -21,6 +21,7 @@ import org.jspecify.annotations.Nullable; import org.openrewrite.*; import org.openrewrite.java.JavaVisitor; +import org.openrewrite.java.MethodMatcher; import org.openrewrite.java.VariableNameUtils; import org.openrewrite.java.search.SemanticallyEqual; import org.openrewrite.java.search.UsesJavaVersion; @@ -45,6 +46,8 @@ @EqualsAndHashCode(callSuper = false) public class InstanceOfPatternMatch extends Recipe { + private static MethodMatcher STREAM_COLLECT_MATCHER = new MethodMatcher("java.util.stream.Stream collect(..)"); + @Override public String getDisplayName() { return "Changes code to use Java 17's `instanceof` pattern matching"; @@ -74,7 +77,8 @@ public TreeVisitor getVisitor() { public @Nullable J postVisit(J tree, ExecutionContext ctx) { J result = super.postVisit(tree, ctx); InstanceOfPatternReplacements original = getCursor().getMessage("flowTypeScope"); - if (original != null && !original.isEmpty()) { + boolean exclusion = getCursor().getNearestMessage("exclusionScope", false); + if (original != null && !original.isEmpty() && !exclusion) { return UseInstanceOfPatternMatching.refactor(result, original, getCursor().getParentOrThrow()); } return result; @@ -144,6 +148,55 @@ public J visitTypeCast(J.TypeCast typeCast, ExecutionContext ctx) { } return result; } + + @Override + public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) { + J j = super.visitMethodInvocation(method, ctx); + if (j instanceof J.MethodInvocation) { + J.MethodInvocation m = (J.MethodInvocation) j; + if (STREAM_COLLECT_MATCHER.matches(m, false)) { + Cursor cursorWithFlowTypeScope = getNearestCursorWithMessage(getCursor(), "flowTypeScope"); + InstanceOfPatternReplacements replacements = cursorWithFlowTypeScope.getMessage("flowTypeScope"); + if (replacements != null) { + Expression originalSelect = selectFromMethodInvocationChain(m); + while (originalSelect instanceof J.Parentheses) { + originalSelect = originalSelect.unwrap(); + } + if (originalSelect instanceof J.Identifier) { + J.Identifier varRef = (J.Identifier) originalSelect; + if (replacements.variablesToDelete.values().stream().anyMatch(nv -> nv.getSimpleName().equals(varRef.getSimpleName()) && nv.getType().equals(varRef.getType()))) { + cursorWithFlowTypeScope.putMessage("exclusionScope", true); + } + } else if (originalSelect instanceof J.TypeCast) { + J.TypeCast typeCast = (J.TypeCast) originalSelect; + if (replacements.replacements.containsKey(typeCast)) { + cursorWithFlowTypeScope.putMessage("exclusionScope", true); + } + } + } + } + } + return j; + } + + private Expression selectFromMethodInvocationChain(J.MethodInvocation method) { + J.MethodInvocation m = method; + for (; m.getSelect() instanceof J.MethodInvocation; m = (J.MethodInvocation) m.getSelect()) {} + return m.getSelect(); + } + + public @Nullable Cursor getNearestCursorWithMessage(Cursor cursor, String key) { + if (cursor == null) { + return null; + } + Object msg = cursor.getMessage(key); + if (msg == null) { + return getNearestCursorWithMessage(cursor.getParent(), key); + } else { + return cursor; + } + } + }); } @@ -215,6 +268,7 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) { if (!contextScopes.containsKey(instanceOf)) { return instanceOf; } + @Nullable JavaType type = ((TypedTree) instanceOf.getClazz()).getType(); String name = patternVariableName(instanceOf, cursor); J.InstanceOf result = instanceOf.withPattern(new J.Identifier( @@ -225,25 +279,6 @@ public J.InstanceOf processInstanceOf(J.InstanceOf instanceOf, Cursor cursor) { name, type, null)); - JavaType.FullyQualified fqType = TypeUtils.asFullyQualified(type); - if (fqType != null && !fqType.getTypeParameters().isEmpty() && !(instanceOf.getClazz() instanceof J.ParameterizedType)) { - TypedTree oldTypeTree = (TypedTree) instanceOf.getClazz(); - - // Each type parameter is turned into a wildcard, i.e. `List` -> `List` or `Map.Entry` -> `Map.Entry` - List wildcardsList = IntStream.range(0, fqType.getTypeParameters().size()) - .mapToObj(i -> new J.Wildcard(randomId(), Space.EMPTY, Markers.EMPTY, null, null)) - .collect(Collectors.toList()); - - J.ParameterizedType newTypeTree = new J.ParameterizedType( - randomId(), - oldTypeTree.getPrefix(), - Markers.EMPTY, - oldTypeTree.withPrefix(Space.EMPTY), - null, - oldTypeTree.getType() - ).withTypeParameters(wildcardsList); - result = result.withClazz(newTypeTree); - } // update entry in replacements to share the pattern variable name for (Map.Entry entry : replacements.entrySet()) { diff --git a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java index 8279511d9..6cd91d9dc 100644 --- a/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/InstanceOfPatternMatchTest.java @@ -133,7 +133,7 @@ void test(Object o) { } @Test - void genericsWithoutParameters() { + void genericsWithoutParameters_1() { rewriteRun( //language=java java( @@ -142,21 +142,72 @@ void genericsWithoutParameters() { import java.util.List; import java.util.Map; import java.util.stream.Collectors; + import java.util.stream.Stream; public class A { @SuppressWarnings("unchecked") - public static List> applyRoutesType(Object routes) { + public static Stream> applyRoutesType(Object routes) { if (routes instanceof List) { List routesList = (List) routes; if (routesList.isEmpty()) { - return Collections.emptyList(); + return Stream.empty(); } if (routesList.stream() .anyMatch(route -> !(route instanceof Map))) { - return Collections.emptyList(); + return Stream.empty(); } return routesList.stream() - .map(route -> (Map) route) - .collect(Collectors.toList()); + .map(route -> (Map) route); + } + return Stream.empty(); + } + } + """, + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + import java.util.stream.Stream; + public class A { + @SuppressWarnings("unchecked") + public static Stream> applyRoutesType(Object routes) { + if (routes instanceof List routesList) { + if (routesList.isEmpty()) { + return Stream.empty(); + } + if (routesList.stream() + .anyMatch(route -> !(route instanceof Map))) { + return Stream.empty(); + } + return routesList.stream() + .map(route -> (Map) route); + } + return Stream.empty(); + } + } + """ + ) + ); + } + + @Test + void genericsWithoutParameters_2() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; + if (routesList.isEmpty()) { + return Collections.emptyList(); + } } return Collections.emptyList(); } @@ -170,7 +221,109 @@ public static List> applyRoutesType(Object routes) { public class A { @SuppressWarnings("unchecked") public static List> applyRoutesType(Object routes) { - if (routes instanceof List routesList) { + if (routes instanceof List routesList) { + if (routesList.isEmpty()) { + return Collections.emptyList(); + } + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + + @Test + void genericsWithoutParameters_3() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; + String.join(",", (List) routes); + } + return Collections.emptyList(); + } + } + """, """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List routesList) { + String.join(",", routesList); + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + + @Test + void genericsWithoutParameters_4() { + rewriteRun( + //language=java + java( + """ + import java.util.Arrays; + import java.util.Collection; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + private Collection addValueToList(List previousValues, Object value) { + if (previousValues == null) { + return (value instanceof Collection) ? (Collection) value : Arrays.asList(value); + } + return List.of(); + } + } + """, """ + import java.util.Arrays; + import java.util.Collection; + import java.util.List; + public class A { + @SuppressWarnings("unchecked") + private Collection addValueToList(List previousValues, Object value) { + if (previousValues == null) { + return (value instanceof Collection c) ? c : Arrays.asList(value); + } + return List.of(); + } + } + """ + ) + ); + } + + @Test + void genericsWithoutParameters_5() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + List routesList = (List) routes; if (routesList.isEmpty()) { return Collections.emptyList(); } @@ -190,6 +343,32 @@ public static List> applyRoutesType(Object routes) { ); } + @Test + void genericsWithoutParameters_6() { + rewriteRun( + //language=java + java( + """ + import java.util.Collections; + import java.util.List; + import java.util.Map; + import java.util.stream.Collectors; + public class A { + @SuppressWarnings("unchecked") + public static List> applyRoutesType(Object routes) { + if (routes instanceof List) { + return ((List) routes).stream() + .map(route -> (Map) route) + .collect(Collectors.toList()); + } + return Collections.emptyList(); + } + } + """ + ) + ); + } + @Test void primitiveArray() { rewriteRun( @@ -302,7 +481,7 @@ void test(Object o) { public class A { void test(Object o) { Map.Entry entry = null; - if (o instanceof Map.Entry entry1) { + if (o instanceof Map.Entry entry1) { entry = entry1; } System.out.println(entry); @@ -849,7 +1028,7 @@ Object test(Object o) { import java.util.List; public class A { Object test(Object o) { - return o instanceof List l ? l.get(0) : o.toString(); + return o instanceof List l ? l.get(0) : o.toString(); } } """ @@ -874,7 +1053,7 @@ Object test(Object o) { import java.util.List; public class A { Object test(Object o) { - return o instanceof List l ? l.get(0) : o.toString(); + return o instanceof List l ? l.get(0) : o.toString(); } } """ @@ -975,6 +1154,52 @@ String test(Object o) { ) ); } + @Test + void iterableParameter() { + rewriteRun( + //language=java + java( + """ + import java.util.HashMap; + import java.util.List; + import java.util.Map; + + public class ApplicationSecurityGroupsParameterHelper { + + static final String APPLICATION_SECURITY_GROUPS = "application-security-groups"; + + public Map transformGatewayParameters(Map parameters) { + Map environment = new HashMap<>(); + Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS); + if (applicationSecurityGroups instanceof List) { + environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", (List) applicationSecurityGroups)); + } + return environment; + } + } + """, + """ + import java.util.HashMap; + import java.util.List; + import java.util.Map; + + public class ApplicationSecurityGroupsParameterHelper { + + static final String APPLICATION_SECURITY_GROUPS = "application-security-groups"; + + public Map transformGatewayParameters(Map parameters) { + Map environment = new HashMap<>(); + Object applicationSecurityGroups = parameters.get(APPLICATION_SECURITY_GROUPS); + if (applicationSecurityGroups instanceof List list) { + environment.put(APPLICATION_SECURITY_GROUPS, String.join(",", list)); + } + return environment; + } + } + """ + ) + ); + } } @Nested