diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala index 9652aae13a747f..1589f020f31e1d 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/JoinConditionEqualityTransferRule.scala @@ -26,8 +26,13 @@ import org.apache.calcite.rex.{RexBuilder, RexCall, RexInputRef, RexNode} import org.apache.calcite.sql.SqlKind import org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS +import java.util +import java.util.Collections + +import scala.collection.{breakOut, mutable} import scala.collection.JavaConversions._ -import scala.collection.mutable +import scala.collection.JavaConverters._ +import scala.util.control.Breaks.{break, breakable} /** * Planner rule that converts Join's conditions to the left or right table's own independent filter @@ -47,54 +52,75 @@ class JoinConditionEqualityTransferRule return false } - val (optimizableFilters, _) = partitionJoinFilters(join) - val groups = getEquiFilterRelationshipGroup(optimizableFilters) + val optimizableFilters = partitionJoinFilters(join) + val groups = getEquiFilterRelationshipGroup(optimizableFilters.f0) groups.exists(_.size > 2) } override def onMatch(call: RelOptRuleCall): Unit = { val join: Join = call.rel(0) - val (optimizableFilters, remainFilters) = partitionJoinFilters(join) + val optimizableAndRemainFilters = + partitionJoinFilters(join); + val optimizableFilters = optimizableAndRemainFilters.f0 + val remainFilters = optimizableAndRemainFilters.f1 + + /* val partitioned = + getEquiFilterRelationshipGroup(optimizableFilters).stream() + .collect(java.util.stream.Collectors.partitioningBy(new java.util.function.Predicate[java.util.Set[RexInputRef]] { + override def test(t: java.util.Set[RexInputRef]): Boolean = t.size() > 2 + } ))*/ + val (equiFiltersToOpt, equiFiltersNotOpt) = getEquiFilterRelationshipGroup(optimizableFilters).partition(_.size > 2) - + /*val equiFiltersToOpt = partitioned.get(true) + val equiFiltersNotOpt = partitioned.get(false) + */ val builder = call.builder() val rexBuilder = builder.getRexBuilder - val newEquiJoinFilters = mutable.ListBuffer[RexNode]() + val newEquiJoinFilters = new java.util.ArrayList[RexNode]() // add equiFiltersNotOpt. equiFiltersNotOpt.foreach { refs => require(refs.size == 2) - newEquiJoinFilters += rexBuilder.makeCall(EQUALS, refs.head, refs.last) + newEquiJoinFilters.add(rexBuilder.makeCall(EQUALS, refs.head, refs.last)) } // new opt filters. equiFiltersToOpt.foreach { refs => // partition to InputRef to left and right. - val (leftRefs, rightRefs) = refs.partition(fromJoinLeft(join, _)) - val rexCalls = new mutable.ArrayBuffer[RexNode]() + val leftAndRightRefs = + refs + .stream() + // .collect(java.util.stream.Collectors.partitioningBy(t => fromJoinLeft(join, t))); + .collect(java.util.stream.Collectors + .partitioningBy(new java.util.function.Predicate[RexInputRef] { + override def test(t: RexInputRef): Boolean = fromJoinLeft(join, t) + })) + val leftRefs = leftAndRightRefs.get(true) + val rightRefs = leftAndRightRefs.get(false) + val rexCalls = new java.util.ArrayList[RexNode]() // equals for each other. - rexCalls ++= makeCalls(rexBuilder, leftRefs) - rexCalls ++= makeCalls(rexBuilder, rightRefs) + rexCalls.addAll(makeCalls(rexBuilder, leftRefs)) + rexCalls.addAll(makeCalls(rexBuilder, rightRefs)) // equals for left and right. if (leftRefs.nonEmpty && rightRefs.nonEmpty) { - rexCalls += rexBuilder.makeCall(EQUALS, leftRefs.head, rightRefs.head) + rexCalls.add(rexBuilder.makeCall(EQUALS, leftRefs.head, rightRefs.head)) } // add to newEquiJoinFilters with deduplication. - rexCalls.foreach(call => newEquiJoinFilters += call) + newEquiJoinFilters.addAll(rexCalls) } - val newJoinFilter = builder.and( - remainFilters :+ - FlinkRexUtil.simplify( - rexBuilder, - builder.and(newEquiJoinFilters), - join.getCluster.getPlanner.getExecutor)) + remainFilters.add( + FlinkRexUtil.simplify( + rexBuilder, + builder.and(newEquiJoinFilters), + join.getCluster.getPlanner.getExecutor)) + val newJoinFilter = builder.and(remainFilters) val newJoin = join.copy( join.getTraitSet, newJoinFilter, @@ -113,7 +139,7 @@ class JoinConditionEqualityTransferRule } /** Partition join condition to leftRef-rightRef equals and others. */ - def partitionJoinFilters(join: Join): (Seq[RexNode], Seq[RexNode]) = { + /* def partitionJoinFilters(join: Join): (Seq[RexNode], Seq[RexNode]) = { val conjunctions = RelOptUtil.conjunctions(join.getCondition) conjunctions.partition { case call: RexCall if call.isA(SqlKind.EQUALS) => @@ -127,35 +153,126 @@ class JoinConditionEqualityTransferRule case _ => false } } + */ + def partitionJoinFilters(join: Join) + : org.apache.flink.api.java.tuple.Tuple2[java.util.List[RexNode], java.util.List[RexNode]] = { + val left = new java.util.ArrayList[RexNode]() + val right = new java.util.ArrayList[RexNode]() + val conjunctions = RelOptUtil.conjunctions(join.getCondition) + + for (rexNode <- conjunctions) { + if (rexNode.isInstanceOf[RexCall]) { + val call = rexNode.asInstanceOf[RexCall] + if (call.isA(SqlKind.EQUALS)) { + if ( + call.operands.get(0).isInstanceOf[RexInputRef] && call.operands + .get(1) + .isInstanceOf[RexInputRef] + ) { + val ref1 = call.operands.get(0).asInstanceOf[RexInputRef] + val ref2 = call.operands.get(1).asInstanceOf[RexInputRef] + val isLeft1 = fromJoinLeft(join, ref1) + val isLeft2 = fromJoinLeft(join, ref2) + if (isLeft1 != isLeft2) { + left.add(rexNode) + } else { + right.add(rexNode) + } + } else { + right.add(rexNode) + } + } else { + right.add(rexNode) + } + } else { + right.add(rexNode) + } + } + + /* conjunctions.partition { + case call: RexCall if call.isA(SqlKind.EQUALS) => + (call.operands.head, call.operands.last) match { + case (ref1: RexInputRef, ref2: RexInputRef) => + val isLeft1 = fromJoinLeft(join, ref1) + val isLeft2 = fromJoinLeft(join, ref2) + isLeft1 != isLeft2 + case _ => false + } + case _ => false + }*/ + + org.apache.flink.api.java.tuple.Tuple2.of(left, right) + } /** Put fields to a group that have equivalence relationships. */ - def getEquiFilterRelationshipGroup(equiJoinFilters: Seq[RexNode]): Seq[Seq[RexInputRef]] = { - val filterSets = mutable.ArrayBuffer[mutable.HashSet[RexInputRef]]() + def getEquiFilterRelationshipGroup( + equiJoinFilters: java.util.List[RexNode]): Seq[Seq[RexInputRef]] = { + val filterSets = new java.util.ArrayList[util.HashSet[RexInputRef]]() equiJoinFilters.foreach { case call: RexCall => require(call.isA(SqlKind.EQUALS)) val left = call.operands.head.asInstanceOf[RexInputRef] val right = call.operands.last.asInstanceOf[RexInputRef] - val set = filterSets.find(set => set.contains(left) || set.contains(right)) match { + val set = filterSets + .stream() + .filter(set => set.contains(left) || set.contains(right)) + .findFirst() + .orElseGet( + () => { + val s = new util.HashSet[RexInputRef]() + filterSets.add(s) + s + }) + /*val set = filterSets.find(set => set.contains(left) || set.contains(right)) match { case Some(s) => s case None => val s = new mutable.HashSet[RexInputRef]() - filterSets += s + filterSets.add(s) s - } - set += left - set += right + }*/ + set.add(left) + set.add(right) } filterSets.map(_.toSeq) } + /* private def getEquiFilterRelationshipGroup(equiJoinFilters: java.util.List[RexNode]): util.ArrayList[util.Set[RexInputRef]] = { + val res = new java.util.ArrayList[java.util.Set[RexInputRef]] + for (rexNode <- equiJoinFilters) { + if (rexNode.isInstanceOf[RexCall]) { + val call = rexNode.asInstanceOf[RexCall] + require (call.isA(SqlKind.EQUALS)) + val left = call.operands.head.asInstanceOf[RexInputRef] + val right = call.operands.last.asInstanceOf[RexInputRef] + var found = false + /*breakable { + for (refs <- res) { + if (refs.contains(left) || refs.contains(right)) { + refs.add(left) + refs.add(right) + found = true + break; + } + } + }*/ + if (!found) { + val set = new util.HashSet[RexInputRef]() + set.add(left) + set.add(right) + res.add(set) + } + } + } + res + }*/ + /** Make calls to a number of inputRefs, make sure that they both have a relationship. */ - def makeCalls(rexBuilder: RexBuilder, nodes: Seq[RexInputRef]): Seq[RexNode] = { - val calls = new mutable.ArrayBuffer[RexNode]() + def makeCalls(rexBuilder: RexBuilder, nodes: Seq[RexInputRef]): java.util.List[RexNode] = { + val calls = new java.util.ArrayList[RexNode]() if (nodes.length > 1) { val rex = nodes.head - nodes.drop(1).foreach(calls += rexBuilder.makeCall(EQUALS, rex, _)) + nodes.drop(1).foreach(t => calls.add(rexBuilder.makeCall(EQUALS, rex, t))) } calls } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/JoinReorderTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/JoinReorderTest.xml index 889fffa4591e5a..5ee3af181362e8 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/JoinReorderTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/JoinReorderTest.xml @@ -144,10 +144,10 @@ LogicalProject(a1=[$0], b1=[$1], c1=[$2], a2=[$3], b2=[$4], c2=[$5], a3=[$6], b3 (+(b2, b4), 100), >(*(b1, b2), 10), =(a2, a4))], select=[a2, b2, c2, a1, b1, c1, a5, b5, c5, a4, b4, c4], isBroadcast=[true], build=[right]) + :- HashJoin(joinType=[InnerJoin], where=[AND(>(+(b2, b4), 100), >(*(b1, b2), 10), =(a2, a1))], select=[a2, b2, c2, a1, b1, c1, a5, b5, c5, a4, b4, c4], isBroadcast=[true], build=[right]) : :- LegacyTableSourceScan(table=[[default_catalog, default_database, T2, source: [TestTableSource(a2, b2, c2)]]], fields=[a2, b2, c2]) : +- Exchange(distribution=[broadcast]) - : +- HashJoin(joinType=[InnerJoin], where=[=(a1, a4)], select=[a1, b1, c1, a5, b5, c5, a4, b4, c4], isBroadcast=[true], build=[right]) + : +- HashJoin(joinType=[InnerJoin], where=[=(a1, a5)], select=[a1, b1, c1, a5, b5, c5, a4, b4, c4], isBroadcast=[true], build=[right]) : :- LegacyTableSourceScan(table=[[default_catalog, default_database, T1, source: [TestTableSource(a1, b1, c1)]]], fields=[a1, b1, c1]) : +- Exchange(distribution=[broadcast]) : +- HashJoin(joinType=[InnerJoin], where=[=(a4, a5)], select=[a5, b5, c5, a4, b4, c4], isBroadcast=[true], build=[right]) @@ -2011,10 +2011,10 @@ LogicalProject(a1=[$0], b1=[$1], c1=[$2], a2=[$3], b2=[$4], c2=[$5], a3=[$6], b3