Skip to content

Commit

Permalink
Wrap potentially conditional preaggregation with IF expression
Browse files Browse the repository at this point in the history
  • Loading branch information
weiatwork committed Mar 1, 2024
1 parent 4fcaada commit 29fe492
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.trino.Session;
import io.trino.matching.Capture;
Expand All @@ -40,12 +39,13 @@
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.Literal;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.SearchedCaseExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.tree.WhenClause;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand All @@ -54,6 +54,7 @@

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.SystemSessionProperties.isPreAggregateCaseAggregationsEnabled;
import static io.trino.matching.Capture.newCapture;
Expand All @@ -63,6 +64,7 @@
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TinyintType.TINYINT;
import static io.trino.sql.ExpressionUtils.or;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE;
Expand Down Expand Up @@ -282,30 +284,36 @@ private ProjectNode createPreProjection(

private Map<PreAggregationKey, PreAggregation> getPreAggregations(List<CaseAggregation> aggregations, Context context)
{
Set<PreAggregationKey> keys = new HashSet<>();
ImmutableMap.Builder<PreAggregationKey, PreAggregation> preAggregations = ImmutableMap.builder();
for (CaseAggregation aggregation : aggregations) {
PreAggregationKey preAggregationKey = new PreAggregationKey(aggregation);
if (keys.contains(preAggregationKey)) {
continue;
}

// Cast pre-projection if needed to match aggregation input type.
// This is because entire "CASE WHEN" expression could be wrapped in CAST.
Expression preProjection = aggregation.getResult();
Type preProjectionType = getType(context, preProjection);
Type aggregationInputType = getOnlyElement(aggregation.getFunction().getSignature().getArgumentTypes());
if (!preProjectionType.equals(aggregationInputType)) {
preProjection = new Cast(preProjection, toSqlType(aggregationInputType));
preProjectionType = aggregationInputType;
}

Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection, preProjectionType);
Symbol preAggregationSymbol = context.getSymbolAllocator().newSymbol(aggregation.getAggregationSymbol());
preAggregations.put(preAggregationKey, new PreAggregation(preAggregationSymbol, preProjection, preProjectionSymbol));
keys.add(preAggregationKey);
}
return ImmutableMap.copyOf(preAggregations.buildOrThrow());
return aggregations.stream()
.collect(toImmutableSetMultimap(PreAggregationKey::new, identity()))
.asMap().entrySet().stream().collect(toImmutableMap(
Map.Entry::getKey,
entry -> {
PreAggregationKey key = entry.getKey();
Set<CaseAggregation> caseAggregations = (Set<CaseAggregation>) entry.getValue();
Expression preProjection = key.projection;

// Cast pre-projection if needed to match aggregation input type.
// This is because entire "CASE WHEN" expression could be wrapped in CAST.
Type preProjectionType = getType(context, preProjection);
Type aggregationInputType = getOnlyElement(key.getFunction().getSignature().getArgumentTypes());
if (!preProjectionType.equals(aggregationInputType)) {
preProjection = new Cast(preProjection, toSqlType(aggregationInputType));
preProjectionType = aggregationInputType;
}

// Wrap the preProjection with IF to retain the conditional nature on the CASE aggregation(s) during pre-aggregation
if (!(preProjection instanceof SymbolReference || preProjection instanceof Literal)) {
Expression unionConditions = or(caseAggregations.stream()
.map(CaseAggregation::getOperand)
.collect(toImmutableSet()));
preProjection = new IfExpression(unionConditions, preProjection, null);
}

Symbol preProjectionSymbol = context.getSymbolAllocator().newSymbol(preProjection, preProjectionType);
Symbol preAggregationSymbol = context.getSymbolAllocator().newSymbol(caseAggregations.iterator().next().getAggregationSymbol());
return new PreAggregation(preAggregationSymbol, preProjection, preProjectionSymbol);
}));
}

