Skip to content

Commit

Permalink
[DimExpr] Add substitute DimExpr util (PaddlePaddle#60493)
Browse files Browse the repository at this point in the history
* add SubstituteDimExpr

* Fix compile error

* Code format

* Polish DimExprUtilTest

* Change namesapce

* Fix unittest

* Polish DimExprUtilTest
  • Loading branch information
jiahy0825 authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent fb78fa5 commit 4bdc02f
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 1 deletion.
4 changes: 3 additions & 1 deletion paddle/cinn/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ gather_srcs(
nvgpu_dev_info.cc
integer_set.cc
dim_expr_simplify.cc
dim_expr_converter.cc)
dim_expr_converter.cc
dim_expr_util.cc)

cinn_cc_test(test_equation_graph_topo_walker SRCS
equation_graph_topo_walker_test.cc DEPS gtest glog)
Expand All @@ -48,6 +49,7 @@ if(WITH_CUDA)
gtest glog)
endif()
if(NOT CINN_ONLY)
cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore)
cinn_cc_test(dim_expr_simplify_test SRCS dim_expr_simplify_test.cc DEPS
cinncore)
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
Expand Down
111 changes: 111 additions & 0 deletions paddle/cinn/common/dim_expr_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_util.h"

namespace cinn::common {
using namespace symbol; // NOLINT

namespace {

class SubstituteDimExprHelper final {
public:
explicit SubstituteDimExprHelper(
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
pattern_to_replacement)
: pattern_to_replacement_(pattern_to_replacement) {}

std::optional<DimExpr> Substitute(const DimExpr& dim_expr) {
auto iter = pattern_to_replacement_.find(dim_expr);
if (iter != pattern_to_replacement_.end()) return iter->second;
return std::visit([&](const auto& impl) { return SubstituteImpl(impl); },
dim_expr.variant());
}

private:
std::optional<DimExpr> SubstituteImpl(const std::int64_t& value) {
// `Substitute` has handled the case that `value` is matched.
return std::nullopt;
}
std::optional<DimExpr> SubstituteImpl(const std::string& value) {
// `Substitute` has handled the case that `value` is matched.
return std::nullopt;
}

std::optional<DimExpr> SubstituteImpl(const Negative<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}
std::optional<DimExpr> SubstituteImpl(const Reciprocal<DimExpr>& dim_expr) {
return SubstituteUnary(dim_expr);
}

template <typename T>
std::optional<DimExpr> SubstituteUnary(const T& dim_expr) {
const auto& operand = dim_expr->data;
const auto& substituted_operand = Substitute(operand);
if (!substituted_operand.has_value()) return std::nullopt;
return T{substituted_operand.value()};
}

std::optional<DimExpr> SubstituteImpl(const Add<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Mul<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Max<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Min<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

std::optional<DimExpr> SubstituteImpl(const Broadcast<DimExpr>& dim_expr) {
return SubstituteVariadic(dim_expr);
}

template <typename T>
std::optional<DimExpr> SubstituteVariadic(const T& dim_expr) {
const auto& operands = *(dim_expr.operands);
List<DimExpr> substituted_operands{};
size_t replace_cnt = 0;
for (const auto& operand : operands) {
const auto& substituted_operand = Substitute(operand);
replace_cnt += substituted_operand.has_value();
substituted_operands->push_back(substituted_operand.has_value()
? substituted_operand.value()
: operand);
}
if (replace_cnt == 0) return std::nullopt;
return T{substituted_operands};
}

std::unordered_map<symbol::DimExpr, symbol::DimExpr> pattern_to_replacement_;
};

} // namespace

symbol::DimExpr SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
pattern_to_replacement) {
const auto& opt_replaced =
SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr);
return opt_replaced.has_value() ? opt_replaced.value() : dim_expr;
}

} // namespace cinn::common
29 changes: 29 additions & 0 deletions paddle/cinn/common/dim_expr_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <optional>
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"

namespace cinn::common {

symbol::DimExpr SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>&
pattern_to_replacement);

}
43 changes: 43 additions & 0 deletions paddle/cinn/common/dim_expr_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/common/dim_expr_util.h"

#include "gtest/gtest.h"

namespace cinn::common {
using namespace symbol; // NOLINT

namespace {
DimExpr CreateExampleDimExpr() {
DimExpr sym0 = DimExpr("S0");
DimExpr sym1 = DimExpr("S1");
DimExpr constant = DimExpr(2);
return (sym0 - sym1) * constant / sym0;
}
} // namespace

TEST(DimExprUtil, Substitute) {
DimExpr dim_expr = CreateExampleDimExpr();
std::unordered_map<symbol::DimExpr, symbol::DimExpr> naive_to_full_name{
{DimExpr("S0"), DimExpr("symbol0")}, {DimExpr("S1"), DimExpr("symbol1")}};
std::unordered_map<symbol::DimExpr, symbol::DimExpr> full_name_to_naive{
{DimExpr("symbol0"), DimExpr("S0")}, {DimExpr("symbol1"), DimExpr("S1")}};

const auto& mid_expr = SubstituteDimExpr(dim_expr, naive_to_full_name);
const auto& ret_expr = SubstituteDimExpr(mid_expr, full_name_to_naive);
ASSERT_EQ(ret_expr, dim_expr);
}

} // namespace cinn::common

0 comments on commit 4bdc02f

Please sign in to comment.