Skip to content

Commit

Permalink
[fix](Nereids): NullSafeEqual should be in HashJoinCondition #27127 (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored Nov 18, 2023
1 parent 3b84627 commit 203f072
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.AggregateExpression;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand Down Expand Up @@ -1114,7 +1114,7 @@ public PlanFragment visitPhysicalHashJoin(
JoinType joinType = hashJoin.getJoinType();

List<Expr> execEqConjuncts = hashJoin.getHashJoinConjuncts().stream()
.map(EqualTo.class::cast)
.map(EqualPredicate.class::cast)
.map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet()))
.map(e -> ExpressionTranslator.translate(e, context))
.collect(Collectors.toList());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,10 @@ private void pushDownRuntimeFilterCommon(PhysicalHashJoin<? extends Plan, ? exte
List<TRuntimeFilterType> legalTypes = Arrays.stream(TRuntimeFilterType.values())
.filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0)
.collect(Collectors.toList());
// TODO: some complex situation cannot be handled now, see testPushDownThroughJoin.
// we will support it in later version.
for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) {
List<EqualTo> hashJoinConjuncts = join.getEqualToConjuncts();
for (int i = 0; i < hashJoinConjuncts.size(); i++) {
EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet()));
hashJoinConjuncts.get(i), join.left().getOutputSet()));
for (TRuntimeFilterType type : legalTypes) {
//bitmap rf is generated by nested loop join.
if (type == TRuntimeFilterType.BITMAP) {
Expand Down Expand Up @@ -525,7 +524,7 @@ private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin<? extends
|| !(join.getHashJoinConjuncts().get(0) instanceof EqualTo)) {
break;
} else {
EqualTo equalTo = (EqualTo) join.getHashJoinConjuncts().get(0);
EqualTo equalTo = (EqualTo) join.getEqualToConjuncts().get(0);
equalTos.add(equalTo);
equalCondToJoinMap.put(equalTo, join);
}
Expand Down Expand Up @@ -561,12 +560,11 @@ private void analyzeRuntimeFilterPushDownIntoCTEInfos(PhysicalHashJoin<? extends
// check further whether the join upper side can bring equal set, which
// indicating actually the same runtime filter build side
// see above case 2 for reference
List<Expression> conditions = curJoin.getHashJoinConjuncts();
boolean inSameEqualSet = false;
for (Expression e : conditions) {
for (EqualTo e : curJoin.getEqualToConjuncts()) {
if (e instanceof EqualTo) {
SlotReference oneSide = (SlotReference) ((EqualTo) e).left();
SlotReference anotherSide = (SlotReference) ((EqualTo) e).right();
SlotReference oneSide = (SlotReference) e.left();
SlotReference anotherSide = (SlotReference) e.right();
if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) {
inSameEqualSet = true;
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,23 @@

import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.nereids.util.TypeUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSet.Builder;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
Expand Down Expand Up @@ -63,6 +69,45 @@ public Rule build() {
}

JoinType newJoinType = tryEliminateOuterJoin(join.getJoinType(), canFilterLeftNull, canFilterRightNull);
Set<Expression> conjuncts = Sets.newHashSet();
conjuncts.addAll(filter.getConjuncts());
boolean conjunctsChanged = false;
if (!notNullSlots.isEmpty()) {
for (Slot slot : notNullSlots) {
Not isNotNull = new Not(new IsNull(slot));
isNotNull.isGeneratedIsNotNull = true;
conjunctsChanged |= conjuncts.add(isNotNull);
}
}
if (newJoinType.isInnerJoin()) {
/*
* for example: (A left join B on A.a=B.b) join C on B.x=C.x
* inner join condition B.x=C.x implies 'B.x is not null',
* by which the left outer join could be eliminated. Finally, the join transformed to
* (A join B on A.a=B.b) join C on B.x=C.x.
* This elimination can be processed recursively.
*
* TODO: is_not_null can also be inferred from A < B and so on
*/
conjunctsChanged |= join.getHashJoinConjuncts().stream()
.map(EqualPredicate.class::cast)
.map(equalTo -> JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet()))
.anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));

JoinUtils.JoinSlotCoverageChecker checker = new JoinUtils.JoinSlotCoverageChecker(
join.left().getOutput(),
join.right().getOutput());
conjunctsChanged |= join.getOtherJoinConjuncts().stream()
.filter(EqualPredicate.class::isInstance)
.filter(equalTo -> checker.isHashJoinCondition((EqualPredicate) equalTo))
.map(equalTo -> JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo,
join.left().getOutputSet()))
.anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts));
}
if (conjunctsChanged) {
return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
.withChildren(join.withJoinType(newJoinType));
}
return filter.withChildren(join.withJoinType(newJoinType));
}).toRule(RuleType.ELIMINATE_OUTER_JOIN);
}
Expand All @@ -85,4 +130,19 @@ private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftN
}
return joinType;
}

