Skip to content

Commit

Permalink
adjustTopnProject
Browse files Browse the repository at this point in the history
  • Loading branch information
englefly committed Dec 25, 2024
1 parent e5add51 commit 68769f8
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import org.apache.doris.nereids.rules.rewrite.AdjustPreAggStatus;
import org.apache.doris.nereids.rules.rewrite.AdjustTopNProject;
import org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
import org.apache.doris.nereids.rules.rewrite.CTEInline;
Expand Down Expand Up @@ -455,7 +456,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
new CollectCteConsumerOutput()
)
),
topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new)
topic("Collect used column", custom(RuleType.COLLECT_COLUMNS, QueryColumnCollector::new),
topic("Adjust topN project",
topDown(new MergeProjects(),
new AdjustTopNProject()))
)
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ public enum RuleType {
COMPRESSED_MATERIALIZE_SORT(RuleTypeClass.REWRITE),
COMPRESSED_MATERIALIZE_REPEAT(RuleTypeClass.REWRITE),
PUSH_DOWN_ENCODE_SLOT(RuleTypeClass.REWRITE),
ADJUST_TOPN_PROJECT(RuleTypeClass.REWRITE),
DECOUPLE_DECODE_ENCODE_SLOT(RuleTypeClass.REWRITE),
SIMPLIFY_ENCODE_DECODE(RuleTypeClass.REWRITE),
NORMALIZE_AGGREGATE(RuleTypeClass.REWRITE),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
*
* try to reduce shuffle cost of topN operator
*
* topn(orderKey=[a])
* --> project(a+1 as x, a+2 as y, a)
* --> any(output(a))
* =>
* project(a+1 as x, a+2 as y, a)
* --> topn(orderKey=[a])
* --> any(output(a))
*
*/
public class AdjustTopNProject extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalTopN(logicalProject())
.then(topN -> adjust(topN)).toRule(RuleType.ADJUST_TOPN_PROJECT);
}

private Plan adjust(LogicalTopN<LogicalProject<Plan>> topN) {
LogicalProject<Plan> project = topN.child();
Set<Slot> projectInputSlots = project.getInputSlots();
Map<SlotReference, SlotReference> keyAsKey = new HashMap<>();
for (NamedExpression proj : project.getProjects()) {
if (proj instanceof Alias && ((Alias) proj).child(0) instanceof SlotReference) {
keyAsKey.put((SlotReference) ((Alias) proj).toSlot(), (SlotReference) ((Alias) proj).child());
}
}
boolean match = true;
List<OrderKey> newOrderKeys = new ArrayList<>();
for (OrderKey orderKey : topN.getOrderKeys()) {
Expression orderExpr = orderKey.getExpr();
if (orderExpr instanceof SlotReference) {
if (projectInputSlots.contains(orderExpr)) {
newOrderKeys.add(orderKey);
} else if (keyAsKey.containsKey(orderExpr)) {
newOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderExpr)));
} else {
match = false;
break;
}
} else {
match = false;
break;
}
}
if (match) {
if (project.getProjects().size() >= project.getInputSlots().size()) {
LogicalTopN newTopN = topN.withChildren(project.children()).withOrderKeys(newOrderKeys);
project = (LogicalProject<Plan>) project.withChildren(newTopN);
return project;
}
}
return topN;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import com.google.common.collect.ImmutableList;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DecodeAsVarchar;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.junit.jupiter.api.Test;

import java.util.List;

public class AdjustTopNProjectTest implements MemoPatternMatchSupported {
LogicalOlapScan score = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.score);
@Test
void testTopNProject() {
Slot sid = score.getOutput().get(0);
Alias decodeSid = new Alias(new DecodeAsVarchar(sid));
Alias aliasSid = new Alias(sid);
LogicalProject<LogicalOlapScan> bottomProject = new LogicalProject<>(
ImmutableList.of(decodeSid, aliasSid), score);
List<OrderKey> orderKeys = ImmutableList.of(new OrderKey(aliasSid.toSlot(), true, true));
LogicalTopN topN = new LogicalTopN(orderKeys, 1, 1, bottomProject);
LogicalProject topProject = new LogicalProject(ImmutableList.of(decodeSid.toSlot()), topN);
PlanChecker.from(MemoTestUtils.createConnectContext(), topProject)
.applyTopDown(ImmutableList.of(new MergeProjects().build(), new AdjustTopNProject().build()))
.matches(
logicalProject(
logicalTopN(logicalOlapScan())));
}
}

0 comments on commit 68769f8

Please sign in to comment.