diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java index 10a2b2459fad..9aaf2208808b 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/aggregation/ImplementSum.java @@ -29,15 +29,16 @@ import java.util.function.Function; import static io.trino.matching.Capture.newCapture; -import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.hasFilter; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.hasSortOrder; import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument; import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.variable; import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** - * Implements {@code sum(x)} + * Implements {@code sum([DISTINCT] x)} */ public class ImplementSum implements AggregateFunctionRule @@ -54,7 +55,9 @@ public ImplementSum(Function> decimalTypeH @Override public Pattern getPattern() { - return basicAggregation() + return Pattern.typeOf(AggregateFunction.class) + .with(hasSortOrder().equalTo(false)) + .with(hasFilter().equalTo(false)) .with(functionName().equalTo("sum")) .with(singleArgument().matching(variable().capturedAs(ARGUMENT))); } @@ -81,8 +84,9 @@ else if (aggregateFunction.getOutputType() instanceof DecimalType) { } ParameterizedExpression rewrittenArgument = context.rewriteExpression(argument).orElseThrow(); + String function = aggregateFunction.isDistinct() ? "sum(DISTINCT %s)" : "sum(%s)"; return Optional.of(new JdbcExpression( - format("sum(%s)", rewrittenArgument.expression()), + format(function, rewrittenArgument.expression()), rewrittenArgument.parameters(), resultTypeHandle)); } diff --git a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java index d65289570cb0..87589b66d66a 100644 --- a/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java +++ b/plugin/trino-base-jdbc/src/test/java/io/trino/plugin/jdbc/BaseJdbcConnectorTest.java @@ -316,6 +316,7 @@ public void testCaseSensitiveAggregationPushdown() boolean supportsPushdownWithVarcharInequality = hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY); boolean supportsCountDistinctPushdown = hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT); + boolean supportsSumDistinctPushdown = hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN); PlanMatchPattern aggregationOverTableScan = node(AggregationNode.class, node(TableScanNode.class)); PlanMatchPattern groupingAggregationOverTableScan = node(AggregationNode.class, node(ProjectNode.class, node(TableScanNode.class))); @@ -325,7 +326,7 @@ public void testCaseSensitiveAggregationPushdown() "(a_string varchar(1), a_char char(1), a_bigint bigint)", ImmutableList.of( "'A', 'A', 1", - "'B', 'B', 2", + "'B', 'B', 1", "'a', 'a', 3", "'b', 'b', 4"))) { // case-sensitive functions prevent pushdown @@ -388,17 +389,36 @@ public void testCaseSensitiveAggregationPushdown() supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown, node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) .skippingTypesCheck() - .matches("VALUES (BIGINT '4', BIGINT '4')"); + .matches("VALUES (BIGINT '4', BIGINT '3')"); assertConditionallyPushedDown(getSession(), "SELECT count(DISTINCT a_char), count(DISTINCT a_bigint) FROM " + table.getName(), supportsPushdownWithVarcharInequality && supportsCountDistinctPushdown, node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) .skippingTypesCheck() - .matches("VALUES (BIGINT '4', BIGINT '4')"); + .matches("VALUES (BIGINT '4', BIGINT '3')"); + + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_string), sum(DISTINCT a_bigint) FROM " + table.getName(), + supportsPushdownWithVarcharInequality && supportsSumDistinctPushdown, + node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) + .skippingTypesCheck() + .matches(sumDistinctAggregationPushdownExpectedResult()); + + assertConditionallyPushedDown(getSession(), + "SELECT count(DISTINCT a_char), sum(DISTINCT a_bigint) FROM " + table.getName(), + supportsPushdownWithVarcharInequality && supportsSumDistinctPushdown, + node(ExchangeNode.class, node(AggregationNode.class, anyTree(node(TableScanNode.class))))) + .skippingTypesCheck() + .matches(sumDistinctAggregationPushdownExpectedResult()); } } + protected String sumDistinctAggregationPushdownExpectedResult() + { + return "VALUES (BIGINT '4', BIGINT '8')"; + } + @Test public void testAggregationWithUnsupportedResultType() { @@ -453,12 +473,22 @@ public void testDistinctAggregationPushdown() "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation", hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); + assertConditionallyPushedDown( + withMarkDistinct, + "SELECT sum(DISTINCT regionkey), sum(DISTINCT nationkey) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN), + node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); // distinct aggregation and a non-distinct aggregation assertConditionallyPushedDown( withMarkDistinct, "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation", hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); + assertConditionallyPushedDown( + withMarkDistinct, + "SELECT sum(DISTINCT regionkey), count(nationkey) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN), + node(MarkDistinctNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(ProjectNode.class, node(TableScanNode.class)))))); Session withoutMarkDistinct = Session.builder(getSession()) .setSystemProperty(USE_MARK_DISTINCT, "false") @@ -479,12 +509,23 @@ public void testDistinctAggregationPushdown() "SELECT count(DISTINCT regionkey), count(DISTINCT nationkey) FROM nation", hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); + assertConditionallyPushedDown( + withoutMarkDistinct, + "SELECT sum(DISTINCT regionkey), sum(DISTINCT nationkey) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN), + node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); + // distinct aggregation and a non-distinct aggregation assertConditionallyPushedDown( withoutMarkDistinct, "SELECT count(DISTINCT regionkey), sum(nationkey) FROM nation", hasBehavior(SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT), node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); + assertConditionallyPushedDown( + withoutMarkDistinct, + "SELECT sum(DISTINCT regionkey), sum(nationkey) FROM nation", + hasBehavior(SUPPORTS_PREDICATE_PUSHDOWN), + node(AggregationNode.class, node(ExchangeNode.class, node(ExchangeNode.class, node(TableScanNode.class))))); } @Test diff --git a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java index 49a2fa3019ff..0b872b95ef5f 100644 --- a/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java +++ b/plugin/trino-ignite/src/test/java/io/trino/plugin/ignite/TestIgniteClient.java @@ -131,7 +131,13 @@ public void testImplementSum() testImplementAggregation( new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); // distinct not supported + Optional.of("sum(DISTINCT `c_bigint`)")); + + // sum(DISTINCT double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), true, Optional.empty()), + Map.of(doubleVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(DISTINCT `c_double`)")); // sum(bigint) FILTER (WHERE ...) testImplementAggregation( diff --git a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java index 2138ef8c382f..b6d967b94e75 100644 --- a/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java +++ b/plugin/trino-mariadb/src/test/java/io/trino/plugin/mariadb/TestMariaDbClient.java @@ -134,7 +134,13 @@ public void testImplementSum() testImplementAggregation( new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); // distinct not supported + Optional.of("sum(DISTINCT `c_bigint`)")); + + // sum(DISTINCT double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), true, Optional.empty()), + Map.of(doubleVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(DISTINCT `c_double`)")); // sum(bigint) FILTER (WHERE ...) testImplementAggregation( diff --git a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java index 7ab9cd37e3f2..ec564f10ea5d 100644 --- a/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java +++ b/plugin/trino-mysql/src/test/java/io/trino/plugin/mysql/TestMySqlClient.java @@ -138,7 +138,13 @@ public void testImplementSum() testImplementAggregation( new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); // distinct not supported + Optional.of("sum(DISTINCT `c_bigint`)")); + + // sum(DISTINCT double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(bigintVariable), List.of(), true, Optional.empty()), + Map.of(bigintVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(DISTINCT `c_double`)")); // sum(bigint) FILTER (WHERE ...) testImplementAggregation( diff --git a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java index ea4b1a4f8208..274336849737 100644 --- a/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java +++ b/plugin/trino-oracle/src/test/java/io/trino/plugin/oracle/BaseOracleConnectorTest.java @@ -528,6 +528,12 @@ protected void verifyColumnNameLengthFailurePermissible(Throwable e) assertThat(e).hasMessage("ORA-00972: identifier is too long\n"); } + @Override + protected String sumDistinctAggregationPushdownExpectedResult() + { + return "VALUES (BIGINT '4', DECIMAL '8')"; + } + private void predicatePushdownTest(String oracleType, String oracleLiteral, String operator, String filterLiteral) { String tableName = ("test_pdown_" + oracleType.replaceAll("[^a-zA-Z0-9]", "")) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index 9fb4eaff9688..8c6f4c78e317 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -194,7 +194,13 @@ public void testImplementSum() testImplementAggregation( new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); // distinct not supported + Optional.of("sum(DISTINCT \"c_bigint\")")); + + // sum(DISTINCT double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(bigintVariable), List.of(), true, Optional.empty()), + Map.of(bigintVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(DISTINCT \"c_double\")")); // sum(bigint) FILTER (WHERE ...) testImplementAggregation( diff --git a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java index 9d04f2de4e30..9dde9aa76aa5 100644 --- a/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java +++ b/plugin/trino-sqlserver/src/test/java/io/trino/plugin/sqlserver/TestSqlServerClient.java @@ -136,7 +136,13 @@ public void testImplementSum() testImplementAggregation( new AggregateFunction("sum", BIGINT, List.of(bigintVariable), List.of(), true, Optional.empty()), Map.of(bigintVariable.getName(), BIGINT_COLUMN), - Optional.empty()); // distinct not supported + Optional.of("sum(DISTINCT \"c_bigint\")")); + + // sum(DISTINCT double) + testImplementAggregation( + new AggregateFunction("sum", DOUBLE, List.of(doubleVariable), List.of(), true, Optional.empty()), + Map.of(doubleVariable.getName(), DOUBLE_COLUMN), + Optional.of("sum(DISTINCT \"c_double\")")); // sum(bigint) FILTER (WHERE ...) testImplementAggregation(