Skip to content

Commit

Permalink
[PIR] Complement op defs (PaddlePaddle#60475)
Browse files Browse the repository at this point in the history
* complement translation of legacy matmul
* Complement op mappings in translation for deformable_conv_v1.
  • Loading branch information
kangguangli authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent e929eb0 commit 9318906
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
10 changes: 10 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def insert_new_mutable_attributes(
)

# special mapping list
op_name_mappings["deformable_conv_v1"] = "deformable_conv"
op_name_mappings["deformable_conv_v1_grad"] = "deformable_conv_grad"
op_arg_name_mappings["deformable_conv_v1"] = {
"x": "Input",
"offset": "Offset",
"filter": "Filter",
"mask": "Mask",
"out": "Output",
}

op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD"
op_arg_name_mappings["fetch"] = {"x": "X"}
op_arg_name_mappings["elementwise_add_grad_grad"] = {
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,9 @@ std::vector<pir::Value> OpTranscriber::GenerateOperationInput(
info.name,
op_desc.Type());
IR_ENFORCE(param_map->count(legacy_input_vars[0]),
"Input [%s] of op [%s] not found in param map",
"Input [%s: %s] of op [%s] not found in param map",
info.name,
legacy_input_vars[0],
op_desc.Type());
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);
Expand Down Expand Up @@ -2998,6 +2999,14 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber {
param_map->PushValue(output_vars[0],
VariableDefiningInfo(scale_op.out(), false, -1));
}

void HandleNonexistentAttribute(pir::IrContext* ctx,
pir::AttributeMap* attribute_map,
const OpAttributeInfo& info) override {
if (info.name == "transpose_x" || info.name == "transpose_y") {
(*attribute_map)[info.name] = pir::BoolAttribute::get(ctx, false);
}
}
};

struct CEmbeddingOpTranscriber : public OpTranscriber {
Expand Down Expand Up @@ -3051,6 +3060,7 @@ OpTranslator::OpTranslator() {
special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
special_handlers["tril_triu_grad"] = TrilAndTriuGradOpTranscriber();
special_handlers["matmul"] = LegacyMatmulOpTranscriber();
special_handlers["matrix_rank"] = MatrixRankOpTranscriber();
special_handlers["mul"] = MulOpTranscriber();
special_handlers["mul_grad"] = MulGradOpTranscriber();
Expand Down

0 comments on commit 9318906

Please sign in to comment.