private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection<Expression> container) {
boolean containerChanged = false;
if (swapedEqualTo.left().nullable()) {
Not not = new Not(new IsNull(swapedEqualTo.left()));
not.isGeneratedIsNotNull = true;
containerChanged |= container.add(not);
}
if (swapedEqualTo.right().nullable()) {
Not not = new Not(new IsNull(swapedEqualTo.right()));
not.isGeneratedIsNotNull = true;
containerChanged |= container.add(not);
}
return containerChanged;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
Expand Down Expand Up @@ -77,11 +77,10 @@ public Rule build() {
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualTo);
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) conjunct, join.left().getOutputSet());
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,12 @@
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WhenClause;
Expand Down Expand Up @@ -306,7 +305,7 @@ public PrefixIndexCheckResult visitInPredicate(InPredicate in, Map<ExprId, Strin

@Override
public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map<ExprId, String> context) {
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return check(cp, context, PrefixIndexCheckResult::createEqual);
} else {
return check(cp, context, PrefixIndexCheckResult::createNonEqual);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
Expand All @@ -33,7 +34,6 @@
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
Expand Down Expand Up @@ -210,7 +210,7 @@ private Statistics calculateWhenLiteralRight(ComparisonPredicate cp,
return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
}

if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return estimateEqualTo(cp, statsForLeft, statsForRight, context);
} else {
if (cp instanceof LessThan || cp instanceof LessThanEqual) {
Expand Down Expand Up @@ -255,7 +255,7 @@ private Statistics calculateWhenBothColumn(ComparisonPredicate cp, EstimationCon
ColumnStatistic statsForLeft, ColumnStatistic statsForRight) {
Expression left = cp.left();
Expression right = cp.right();
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
if (cp instanceof EqualPredicate) {
return estimateColumnEqualToColumn(left, statsForLeft, right, statsForRight, context);
}
if (cp instanceof GreaterThan || cp instanceof GreaterThanEqual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.EqualPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
Expand All @@ -45,14 +45,14 @@
public class JoinEstimation {
private static double DEFAULT_ANTI_JOIN_SELECTIVITY_COEFFICIENT = 0.3;

private static EqualTo normalizeHashJoinCondition(EqualTo equalTo, Statistics leftStats, Statistics rightStats) {
boolean changeOrder = equalTo.left().getInputSlots().stream().anyMatch(
slot -> rightStats.findColumnStatistics(slot) != null
);
private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats,
Statistics rightStats) {
boolean changeOrder = equal.left().getInputSlots().stream()
.anyMatch(slot -> rightStats.findColumnStatistics(slot) != null);
if (changeOrder) {
return new EqualTo(equalTo.right(), equalTo.left());
return equal.commute();
} else {
return equalTo;
return equal;
}
}

Expand Down Expand Up @@ -81,18 +81,18 @@ private static Statistics estimateHashJoin(Statistics leftStats, Statistics righ
* In order to avoid error propagation, for unTrustEquations, we only use the biggest selectivity.
*/
List<Double> unTrustEqualRatio = Lists.newArrayList();
List<EqualTo> unTrustableCondition = Lists.newArrayList();
List<EqualPredicate> unTrustableCondition = Lists.newArrayList();
boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount();
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
List<EqualTo> trustableConditions = join.getHashJoinConjuncts().stream()
.map(expression -> (EqualTo) expression)
List<EqualPredicate> trustableConditions = join.getHashJoinConjuncts().stream()
.map(expression -> (EqualPredicate) expression)
.filter(
expression -> {
// since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold,
// this column is regarded as unique.
double almostUniqueThreshold = 0.9;
EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
Expand Down Expand Up @@ -189,7 +189,7 @@ private static double estimateJoinConditionSel(Statistics crossJoinStats, Expres
}

private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStats,
Statistics rightStats, Join join, EqualTo equalTo) {
Statistics rightStats, Join join, EqualPredicate equalTo) {
Expression eqLeft = equalTo.left();
Expression eqRight = equalTo.right();
ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft);
Expand Down Expand Up @@ -246,7 +246,7 @@ private static Statistics estimateSemiOrAnti(Statistics leftStats, Statistics ri
double rowCount = Double.POSITIVE_INFINITY;
for (Expression conjunct : join.getHashJoinConjuncts()) {
double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats,
join, (EqualTo) conjunct);
join, (EqualPredicate) conjunct);
if (rowCount > eqRowCount) {
rowCount = eqRowCount;
}
Expand Down Expand Up @@ -321,7 +321,7 @@ public static Statistics estimate(Statistics leftStats, Statistics rightStats, J
private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) {
Map<Expression, ColumnStatistic> updatedCols = new HashMap<>();
for (Expression expr : join.getHashJoinConjuncts()) {
EqualTo equalTo = (EqualTo) expr;
EqualPredicate equalTo = (EqualPredicate) expr;
ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats);
ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats);
double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions;

import java.util.List;

/**
* EqualPredicate
*/
public abstract class EqualPredicate extends ComparisonPredicate {

protected EqualPredicate(List<Expression> children, String symbol) {
super(children, symbol);
}

@Override
public EqualPredicate commute() {
return null;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
/**
* Equal to expression: a = b.
*/
public class EqualTo extends ComparisonPredicate implements PropagateNullable {
public class EqualTo extends EqualPredicate implements PropagateNullable {

public EqualTo(Expression left, Expression right) {
super(ImmutableList.of(left, right), "=");
Expand Down Expand Up @@ -60,7 +60,7 @@ public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
}

@Override
public ComparisonPredicate commute() {
public EqualTo commute() {
return new EqualTo(right(), left());
}
}
Loading

0 comments on commit 203f072

Please sign in to comment.