Skip to content

Commit

Permalink
Merge pull request #1485 from 0x3878f/pir_develop_15
Browse files Browse the repository at this point in the history
Fix the issue with determining whether operators in control flow are registered.
  • Loading branch information
risemeup1 authored Jan 23, 2025
2 parents 040c051 + 3075229 commit 15ab470
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 5 deletions.
8 changes: 3 additions & 5 deletions paddle2onnx/mapper/exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ bool ModelExporter::IsOpsRegistered(const PaddlePirParser& pir_parser,
bool enable_experimental_op) {
OnnxHelper temp_helper;
std::set<std::string> unsupported_ops;
for (auto op : pir_parser.global_blocks_ops) {
for (auto op : pir_parser.total_blocks_ops) {
if (op->name() == "pd_op.data" || op->name() == "pd_op.fetch") {
continue;
}
if (op->name() == "pd_op.if") {
continue;
}
if (op->name() == "pd_op.while") {
if (op->name() == "pd_op.if" || op->name() == "pd_op.while" ||
op->name() == "cf.yield") {
continue;
}
std::string op_name = convert_pir_op_name(op->name());
Expand Down
21 changes: 21 additions & 0 deletions paddle2onnx/parser/pir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle2onnx/proto/p2o_paddle.pb.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

phi::DataType TransToPhiDataType(pir::Type dtype) {
if (dtype.isa<pir::BFloat16Type>()) {
Expand Down Expand Up @@ -211,6 +212,25 @@ void PaddlePirParser::GetOpArgNameMappings() {
}
}

void PaddlePirParser::GetAllBlocksOpsSet(pir::Block* block) {
for(auto &op : block->ops()) {
std::string op_name = op->name();
if(op_name != "builtin.parameter") {
total_blocks_ops.insert(op);
if(op_name == "pd_op.if") {
auto if_op = op->dyn_cast<paddle::dialect::IfOp>();
pir::Block& true_block = if_op.true_block();
GetAllBlocksOpsSet(&true_block);
pir::Block& false_block = if_op.false_block();
GetAllBlocksOpsSet(&false_block);
} else if(op_name == "pd_op.while") {
auto while_op = op->dyn_cast<paddle::dialect::WhileOp>();
GetAllBlocksOpsSet(&while_op.body());
}
}
}
}

std::string PaddlePirParser::GetOpArgName(int64_t op_id,
std::string name,
bool if_in_sub_block) const {
Expand Down Expand Up @@ -436,6 +456,7 @@ bool PaddlePirParser::Init(const std::string& _model,
GetGlobalBlockInputOutputInfo();
GetAllOpOutputName();
GetOpArgNameMappings();
GetAllBlocksOpsSet(pir_program_->block());
return true;
}
int PaddlePirParser::NumOfBlocks() const {
Expand Down
3 changes: 3 additions & 0 deletions paddle2onnx/parser/pir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class PaddlePirParser {
// recoring set of operators for sub block
mutable std::vector<pir::Operation*>
sub_blocks_ops; // todo(wangmingkai02): delete sub_blocks_ops
// recoring set of operators for all blocks
std::set<pir::Operation*> total_blocks_ops;
// recording args of while op body name info
std::unordered_map<pir::detail::ValueImpl*, pir::detail::ValueImpl*>
while_op_input_value_map;
Expand Down Expand Up @@ -270,6 +272,7 @@ class PaddlePirParser {
bool LoadParams(const std::string& path);
bool GetParamValueName(std::vector<std::string>* var_names);
void GetGlobalBlocksOps();
void GetAllBlocksOpsSet(pir::Block *block);
void GetGlobalBlockInputOutputInfo();
void GetGlobalBlockInputValueName();
void GetGlobalBlockOutputValueName();
Expand Down

0 comments on commit 15ab470

Please sign in to comment.