Skip to content

Commit

Permalink
fix: Sort spel parameters by the length of their name before replacin…
Browse files Browse the repository at this point in the history
…g them with literals.

Fixes #2947.
  • Loading branch information
michael-simons committed Sep 4, 2024
1 parent d57f94b commit b401c30
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -245,37 +245,36 @@ static class QueryContext {

final Map<String, Object> boundParameters;

String query;
final String query;

private boolean hasLiteralReplacementForSort = false;

QueryContext(String repositoryMethodName, String template, Map<String, Object> boundParameters) {
this.repositoryMethodName = repositoryMethodName;
this.template = template;
this.query = this.template;
this.boundParameters = boundParameters;
}
}

void replaceLiteralsIn(QueryContext queryContext) {

String cypherQuery = queryContext.template;
Iterator<Map.Entry<String, Object>> iterator = queryContext.boundParameters.entrySet().iterator();
while (iterator.hasNext()) {
Map.Entry<String, Object> entry = iterator.next();
Object value = entry.getValue();
if (!(value instanceof Neo4jSpelSupport.LiteralReplacement)) {
continue;
String cypherQuery = this.template;
Comparator<Map.Entry<String, Object>> byLengthDescending = Comparator.comparing(e -> e.getKey().length());
byLengthDescending = byLengthDescending.reversed();
List<Map.Entry<String, Object>> entries = this.boundParameters.entrySet()
.stream().sorted(byLengthDescending)
.toList();
for (var entry : entries) {
Object value = entry.getValue();
if (!(value instanceof Neo4jSpelSupport.LiteralReplacement)) {
continue;
}
this.boundParameters.remove(entry.getKey());

String key = entry.getKey();
cypherQuery = cypherQuery.replace("$" + key, ((Neo4jSpelSupport.LiteralReplacement) value).getValue());
this.hasLiteralReplacementForSort =
this.hasLiteralReplacementForSort ||
((Neo4jSpelSupport.LiteralReplacement) value).getTarget() == Neo4jSpelSupport.LiteralReplacement.Target.SORT;
}
iterator.remove();

String key = entry.getKey();
cypherQuery = cypherQuery.replace("$" + key, ((Neo4jSpelSupport.LiteralReplacement) value).getValue());
queryContext.hasLiteralReplacementForSort =
queryContext.hasLiteralReplacementForSort ||
((Neo4jSpelSupport.LiteralReplacement) value).getTarget() == Neo4jSpelSupport.LiteralReplacement.Target.SORT;
this.query = cypherQuery;
}
queryContext.query = cypherQuery;
}

void logWarningsIfNecessary(QueryContext queryContext, Neo4jParameterAccessor parameterAccessor) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ protected <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType
boundParameters
);

replaceLiteralsIn(queryContext);
logWarningsIfNecessary(queryContext, parameterAccessor);

return PreparedQuery.queryFor(returnedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,6 @@ protected <T extends Object> PreparedQuery<T> prepareQuery(Class<T> returnedType
boundParameters
);

replaceLiteralsIn(queryContext);
logWarningsIfNecessary(queryContext, parameterAccessor);

return PreparedQuery.queryFor(returnedType)
Expand Down Expand Up @@ -247,8 +246,6 @@ protected Optional<PreparedQuery<Long>> getCountQuery(Neo4jParameterAccessor par
boundParameters
);

replaceLiteralsIn(queryContext);

return PreparedQuery.queryFor(Long.class)
.withCypherQuery(queryContext.query)
.withParameters(boundParameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.concurrent.Callable;
Expand All @@ -43,6 +44,7 @@
import org.springframework.data.neo4j.core.schema.Node;
import org.springframework.data.neo4j.repository.query.Neo4jSpelSupport.LiteralReplacement;
import org.springframework.data.repository.core.EntityMetadata;
import org.springframework.data.repository.query.SpelQueryContext;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.util.ReflectionUtils;

Expand Down Expand Up @@ -183,6 +185,46 @@ void shouldUnquoteParameterExpressionsCorrectly(String quoted, String expected)
assertThat(query).isEqualTo(expected);
}

@Test
void moreThan10SpelEntriesShouldWork() {

SpelQueryContext spelQueryContext = StringBasedNeo4jQuery.SPEL_QUERY_CONTEXT;

StringBuilder template = new StringBuilder("MATCH (user:User) WHERE ");
String query;
SpelQueryContext.SpelExtractor spelExtractor;

class R implements LiteralReplacement {
private final String value;

R(String value) {
this.value = value;
}

@Override
public String getValue() {
return value;
}

@Override
public Target getTarget() {
return Target.UNSPECIFIED;
}
}

Map<String, Object> parameters = new HashMap<>();
for (int i = 0; i <= 20; ++i) {
template.append("user.name = :#{#searchUser.name} OR ");
parameters.put("__SpEL__" + i, new R("'x" + i + "'"));
}
template.delete(template.length() - 4, template.length());
spelExtractor = spelQueryContext.parse(template.toString());
query = spelExtractor.getQueryString();
Neo4jQuerySupport.QueryContext qc = new Neo4jQuerySupport.QueryContext("n/a", query, parameters);
assertThat(qc.query).isEqualTo(
"MATCH (user:User) WHERE user.name = 'x0' OR user.name = 'x1' OR user.name = 'x2' OR user.name = 'x3' OR user.name = 'x4' OR user.name = 'x5' OR user.name = 'x6' OR user.name = 'x7' OR user.name = 'x8' OR user.name = 'x9' OR user.name = 'x10' OR user.name = 'x11' OR user.name = 'x12' OR user.name = 'x13' OR user.name = 'x14' OR user.name = 'x15' OR user.name = 'x16' OR user.name = 'x17' OR user.name = 'x18' OR user.name = 'x19' OR user.name = 'x20'");
}

@Test // GH-2279
void shouldQuoteParameterExpressionsCorrectly() {

Expand Down

0 comments on commit b401c30

Please sign in to comment.