From 203f072cf2a2115135266368dc6fecb824e871e3 Mon Sep 17 00:00:00 2001 From: jakevin Date: Sat, 18 Nov 2023 23:57:01 +0800 Subject: [PATCH] [fix](Nereids): NullSafeEqual should be in HashJoinCondition #27127 (#27232) --- .../translator/PhysicalPlanTranslator.java | 4 +- .../post/RuntimeFilterGenerator.java | 16 +++-- .../rules/rewrite/EliminateOuterJoin.java | 60 +++++++++++++++++++ .../PushdownExpressionsInHashCondition.java | 7 +-- .../AbstractSelectMaterializedIndexRule.java | 5 +- .../doris/nereids/stats/FilterEstimation.java | 6 +- .../doris/nereids/stats/JoinEstimation.java | 28 ++++----- .../trees/expressions/EqualPredicate.java | 36 +++++++++++ .../nereids/trees/expressions/EqualTo.java | 4 +- .../trees/expressions/NullSafeEqual.java | 11 +--- .../plans/physical/AbstractPhysicalJoin.java | 7 +++ .../apache/doris/nereids/util/JoinUtils.java | 49 +++++---------- 12 files changed, 153 insertions(+), 80 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 4a413158b3aaeb..a3ac99ba359375 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -69,7 +69,7 @@ import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.CTEId; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -1114,7 +1114,7 @@ public PlanFragment visitPhysicalHashJoin( JoinType joinType = hashJoin.getJoinType(); List execEqConjuncts = hashJoin.getHashJoinConjuncts().stream() - .map(EqualTo.class::cast) + .map(EqualPredicate.class::cast) .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet())) .map(e -> ExpressionTranslator.translate(e, context)) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index 0243326b106755..de13b7adb70579 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -281,11 +281,10 @@ private void pushDownRuntimeFilterCommon(PhysicalHashJoin legalTypes = Arrays.stream(TRuntimeFilterType.values()) .filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) .collect(Collectors.toList()); - // TODO: some complex situation cannot be handled now, see testPushDownThroughJoin. - // we will support it in later version. - for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) { + List hashJoinConjuncts = join.getEqualToConjuncts(); + for (int i = 0; i < hashJoinConjuncts.size(); i++) { EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder( - (EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet())); + hashJoinConjuncts.get(i), join.left().getOutputSet())); for (TRuntimeFilterType type : legalTypes) { //bitmap rf is generated by nested loop join. if (type == TRuntimeFilterType.BITMAP) { @@ -525,7 +524,7 @@ private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin conditions = curJoin.getHashJoinConjuncts(); boolean inSameEqualSet = false; - for (Expression e : conditions) { + for (EqualTo e : curJoin.getEqualToConjuncts()) { if (e instanceof EqualTo) { - SlotReference oneSide = (SlotReference) ((EqualTo) e).left(); - SlotReference anotherSide = (SlotReference) ((EqualTo) e).right(); + SlotReference oneSide = (SlotReference) e.left(); + SlotReference anotherSide = (SlotReference) e.right(); if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) { inSameEqualSet = true; break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java index 83cc37ed0b3d18..c2dcafbee435fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java @@ -19,17 +19,23 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.nereids.util.TypeUtils; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSet.Builder; +import com.google.common.collect.Sets; +import java.util.Collection; import java.util.HashSet; import java.util.Optional; import java.util.Set; @@ -63,6 +69,45 @@ public Rule build() { } JoinType newJoinType = tryEliminateOuterJoin(join.getJoinType(), canFilterLeftNull, canFilterRightNull); + Set conjuncts = Sets.newHashSet(); + conjuncts.addAll(filter.getConjuncts()); + boolean conjunctsChanged = false; + if (!notNullSlots.isEmpty()) { + for (Slot slot : notNullSlots) { + Not isNotNull = new Not(new IsNull(slot)); + isNotNull.isGeneratedIsNotNull = true; + conjunctsChanged |= conjuncts.add(isNotNull); + } + } + if (newJoinType.isInnerJoin()) { + /* + * for example: (A left join B on A.a=B.b) join C on B.x=C.x + * inner join condition B.x=C.x implies 'B.x is not null', + * by which the left outer join could be eliminated. Finally, the join transformed to + * (A join B on A.a=B.b) join C on B.x=C.x. + * This elimination can be processed recursively. + * + * TODO: is_not_null can also be inferred from A < B and so on + */ + conjunctsChanged |= join.getHashJoinConjuncts().stream() + .map(EqualPredicate.class::cast) + .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet())) + .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts)); + + JoinUtils.JoinSlotCoverageChecker checker = new JoinUtils.JoinSlotCoverageChecker( + join.left().getOutput(), + join.right().getOutput()); + conjunctsChanged |= join.getOtherJoinConjuncts().stream() + .filter(EqualPredicate.class::isInstance) + .filter(equalTo -> checker.isHashJoinCondition((EqualPredicate) equalTo)) + .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo, + join.left().getOutputSet())) + .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts)); + } + if (conjunctsChanged) { + return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet())) + .withChildren(join.withJoinType(newJoinType)); + } return filter.withChildren(join.withJoinType(newJoinType)); }).toRule(RuleType.ELIMINATE_OUTER_JOIN); } @@ -85,4 +130,19 @@ private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftN } return joinType; } + + private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection container) { + boolean containerChanged = false; + if (swapedEqualTo.left().nullable()) { + Not not = new Not(new IsNull(swapedEqualTo.left())); + not.isGeneratedIsNotNull = true; + containerChanged |= container.add(not); + } + if (swapedEqualTo.right().nullable()) { + Not not = new Not(new IsNull(swapedEqualTo.right())); + not.isGeneratedIsNotNull = true; + containerChanged |= container.add(not); + } + return containerChanged; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java index 05da591526cd83..df7acb4553c6ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java @@ -20,7 +20,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -77,11 +77,10 @@ public Rule build() { Set rightProjectExprs = Sets.newHashSet(); Map exprReplaceMap = Maps.newHashMap(); join.getHashJoinConjuncts().forEach(conjunct -> { - Preconditions.checkArgument(conjunct instanceof EqualTo); + Preconditions.checkArgument(conjunct instanceof EqualPredicate); // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it // doesn't swap the two sides. - conjunct = JoinUtils.swapEqualToForChildrenOrder( - (EqualTo) conjunct, join.left().getOutputSet()); + conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet()); generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs); generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs); }); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index 012dec4c91cf49..c1550cb5bd5d5b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -27,13 +27,12 @@ import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; @@ -306,7 +305,7 @@ public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map context) { - if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { + if (cp instanceof EqualPredicate) { return check(cp, context, PrefixIndexCheckResult::createEqual); } else { return check(cp, context, PrefixIndexCheckResult::createNonEqual); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index f06c9d1cc4f4ee..a412ff375fd63e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -33,7 +34,6 @@ import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Like; import org.apache.doris.nereids.trees.expressions.Not; -import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -210,7 +210,7 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp, return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT); } - if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { + if (cp instanceof EqualPredicate) { return estimateEqualTo(cp, statsForLeft, statsForRight, context); } else { if (cp instanceof LessThan || cp instanceof LessThanEqual) { @@ -255,7 +255,7 @@ private Statistics calculateWhenBothColumn(ComparisonPredicate cp, EstimationCon ColumnStatistic statsForLeft, ColumnStatistic statsForRight) { Expression left = cp.left(); Expression right = cp.right(); - if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { + if (cp instanceof EqualPredicate) { return estimateColumnEqualToColumn(left, statsForLeft, right, statsForRight, context); } if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java index 800886c177f242..f9d25cab171c9e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java @@ -19,7 +19,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; @@ -45,14 +45,14 @@ public class JoinEstimation { private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3; - private static EqualTo normalizeHashJoinCondition(EqualTo equalTo, Statistics leftStats, Statistics rightStats) { - boolean changeOrder = equalTo.left().getInputSlots().stream().anyMatch( - slot -> rightStats.findColumnStatistics(slot) != null - ); + private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats, + Statistics rightStats) { + boolean changeOrder = equal.left().getInputSlots().stream() + .anyMatch(slot -> rightStats.findColumnStatistics(slot) != null); if (changeOrder) { - return new EqualTo(equalTo.right(), equalTo.left()); + return equal.commute(); } else { - return equalTo; + return equal; } } @@ -81,18 +81,18 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ * In order to avoid error propagation, for unTrustEquations, we only use the biggest selectivity. */ List unTrustEqualRatio = Lists.newArrayList(); - List unTrustableCondition = Lists.newArrayList(); + List unTrustableCondition = Lists.newArrayList(); boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount(); double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount()); double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount()); - List trustableConditions = join.getHashJoinConjuncts().stream() - .map(expression -> (EqualTo) expression) + List trustableConditions = join.getHashJoinConjuncts().stream() + .map(expression -> (EqualPredicate) expression) .filter( expression -> { // since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold, // this column is regarded as unique. double almostUniqueThreshold = 0.9; - EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats); + EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats); ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats); ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats); boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold @@ -189,7 +189,7 @@ private static double estimateJoinConditionSel(Statistics crossJoinStats, Expres } private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStats, - Statistics rightStats, Join join, EqualTo equalTo) { + Statistics rightStats, Join join, EqualPredicate equalTo) { Expression eqLeft = equalTo.left(); Expression eqRight = equalTo.right(); ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft); @@ -246,7 +246,7 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri double rowCount = Double.POSITIVE_INFINITY; for (Expression conjunct : join.getHashJoinConjuncts()) { double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats, - join, (EqualTo) conjunct); + join, (EqualPredicate) conjunct); if (rowCount > eqRowCount) { rowCount = eqRowCount; } @@ -321,7 +321,7 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) { Map updatedCols = new HashMap<>(); for (Expression expr : join.getHashJoinConjuncts()) { - EqualTo equalTo = (EqualTo) expr; + EqualPredicate equalTo = (EqualPredicate) expr; ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats); ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats); double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java new file mode 100644 index 00000000000000..3f61bd3cf621a5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import java.util.List; + +/** + * EqualPredicate + */ +public abstract class EqualPredicate extends ComparisonPredicate { + + protected EqualPredicate(List children, String symbol) { + super(children, symbol); + } + + @Override + public EqualPredicate commute() { + return null; + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 0fa23a57e0a310..1e72a006057462 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -29,7 +29,7 @@ /** * Equal to expression: a = b. */ -public class EqualTo extends ComparisonPredicate implements PropagateNullable { +public class EqualTo extends EqualPredicate implements PropagateNullable { public EqualTo(Expression left, Expression right) { super(ImmutableList.of(left, right), "="); @@ -60,7 +60,7 @@ public R accept(ExpressionVisitor visitor, C context) { } @Override - public ComparisonPredicate commute() { + public EqualTo commute() { return new EqualTo(right(), left()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java index c2b63aebbd793a..48d05364fa3441 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java @@ -29,13 +29,7 @@ * Null safe equal expression: a <=> b. * Unlike normal equal to expression, null <=> null is true. */ -public class NullSafeEqual extends ComparisonPredicate implements AlwaysNotNullable { - /** - * Constructor of Null Safe Equal ComparisonPredicate. - * - * @param left left child of Null Safe Equal - * @param right right child of Null Safe Equal - */ +public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable { public NullSafeEqual(Expression left, Expression right) { super(ImmutableList.of(left, right), "<=>"); } @@ -61,8 +55,7 @@ public NullSafeEqual withChildren(List children) { } @Override - public ComparisonPredicate commute() { + public NullSafeEqual commute() { return new NullSafeEqual(right(), left()); } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java index f67123522c3daf..a39634917aae61 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java @@ -20,6 +20,7 @@ import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.expressions.Slot; @@ -41,6 +42,7 @@ import java.util.List; import java.util.Objects; import java.util.Optional; +import java.util.stream.Collectors; /** * Abstract class for all physical join node. @@ -109,6 +111,11 @@ public List getHashJoinConjuncts() { return hashJoinConjuncts; } + public List getEqualToConjuncts() { + return hashJoinConjuncts.stream().filter(EqualTo.class::isInstance).map(EqualTo.class::cast) + .collect(Collectors.toList()); + } + public boolean isShouldTranslateOutput() { return shouldTranslateOutput; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index eda7d2e6ad1fd3..25f84c096c8e07 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -24,7 +24,7 @@ import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; import org.apache.doris.nereids.properties.DistributionSpecReplicated; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Not; @@ -61,31 +61,18 @@ public static boolean couldBroadcast(Join join) { return !(join.getJoinType().isRightJoin() || join.getJoinType().isFullOuterJoin()); } - private static final class JoinSlotCoverageChecker { + /** + * JoinSlotCoverageChecker + */ + public static final class JoinSlotCoverageChecker { Set leftExprIds; Set rightExprIds; - JoinSlotCoverageChecker(List left, List right) { + public JoinSlotCoverageChecker(List left, List right) { leftExprIds = left.stream().map(Slot::getExprId).collect(Collectors.toSet()); rightExprIds = right.stream().map(Slot::getExprId).collect(Collectors.toSet()); } - JoinSlotCoverageChecker(Set left, Set right) { - leftExprIds = left; - rightExprIds = right; - } - - /** - * PushDownExpressionInHashConjuncts ensure the "slots" is only one slot. - */ - boolean isCoveredByLeftSlots(ExprId slot) { - return leftExprIds.contains(slot); - } - - boolean isCoveredByRightSlots(ExprId slot) { - return rightExprIds.contains(slot); - } - /** * consider following cases: * 1# A=1 => not for hash table @@ -94,25 +81,20 @@ boolean isCoveredByRightSlots(ExprId slot) { * 4# t1.a=t2.a or t1.b=t2.b not for hash table * 5# t1.a > 1 not for hash table * - * @param equalTo a conjunct in on clause condition + * @param equal a conjunct in on clause condition * @return true if the equal can be used as hash join condition */ - boolean isHashJoinCondition(EqualTo equalTo) { - Set equalLeft = equalTo.left().collect(Slot.class::isInstance); - if (equalLeft.isEmpty()) { + public boolean isHashJoinCondition(EqualPredicate equal) { + Set equalLeftExprIds = equal.left().getInputSlotExprIds(); + if (equalLeftExprIds.isEmpty()) { return false; } - Set equalRight = equalTo.right().collect(Slot.class::isInstance); - if (equalRight.isEmpty()) { + Set equalRightExprIds = equal.right().getInputSlotExprIds(); + if (equalRightExprIds.isEmpty()) { return false; } - List equalLeftExprIds = equalLeft.stream() - .map(Slot::getExprId).collect(Collectors.toList()); - - List equalRightExprIds = equalRight.stream() - .map(Slot::getExprId).collect(Collectors.toList()); return leftExprIds.containsAll(equalLeftExprIds) && rightExprIds.containsAll(equalRightExprIds) || leftExprIds.containsAll(equalRightExprIds) && rightExprIds.containsAll(equalLeftExprIds); } @@ -129,9 +111,8 @@ boolean isHashJoinCondition(EqualTo equalTo) { public static Pair, List> extractExpressionForHashTable(List leftSlots, List rightSlots, List onConditions) { JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(leftSlots, rightSlots); - Map> mapper = onConditions.stream() - .collect(Collectors.groupingBy( - expr -> (expr instanceof EqualTo) && checker.isHashJoinCondition((EqualTo) expr))); + Map> mapper = onConditions.stream().collect(Collectors.groupingBy( + expr -> (expr instanceof EqualPredicate) && checker.isHashJoinCondition((EqualPredicate) expr))); return Pair.of( mapper.getOrDefault(true, ImmutableList.of()), mapper.getOrDefault(false, ImmutableList.of()) @@ -187,7 +168,7 @@ public static boolean shouldNestedLoopJoin(JoinType joinType, List h * The left child of origin predicate is t2.id and the right child of origin predicate is t1.id. * In this situation, the children of predicate need to be swap => t1.id=t2.id. */ - public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, Set leftOutput) { + public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate equalTo, Set leftOutput) { if (leftOutput.containsAll(equalTo.left().getInputSlots())) { return equalTo; } else {