From 8f94508b18920aaf2004d1026c99fe90de1b59eb Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 14 Mar 2024 18:58:57 +0100 Subject: [PATCH] [FLINK-34675][table] Migrate AggregateReduceGroupingRule to java --- .../logical/AggregateReduceGroupingRule.java | 174 ++++++++++++++++++ .../logical/AggregateReduceGroupingRule.scala | 129 ------------- tools/maven/suppressions.xml | 2 +- 3 files changed, 175 insertions(+), 130 deletions(-) create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.scala diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.java new file mode 100644 index 00000000000000..2277c265ea34fa --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.plan.rules.logical; + +import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable; +import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Aggregate.Group; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.core.RelFactories; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.immutables.value.Value; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Planner rule that reduces unless grouping columns. + * + *

Find (minimum) unique group for the grouping columns, and use it as new grouping columns. + */ +@Value.Enclosing +public class AggregateReduceGroupingRule + extends RelRule { + + public static final AggregateReduceGroupingRule INSTANCE = + AggregateReduceGroupingRule.AggregateReduceGroupingRuleConfig.DEFAULT.toRule(); + + protected AggregateReduceGroupingRule(AggregateReduceGroupingRuleConfig config) { + super(config); + } + + @Override + public boolean matches(RelOptRuleCall call) { + Aggregate agg = call.rel(0); + return agg.getGroupCount() > 1 && agg.getGroupType() == Group.SIMPLE; + } + + @Override + public void onMatch(RelOptRuleCall call) { + Aggregate agg = call.rel(0); + RelDataType aggRowType = agg.getRowType(); + RelNode input = agg.getInput(); + RelDataType inputRowType = input.getRowType(); + ImmutableBitSet originalGrouping = agg.getGroupSet(); + FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery()); + ImmutableBitSet newGrouping = fmq.getUniqueGroups(input, originalGrouping); + ImmutableBitSet uselessGrouping = originalGrouping.except(newGrouping); + if (uselessGrouping.isEmpty()) { + return; + } + + // new agg: new grouping + aggCalls for dropped grouping + original aggCalls + Map indexOldToNewMap = new HashMap<>(); + List newGroupingList = newGrouping.toList(); + int idxOfNewGrouping = 0; + int idxOfAggCallsForDroppedGrouping = newGroupingList.size(); + int index = 0; + for (int column : originalGrouping) { + if (newGroupingList.contains(column)) { + indexOldToNewMap.put(index, idxOfNewGrouping); + idxOfNewGrouping++; + } else { + indexOldToNewMap.put(index, idxOfAggCallsForDroppedGrouping); + idxOfAggCallsForDroppedGrouping++; + } + index++; + } + + assert (indexOldToNewMap.size() == originalGrouping.cardinality()); + + // the indices of aggCalls (or NamedProperties for WindowAggregate) do not change + for (int i = originalGrouping.cardinality(); i < aggRowType.getFieldCount(); i++) { + indexOldToNewMap.put(i, i); + } + + List aggCallsForDroppedGrouping = + uselessGrouping.asList().stream() + .map( + column -> { + RelDataType fieldType = + inputRowType.getFieldList().get(column).getType(); + String fieldName = inputRowType.getFieldNames().get(column); + return AggregateCall.create( + FlinkSqlOperatorTable.AUXILIARY_GROUP, + false, + false, + false, + ImmutableList.of(column), + -1, + null, + RelCollations.EMPTY, + fieldType, + fieldName); + }) + .collect(Collectors.toList()); + + aggCallsForDroppedGrouping.addAll(agg.getAggCallList()); + Aggregate newAgg = + agg.copy( + agg.getTraitSet(), + input, + newGrouping, + ImmutableList.of(newGrouping), + aggCallsForDroppedGrouping); + RelBuilder builder = call.builder(); + builder.push(newAgg); + List projects = + IntStream.range(0, newAgg.getRowType().getFieldCount()) + .mapToObj( + i -> { + Integer refIndex = indexOldToNewMap.get(i); + if (refIndex == null) { + throw new IllegalArgumentException("Illegal index: " + i); + } + return builder.field(refIndex); + }) + .collect(Collectors.toList()); + builder.project(projects, aggRowType.getFieldNames()); + call.transformTo(builder.build()); + } + + /** Rule configuration. */ + @Value.Immutable(singleton = false) + public interface AggregateReduceGroupingRuleConfig extends RelRule.Config { + AggregateReduceGroupingRule.AggregateReduceGroupingRuleConfig DEFAULT = + ImmutableAggregateReduceGroupingRule.AggregateReduceGroupingRuleConfig.builder() + .relBuilderFactory(RelFactories.LOGICAL_BUILDER) + .operandSupplier(b0 -> b0.operand(Aggregate.class).anyInputs()) + .description("AggregateReduceGroupingRule") + .build(); + + @Override + default AggregateReduceGroupingRule toRule() { + return new AggregateReduceGroupingRule(this); + } + } +} + +/*object AggregateReduceGroupingRule { + val INSTANCE = new AggregateReduceGroupingRule(RelFactories.LOGICAL_BUILDER) +}*/ +/*( +operand(classOf[Aggregate], any), +relBuilderFactory, +"AggregateReduceGroupingRule") {*/ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.scala deleted file mode 100644 index 6ee02c5e1da2f2..00000000000000 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRule.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.flink.table.planner.plan.rules.logical - -import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable -import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery - -import com.google.common.collect.ImmutableList -import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} -import org.apache.calcite.plan.RelOptRule.{any, operand} -import org.apache.calcite.rel.RelCollations -import org.apache.calcite.rel.core.{Aggregate, AggregateCall, RelFactories} -import org.apache.calcite.rel.core.Aggregate.Group -import org.apache.calcite.tools.RelBuilderFactory - -import scala.collection.JavaConversions._ -import scala.collection.mutable - -/** - * Planner rule that reduces unless grouping columns. - * - * Find (minimum) unique group for the grouping columns, and use it as new grouping columns. - */ -class AggregateReduceGroupingRule(relBuilderFactory: RelBuilderFactory) - extends RelOptRule( - operand(classOf[Aggregate], any), - relBuilderFactory, - "AggregateReduceGroupingRule") { - - override def matches(call: RelOptRuleCall): Boolean = { - val agg: Aggregate = call.rel(0) - agg.getGroupCount > 1 && agg.getGroupType == Group.SIMPLE - } - - override def onMatch(call: RelOptRuleCall): Unit = { - val agg: Aggregate = call.rel(0) - val aggRowType = agg.getRowType - val input = agg.getInput - val inputRowType = input.getRowType - val originalGrouping = agg.getGroupSet - val fmq = FlinkRelMetadataQuery.reuseOrCreate(call.getMetadataQuery) - val newGrouping = fmq.getUniqueGroups(input, originalGrouping) - val uselessGrouping = originalGrouping.except(newGrouping) - if (uselessGrouping.isEmpty) { - return - } - - // new agg: new grouping + aggCalls for dropped grouping + original aggCalls - val indexOldToNewMap = new mutable.HashMap[Int, Int]() - val newGroupingList = newGrouping.toList - var idxOfNewGrouping = 0 - var idxOfAggCallsForDroppedGrouping = newGroupingList.size() - originalGrouping.zipWithIndex.foreach { - case (column, oldIdx) => - val newIdx = if (newGroupingList.contains(column)) { - val p = idxOfNewGrouping - idxOfNewGrouping += 1 - p - } else { - val p = idxOfAggCallsForDroppedGrouping - idxOfAggCallsForDroppedGrouping += 1 - p - } - indexOldToNewMap += (oldIdx -> newIdx) - } - require(indexOldToNewMap.size == originalGrouping.cardinality()) - - // the indices of aggCalls (or NamedProperties for WindowAggregate) do not change - (originalGrouping.cardinality() until aggRowType.getFieldCount).foreach { - index => indexOldToNewMap += (index -> index) - } - - val aggCallsForDroppedGrouping = uselessGrouping.map { - column => - val fieldType = inputRowType.getFieldList.get(column).getType - val fieldName = inputRowType.getFieldNames.get(column) - AggregateCall.create( - FlinkSqlOperatorTable.AUXILIARY_GROUP, - false, - false, - false, - ImmutableList.of(column), - -1, - null, - RelCollations.EMPTY, - fieldType, - fieldName) - }.toList - - val newAggCalls = aggCallsForDroppedGrouping ++ agg.getAggCallList - val newAgg = agg.copy( - agg.getTraitSet, - input, - newGrouping, - ImmutableList.of(newGrouping), - newAggCalls - ) - val builder = call.builder() - builder.push(newAgg) - val projects = (0 until aggRowType.getFieldCount).map { - index => - val refIndex = indexOldToNewMap.getOrElse( - index, - throw new IllegalArgumentException(s"Illegal index: $index")) - builder.field(refIndex) - } - builder.project(projects, aggRowType.getFieldNames) - call.transformTo(builder.build()) - } -} - -object AggregateReduceGroupingRule { - val INSTANCE = new AggregateReduceGroupingRule(RelFactories.LOGICAL_BUILDER) -} diff --git a/tools/maven/suppressions.xml b/tools/maven/suppressions.xml index e4203d8361a027..890d3170c9f408 100644 --- a/tools/maven/suppressions.xml +++ b/tools/maven/suppressions.xml @@ -54,7 +54,7 @@ under the License.