Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joda-Time to Java time: Add support for Method Parameter Migration #605

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import static org.openrewrite.java.migrate.joda.templates.TimeClassNames.JODA_CLASS_PATTERN;

public class JodaTimeFlowSpec extends DataFlowSpec {
class JodaTimeFlowSpec extends DataFlowSpec {

@Override
public boolean isSource(@NonNull DataFlowNode srcNode) {
Expand All @@ -36,6 +36,12 @@ public boolean isSource(@NonNull DataFlowNode srcNode) {
if (value instanceof J.VariableDeclarations.NamedVariable) {
return isJodaType(((J.VariableDeclarations.NamedVariable) value).getType());
}

if (value instanceof J.VariableDeclarations) {
if (srcNode.getCursor().getParentTreeCursor().getParentTreeCursor().getValue() instanceof J.MethodDeclaration) {
return isJodaType(((J.VariableDeclarations) value).getType());
}
}
return false;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,20 @@
*/
package org.openrewrite.java.migrate.joda;

import lombok.Getter;
import org.jspecify.annotations.Nullable;
import org.openrewrite.ExecutionContext;
import org.openrewrite.ScanningRecipe;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.VariableDeclarations.NamedVariable;
import org.openrewrite.java.tree.JavaType;

import java.util.HashSet;
import java.util.Set;
import java.util.*;

public class JodaTimeRecipe extends ScanningRecipe<Set<NamedVariable>> {
public class JodaTimeRecipe extends ScanningRecipe<JodaTimeRecipe.Accumulator> {
@Override
public String getDisplayName() {
return "Migrate Joda Time to Java Time";
return "Migrate Joda-Time to Java time";
}

@Override
Expand All @@ -34,17 +37,46 @@ public String getDescription() {
}

@Override
public Set<NamedVariable> getInitialValue(ExecutionContext ctx) {
return new HashSet<>();
public Accumulator getInitialValue(ExecutionContext ctx) {
return new Accumulator();
}

@Override
public JodaTimeScanner getScanner(Set<NamedVariable> acc) {
public JodaTimeScanner getScanner(Accumulator acc) {
return new JodaTimeScanner(acc);
}

@Override
public JodaTimeVisitor getVisitor(Set<NamedVariable> acc) {
return new JodaTimeVisitor(acc);
public JodaTimeVisitor getVisitor(Accumulator acc) {
return new JodaTimeVisitor(acc, true, new LinkedList<>());
}

@Getter
public static class Accumulator {
private final Set<NamedVariable> unsafeVars = new HashSet<>();
private final VarTable varTable = new VarTable();
}

static class VarTable {
private final Map<JavaType, List<NamedVariable>> vars = new HashMap<>();

public void addVars(J.MethodDeclaration methodDeclaration) {
JavaType type = methodDeclaration.getMethodType();
assert type != null;
methodDeclaration.getParameters().forEach(p -> {
if (!(p instanceof J.VariableDeclarations)) {
return;
}
J.VariableDeclarations.NamedVariable namedVariable = ((J.VariableDeclarations) p).getVariables().get(0);
vars.computeIfAbsent(type, k -> new ArrayList<>()).add(namedVariable);
});
}

public @Nullable NamedVariable getVarByName(@Nullable JavaType declaringType, String varName) {
return vars.getOrDefault(declaringType, Collections.emptyList()).stream()
.filter(v -> v.getSimpleName().equals(varName))
.findFirst() // there should be only one variable with the same name
.orElse(null);
}
}
}
109 changes: 75 additions & 34 deletions src/main/java/org/openrewrite/java/migrate/joda/JodaTimeScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.analysis.dataflow.Dataflow;
Expand All @@ -28,36 +29,34 @@
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.VariableDeclarations.NamedVariable;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.MethodCall;

import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.openrewrite.java.migrate.joda.templates.TimeClassNames.JODA_CLASS_PATTERN;

public class JodaTimeScanner extends ScopeAwareVisitor {
class JodaTimeScanner extends ScopeAwareVisitor {

@Getter
private final Set<NamedVariable> unsafeVars;
private final JodaTimeRecipe.Accumulator acc;

private final Map<NamedVariable, Set<NamedVariable>> varDependencies = new HashMap<>();
private final Map<JavaType, Set<String>> unsafeVarsByType = new HashMap<>();

public JodaTimeScanner(Set<NamedVariable> unsafeVars, LinkedList<VariablesInScope> scopes) {
super(scopes);
this.unsafeVars = unsafeVars;
}

public JodaTimeScanner(Set<NamedVariable> unsafeVars) {
this(unsafeVars, new LinkedList<>());
public JodaTimeScanner(JodaTimeRecipe.Accumulator acc) {
super(new LinkedList<>());
this.acc = acc;
}

@Override
public J visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) {
super.visitCompilationUnit(cu, ctx);
Set<NamedVariable> allReachable = new HashSet<>();
for (NamedVariable var : unsafeVars) {
for (NamedVariable var : acc.getUnsafeVars()) {
dfs(var, allReachable);
}
unsafeVars.addAll(allReachable);
acc.getUnsafeVars().addAll(allReachable);
return cu;
}

Expand All @@ -66,21 +65,32 @@ public NamedVariable visitVariable(NamedVariable variable, ExecutionContext ctx)
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return variable;
}
// TODO: handle class variables && method parameters
if (!isLocalVar(variable)) {
unsafeVars.add(variable);
// TODO: handle class variables
if (isClassVar(variable)) {
acc.getUnsafeVars().add(variable);
return variable;
}
variable = (NamedVariable) super.visitVariable(variable, ctx);

if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN) || variable.getInitializer() == null) {
if (!variable.getType().isAssignableFrom(JODA_CLASS_PATTERN)) {
return variable;
}
boolean isMethodParam = getCursor().getParentTreeCursor() // VariableDeclaration
.getParentTreeCursor() // MethodDeclaration
.getValue() instanceof J.MethodDeclaration;
Cursor cursor = null;
if (isMethodParam) {
cursor = getCursor();
} else if (variable.getInitializer() != null) {
cursor = new Cursor(getCursor(), variable.getInitializer());
}
if (cursor == null) {
return variable;
}
List<Expression> sinks = findSinks(variable.getInitializer());
List<Expression> sinks = findSinks(cursor);

Cursor currentScope = getCurrentScope();
J.Block block = currentScope.getValue();
new AddSafeCheckMarker(sinks).visit(block, ctx, currentScope.getParent());
new AddSafeCheckMarker(sinks).visit(currentScope.getValue(), ctx, currentScope.getParentOrThrow());
processMarkersOnExpression(sinks, variable);
return variable;
}
Expand All @@ -99,12 +109,24 @@ public J.Assignment visitAssignment(J.Assignment assignment, ExecutionContext ct
}
NamedVariable variable = mayBeVar.get();
Cursor varScope = findScope(variable);
List<Expression> sinks = findSinks(assignment.getAssignment());
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParent());
List<Expression> sinks = findSinks(new Cursor(getCursor(), assignment.getAssignment()));
new AddSafeCheckMarker(sinks).visit(varScope.getValue(), ctx, varScope.getParentOrThrow());
processMarkersOnExpression(sinks, variable);
return assignment;
}

@Override
public J.MethodDeclaration visitMethodDeclaration(J.MethodDeclaration method, ExecutionContext ctx) {
acc.getVarTable().addVars(method);
unsafeVarsByType.getOrDefault(method.getMethodType(), Collections.emptySet()).forEach(varName -> {
NamedVariable var = acc.getVarTable().getVarByName(method.getMethodType(), varName);
if (var != null) { // var can only be null if method is not correctly type attributed
acc.getUnsafeVars().add(var);
}
});
return (J.MethodDeclaration) super.visitMethodDeclaration(method, ctx);
}

private void processMarkersOnExpression(List<Expression> expressions, NamedVariable var) {
for (Expression expr : expressions) {
Optional<SafeCheckMarker> mayBeMarker = expr.getMarkers().findFirst(SafeCheckMarker.class);
Expand All @@ -113,7 +135,7 @@ private void processMarkersOnExpression(List<Expression> expressions, NamedVaria
}
SafeCheckMarker marker = mayBeMarker.get();
if (!marker.isSafe()) {
unsafeVars.add(var);
acc.getUnsafeVars().add(var);
}
if (!marker.getReferences().isEmpty()) {
varDependencies.compute(var, (k, v) -> v == null ? new HashSet<>() : v).addAll(marker.getReferences());
Expand All @@ -128,21 +150,16 @@ private boolean isJodaExpr(Expression expression) {
return expression.getType() != null && expression.getType().isAssignableFrom(JODA_CLASS_PATTERN);
}

private List<Expression> findSinks(Expression expr) {
Cursor cursor = new Cursor(getCursor(), expr);
private List<Expression> findSinks(Cursor cursor) {
Option<SinkFlowSummary> mayBeSinks = Dataflow.startingAt(cursor).findSinks(new JodaTimeFlowSpec());
if (mayBeSinks.isNone()) {
return Collections.emptyList();
}
return mayBeSinks.some().getExpressionSinks();
}

private boolean isLocalVar(NamedVariable variable) {
if (!(variable.getVariableType().getOwner() instanceof JavaType.Method)) {
return false;
}
J j = getCursor().dropParentUntil(t -> t instanceof J.Block || t instanceof J.MethodDeclaration).getValue();
return j instanceof J.Block;
private boolean isClassVar(NamedVariable variable) {
return variable.getVariableType().getOwner() instanceof JavaType.Class;
}

private void dfs(NamedVariable root, Set<NamedVariable> visited) {
Expand All @@ -167,7 +184,17 @@ public Expression visitExpression(Expression expression, ExecutionContext ctx) {
if (index == -1) {
return super.visitExpression(expression, ctx);
}
Expression withMarker = expression.withMarkers(expression.getMarkers().addIfAbsent(getMarker(expression, ctx)));
SafeCheckMarker marker = getMarker(expression, ctx);
if (!marker.isSafe()) {
Optional<Cursor> mayBeArgCursor = findArgumentExprCursor();
if (mayBeArgCursor.isPresent()) {
MethodCall parentMethod = mayBeArgCursor.get().getParentTreeCursor().getValue();
int argPos = parentMethod.getArguments().indexOf(mayBeArgCursor.get().getValue());
String paramName = parentMethod.getMethodType().getParameterNames().get(argPos);
unsafeVarsByType.computeIfAbsent(parentMethod.getMethodType(), k -> new HashSet<>()).add(paramName);
}
}
Expression withMarker = expression.withMarkers(expression.getMarkers().addIfAbsent(marker));
expressions.set(index, withMarker);
return withMarker;
}
Expand All @@ -185,8 +212,9 @@ private SafeCheckMarker getMarker(Expression expr, ExecutionContext ctx) {
isSafe = false;
}
Expression boundaryExpr = boundary.getValue();
J j = new JodaTimeVisitor(new HashSet<>(), scopes).visit(boundaryExpr, ctx, boundary.getParentTreeCursor());
Set<NamedVariable> referencedVars = new HashSet<>();
J j = new JodaTimeVisitor(new JodaTimeRecipe.Accumulator(), false, scopes)
.visit(boundaryExpr, ctx, boundary.getParentTreeCursor());
Set<@Nullable NamedVariable> referencedVars = new HashSet<>();
new FindVarReferences().visit(expr, referencedVars, getCursor().getParentTreeCursor());
AtomicBoolean hasJodaType = new AtomicBoolean();
new HasJodaType().visit(j, hasJodaType);
Expand All @@ -211,12 +239,25 @@ private Cursor findBoundaryCursorForJodaExpr() {
}
return cursor;
}

private Optional<Cursor> findArgumentExprCursor() {
Cursor cursor = getCursor();
while (cursor.getValue() instanceof Expression && isJodaExpr(cursor.getValue())) {
Cursor parentCursor = cursor.getParentTreeCursor();
if (parentCursor.getValue() instanceof MethodCall &&
((MethodCall) parentCursor.getValue()).getArguments().contains(cursor.getValue())) {
return Optional.of(cursor);
}
cursor = parentCursor;
}
return Optional.empty();
}
}

private class FindVarReferences extends JavaIsoVisitor<Set<NamedVariable>> {
private class FindVarReferences extends JavaIsoVisitor<Set<@Nullable NamedVariable>> {

@Override
public J.Identifier visitIdentifier(J.Identifier ident, Set<NamedVariable> vars) {
public J.Identifier visitIdentifier(J.Identifier ident, Set<@Nullable NamedVariable> vars) {
if (!isJodaExpr(ident) || ident.getFieldType() == null) {
return ident;
}
Expand Down
Loading