private Optional<List<CaseAggregation>> extractCaseAggregations(AggregationNode aggregationNode, ProjectNode projectNode, Context context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@ public void testPreAggregatesCaseAggregations()
ImmutableMap.of(
Optional.of("SUM_BIGINT"), functionCall("sum", ImmutableList.of("VALUE_BIGINT")),
Optional.of("SUM_INT_CAST"), functionCall("sum", ImmutableList.of("VALUE_INT_CAST")),
Optional.of("MIN_BIGINT"), functionCall("min", ImmutableList.of("VALUE_BIGINT")),
Optional.of("MIN_BIGINT"), functionCall("min", ImmutableList.of("VALUE_2_BIGINT")),
Optional.of("SUM_DECIMAL"), functionCall("sum", ImmutableList.of("COL_DECIMAL")),
Optional.of("SUM_DECIMAL_CAST"), functionCall("sum", ImmutableList.of("VALUE_DECIMAL_CAST"))),
Optional.empty(),
SINGLE,
exchange(
project(ImmutableMap.of(
"KEY", expression("CONCAT(COL_VARCHAR, VARCHAR 'a')"),
"VALUE_BIGINT", expression("COL_BIGINT * BIGINT '2'"),
"VALUE_INT_CAST", expression("CAST(CAST(COL_BIGINT * BIGINT '2' AS INTEGER) AS BIGINT)"),
"VALUE_DECIMAL_CAST", expression("CAST(COL_DECIMAL * CAST(DECIMAL '2' AS DECIMAL(10, 0)) AS BIGINT)")),
"VALUE_BIGINT", expression("(CASE WHEN (COL_BIGINT IN (BIGINT '1', BIGINT '2')) THEN (COL_BIGINT * BIGINT '2') END)"),
"VALUE_INT_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '1') THEN CAST(CAST((COL_BIGINT * BIGINT '2') AS INTEGER) AS bigint) END)"),
"VALUE_2_BIGINT", expression("(CASE WHEN ((COL_BIGINT % BIGINT '2') > BIGINT '1') THEN (COL_BIGINT * BIGINT '2') END)"),
"VALUE_DECIMAL_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '4') THEN CAST((COL_DECIMAL * CAST(DECIMAL '2' AS decimal(10, 0))) AS bigint) END)")),
tableScan(
"t",
ImmutableMap.of(
Expand Down Expand Up @@ -194,16 +195,17 @@ public void testGlobalPreAggregatesCaseAggregations()
ImmutableMap.of(
Optional.of("SUM_BIGINT"), functionCall("sum", ImmutableList.of("VALUE_BIGINT")),
Optional.of("SUM_INT_CAST"), functionCall("sum", ImmutableList.of("VALUE_INT_CAST")),
Optional.of("MIN_BIGINT"), functionCall("min", ImmutableList.of("VALUE_BIGINT")),
Optional.of("MIN_BIGINT"), functionCall("min", ImmutableList.of("VALUE_2_INT_CAST")),
Optional.of("SUM_DECIMAL"), functionCall("sum", ImmutableList.of("COL_DECIMAL")),
Optional.of("SUM_DECIMAL_CAST"), functionCall("sum", ImmutableList.of("VALUE_DECIMAL_CAST"))),
Optional.empty(),
SINGLE,
exchange(
project(ImmutableMap.of(
"VALUE_BIGINT", expression("COL_BIGINT * BIGINT '2'"),
"VALUE_INT_CAST", expression("CAST(CAST(COL_BIGINT * BIGINT '2' AS INTEGER) AS BIGINT)"),
"VALUE_DECIMAL_CAST", expression("CAST(COL_DECIMAL * CAST(DECIMAL '2' AS DECIMAL(10, 0)) AS BIGINT)")),
"VALUE_BIGINT", expression("(CASE WHEN (COL_BIGINT IN (BIGINT '1', BIGINT '2')) THEN (COL_BIGINT * BIGINT '2') END)"),
"VALUE_INT_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '1') THEN CAST(CAST((COL_BIGINT * BIGINT '2') AS INTEGER) AS bigint) END)"),
"VALUE_2_INT_CAST", expression("(CASE WHEN ((COL_BIGINT % BIGINT '2') > BIGINT '1') THEN (COL_BIGINT * BIGINT '2') END)"),
"VALUE_DECIMAL_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '4') THEN CAST((COL_DECIMAL * CAST(DECIMAL '2' AS decimal(10, 0))) AS bigint) END)")),
tableScan(
"t",
ImmutableMap.of(
Expand Down Expand Up @@ -275,8 +277,8 @@ public void testPreAggregatesWithDefaultValues()
SINGLE,
exchange(
project(ImmutableMap.of(
"VALUE_INT_CAST", expression("CAST(CAST(COL_BIGINT AS INTEGER) AS BIGINT)"),
"VALUE_TINYINT_CAST", expression("CAST(COL_TINYINT AS BIGINT)")),
"VALUE_INT_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '2') THEN CAST(CAST(COL_BIGINT AS INTEGER) AS bigint) END)"),
"VALUE_TINYINT_CAST", expression("(CASE WHEN (COL_BIGINT = BIGINT '3') THEN CAST(COL_TINYINT AS bigint) END)")),
tableScan(
"t",
ImmutableMap.of(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.query;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT;

@TestInstance(PER_CLASS)
@Execution(CONCURRENT)
public class TestPreAggregateCaseAggregations
{
private final QueryAssertions assertions = new QueryAssertions();

@AfterAll
public void teardown()
{
assertions.close();
}

@Test
public void testCastExpression()
{
assertThat(assertions.query("SELECT " +
"MAX(CASE WHEN c1 = 1 THEN CAST(c2 AS int) END) AS m1, " +
"MAX(CASE WHEN c1 = 2 THEN c2 END) AS m2, " +
"MAX(CASE WHEN c1 = 3 THEN c2 END) AS m3, " +
"MAX(CASE WHEN c1 = 4 THEN c2 END) AS m4 " +
"FROM (VALUES (1, '1'), (2, '2'), (3, '3'), (4, 'd')) t(c1, c2)"))
.matches("VALUES (1, '2', '3', 'd')");

assertThat(assertions.query("SELECT " +
"MAX(CAST(CASE WHEN c1 = 1 THEN C2 END AS INT)) AS m1, " +
"MAX(CASE WHEN c1 = 2 THEN c2 END) AS m2, " +
"MAX(CASE WHEN c1 = 3 THEN c2 END) AS m3, " +
"MAX(CASE WHEN c1 = 4 THEN c2 END) AS m4 " +
"FROM (VALUES (1, '1'), (2, '2'), (3, '3'), (4, 'd')) t(c1, c2)"))
.matches("VALUES (1, '2', '3', 'd')");

assertThat(assertions.query("SELECT " +
"CAST(MAX(CASE WHEN c1 = 1 THEN C2 END) AS INT) AS m1, " +
"MAX(CASE WHEN c1 = 2 THEN c2 END) AS m2, " +
"MAX(CASE WHEN c1 = 3 THEN c2 END) AS m3, " +
"MAX(CASE WHEN c1 = 4 THEN c2 END) AS m4 " +
"FROM (VALUES (1, '1'), (2, '2'), (3, '3'), (4, 'd')) t(c1, c2)"))
.matches("VALUES (1, '2', '3', 'd')");
}

@Test
public void testDivisionByZero()
{
assertThat(assertions.query("SELECT " +
"MAX(CASE WHEN c1 = '1' THEN 10 / c2 END) AS m1, " +
"MAX(CASE WHEN c1 = '2' THEN 10 / c2 END) AS m2, " +
"MAX(CASE WHEN c1 = '3' THEN c2 END) AS m3, " +
"MAX(CASE WHEN c1 = '4' THEN c2 END) AS m4 " +
"FROM (VALUES ('1', 1), ('2', 2), ('3', 3), ('4', 0)) t(c1, c2)"))
.matches("VALUES (10, 5, 3, 0)");
}
}

0 comments on commit 29fe492

Please sign in to comment.