From e60388f115e78dde9e3ce13c697314a7c557f7f5 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 14 Mar 2024 18:54:24 +0100 Subject: [PATCH] fix --- .../rules/logical/ConvertToNotInOrInRule.java | 388 +++++++++++------- 1 file changed, 245 insertions(+), 143 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConvertToNotInOrInRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConvertToNotInOrInRule.java index e47f177e69ec6..23cf7ea11b447 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConvertToNotInOrInRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/ConvertToNotInOrInRule.java @@ -18,187 +18,289 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.planner.calcite.FlinkTypeFactory; +import org.apache.flink.table.types.logical.LogicalTypeRoot; + import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptUtil; -import org.apache.calcite.plan.RelRule -import org.apache.calcite.rel.core.Filter +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlBinaryOperator +import org.apache.calcite.sql.SqlBinaryOperator; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.tools.RelBuilder -import org.apache.flink.table.planner.calcite.FlinkTypeFactory -import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.calcite.tools.RelBuilder; +import org.immutables.value.Value; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AND; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT_EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT_IN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.OR; /** * Rule for converting a cascade of predicates to [[IN]] or [[NOT_IN]]. * - * For example, - * 1. convert predicate: (x = 1 OR x = 2 OR x = 3 OR x = 4) AND y = 5 to predicate: x IN (1, 2, 3, - * 4) AND y = 5. 2. convert predicate: (x <> 1 AND x <> 2 AND x <> 3 AND x <> 4) AND y = 5 to - * predicate: x NOT IN (1, 2, 3, 4) AND y = 5. + *

For example, 1. convert predicate: (x = 1 OR x = 2 OR x = 3 OR x = 4) AND y = 5 to predicate: + * x IN (1, 2, 3, 4) AND y = 5. 2. convert predicate: (x <> 1 AND x <> 2 AND x <> 3 AND x <> 4) AND + * y = 5 to predicate: x NOT IN (1, 2, 3, 4) AND y = 5. */ +@Value.Enclosing public class ConvertToNotInOrInRule - extends RelRule { /*(operand(classOf[Filter], any), "ConvertToNotInOrInRule") { */ - - // these threshold values are set by OptimizableHashSet benchmark test on different type. - // threshold for non-float and non-double type - private static final int THRESHOLD = 4; - // threshold for float and double type - private static final int FRACTIONAL_THRESHOLD = 20; - - public void onMatch(RelOptRuleCall call) { - Filter filter = call.rel(0); - RexNode condition = filter.getCondition(); - - // convert equal expression connected by OR to IN - val inExpr = convertToNotInOrIn(call.builder(), condition, IN); - // convert not-equal expression connected by AND to NOT_IN - val notInExpr = convertToNotInOrIn(call.builder(), inExpr.getOrElse(condition), NOT_IN) - - notInExpr match { - case Some(expr) => - val newFilter = filter.copy(filter.getTraitSet, filter.getInput, expr) - call.transformTo(newFilter) - case _ => + extends RelRule { + + public static final ConvertToNotInOrInRule INSTANCE = + ConvertToNotInOrInRule.ConvertToNotInOrInRuleConfig.DEFAULT.toRule(); + // these threshold values are set by OptimizableHashSet benchmark test on different type. + // threshold for non-float and non-double type + private static final int THRESHOLD = 4; + // threshold for float and double type + private static final int FRACTIONAL_THRESHOLD = 20; + + protected ConvertToNotInOrInRule(ConvertToNotInOrInRuleConfig config) { + super(config); + } + + public void onMatch(RelOptRuleCall call) { + Filter filter = call.rel(0); + RexNode condition = filter.getCondition(); + + // convert equal expression connected by OR to IN + Optional inExpr = convertToNotInOrIn(call.builder(), condition, IN); + // convert not-equal expression connected by AND to NOT_IN + Optional notInExpr = + convertToNotInOrIn(call.builder(), inExpr.orElse(condition), NOT_IN); + // check IN conversion if NOT_IN conversion is fail - inExpr match { - case Some(expr) => - val newFilter = filter.copy(filter.getTraitSet, filter.getInput, expr) - call.transformTo(newFilter) - case _ => // do nothing + if (notInExpr.isPresent() || inExpr.isPresent()) { + RexNode expr = notInExpr.orElseGet(inExpr::get); + Filter newFilter = filter.copy(filter.getTraitSet(), filter.getInput(), expr); + call.transformTo(newFilter); } } - } - /** Returns a condition decomposed by [[AND]] or [[OR]]. */ - private List decomposedBy(RexNode rex, SqlBinaryOperator operator) { - final SqlKind kind = operator.getKind(); - switch (kind) { + /** Returns a condition decomposed by [[AND]] or [[OR]]. */ + private List decomposedBy(RexNode rex, SqlBinaryOperator operator) { + final SqlKind kind = operator.getKind(); + switch (kind) { case AND: return RelOptUtil.conjunctions(rex); case OR: return RelOptUtil.disjunctions(rex); default: throw new AssertionError("Unsupported operator " + kind); - } - } - - /** - * Convert a cascade predicates to [[IN]] or [[NOT_IN]]. - * - * @param builder - * The [[RelBuilder]] to build the [[RexNode]]. - * @param rex - * The predicates to be converted. - * @return - * The converted predicates. - */ - private Optional convertToNotInOrIn( - RelBuilder builder, - RexNode rex, - SqlBinaryOperator toOperator) { - - // For example, when convert to [[IN]], fromOperator is [[EQUALS]]. - // We convert a cascade of [[EQUALS]] to [[IN]]. - // A connect operator is used to connect the fromOperator. - // A composed operator may contains sub [[IN]] or [[NOT_IN]]. - val (fromOperator, connectOperator, composedOperator) = toOperator match { - case IN => (EQUALS, OR, AND) - case NOT_IN => (NOT_EQUALS, AND, OR) + } } - val decomposed = decomposedBy(rex, connectOperator) - val combineMap = new java.util.HashMap[String, mutable.ListBuffer[RexCall]] - val rexBuffer = new mutable.ArrayBuffer[RexNode] - var beenConverted = false - - // traverse decomposed predicates - decomposed.foreach { - case call: RexCall => - call.getOperator match { - // put same predicates into combine map - case `fromOperator` => - (call.operands(0), call.operands(1)) match { - case (ref, _: RexLiteral) => - combineMap.getOrElseUpdate(ref.toString, mutable.ListBuffer[RexCall]()) += call - case (l: RexLiteral, ref) => - combineMap.getOrElseUpdate(ref.toString, mutable.ListBuffer[RexCall]()) += - call.clone(call.getType, List(ref, l)) - case _ => rexBuffer += call - } - - // process sub predicates - case `composedOperator` => - val newRex = decomposedBy(call, composedOperator).map { - r => - convertToNotInOrIn(builder, r, toOperator) match { - case Some(ex) => - beenConverted = true - ex - case None => r - } - } - composedOperator match { - case AND => rexBuffer += builder.and(newRex) - case OR => rexBuffer += builder.or(newRex) - } + /** + * Convert a cascade predicates to [[IN]] or [[NOT_IN]]. + * + * @param builder The [[RelBuilder]] to build the [[RexNode]]. + * @param rex The predicates to be converted. + * @return The converted predicates. + */ + private Optional convertToNotInOrIn( + RelBuilder builder, RexNode rex, SqlBinaryOperator toOperator) { - case _ => rexBuffer += call + // For example, when convert to [[IN]], fromOperator is [[EQUALS]]. + // We convert a cascade of [[EQUALS]] to [[IN]]. + // A connect operator is used to connect the fromOperator. + // A composed operator may contains sub [[IN]] or [[NOT_IN]]. + SqlBinaryOperator fromOperator; + SqlBinaryOperator connectOperator; + SqlBinaryOperator composedOperator; + switch (toOperator.kind) { + case IN: + fromOperator = EQUALS; + connectOperator = OR; + composedOperator = AND; + break; + case NOT_IN: + fromOperator = NOT_EQUALS; + connectOperator = AND; + composedOperator = OR; + break; + default: + throw new AssertionError("Unsupported operator " + toOperator); } - case rex => rexBuffer += rex - } + List decomposed = decomposedBy(rex, connectOperator); + Map> combineMap = new HashMap<>(); + List rexBuffer = new ArrayList<>(); + boolean[] beenConverted = new boolean[] {false}; + + // traverse decomposed predicates + decomposed.forEach( + rexNode -> { + if (rexNode instanceof RexCall) { + RexCall call = (RexCall) rexNode; + + /** + * case `fromOperator` => (call.operands(0), call.operands(1)) match { case + * (ref, _: RexLiteral) => combineMap.getOrElseUpdate(ref.toString, + * mutable.ListBuffer[RexCall]()) += call case (l: RexLiteral, ref) => + * combineMap.getOrElseUpdate(ref.toString, mutable.ListBuffer[RexCall]()) + * += call.clone(call.getType, List(ref, l)) case _ => rexBuffer += call } + */ + if (call.getOperator() == fromOperator) { + // put same predicates into combine map + RexNode rexNode0 = call.operands.get(0); + RexNode rexNode1 = call.operands.get(1); + if (rexNode1 instanceof RexLiteral) { + combineMap + .computeIfAbsent( + rexNode0.toString(), k -> new ArrayList<>()) + .add(call); + } else if (rexNode0 instanceof RexLiteral) { + combineMap + .computeIfAbsent( + rexNode1.toString(), k -> new ArrayList<>()) + .add( + call.clone( + call.getType(), + Arrays.asList(rexNode1, rexNode0))); + } else { + rexBuffer.add(call); + } + } else if (call.getOperator() == composedOperator) { - combineMap.values.foreach { - list => - if (needConvert(list.toList)) { - val inputRef = list.head.getOperands.head - val values = list.map(_.getOperands.last) - val call = toOperator match { - case IN => builder.getRexBuilder.makeIn(inputRef, values) - case NOT_IN => - builder.getRexBuilder - .makeCall(NOT, builder.getRexBuilder.makeIn(inputRef, values)) - } - rexBuffer += call - beenConverted = true + /** + * val newRex = decomposedBy(call, composedOperator).map { r => + * convertToNotInOrIn(builder, r, toOperator) match { case Some(ex) => + * beenConverted = true ex case None => r } } composedOperator match { + * case AND => rexBuffer += builder.and(newRex) case OR => rexBuffer += + * builder.or(newRex) } + */ + // process sub predicates + List newRex = + decomposedBy(call, composedOperator).stream() + .map( + r -> { + Optional ex = + convertToNotInOrIn( + builder, r, toOperator); + if (ex.isPresent()) { + beenConverted[0] = true; + return ex.get(); + } else { + return r; + } + }) + .collect(Collectors.toList()); + switch (composedOperator.kind) { + case AND: + rexBuffer.add(builder.and(newRex)); + break; + case OR: + rexBuffer.add(builder.or(newRex)); + break; + default: + throw new AssertionError( + "Unsupported operator " + composedOperator); + } + } else { + rexBuffer.add(call); + } + } else { + rexBuffer.add(rexNode); + } + }); + + combineMap + .values() + .forEach( + list -> { + if (needConvert(list)) { + RexNode inputRef = list.get(0).getOperands().get(0); + List values = + list.stream() + .map( + call -> + call.getOperands() + .get( + call.getOperands() + .size() + - 1)) + .collect(Collectors.toList()); + RexNode call = + toOperator == IN + ? builder.getRexBuilder().makeIn(inputRef, values) + : builder.getRexBuilder() + .makeCall( + NOT, + builder.getRexBuilder() + .makeIn(inputRef, values)); + rexBuffer.add(call); + beenConverted[0] = true; + } else { + switch (connectOperator.kind) { + case AND: + rexBuffer.add(builder.and(list)); + break; + case OR: + rexBuffer.add(builder.or(list)); + break; + default: + throw new AssertionError( + "Unsupported operator " + connectOperator); + } + } + }); + + if (beenConverted[0]) { + // return result if has been converted + if (connectOperator == AND) { + return Optional.of(builder.and(rexBuffer)); + } else { + return Optional.of(builder.or(rexBuffer)); + } } else { - connectOperator match { - case AND => rexBuffer += builder.and(list) - case OR => rexBuffer += builder.or(list) - } + return Optional.empty(); } } - if (beenConverted) { - // return result if has been converted - connectOperator match { - case AND => Some(builder.and(rexBuffer)) - case OR => Some(builder.or(rexBuffer)) - } - } else { - None + private boolean needConvert(List rexNodes) { + RexNode inputRef = rexNodes.get(0).getOperands().get(0); + LogicalTypeRoot logicalTypeRoot = + FlinkTypeFactory.toLogicalType(inputRef.getType()).getTypeRoot(); + switch (logicalTypeRoot) { + case FLOAT: + case DOUBLE: + return rexNodes.size() >= FRACTIONAL_THRESHOLD; + default: + return rexNodes.size() >= THRESHOLD; + } } - } - - private boolean needConvert(List rexNodes) { - RexNode inputRef = rexNodes.get(0).getOperands().get(0); - LogicalTypeRoot logicalTypeRoot = FlinkTypeFactory.toLogicalType(inputRef.getType()).getTypeRoot(); - switch (logicalTypeRoot) { - case FLOAT: - case DOUBLE: - return rexNodes.size() >= FRACTIONAL_THRESHOLD; - default: - return rexNodes.size() >= THRESHOLD; + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface ConvertToNotInOrInRuleConfig extends RelRule.Config { + ConvertToNotInOrInRule.ConvertToNotInOrInRuleConfig DEFAULT = + ImmutableConvertToNotInOrInRule.ConvertToNotInOrInRuleConfig.builder() + .build() + .withOperandSupplier(b0 -> b0.operand(Filter.class).anyInputs()) + .withDescription("ConvertToNotInOrInRule"); + + @Override + default ConvertToNotInOrInRule toRule() { + return new ConvertToNotInOrInRule(this); + } } - } } - +/*(operand(classOf[Filter], any), "ConvertToNotInOrInRule") { */ +/* object ConvertToNotInOrInRule { val INSTANCE = new ConvertToNotInOrInRule } +*/