Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: SQL COUNT with GROUP BY to prevent incorrect row returns #33380

Merged
merged 13 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,30 @@

package org.apache.shardingsphere.sharding.merge.dql.groupby;

import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.table.NoSuchTableException;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.Projection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationDistinctProjection;
import org.apache.shardingsphere.infra.binder.context.segment.select.projection.impl.AggregationProjection;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.infra.exception.dialect.exception.syntax.table.NoSuchTableException;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.impl.memory.MemoryMergedResult;
import org.apache.shardingsphere.infra.merge.result.impl.memory.MemoryQueryResultRow;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.database.schema.model.ShardingSphereTable;
import org.apache.shardingsphere.infra.exception.core.ShardingSpherePreconditions;
import org.apache.shardingsphere.sharding.exception.data.NotImplementComparableValueException;
import org.apache.shardingsphere.sharding.merge.dql.groupby.aggregation.AggregationUnit;
import org.apache.shardingsphere.sharding.merge.dql.groupby.aggregation.AggregationUnitFactory;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;

import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -143,22 +138,10 @@ private boolean getValueCaseSensitiveFromTables(final QueryResult queryResult,
private List<MemoryQueryResultRow> getMemoryResultSetRows(final SelectStatementContext selectStatementContext,
final Map<GroupByValue, MemoryQueryResultRow> dataMap, final List<Boolean> valueCaseSensitive) {
if (dataMap.isEmpty()) {
Object[] data = generateReturnData(selectStatementContext);
return Arrays.stream(data).anyMatch(Objects::nonNull) ? Collections.singletonList(new MemoryQueryResultRow(data)) : Collections.emptyList();
return Collections.emptyList();
}
List<MemoryQueryResultRow> result = new ArrayList<>(dataMap.values());
result.sort(new GroupByRowComparator(selectStatementContext, valueCaseSensitive));
return result;
}

private Object[] generateReturnData(final SelectStatementContext selectStatementContext) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Malaydewangan09, can you add a e2e test case for this issue - #4680

List<Projection> projections = new LinkedList<>(selectStatementContext.getProjectionsContext().getExpandProjections());
Object[] result = new Object[projections.size()];
for (int i = 0; i < projections.size(); i++) {
if (projections.get(i) instanceof AggregationProjection && AggregationType.COUNT == ((AggregationProjection) projections.get(i)).getType()) {
result[i] = 0;
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.shardingsphere.infra.binder.context.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.database.core.DefaultDatabase;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.infra.database.core.type.DatabaseType;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
Expand All @@ -33,7 +34,6 @@
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.sharding.merge.dql.ShardingDQLResultMerger;
import org.apache.shardingsphere.sql.parser.statement.core.enums.AggregationType;
import org.apache.shardingsphere.infra.database.core.metadata.database.enums.NullsOrderType;
import org.apache.shardingsphere.sql.parser.statement.core.enums.OrderDirection;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
Expand Down Expand Up @@ -62,7 +62,6 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand All @@ -80,9 +79,6 @@ void assertNextForResultSetsAllEmpty() throws SQLException {
when(database.getName()).thenReturn("db_schema");
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(createQueryResult(), createQueryResult(), createQueryResult()), createSelectStatementContext(), database, mock(ConnectionContext.class));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(0));
assertNull(actual.getValue(2, Object.class));
assertFalse(actual.next());
}

Expand Down Expand Up @@ -217,4 +213,57 @@ void assertNextForDistinctShorthandResultSetsEmpty() throws SQLException {
MergedResult actual = merger.merge(Arrays.asList(queryResult, queryResult, queryResult), createSelectStatementContext(database), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountAndGroupBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult1 = createEmptyQueryResultWithCountGroupBy();
QueryResult queryResult2 = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(queryResult1, queryResult2), createSelectStatementContextForCountGroupBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

@Test
void assertNextForEmptyResultWithCountGroupByDifferentOrderBy() throws SQLException {
when(database.getName()).thenReturn("db_schema");
QueryResult queryResult = createEmptyQueryResultWithCountGroupBy();
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(TypedSPILoader.getService(DatabaseType.class, "MySQL"));
MergedResult actual = resultMerger.merge(Collections.singletonList(queryResult), createSelectStatementContextForCountGroupByDifferentOrderBy(), database, mock(ConnectionContext.class));
assertFalse(actual.next());
}

private QueryResult createEmptyQueryResultWithCountGroupBy() throws SQLException {
QueryResult result = mock(QueryResult.class, RETURNS_DEEP_STUBS);
when(result.getMetaData().getColumnCount()).thenReturn(3);
when(result.getMetaData().getColumnLabel(1)).thenReturn("COUNT(*)");
when(result.getMetaData().getColumnLabel(2)).thenReturn("user_id");
when(result.getMetaData().getColumnLabel(3)).thenReturn("order_id");
when(result.next()).thenReturn(false);
return result;
}

private SelectStatementContext createSelectStatementContextForCountGroupBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}

private SelectStatementContext createSelectStatementContextForCountGroupByDifferentOrderBy() {
SelectStatement selectStatement = new MySQLSelectStatement();
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
projectionsSegment.getProjections().add(new AggregationProjectionSegment(0, 0, AggregationType.COUNT, "COUNT(*)"));
selectStatement.setGroupBy(new GroupBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 2, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setOrderBy(new OrderBySegment(0, 0, Collections.singletonList(new IndexOrderByItemSegment(0, 0, 3, OrderDirection.ASC, NullsOrderType.FIRST))));
selectStatement.setProjections(projectionsSegment);
ShardingSphereDatabase database = mock(ShardingSphereDatabase.class, RETURNS_DEEP_STUBS);
when(database.getSchema(DefaultDatabase.LOGIC_NAME)).thenReturn(mock(ShardingSphereSchema.class));
return new SelectStatementContext(createShardingSphereMetaData(database), Collections.emptyList(), selectStatement, DefaultDatabase.LOGIC_NAME, Collections.emptyList());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,9 @@
scenario-comments="Test single table's LIKE operator underscore wildcard in select group by statement when use sharding feature.|Test encrypt table's LIKE operator underscore wildcard in select group by statement when use encrypt feature.">
<assertion expected-data-source-name="read_dataset" />
</test-case>

<test-case sql="SELECT COUNT(1) FROM t_order WHERE 1 = 2 GROUP BY order_id,user_id ORDER BY user_id" db-types="MySQL,PostgreSQL,openGauss" scenario-types="db,tb"
scenario-comments="Test GROUP BY with ORDER BY different fields returns no rows when WHERE condition matches no data">
<assertion expected-data-source-name="read_dataset" />
</test-case>
</e2e-test-cases>
Loading