diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 80ecad93997db8..e5e237bf7fe340 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -638,7 +638,7 @@ Operation *BuildOpFrom( std::back_inserter(to_create_argument.inputs), [&value_map](const pir::OpOperand &operand) { // Operand -> OpResult - return OpResult::dyn_cast_from(value_map[operand.source()]); + return value_map[operand.source()]; }); auto *cloned_op = Operation::Create(std::move(to_create_argument)); @@ -834,11 +834,8 @@ SplitedResult ForwardBackwardSplit( pir::StrAttribute::get( ctx, std::string("output_") + std::to_string(counter))}, }; - pir::Operation *operation = - pir::Operation::Create({OpResult::dyn_cast_from(forward_value_map[v])}, - attribute_map, - {}, - op_info); + pir::Operation *operation = pir::Operation::Create( + {forward_value_map[v]}, attribute_map, {}, op_info); forward_program->block()->push_back(operation); counter += 1; }; @@ -857,10 +854,7 @@ SplitedResult ForwardBackwardSplit( ctx, std::string("output_") + std::to_string(counter))}, }; pir::Operation *operation = pir::Operation::Create( - {OpResult::dyn_cast_from(backward_value_map.at(v))}, - attribute_map, - {}, - op_info); + {backward_value_map.at(v)}, attribute_map, {}, op_info); backward_program->block()->push_back(operation); counter += 1; }; diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index 538b48bed6a9c9..314dbe3f3706e8 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -103,6 +103,18 @@ class OpInterfaceBase : public OpBase { } }; +template +struct VerifyTraitOrInterface { + static void call(Operation *) {} +}; + +template +struct VerifyTraitOrInterface()))> { + static void call(Operation *op) { T::Verify(op); } +}; + template class Op : public OpBase { public: @@ -139,12 +151,13 @@ class Op : public OpBase { class EmptyOp : public Op {}; return sizeof(ConcreteOp) == sizeof(EmptyOp); } - // Implementation of `VerifyInvariantsFn` OperationName hook. static void VerifyInvariants(Operation *op) { static_assert(HasNoDataMembers(), "Op class shouldn't define new data members"); op->dyn_cast().Verify(); + (void)std::initializer_list{ + 0, (VerifyTraitOrInterface::call(op), 0)...}; } }; diff --git a/paddle/pir/core/op_result.h b/paddle/pir/core/op_result.h index 781ed931481038..8860473fe33395 100644 --- a/paddle/pir/core/op_result.h +++ b/paddle/pir/core/op_result.h @@ -32,7 +32,6 @@ class IR_API OpResult : public Value { Operation *owner() const; uint32_t index() const; bool operator==(const OpResult &other) const; - static OpResult dyn_cast_from(Value value); private: friend Operation; @@ -40,6 +39,7 @@ class IR_API OpResult : public Value { // Access classof annd dyn_cast_from. friend Value; static bool classof(Value value); + static OpResult dyn_cast_from(Value value); }; } // namespace pir diff --git a/paddle/pir/core/operation_utils.h b/paddle/pir/core/operation_utils.h index c868731ca4753f..36dcca7bd0d531 100644 --- a/paddle/pir/core/operation_utils.h +++ b/paddle/pir/core/operation_utils.h @@ -84,6 +84,12 @@ struct OperationArgument { /// Add an array of named attributes. template void AddAttributes(InputIt first, InputIt last); + + template + void AddAttributes(const AttrContainer& attr_container) { + AddAttributes(std::begin(attr_container), std::end(attr_container)); + } + /// Get the context held by this operation state. IrContext* getContext() const { return info.ir_context(); } diff --git a/test/cpp/pir/core/CMakeLists.txt b/test/cpp/pir/core/CMakeLists.txt index a131f84fe313cb..355738d3baef53 100644 --- a/test/cpp/pir/core/CMakeLists.txt +++ b/test/cpp/pir/core/CMakeLists.txt @@ -8,7 +8,14 @@ cc_test_old( pd_op_dialect) cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS pir gtest) cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS pir gtest) -cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS pir gtest) +cc_test_old( + ir_op_test + SRCS + ir_op_test.cc + DEPS + pir + gtest + test_dialect) cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS pir gtest) cc_test_old(ir_builder_test SRCS ir_builder_test.cc DEPS pir gtest) cc_test_old( diff --git a/test/cpp/pir/core/ir_op_test.cc b/test/cpp/pir/core/ir_op_test.cc index bfc03e66944e9c..0a5317c36cc4fa 100644 --- a/test/cpp/pir/core/ir_op_test.cc +++ b/test/cpp/pir/core/ir_op_test.cc @@ -27,49 +27,8 @@ #include "paddle/pir/core/op_base.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/region.h" - -/// \brief Define built-in Trait, derived from OpTraitBase. -class ReadOnlyTrait : public pir::OpTraitBase { - public: - explicit ReadOnlyTrait(pir::Operation *op) - : pir::OpTraitBase(op) {} -}; -IR_DECLARE_EXPLICIT_TYPE_ID(ReadOnlyTrait) -IR_DEFINE_EXPLICIT_TYPE_ID(ReadOnlyTrait) - -/// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and -/// Models need to be defined within the class. Concept defines abstract -/// interface functions, and Model is a template class that defines the specific -/// implementation of interface functions based on template parameters. -class InferShapeInterface : public pir::OpInterfaceBase { - public: - struct Concept { - explicit Concept(void (*infer_shape)(pir::Operation *)) - : infer_shape_(infer_shape) {} - void (*infer_shape_)(pir::Operation *); - }; - - template - struct Model : public Concept { - static void InferShape(pir::Operation *op) { - ConcreteOp concret_op = ConcreteOp(op); - if (concret_op == nullptr) throw("concret_op is nullptr"); - concret_op.InferShape(); - } - - Model() : Concept(InferShape) {} - }; - - InferShapeInterface(pir::Operation *op, Concept *impl) - : pir::OpInterfaceBase(op), impl_(impl) {} - - void InferShape() { impl_->infer_shape_(operation()); } - - private: - Concept *impl_; -}; -IR_DECLARE_EXPLICIT_TYPE_ID(InferShapeInterface) -IR_DEFINE_EXPLICIT_TYPE_ID(InferShapeInterface) +#include "test/cpp/pir/tools/test_dialect.h" +#include "test/cpp/pir/tools/test_op.h" pir::AttributeMap CreateAttributeMap( const std::vector &attribute_names, @@ -84,139 +43,15 @@ pir::AttributeMap CreateAttributeMap( return attr_map; } -// Define op1. -class Operation1 : public pir::Op { - public: - using Op::Op; - static const char *name() { return "test.operation1"; } - static constexpr uint32_t attributes_num = 2; - static const char *attributes_name[attributes_num]; // NOLINT - void Verify() { - auto &attributes = this->attributes(); - if (attributes.count("op1_attr1") == 0 || - !attributes.at("op1_attr1").isa()) { - throw("Type of attribute: parameter_name is not right."); - } - if (attributes.count("op1_attr2") == 0 || - !attributes.at("op1_attr2").isa()) { - throw("Type of attribute: parameter_name is not right."); - } - } - static void Build(const pir::Builder &builder, - pir::OperationArgument &argument) { // NOLINT - std::vector output_types = { - pir::Float32Type::get(builder.ir_context())}; - std::unordered_map attributes = - CreateAttributeMap({"op1_attr1", "op1_attr2"}, - {"op1_attr1", "op1_attr2"}); - argument.AddOutputs(output_types.begin(), output_types.end()); - argument.AddAttributes(attributes.begin(), attributes.end()); - } -}; -const char *Operation1::attributes_name[attributes_num] = { // NOLINT - "op1_attr1", - "op1_attr2"}; - -IR_DECLARE_EXPLICIT_TYPE_ID(Operation1) -IR_DEFINE_EXPLICIT_TYPE_ID(Operation1) - -// Define op2. -class Operation2 - : public pir::Op { - public: - using Op::Op; - static const char *name() { return "test.operation2"; } - static constexpr uint32_t attributes_num = 2; - static const char *attributes_name[attributes_num]; // NOLINT - void Verify() { - auto &attributes = this->attributes(); - if (attributes.count("op2_attr1") == 0 || - (!attributes.at("op2_attr1").isa())) { - throw("Type of attribute: parameter_name is not right."); - } - if (attributes.count("op2_attr2") == 0 || - (!attributes.at("op2_attr2").isa())) { - throw("Type of attribute: parameter_name is not right."); - } - } - static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } -}; -const char *Operation2::attributes_name[attributes_num] = { // NOLINT - "op2_attr1", - "op2_attr2"}; -IR_DECLARE_EXPLICIT_TYPE_ID(Operation2) -IR_DEFINE_EXPLICIT_TYPE_ID(Operation2) - -// Define a dialect, op1 and op2 will be registered by this dialect. -class TestDialect : public pir::Dialect { - public: - explicit TestDialect(pir::IrContext *context) - : pir::Dialect(name(), context, pir::TypeId::get()) { - initialize(); - } - static const char *name() { return "test"; } - - void PrintOperation(pir::Operation *op, - pir::IrPrinter &printer) const override { - printer.PrintOpResult(op); - printer.os << " ="; - - printer.os << " \"" << op->name() << "\""; - printer.PrintOpOperands(op); - } - - private: - void initialize() { RegisterOps(); } -}; -IR_DECLARE_EXPLICIT_TYPE_ID(TestDialect) -IR_DEFINE_EXPLICIT_TYPE_ID(TestDialect) - -TEST(op_test, op_test) { - // (1) Register Dialect, Operation1, Operation2 into IrContext. - pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); - EXPECT_EQ(test_dialect != nullptr, true); - - // (2) Get registered operations. - std::string op1_name = Operation1::name(); - pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); - EXPECT_TRUE(op1_info); - std::string op2_name = Operation2::name(); - pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); - EXPECT_TRUE(op2_info); - EXPECT_EQ(op1_info.HasTrait(), false); - EXPECT_EQ(op1_info.HasInterface(), false); - EXPECT_EQ(op2_info.HasTrait(), true); - EXPECT_EQ(op2_info.HasInterface(), true); - - // (3) Test uses for op. - std::vector op_inputs = {}; - std::vector op_output_types = {pir::Float32Type::get(ctx)}; - pir::Operation *op2 = - pir::Operation::Create(op_inputs, - CreateAttributeMap({"op2_attr1", "op2_attr2"}, - {"op2_attr1", "op2_attr2"}), - op_output_types, - op2_info); - - ReadOnlyTrait trait = op2->dyn_cast(); - EXPECT_EQ(trait.operation(), op2); - InferShapeInterface interface = op2->dyn_cast(); - interface.InferShape(); - Operation2 Op2 = op2->dyn_cast(); - EXPECT_EQ(Op2.operation(), op2); - op2->Destroy(); -} - TEST(op_test, region_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); + pir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); EXPECT_EQ(test_dialect != nullptr, true); // (2) Get registered operations. - pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(Operation1::name()); - pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(Operation2::name()); + pir::OpInfo op1_info = ctx->GetRegisteredOpInfo(test::Operation1::name()); + pir::OpInfo op2_info = ctx->GetRegisteredOpInfo(test::Operation2::name()); pir::Operation *op1 = pir::Operation::Create({}, @@ -224,16 +59,10 @@ TEST(op_test, region_test) { {"op1_attr1", "op1_attr2"}), {pir::Float32Type::get(ctx)}, op1_info); - pir::Operation *op1_2 = - pir::Operation::Create({}, - CreateAttributeMap({"op1_attr1", "op1_attr2"}, - {"op1_attr1", "op1_attr2"}), - {pir::Float32Type::get(ctx)}, - op1_info); + pir::Operation *op_2 = + pir::Operation::Create({}, {}, {pir::Float32Type::get(ctx)}, op2_info); pir::OperationArgument argument(op2_info); - argument.attributes = CreateAttributeMap({"op2_attr1", "op2_attr2"}, - {"op2_attr1", "op2_attr2"}); argument.output_types = {pir::Float32Type::get(ctx)}; argument.num_regions = 1; @@ -252,7 +81,7 @@ TEST(op_test, region_test) { region.insert(region.begin(), new pir::Block()); pir::Block *block = region.front(); block->push_front(op1); - block->insert(block->begin(), op1_2); + block->insert(block->begin(), op_2); op3->Destroy(); } @@ -279,3 +108,22 @@ TEST(op_test, module_op_death) { program.module_op()->set_attribute("program", pir::PointerAttribute::get(ctx, &program)); } + +TEST(op_test, trait_and_interface) { + pir::IrContext ctx; + ctx.GetOrRegisterDialect(); + pir::Program program(&ctx); + auto block = program.block(); + pir::Builder builder(&ctx, block); + auto op1 = builder.Build(); + auto op2 = builder.Build(); + + EXPECT_EQ(op1->HasTrait(), false); + EXPECT_EQ(op1->HasInterface(), false); + EXPECT_EQ(op2->HasTrait(), true); + EXPECT_EQ(op2->HasInterface(), true); + + pir::OperationArgument argument(&ctx, "test.region"); + argument.num_regions = 2u; + EXPECT_THROW(builder.Build(argument), pir::IrNotMetException); +} diff --git a/test/cpp/pir/tools/CMakeLists.txt b/test/cpp/pir/tools/CMakeLists.txt index 64e5b972436203..5a1f0736988333 100644 --- a/test/cpp/pir/tools/CMakeLists.txt +++ b/test/cpp/pir/tools/CMakeLists.txt @@ -1,4 +1,4 @@ cc_library( test_dialect - SRCS test_dialect.cc test_op.cc + SRCS test_dialect.cc test_op.cc test_trait.cc test_interface.cc DEPS pir) diff --git a/test/cpp/pir/tools/test_dialect.cc b/test/cpp/pir/tools/test_dialect.cc index bf94e8db3dce13..49fb4a6951dd79 100644 --- a/test/cpp/pir/tools/test_dialect.cc +++ b/test/cpp/pir/tools/test_dialect.cc @@ -12,8 +12,25 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "test/cpp/pir/tools/test_dialect.h" +#include "paddle/pir/core/ir_printer.h" #include "test/cpp/pir/tools/test_op.h" namespace test { -void TestDialect::initialize() { RegisterOps(); } + +TestDialect::TestDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} +void TestDialect::initialize() { + RegisterOps(); +} + +void TestDialect::PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const { + printer.PrintOpResult(op); + printer.os << " ="; + + printer.os << " \"" << op->name() << "\""; + printer.PrintOpOperands(op); +} } // namespace test IR_DEFINE_EXPLICIT_TYPE_ID(test::TestDialect) diff --git a/test/cpp/pir/tools/test_dialect.h b/test/cpp/pir/tools/test_dialect.h index 8b259c5563c4bb..c3594273b53558 100644 --- a/test/cpp/pir/tools/test_dialect.h +++ b/test/cpp/pir/tools/test_dialect.h @@ -19,11 +19,10 @@ namespace test { class TestDialect : public pir::Dialect { public: - explicit TestDialect(pir::IrContext *context) - : pir::Dialect(name(), context, pir::TypeId::get()) { - initialize(); - } + explicit TestDialect(pir::IrContext *context); static const char *name() { return "test"; } + void PrintOperation(pir::Operation *op, + pir::IrPrinter &printer) const override; private: void initialize(); diff --git a/test/cpp/pir/tools/test_interface.cc b/test/cpp/pir/tools/test_interface.cc new file mode 100644 index 00000000000000..b0d72b48baa20a --- /dev/null +++ b/test/cpp/pir/tools/test_interface.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "test/cpp/pir/tools/test_interface.h" +IR_DEFINE_EXPLICIT_TYPE_ID(test::InferShapeInterface) diff --git a/test/cpp/pir/tools/test_interface.h b/test/cpp/pir/tools/test_interface.h new file mode 100644 index 00000000000000..a2de7e1bb6972e --- /dev/null +++ b/test/cpp/pir/tools/test_interface.h @@ -0,0 +1,65 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include + +#include "paddle/pir/core/block.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/region.h" + +namespace test { +/// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and +/// Models need to be defined within the class. Concept defines abstract +/// interface functions, and Model is a template class that defines the specific +/// implementation of interface functions based on template parameters. +class InferShapeInterface : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(void (*infer_shape)(pir::Operation *)) + : infer_shape(infer_shape) {} + void (*infer_shape)(pir::Operation *); + }; + + template + struct Model : public Concept { + static void InferShape(pir::Operation *op) { + ConcreteOp concret_op = ConcreteOp(op); + if (concret_op == nullptr) throw("concret_op is nullptr"); + concret_op.InferShape(); + } + + Model() : Concept(InferShape) {} + }; + + InferShapeInterface(pir::Operation *op, Concept *impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + void InferShape() { impl_->infer_shape(operation()); } + + private: + Concept *impl_; +}; + +} // namespace test +IR_DECLARE_EXPLICIT_TYPE_ID(test::InferShapeInterface) diff --git a/test/cpp/pir/tools/test_op.cc b/test/cpp/pir/tools/test_op.cc index 9adce7ea402e98..99515ecf2e2e10 100644 --- a/test/cpp/pir/tools/test_op.cc +++ b/test/cpp/pir/tools/test_op.cc @@ -12,17 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "test/cpp/pir/tools/test_op.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/enforce.h" namespace test { + void RegionOp::Build(pir::Builder &builder, pir::OperationArgument &argument) { argument.num_regions = 1; } -void RegionOp::Verify() const { - auto num_regions = (*this)->num_regions(); - IR_ENFORCE(num_regions == 1u, - "The region's number in Region Op must be 1, but current is %d", - num_regions); -} void BranchOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, @@ -38,7 +35,32 @@ void BranchOp::Verify() const { IR_ENFORCE((*this)->successor(0), "successor[0] can't be nullptr"); } +const char *Operation1::attributes_name[2] = { // NOLINT + "op1_attr1", + "op1_attr2"}; + +void Operation1::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) { // NOLINT + std::unordered_map attributes{ + {"op1_attr1", builder.str_attr("op1_attr2")}, + {"op1_attr2", builder.str_attr("op1_attr2")}}; + argument.AddOutput(builder.float32_type()); + argument.AddAttributes(attributes); +} +void Operation1::Verify() const { + auto &attributes = this->attributes(); + if (attributes.count("op1_attr1") == 0 || + !attributes.at("op1_attr1").isa()) { + throw("Type of attribute: parameter_name is not right."); + } + if (attributes.count("op1_attr2") == 0 || + !attributes.at("op1_attr2").isa()) { + throw("Type of attribute: parameter_name is not right."); + } +} } // namespace test IR_DEFINE_EXPLICIT_TYPE_ID(test::RegionOp) IR_DEFINE_EXPLICIT_TYPE_ID(test::BranchOp) +IR_DEFINE_EXPLICIT_TYPE_ID(test::Operation1) +IR_DEFINE_EXPLICIT_TYPE_ID(test::Operation2) diff --git a/test/cpp/pir/tools/test_op.h b/test/cpp/pir/tools/test_op.h index 9e0f9f1e933b21..8d4ccd49a38edb 100644 --- a/test/cpp/pir/tools/test_op.h +++ b/test/cpp/pir/tools/test_op.h @@ -15,13 +15,17 @@ #pragma once #include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/op_base.h" +#include "paddle/pir/core/operation_utils.h" +#include "test/cpp/pir/tools/test_interface.h" +#include "test/cpp/pir/tools/test_trait.h" namespace test { /// /// \brief TestRegionOp /// -class RegionOp : public pir::Op { +class RegionOp : public pir::Op { public: using Op::Op; static const char *name() { return "test.region"; } @@ -29,7 +33,7 @@ class RegionOp : public pir::Op { static constexpr const char **attributes_name = nullptr; static void Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument); // NOLINT - void Verify() const; + void Verify() const {} }; /// @@ -48,7 +52,35 @@ class BranchOp : public pir::Op { void Verify() const; }; +// Define case op1. +class Operation1 : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.operation1"; } + static constexpr uint32_t attributes_num = 2; + static const char *attributes_name[attributes_num]; // NOLINT + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument); // NOLINT + void Verify() const; +}; + +// Define op2. +class Operation2 + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "test.operation2"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; // NOLINT + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument) {} // NOLINT + void Verify() const {} + static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } +}; + } // namespace test IR_DECLARE_EXPLICIT_TYPE_ID(test::RegionOp) IR_DECLARE_EXPLICIT_TYPE_ID(test::BranchOp) +IR_DECLARE_EXPLICIT_TYPE_ID(test::Operation1) +IR_DECLARE_EXPLICIT_TYPE_ID(test::Operation2) diff --git a/test/cpp/pir/tools/test_trait.cc b/test/cpp/pir/tools/test_trait.cc new file mode 100644 index 00000000000000..1fa5dd0bba9118 --- /dev/null +++ b/test/cpp/pir/tools/test_trait.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "test/cpp/pir/tools/test_trait.h" +#include "glog/logging.h" + +#include "paddle/pir/core/enforce.h" + +namespace test { +void OneRegionTrait::Verify(pir::Operation *op) { + VLOG(1) << "here"; + IR_ENFORCE(op->num_regions() == 1u, + "%s op has one region trait, but its region size is %u", + op->name(), + op->num_regions()); +} +} // namespace test + +IR_DEFINE_EXPLICIT_TYPE_ID(test::ReadOnlyTrait) +IR_DEFINE_EXPLICIT_TYPE_ID(test::OneRegionTrait) diff --git a/test/cpp/pir/tools/test_trait.h b/test/cpp/pir/tools/test_trait.h new file mode 100644 index 00000000000000..cc002081dddc2e --- /dev/null +++ b/test/cpp/pir/tools/test_trait.h @@ -0,0 +1,37 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include + +#include "paddle/pir/core/op_base.h" + +namespace test { + +class ReadOnlyTrait : public pir::OpTraitBase { + public: + explicit ReadOnlyTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} +}; + +class OneRegionTrait : public pir::OpTraitBase { + public: + explicit OneRegionTrait(pir::Operation *op) + : pir::OpTraitBase(op) {} + static void Verify(pir::Operation *op); +}; + +} // namespace test +IR_DECLARE_EXPLICIT_TYPE_ID(test::ReadOnlyTrait) +IR_DECLARE_EXPLICIT_TYPE_ID(test::OneRegionTrait)