From 8d7fa41547e16d601446501f8634f39784c85cea Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Sun, 10 Mar 2024 19:19:57 +0100 Subject: [PATCH] migrate --- .../JoinConditionEqualityTransferRule.java | 241 +++++++++++++++ .../JoinConditionEqualityTransferRule.scala | 283 ------------------ 2 files changed, 241 insertions(+), 283 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.java new file mode 100644 index 0000000000000..c134a2c4b9003 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.table.planner.plan.utils.FlinkRexUtil; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable.EQUALS; + +/** + * Planner rule that converts Join's conditions to the left or right table's own independent filter + * as much as possible, so that the rules of filter-push-down can push down the filter to below. + * + *

e.g. join condition: l_a = r_b and l_a = r_c. The l_a is a field from left input, both r_b and + * r_c are fields from the right input. After rewrite, condition will be: l_a = r_b and r_b = r_c. + * r_b = r_c can be pushed down to the right input. + */ +@Value.Enclosing +public class JoinConditionEqualityTransferRule + extends RelRule { + + public static final JoinConditionEqualityTransferRule INSTANCE = + JoinConditionEqualityTransferRuleConfig.DEFAULT.toRule(); + + protected JoinConditionEqualityTransferRule(JoinConditionEqualityTransferRuleConfig config) { + super(config); + } + + @Override + public boolean matches(RelOptRuleCall call) { + Join join = call.rel(0); + JoinRelType joinType = join.getJoinType(); + if (joinType != JoinRelType.INNER && joinType != JoinRelType.SEMI) { + return false; + } + + Tuple2, List> partitionJoinFilters = partitionJoinFilters(join); + List> groups = getEquiFilterRelationshipGroup(partitionJoinFilters.f0); + for (Set group : groups) { + if (group.size() > 2) { + return true; + } + } + return false; + } + + @Override + public void onMatch(RelOptRuleCall call) { + Join join = call.rel(0); + Tuple2, List> optimizableAndRemainFilters = + partitionJoinFilters(join); + List optimizableFilters = optimizableAndRemainFilters.f0; + List remainFilters = optimizableAndRemainFilters.f1; + Map>> partitioned = + getEquiFilterRelationshipGroup(optimizableFilters).stream() + .collect(Collectors.partitioningBy(t -> t.size() > 2)); + List> equiFiltersToOpt = partitioned.get(true); + List> equiFiltersNotOpt = partitioned.get(false); + + RelBuilder builder = call.builder(); + RexBuilder rexBuilder = builder.getRexBuilder(); + List newEquiJoinFilters = new ArrayList<>(); + + // add equiFiltersNotOpt. + equiFiltersNotOpt.forEach( + refs -> { + assert (refs.size() == 2); + Iterator iterator = refs.iterator(); + newEquiJoinFilters.add( + rexBuilder.makeCall(EQUALS, iterator.next(), iterator.next())); + }); + + // new opt filters. + equiFiltersToOpt.forEach( + refs -> { + // partition to InputRef to left and right. + Map> leftAndRightRefs = + refs.stream() + .collect(Collectors.partitioningBy(t -> fromJoinLeft(join, t))); + List leftRefs = leftAndRightRefs.get(true); + List rightRefs = leftAndRightRefs.get(false); + + // equals for each other. + List rexCalls = new ArrayList<>(makeCalls(rexBuilder, leftRefs)); + rexCalls.addAll(makeCalls(rexBuilder, rightRefs)); + + // equals for left and right. + if (!leftRefs.isEmpty() && !rightRefs.isEmpty()) { + rexCalls.add( + rexBuilder.makeCall(EQUALS, leftRefs.get(0), rightRefs.get(0))); + } + + // add to newEquiJoinFilters with deduplication. + newEquiJoinFilters.addAll(rexCalls); + }); + + remainFilters.add( + FlinkRexUtil.simplify( + rexBuilder, + builder.and(newEquiJoinFilters), + join.getCluster().getPlanner().getExecutor())); + RexNode newJoinFilter = builder.and(remainFilters); + Join newJoin = + join.copy( + join.getTraitSet(), + newJoinFilter, + join.getLeft(), + join.getRight(), + join.getJoinType(), + join.isSemiJoinDone()); + + call.transformTo(newJoin); + } + + /** Returns true if the given input ref is from join left, else false. */ + private boolean fromJoinLeft(Join join, RexInputRef ref) { + assert join.getSystemFieldList().isEmpty(); + return ref.getIndex() < join.getLeft().getRowType().getFieldCount(); + } + + /** Partition join condition to leftRef-rightRef equals and others. */ + private Tuple2, List> partitionJoinFilters(Join join) { + List left = new ArrayList<>(); + List right = new ArrayList<>(); + List conjunctions = RelOptUtil.conjunctions(join.getCondition()); + for (RexNode rexNode : conjunctions) { + if (rexNode instanceof RexCall) { + RexCall call = (RexCall) rexNode; + if (call.isA(SqlKind.EQUALS)) { + if (call.operands.get(0) instanceof RexInputRef + && call.operands.get(1) instanceof RexInputRef) { + RexInputRef ref1 = (RexInputRef) call.operands.get(0); + RexInputRef ref2 = (RexInputRef) call.operands.get(1); + boolean isLeft1 = fromJoinLeft(join, ref1); + boolean isLeft2 = fromJoinLeft(join, ref2); + if (isLeft1 != isLeft2) { + left.add(rexNode); + continue; + } + } + } + } + right.add(rexNode); + } + return Tuple2.of(left, right); + } + + /** Put fields to a group that have equivalence relationships. */ + private List> getEquiFilterRelationshipGroup(List equiJoinFilters) { + List> res = new ArrayList<>(); + for (RexNode rexNode : equiJoinFilters) { + if (rexNode instanceof RexCall) { + RexCall call = (RexCall) rexNode; + if (call.isA(SqlKind.EQUALS)) { + RexInputRef left = (RexInputRef) call.operands.get(0); + RexInputRef right = (RexInputRef) call.operands.get(1); + boolean found = false; + for (Set refs : res) { + if (refs.contains(left) || refs.contains(right)) { + refs.add(left); + refs.add(right); + found = true; + break; + } + } + if (!found) { + Set set = new HashSet<>(); + set.add(left); + set.add(right); + res.add(set); + } + } + } + } + return res; + } + + /** Make calls to a number of inputRefs, make sure that they both have a relationship. */ + private List makeCalls(RexBuilder rexBuilder, List nodes) { + final List calls = new ArrayList<>(); + if (nodes.size() > 1) { + RexInputRef rex = nodes.get(0); + nodes.subList(1, nodes.size()) + .forEach(t -> calls.add(rexBuilder.makeCall(EQUALS, rex, t))); + } + return calls; + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface JoinConditionEqualityTransferRuleConfig extends RelRule.Config { + JoinConditionEqualityTransferRule.JoinConditionEqualityTransferRuleConfig DEFAULT = + ImmutableJoinConditionEqualityTransferRule.JoinConditionEqualityTransferRuleConfig + .builder() + .description("JoinConditionEqualityTransferRule") + .build() + .withOperandSupplier(b0 -> b0.operand(Join.class).anyInputs()); + + @Override + default JoinConditionEqualityTransferRule toRule() { + return new JoinConditionEqualityTransferRule(this); + } + } +} diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala deleted file mode 100644 index 1589f020f31e1..0000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import org.apache.flink.table.planner.plan.utils.FlinkRexUtil - -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil} -import org.apache.calcite.plan.RelOptRule.{any, operand} -import org.apache.calcite.rel.core.{Join, JoinRelType} -import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexNode} -import org.apache.calcite.sql.SqlKind -import org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS - -import java.util -import java.util.Collections - -import scala.collection.{breakOut, mutable} -import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ -import scala.util.control.Breaks.{break, breakable} - -/** - * Planner rule that converts Join's conditions to the left or right table's own independent filter - * as much as possible, so that the rules of filter-push-down can push down the filter to below. - * - *

e.g. join condition: l_a = r_b and l_a = r_c. The l_a is a field from left input, both r_b and - * r_c are fields from the right input. After rewrite, condition will be: l_a = r_b and r_b = r_c. - * r_b = r_c can be pushed down to the right input. - */ -class JoinConditionEqualityTransferRule - extends RelOptRule(operand(classOf[Join], any), "JoinConditionEqualityTransferRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val join: Join = call.rel(0) - val joinType = join.getJoinType - if (joinType != JoinRelType.INNER && joinType != JoinRelType.SEMI) { - return false - } - - val optimizableFilters = partitionJoinFilters(join) - val groups = getEquiFilterRelationshipGroup(optimizableFilters.f0) - groups.exists(_.size > 2) - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val join: Join = call.rel(0) - val optimizableAndRemainFilters = - partitionJoinFilters(join); - val optimizableFilters = optimizableAndRemainFilters.f0 - val remainFilters = optimizableAndRemainFilters.f1 - - /* val partitioned = - getEquiFilterRelationshipGroup(optimizableFilters).stream() - .collect(java.util.stream.Collectors.partitioningBy(new java.util.function.Predicate[java.util.Set[RexInputRef]] { - override def test(t: java.util.Set[RexInputRef]): Boolean = t.size() > 2 - } ))*/ - - val (equiFiltersToOpt, equiFiltersNotOpt) = - getEquiFilterRelationshipGroup(optimizableFilters).partition(_.size > 2) - /*val equiFiltersToOpt = partitioned.get(true) - val equiFiltersNotOpt = partitioned.get(false) - */ - val builder = call.builder() - val rexBuilder = builder.getRexBuilder - val newEquiJoinFilters = new java.util.ArrayList[RexNode]() - - // add equiFiltersNotOpt. - equiFiltersNotOpt.foreach { - refs => - require(refs.size == 2) - newEquiJoinFilters.add(rexBuilder.makeCall(EQUALS, refs.head, refs.last)) - } - - // new opt filters. - equiFiltersToOpt.foreach { - refs => - // partition to InputRef to left and right. - val leftAndRightRefs = - refs - .stream() - // .collect(java.util.stream.Collectors.partitioningBy(t => fromJoinLeft(join, t))); - .collect(java.util.stream.Collectors - .partitioningBy(new java.util.function.Predicate[RexInputRef] { - override def test(t: RexInputRef): Boolean = fromJoinLeft(join, t) - })) - val leftRefs = leftAndRightRefs.get(true) - val rightRefs = leftAndRightRefs.get(false) - val rexCalls = new java.util.ArrayList[RexNode]() - - // equals for each other. - rexCalls.addAll(makeCalls(rexBuilder, leftRefs)) - rexCalls.addAll(makeCalls(rexBuilder, rightRefs)) - - // equals for left and right. - if (leftRefs.nonEmpty && rightRefs.nonEmpty) { - rexCalls.add(rexBuilder.makeCall(EQUALS, leftRefs.head, rightRefs.head)) - } - - // add to newEquiJoinFilters with deduplication. - newEquiJoinFilters.addAll(rexCalls) - } - - remainFilters.add( - FlinkRexUtil.simplify( - rexBuilder, - builder.and(newEquiJoinFilters), - join.getCluster.getPlanner.getExecutor)) - val newJoinFilter = builder.and(remainFilters) - val newJoin = join.copy( - join.getTraitSet, - newJoinFilter, - join.getLeft, - join.getRight, - join.getJoinType, - join.isSemiJoinDone) - - call.transformTo(newJoin) - } - - /** Returns true if the given input ref is from join left, else false. */ - private def fromJoinLeft(join: Join, ref: RexInputRef): Boolean = { - require(join.getSystemFieldList.size() == 0) - ref.getIndex < join.getLeft.getRowType.getFieldCount - } - - /** Partition join condition to leftRef-rightRef equals and others. */ - /* def partitionJoinFilters(join: Join): (Seq[RexNode], Seq[RexNode]) = { - val conjunctions = RelOptUtil.conjunctions(join.getCondition) - conjunctions.partition { - case call: RexCall if call.isA(SqlKind.EQUALS) => - (call.operands.head, call.operands.last) match { - case (ref1: RexInputRef, ref2: RexInputRef) => - val isLeft1 = fromJoinLeft(join, ref1) - val isLeft2 = fromJoinLeft(join, ref2) - isLeft1 != isLeft2 - case _ => false - } - case _ => false - } - } - */ - def partitionJoinFilters(join: Join) - : org.apache.flink.api.java.tuple.Tuple2[java.util.List[RexNode], java.util.List[RexNode]] = { - val left = new java.util.ArrayList[RexNode]() - val right = new java.util.ArrayList[RexNode]() - val conjunctions = RelOptUtil.conjunctions(join.getCondition) - - for (rexNode <- conjunctions) { - if (rexNode.isInstanceOf[RexCall]) { - val call = rexNode.asInstanceOf[RexCall] - if (call.isA(SqlKind.EQUALS)) { - if ( - call.operands.get(0).isInstanceOf[RexInputRef] && call.operands - .get(1) - .isInstanceOf[RexInputRef] - ) { - val ref1 = call.operands.get(0).asInstanceOf[RexInputRef] - val ref2 = call.operands.get(1).asInstanceOf[RexInputRef] - val isLeft1 = fromJoinLeft(join, ref1) - val isLeft2 = fromJoinLeft(join, ref2) - if (isLeft1 != isLeft2) { - left.add(rexNode) - } else { - right.add(rexNode) - } - } else { - right.add(rexNode) - } - } else { - right.add(rexNode) - } - } else { - right.add(rexNode) - } - } - - /* conjunctions.partition { - case call: RexCall if call.isA(SqlKind.EQUALS) => - (call.operands.head, call.operands.last) match { - case (ref1: RexInputRef, ref2: RexInputRef) => - val isLeft1 = fromJoinLeft(join, ref1) - val isLeft2 = fromJoinLeft(join, ref2) - isLeft1 != isLeft2 - case _ => false - } - case _ => false - }*/ - - org.apache.flink.api.java.tuple.Tuple2.of(left, right) - } - - /** Put fields to a group that have equivalence relationships. */ - def getEquiFilterRelationshipGroup( - equiJoinFilters: java.util.List[RexNode]): Seq[Seq[RexInputRef]] = { - val filterSets = new java.util.ArrayList[util.HashSet[RexInputRef]]() - equiJoinFilters.foreach { - case call: RexCall => - require(call.isA(SqlKind.EQUALS)) - val left = call.operands.head.asInstanceOf[RexInputRef] - val right = call.operands.last.asInstanceOf[RexInputRef] - val set = filterSets - .stream() - .filter(set => set.contains(left) || set.contains(right)) - .findFirst() - .orElseGet( - () => { - val s = new util.HashSet[RexInputRef]() - filterSets.add(s) - s - }) - /*val set = filterSets.find(set => set.contains(left) || set.contains(right)) match { - case Some(s) => s - case None => - val s = new mutable.HashSet[RexInputRef]() - filterSets.add(s) - s - }*/ - set.add(left) - set.add(right) - } - - filterSets.map(_.toSeq) - } - - /* private def getEquiFilterRelationshipGroup(equiJoinFilters: java.util.List[RexNode]): util.ArrayList[util.Set[RexInputRef]] = { - val res = new java.util.ArrayList[java.util.Set[RexInputRef]] - for (rexNode <- equiJoinFilters) { - if (rexNode.isInstanceOf[RexCall]) { - val call = rexNode.asInstanceOf[RexCall] - require (call.isA(SqlKind.EQUALS)) - val left = call.operands.head.asInstanceOf[RexInputRef] - val right = call.operands.last.asInstanceOf[RexInputRef] - var found = false - /*breakable { - for (refs <- res) { - if (refs.contains(left) || refs.contains(right)) { - refs.add(left) - refs.add(right) - found = true - break; - } - } - }*/ - if (!found) { - val set = new util.HashSet[RexInputRef]() - set.add(left) - set.add(right) - res.add(set) - } - } - } - res - }*/ - - /** Make calls to a number of inputRefs, make sure that they both have a relationship. */ - def makeCalls(rexBuilder: RexBuilder, nodes: Seq[RexInputRef]): java.util.List[RexNode] = { - val calls = new java.util.ArrayList[RexNode]() - if (nodes.length > 1) { - val rex = nodes.head - nodes.drop(1).foreach(t => calls.add(rexBuilder.makeCall(EQUALS, rex, t))) - } - calls - } -} - -object JoinConditionEqualityTransferRule { - val INSTANCE = new JoinConditionEqualityTransferRule -}