From cadd61a4a0f3277054d924a97ee5166f6dfe8d56 Mon Sep 17 00:00:00 2001 From: Evgeniy Cheban Date: Mon, 13 Jan 2025 12:03:36 +0200 Subject: [PATCH] Consider supporting Pageable for String query Closes gh-41 --- .../query/ProjectingResultIterator.java | 184 ++++++++++++++ .../query/ReindexerParameterAccessor.java | 45 ++++ .../query/ReindexerQueryCreator.java | 5 +- .../query/ReindexerQueryMethod.java | 3 +- .../query/ReindexerRepositoryQuery.java | 152 ------------ .../StringBasedReindexerRepositoryQuery.java | 227 +++++++++++++----- .../repository/ReindexerRepositoryTests.java | 154 ++++++++++++ 7 files changed, 547 insertions(+), 223 deletions(-) create mode 100644 src/main/java/org/springframework/data/reindexer/repository/query/ProjectingResultIterator.java create mode 100644 src/main/java/org/springframework/data/reindexer/repository/query/ReindexerParameterAccessor.java diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ProjectingResultIterator.java b/src/main/java/org/springframework/data/reindexer/repository/query/ProjectingResultIterator.java new file mode 100644 index 0000000..207e4c4 --- /dev/null +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ProjectingResultIterator.java @@ -0,0 +1,184 @@ +/* + * Copyright 2022 evgeniycheban + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.reindexer.repository.query; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +import ru.rt.restream.reindexer.AggregationResult; +import ru.rt.restream.reindexer.AggregationResult.Facet; +import ru.rt.restream.reindexer.Query; +import ru.rt.restream.reindexer.ResultIterator; +import ru.rt.restream.reindexer.util.BeanPropertyUtils; + +import org.springframework.core.convert.ConversionService; +import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.data.mapping.PreferredConstructor; +import org.springframework.data.mapping.model.PreferredConstructorDiscoverer; +import org.springframework.data.repository.query.ReturnedType; +import org.springframework.data.util.ReflectionUtils; +import org.springframework.util.Assert; + +/** + * For internal use only, as this contract is likely to change. + * + * @author Evgeniy Cheban + */ +final class ProjectingResultIterator implements ResultIterator { + + private static final Map, Constructor> cache = new ConcurrentHashMap<>(); + + private final ResultIterator delegate; + + private final ReturnedType projectionType; + + private final AggregationResult aggregationFacet; + + private final Map> distinctAggregationResults; + + private final ConversionService conversionService = DefaultConversionService.getSharedInstance(); + + private int aggregationPosition; + + ProjectingResultIterator(Query query, ReturnedType projectionType) { + this(query.execute(), projectionType); + } + + ProjectingResultIterator(ResultIterator delegate, ReturnedType projectionType) { + this.delegate = delegate; + this.projectionType = projectionType; + this.aggregationFacet = getAggregationFacet(); + this.distinctAggregationResults = getDistinctAggregationResults(); + } + + @Override + public long getTotalCount() { + return this.delegate.getTotalCount(); + } + + @Override + public long size() { + return this.delegate.size(); + } + + @Override + public List aggResults() { + return this.delegate.aggResults(); + } + + @Override + public void close() { + this.delegate.close(); + } + + @Override + public boolean hasNext() { + return this.delegate.hasNext() || this.aggregationFacet != null && this.aggregationPosition < this.aggregationFacet.getFacets().size(); + } + + @Override + public Object next() { + if (this.aggregationFacet != null && !this.distinctAggregationResults.isEmpty()) { + Object item = null; + Object[] arguments = null; + List fields = this.aggregationFacet.getFields(); + if (this.projectionType.needsCustomConstruction() && !this.projectionType.getReturnedType().isInterface()) { + arguments = new Object[fields.size()]; + } + else { + try { + item = this.projectionType.getDomainType().getDeclaredConstructor().newInstance(); + } + catch (NoSuchMethodException | InvocationTargetException | + InstantiationException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + for (int i = 0; i < fields.size(); i++) { + String field = fields.get(i); + Facet facet = this.aggregationFacet.getFacets().get(this.aggregationPosition); + if (i < facet.getValues().size() && this.distinctAggregationResults.get(field).remove(facet.getValues().get(i))) { + Object value = this.conversionService.convert(facet.getValues().get(i), ReflectionUtils.findRequiredField(this.projectionType.getDomainType(), field).getType()); + if (arguments != null) { + arguments[i] = value; + } + else { + BeanPropertyUtils.setProperty(item, field, value); + } + } + else { + this.aggregationPosition++; + return null; + } + } + this.aggregationPosition++; + if (item != null) { + return item; + } + return constructTargetObject(arguments); + } + Object item = this.delegate.next(); + if (this.projectionType.needsCustomConstruction() && !this.projectionType.getReturnedType().isInterface()) { + List properties = this.projectionType.getInputProperties(); + Object[] values = new Object[properties.size()]; + for (int i = 0; i < properties.size(); i++) { + values[i] = BeanPropertyUtils.getProperty(item, properties.get(i)); + } + return constructTargetObject(values); + } + return item; + } + + private Object constructTargetObject(Object[] values) { + Constructor constructor = cache.computeIfAbsent(this.projectionType.getReturnedType(), (type) -> { + PreferredConstructor preferredConstructor = PreferredConstructorDiscoverer.discover(type); + Assert.state(preferredConstructor != null, () -> "No preferred constructor found for " + type); + return preferredConstructor.getConstructor(); + }); + try { + return constructor.newInstance(values); + } + catch (InvocationTargetException | InstantiationException | + IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private Map> getDistinctAggregationResults() { + Map> result = new HashMap<>(); + for (AggregationResult aggregationResult : aggResults()) { + if ("distinct".equals(aggregationResult.getType())) { + result.put(aggregationResult.getFields().get(0), new HashSet<>(aggregationResult.getDistincts())); + } + } + return result; + } + + private AggregationResult getAggregationFacet() { + for (AggregationResult aggregationResult : aggResults()) { + if ("facet".equals(aggregationResult.getType())) { + return aggregationResult; + } + } + return null; + } +} diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerParameterAccessor.java b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerParameterAccessor.java new file mode 100644 index 0000000..c90bdee --- /dev/null +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerParameterAccessor.java @@ -0,0 +1,45 @@ +/* + * Copyright 2022 evgeniycheban + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.reindexer.repository.query; + +import org.springframework.data.repository.query.Parameters; +import org.springframework.data.repository.query.ParametersParameterAccessor; + +/** + * For internal use only, as this contract is likely to change. + * + * @author Evgeniy Cheban + */ +final class ReindexerParameterAccessor extends ParametersParameterAccessor { + + /** + * Creates a new {@link ParametersParameterAccessor}. + * + * @param parameters must not be {@literal null}. + * @param values must not be {@literal null}. + */ + ReindexerParameterAccessor(Parameters parameters, Object[] values) { + super(parameters, values); + } + + /** + * {@inheritDoc} + */ + @Override + protected Object[] getValues() { + return super.getValues(); + } +} diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryCreator.java b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryCreator.java index ff27e4d..6b18dd6 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryCreator.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryCreator.java @@ -174,9 +174,6 @@ protected Query complete(Query criteria, Sort sort) { } criteria.aggregateFacet(this.returnedType.getInputProperties().toArray(String[]::new)); } - else { - criteria.aggregateDistinct(this.entityInformation.getIdFieldName()); - } } if (this.returnedType.needsCustomConstruction()) { criteria.select(this.returnedType.getInputProperties().toArray(String[]::new)); @@ -214,7 +211,7 @@ else if (this.tree.isExistsProjection()) { return criteria; } - private int getOffsetAsInteger(Pageable pageable) { + static int getOffsetAsInteger(Pageable pageable) { if (pageable.getOffset() > Integer.MAX_VALUE) { throw new InvalidDataAccessApiUsageException("Page offset exceeds Integer.MAX_VALUE (" + Integer.MAX_VALUE + ")"); } diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryMethod.java b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryMethod.java index 58365e1..f35d4cb 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryMethod.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerQueryMethod.java @@ -130,7 +130,8 @@ public String getQuery() { * * @return true, if the query is for UPDATE */ - public boolean isUpdateQuery() { + @Override + public boolean isModifyingQuery() { Query query = this.queryAnnotationExtractor.get(); return query.update(); } diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java index 7370b41..b7241af 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java @@ -15,47 +15,31 @@ */ package org.springframework.data.reindexer.repository.query; -import java.lang.reflect.Constructor; -import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.Spliterator; import java.util.Spliterators; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.StreamSupport; -import ru.rt.restream.reindexer.AggregationResult; -import ru.rt.restream.reindexer.AggregationResult.Facet; import ru.rt.restream.reindexer.Namespace; -import ru.rt.restream.reindexer.Query; import ru.rt.restream.reindexer.Reindexer; import ru.rt.restream.reindexer.ReindexerIndex; import ru.rt.restream.reindexer.ReindexerNamespace; import ru.rt.restream.reindexer.ResultIterator; -import ru.rt.restream.reindexer.util.BeanPropertyUtils; -import org.springframework.core.convert.ConversionService; -import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.data.domain.Pageable; -import org.springframework.data.mapping.PreferredConstructor; -import org.springframework.data.mapping.model.PreferredConstructorDiscoverer; import org.springframework.data.reindexer.repository.support.TransactionalNamespace; import org.springframework.data.repository.query.ParameterAccessor; import org.springframework.data.repository.query.ParametersParameterAccessor; import org.springframework.data.repository.query.RepositoryQuery; import org.springframework.data.repository.query.ResultProcessor; -import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.support.PageableExecutionUtils; import org.springframework.data.util.Lazy; -import org.springframework.data.util.ReflectionUtils; import org.springframework.util.Assert; /** @@ -226,140 +210,4 @@ public Object execute(ReindexerQueryCreator queryCreator) { } } } - - private static final class ProjectingResultIterator implements ResultIterator { - - private static final Map, Constructor> cache = new ConcurrentHashMap<>(); - - private final ResultIterator delegate; - - private final ReturnedType projectionType; - - private final AggregationResult aggregationFacet; - - private final Map> distinctAggregationResults; - - private final ConversionService conversionService = DefaultConversionService.getSharedInstance(); - - private int aggregationPosition; - - private ProjectingResultIterator(Query query, ReturnedType projectionType) { - this.delegate = query.execute(); - this.projectionType = projectionType; - this.aggregationFacet = getAggregationFacet(); - this.distinctAggregationResults = getDistinctAggregationResults(); - } - - @Override - public long getTotalCount() { - return this.delegate.getTotalCount(); - } - - @Override - public long size() { - return this.delegate.size(); - } - - @Override - public List aggResults() { - return this.delegate.aggResults(); - } - - @Override - public void close() { - this.delegate.close(); - } - - @Override - public boolean hasNext() { - return this.delegate.hasNext() || this.aggregationFacet != null && this.aggregationPosition < this.aggregationFacet.getFacets().size(); - } - - @Override - public Object next() { - if (!this.distinctAggregationResults.isEmpty() && this.aggregationFacet != null && this.projectionType != null) { - Object item = null; - Object[] arguments = null; - List fields = this.aggregationFacet.getFields(); - if (this.projectionType.getReturnedType().isInterface()) { - try { - item = this.projectionType.getDomainType().getDeclaredConstructor().newInstance(); - } - catch (NoSuchMethodException | InvocationTargetException | - InstantiationException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - else { - arguments = new Object[fields.size()]; - } - for (int i = 0; i < fields.size(); i++) { - String field = fields.get(i); - Facet facet = this.aggregationFacet.getFacets().get(this.aggregationPosition); - if (i < facet.getValues().size() && this.distinctAggregationResults.get(field).remove(facet.getValues().get(i))) { - Object value = this.conversionService.convert(facet.getValues().get(i), ReflectionUtils.findRequiredField(this.projectionType.getDomainType(), field).getType()); - if (arguments != null) { - arguments[i] = value; - } - else { - BeanPropertyUtils.setProperty(item, field, value); - } - } - else { - this.aggregationPosition++; - return null; - } - } - this.aggregationPosition++; - if (item != null) { - return item; - } - return constructTargetObject(arguments); - } - Object item = this.delegate.next(); - if (this.projectionType != null && this.projectionType.needsCustomConstruction() - && !this.projectionType.getReturnedType().isInterface()) { - List properties = this.projectionType.getInputProperties(); - Object[] values = new Object[properties.size()]; - for (int i = 0; i < properties.size(); i++) { - values[i] = BeanPropertyUtils.getProperty(item, properties.get(i)); - } - return constructTargetObject(values); - } - return item; - } - - private Object constructTargetObject(Object[] values) { - Constructor constructor = cache.computeIfAbsent(this.projectionType.getReturnedType(), (type) -> { - PreferredConstructor preferredConstructor = PreferredConstructorDiscoverer.discover(type); - Assert.state(preferredConstructor != null, () -> "No preferred constructor found for " + type); - return preferredConstructor.getConstructor(); - }); - try { - return constructor.newInstance(values); - } - catch (InvocationTargetException | InstantiationException | IllegalAccessException e) { - throw new RuntimeException(e); - } - } - - private Map> getDistinctAggregationResults() { - Map> result = new HashMap<>(); - for (AggregationResult aggregationResult : aggResults()) { - if ("distinct".equals(aggregationResult.getType())) { - result.put(aggregationResult.getFields().get(0), new HashSet<>(aggregationResult.getDistincts())); - } - } - return result; - } - - private AggregationResult getAggregationFacet() { - for (AggregationResult aggregationResult : aggResults()) { - if ("facet".equals(aggregationResult.getType())) { - return aggregationResult; - } - } - return null; - } - } } diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java b/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java index 87a6b2d..307857a 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java @@ -15,29 +15,40 @@ */ package org.springframework.data.reindexer.repository.query; +import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; +import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Spliterator; import java.util.Spliterators; -import java.util.function.Supplier; +import java.util.StringJoiner; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Stream; import java.util.stream.StreamSupport; import ru.rt.restream.reindexer.Namespace; import ru.rt.restream.reindexer.Reindexer; -import ru.rt.restream.reindexer.ResultIterator; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Sort.Order; import org.springframework.data.expression.ValueExpressionParser; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.QueryMethodValueEvaluationContextAccessor; import org.springframework.data.repository.query.RepositoryQuery; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.repository.query.ValueExpressionQueryRewriter; import org.springframework.data.repository.query.ValueExpressionQueryRewriter.QueryExpressionEvaluator; +import org.springframework.data.support.PageableExecutionUtils; +import org.springframework.data.util.Lazy; import org.springframework.util.Assert; /** @@ -49,6 +60,8 @@ public class StringBasedReindexerRepositoryQuery implements RepositoryQuery { private static final String EXPRESSION_PARAMETER_PREFIX = "__$synthetic$__"; + private static final Pattern LIMIT_PATTERN = Pattern.compile("(?i)\\bLIMIT\\s+(\\d+)"); + private final ReindexerQueryMethod queryMethod; private final Namespace namespace; @@ -57,6 +70,8 @@ public class StringBasedReindexerRepositoryQuery implements RepositoryQuery { private final Map namedParameters; + private final Lazy queryExecution; + /** * Creates an instance. * @@ -67,6 +82,7 @@ public class StringBasedReindexerRepositoryQuery implements RepositoryQuery { */ public StringBasedReindexerRepositoryQuery(ReindexerQueryMethod queryMethod, ReindexerEntityInformation entityInformation, QueryMethodValueEvaluationContextAccessor accessor, Reindexer reindexer) { + validate(queryMethod); this.queryMethod = queryMethod; this.namespace = reindexer.openNamespace(entityInformation.getNamespaceName(), entityInformation.getNamespaceOptions(), entityInformation.getJavaType()); @@ -77,6 +93,47 @@ public StringBasedReindexerRepositoryQuery(ReindexerQueryMethod queryMethod, Rei parameter.getName().ifPresent(name -> this.namedParameters.put(name, parameter.getIndex())); } } + this.queryExecution = Lazy.of(() -> { + if (queryMethod.isCollectionQuery()) { + return this::toList; + } + if (queryMethod.isPageQuery()) { + return (parameters, returnedType) -> { + try (ProjectingResultIterator it = toIterator(parameters, returnedType)) { + return PageableExecutionUtils.getPage(toList(it), parameters.getPageable(), it::getTotalCount); + } + }; + } + if (queryMethod.isStreamQuery()) { + return this::toStream; + } + if (queryMethod.isIteratorQuery()) { + return this::toIterator; + } + if (this.queryMethod.isModifyingQuery()) { + return (parameters, returnedType) -> { + this.namespace.updateSql(prepareQuery(parameters)); + return null; + }; + } + return (parameters, returnedType) -> { + Object entity = toEntity(parameters, returnedType); + if (queryMethod.isOptionalQuery()) { + return Optional.ofNullable(entity); + } + Assert.state(entity != null, "Exactly one item expected, but there is zero"); + return entity; + }; + }); + } + + private void validate(ReindexerQueryMethod queryMethod) { + if (queryMethod.isPageQuery()) { + String query = queryMethod.getQuery().toLowerCase(); + if (!query.contains("count(*)") && !query.contains("count_cached(*)")) { + throw new InvalidDataAccessApiUsageException("Page query must contain COUNT or COUNT_CACHED expression for method: " + queryMethod); + } + } } private QueryExpressionEvaluator createQueryExpressionEvaluator(QueryMethodValueEvaluationContextAccessor accessor) { @@ -87,34 +144,53 @@ private QueryExpressionEvaluator createQueryExpressionEvaluator(QueryMethodValue @Override public Object execute(Object[] parameters) { - if (this.queryMethod.isUpdateQuery()) { - this.namespace.updateSql(prepareQuery(parameters)); - return null; - } - if (this.queryMethod.isIteratorQuery()) { - return this.namespace.execSql(prepareQuery(parameters)); - } - if (this.queryMethod.isStreamQuery()) { - return toStream(this.namespace.execSql(prepareQuery(parameters))); - } - if (this.queryMethod.isListQuery()) { - return toCollection(this.namespace.execSql(prepareQuery(parameters)), ArrayList::new); - } - if (this.queryMethod.isSetQuery()) { - return toCollection(this.namespace.execSql(prepareQuery(parameters)), HashSet::new); + ReindexerParameterAccessor parameterAccessor = new ReindexerParameterAccessor(this.queryMethod.getParameters(), parameters); + ResultProcessor resultProcessor = this.queryMethod.getResultProcessor().withDynamicProjection(parameterAccessor); + Object result = this.queryExecution.get().execute(parameterAccessor, resultProcessor.getReturnedType()); + return resultProcessor.processResult(result); + } + + private Stream toStream(ReindexerParameterAccessor parameters, ReturnedType returnedType) { + ProjectingResultIterator iterator = toIterator(parameters, returnedType); + Spliterator spliterator = Spliterators.spliterator(iterator, iterator.size(), Spliterator.NONNULL); + return StreamSupport.stream(spliterator, false); + } + + private List toList(ReindexerParameterAccessor parameters, ReturnedType returnedType) { + try (ProjectingResultIterator it = toIterator(parameters, returnedType)) { + return toList(it); } - if (this.queryMethod.isOptionalQuery()) { - return toOptionalEntity(this.namespace.execSql(prepareQuery(parameters))); + } + + private List toList(ProjectingResultIterator iterator) { + List result = new ArrayList<>(); + while (iterator.hasNext()) { + result.add(iterator.next()); } - if (this.queryMethod.isQueryForEntity()) { - return toEntity(this.namespace.execSql(prepareQuery(parameters))); + return result; + } + + private Object toEntity(ReindexerParameterAccessor parameters, ReturnedType returnedType) { + Object item = null; + try (ProjectingResultIterator it = toIterator(parameters, returnedType)) { + if (it.hasNext()) { + item = it.next(); + } + if (it.hasNext()) { + throw new IllegalStateException("Exactly one item expected, but there are more"); + } } - throw new IllegalStateException("Unsupported method return type " + this.queryMethod.getReturnedObjectType()); + return item; + } + + private ProjectingResultIterator toIterator(ReindexerParameterAccessor parameters, ReturnedType returnedType) { + String preparedQuery = prepareQuery(parameters); + return new ProjectingResultIterator(this.namespace.execSql(preparedQuery), returnedType); } - private String prepareQuery(Object[] parameters) { - Map parameterMap = this.queryEvaluator.evaluate(parameters); - StringBuilder result = new StringBuilder(this.queryEvaluator.getQueryString()); + private String prepareQuery(ReindexerParameterAccessor parameters) { + Map parameterMap = this.queryEvaluator.evaluate(parameters.getValues()); + StringBuilder result = new StringBuilder(this.queryEvaluator.getQueryString().toLowerCase()); char[] queryParts = this.queryEvaluator.getQueryString().toCharArray(); int offset = 0; for (int i = 1; i < queryParts.length; i++) { @@ -134,7 +210,7 @@ private String prepareQuery(Object[] parameters) { if (c == ':') { Integer index = this.namedParameters.get(parameterReference); Assert.notNull(index, () -> "No parameter found for name: " + parameterReference); - value = parameters[index]; + value = parameters.getBindableValue(index); } else { int index; @@ -144,7 +220,7 @@ private String prepareQuery(Object[] parameters) { catch (NumberFormatException e) { throw new IllegalStateException("Invalid parameter reference: " + parameterReference + " at index: " + i); } - value = parameters[index - 1]; + value = parameters.getBindableValue(index - 1); } } String valueString = getParameterValuePart(value); @@ -154,6 +230,46 @@ private String prepareQuery(Object[] parameters) { } } } + if (result.indexOf("order by") == -1) { + Sort sort = parameters.getSort(); + if (sort.isSorted()) { + result.append(" order by "); + for (Iterator orderIterator = sort.iterator(); orderIterator.hasNext(); ) { + Order order = orderIterator.next(); + result.append(order.getProperty()).append(" ").append(order.getDirection()); + if (orderIterator.hasNext()) { + result.append(", "); + } + } + } + } + Pageable pageable = parameters.getPageable(); + if (pageable.isPaged()) { + Matcher limitMatcher = LIMIT_PATTERN.matcher(result); + int maxResults; + if (limitMatcher.find()) { + maxResults = Integer.parseInt(limitMatcher.group(1)); + } + else { + maxResults = pageable.getPageSize(); + result.append(" limit ").append(maxResults); + } + if (result.indexOf("offset") == -1) { + int firstResult = ReindexerQueryCreator.getOffsetAsInteger(pageable); + if (firstResult > 0) { + /* + * In order to return the correct results, we have to adjust the first result offset to be returned if: + * - a Pageable parameter is present + * - AND the requested page number > 0 + * - AND the requested page size was bigger than the derived result limitation via the First/Top keyword. + */ + if (pageable.getPageSize() > maxResults) { + firstResult = firstResult - (pageable.getPageSize() - maxResults); + } + result.append(" offset ").append(firstResult); + } + } + } return result.toString(); } @@ -161,52 +277,31 @@ private String getParameterValuePart(Object value) { if (value instanceof String) { return "'" + value + "'"; } - return String.valueOf(value); - } - - private Stream toStream(ResultIterator iterator) { - Spliterator spliterator = Spliterators.spliterator(iterator, iterator.size(), Spliterator.NONNULL); - return StreamSupport.stream(spliterator, false); - } - - private Collection toCollection(ResultIterator iterator, Supplier> collectionSupplier) { - Collection result = collectionSupplier.get(); - try (ResultIterator it = iterator) { - while (it.hasNext()) { - result.add(it.next()); + if (value instanceof Collection values) { + StringJoiner joiner = new StringJoiner(", ", "(", ")"); + for (Object object : values) { + joiner.add(getParameterValuePart(object)); } + return joiner.toString(); } - return result; - } - - private Optional toOptionalEntity(ResultIterator iterator) { - T item = getOneInternal(iterator); - return Optional.ofNullable(item); - } - - private T toEntity(ResultIterator iterator) { - T item = getOneInternal(iterator); - if (item == null) { - throw new IllegalStateException("Exactly one item expected, but there is zero"); - } - return item; - } - - private T getOneInternal(ResultIterator iterator) { - T item = null; - try (ResultIterator it = iterator) { - if (it.hasNext()) { - item = it.next(); - } - if (it.hasNext()) { - throw new IllegalStateException("Exactly one item expected, but there are more"); + if (value != null && value.getClass().isArray()) { + StringJoiner joiner = new StringJoiner(", ", "(", ")"); + int length = Array.getLength(value); + for (int i = 0; i < length; i++) { + joiner.add(getParameterValuePart(Array.get(value, i))); } + return joiner.toString(); } - return item; + return String.valueOf(value); } @Override public QueryMethod getQueryMethod() { return this.queryMethod; } + + @FunctionalInterface + private interface QueryExecution { + Object execute(ReindexerParameterAccessor parameters, ReturnedType returnedType); + } } diff --git a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java index e7945eb..c650c47 100644 --- a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java +++ b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java @@ -66,6 +66,7 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; +import org.springframework.data.domain.Sort.Order; import org.springframework.data.reindexer.ReindexerTransactionManager; import org.springframework.data.reindexer.core.mapping.Namespace; import org.springframework.data.reindexer.core.mapping.Query; @@ -1215,6 +1216,135 @@ public void findByActiveIsFalse() { assertThat(foundItems.stream().map(TestItem::getId).toList()).containsOnly(3L, 4L); } + @Test + public void findAllItemProjectionByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findAllItemProjectionByIdIn(expectedItems.stream() + .map(TestItem::getId) + .toList(), Sort.by(Direction.DESC, "id")); + assertThat(foundItems).hasSameSizeAs(expectedItems); + for (int i = 0; i < foundItems.size(); i++) { + TestItemProjection foundItem = foundItems.get(i); + TestItem expectedItem = expectedItems.get(expectedItems.size() - 1 - i); + assertThat(foundItem.getId()).isEqualTo(expectedItem.getId()); + assertThat(foundItem.getName()).isEqualTo(expectedItem.getName()); + } + } + + @Test + public void findAllItemDtoByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findAllItemDtoByIdIn(expectedItems.stream() + .map(TestItem::getId) + .toList(), Sort.by(Direction.ASC, "id")); + assertThat(foundItems).hasSameSizeAs(expectedItems); + for (int i = 0; i < foundItems.size(); i++) { + TestItemDto foundItem = foundItems.get(i); + TestItem expectedItem = expectedItems.get(expectedItems.size() - 1 - i); + assertThat(foundItem.getId()).isEqualTo(expectedItem.getId()); + assertThat(foundItem.getName()).isEqualTo(expectedItem.getName()); + } + } + + @Test + public void findAllItemRecordByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findAllItemRecordByIdIn(expectedItems.stream() + .map(TestItem::getId) + .toList()); + assertThat(foundItems).hasSameSizeAs(expectedItems); + for (int i = 0; i < foundItems.size(); i++) { + TestItemRecord foundItem = foundItems.get(i); + TestItem expectedItem = expectedItems.get(i); + assertThat(foundItem.id()).isEqualTo(expectedItem.getId()); + assertThat(foundItem.name()).isEqualTo(expectedItem.getName()); + } + } + + @Test + public void findAllCountByIdInPageable() { + Set expectedItems = new HashSet<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + Pageable pageable = PageRequest.of(0, 5, Sort.by(Order.desc("id"), Order.asc("name"))); + List expectedIds = expectedItems.stream() + .map(TestItem::getId) + .toList(); + do { + Page foundItems = this.repository.findAllCountByIdIn(expectedIds, pageable); + for (TestItem item : foundItems) { + assertTrue(expectedItems.remove(item)); + } + pageable = foundItems.nextPageable(); + } while (pageable.isPaged()); + assertEquals(0, expectedItems.size()); + } + + @Test + public void findAllCountCachedByIdInPageable() { + Set expectedItems = new HashSet<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + Pageable pageable = PageRequest.of(0, 5, Sort.by(Order.desc("id"), Order.asc("name"))); + List expectedIds = expectedItems.stream() + .map(TestItem::getId) + .toList(); + do { + Page foundItems = this.repository.findAllCountCachedByIdIn(expectedIds, pageable); + for (TestItem item : foundItems) { + assertTrue(expectedItems.remove(item)); + } + pageable = foundItems.nextPageable(); + } while (pageable.isPaged()); + assertEquals(0, expectedItems.size()); + } + + @Test + public void findFirst2Sql() { + TestItem item1 = this.repository.save(new TestItem(1L, "TestName1", "TestValue1")); + TestItem item2 = this.repository.save(new TestItem(2L, "TestName2", "TestValue2")); + TestItem item3 = this.repository.save(new TestItem(3L, "TestName3", "TestValue3")); + Page firstPage = this.repository.findFirst2Sql(PageRequest.of(0, 3, Direction.ASC, "id")); + assertThat(firstPage.getContent()).contains(item1, item2); + Page secondPage = this.repository.findFirst2Sql(PageRequest.of(1, 3, Direction.ASC, "id")); + assertThat(secondPage).contains(item3); + } + + @Test + public void findFirst3Sql() { + TestItem item1 = this.repository.save(new TestItem(1L, "TestName1", "TestValue1")); + TestItem item2 = this.repository.save(new TestItem(2L, "TestName2", "TestValue2")); + TestItem item3 = this.repository.save(new TestItem(3L, "TestName3", "TestValue3")); + Page firstPage = this.repository.findFirst3Sql(PageRequest.of(0, 2, Direction.ASC, "id")); + assertThat(firstPage.getContent()).contains(item1, item2); + Page secondPage = this.repository.findFirst3Sql(PageRequest.of(1, 2, Direction.ASC, "id")); + assertThat(secondPage).contains(item3); + } + + @Test + public void findAllSqlLimit() { + Set expectedItems = new HashSet<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findAllSqlLimit(Limit.of(10)); + for (TestItem item : foundItems) { + assertTrue(expectedItems.remove(item)); + } + assertEquals(90, expectedItems.size()); + } + @Configuration @EnableReindexerRepositories(basePackageClasses = TestItemReindexerRepository.class, considerNestedRepositories = true) @EnableTransactionManagement @@ -1410,6 +1540,30 @@ Optional findOneSqlByNameAndValueManyParams(String name1, String name2 List findByActiveIsTrue(); List findByActiveIsFalse(); + + @Query("SELECT id, name FROM items WHERE id IN :ids") + List findAllItemProjectionByIdIn(List ids, Sort sort); + + @Query("SELECT id, name FROM items WHERE id IN :ids ORDER BY id DESC") + List findAllItemDtoByIdIn(List ids, Sort sort); + + @Query("SELECT id, name FROM items WHERE id IN :ids") + List findAllItemRecordByIdIn(List ids); + + @Query("SELECT *, COUNT(*) FROM items WHERE id IN :ids") + Page findAllCountByIdIn(List ids, Pageable pageable); + + @Query("SELECT *, COUNT_CACHED(*) FROM items WHERE id IN :ids") + Page findAllCountCachedByIdIn(List ids, Pageable pageable); + + @Query("SELECT *, COUNT_CACHED(*) FROM items LIMIT 2") + Page findFirst2Sql(Pageable pageable); + + @Query("SELECT *, COUNT_CACHED(*) FROM items LIMIT 3") + Page findFirst3Sql(Pageable pageable); + + @Query("SELECT * FROM items") + List findAllSqlLimit(Limit limit); } @Namespace(name = NAMESPACE_NAME)