Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor build attribute #54968

Merged
merged 9 commits into from
Jun 29, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -969,11 +969,16 @@ void BuildOpFuncList(

VLOG(6) << "op name" << op_func_node.phi_op_name_;
dialect::OpYamlInfoParser op_yaml_info_parser(impl->get_op_info_());
::ir::BuildInferMetaContext((*it),
value_2_name_map,
scope,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));
::ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it),
value_2_name_map,
scope,
op_yaml_info_parser,
&(op_func_node.infer_meta_context_));

auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
Expand All @@ -990,7 +995,11 @@ void BuildOpFuncList(
true,
"not found kernel for [%s]",
kernel_name);
::ir::BuildPhiKernelContext((*it),
::ir::BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
true>((*it),
value_2_name_map,
scope,
op_yaml_info_parser,
Expand Down
57 changes: 29 additions & 28 deletions paddle/fluid/ir/interface/op_yaml_info_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,21 @@ size_t OpYamlInfoParser::InputTensorNumber() const {

const std::string& OpYamlInfoParser::AttrTypeName(
const std::string& name) const {
auto it = map_attr_info_.find(name);
auto it = attr_info_.find(name);

PADDLE_ENFORCE_NE(
it,
map_attr_info_.end(),
attr_info_.end(),
phi::errors::NotFound("Not found [%s] in attribute map", name));
return it->second.type_name;
}

const std::string& OpYamlInfoParser::TensorAttrTypeName(
const std::string& name) const {
auto it = map_input_info_.find(name);
auto it = input_info_.find(name);

PADDLE_ENFORCE_NE(it,
map_input_info_.end(),
input_info_.end(),
phi::errors::NotFound("Not found [%s] in input map", name));

PADDLE_ENFORCE_EQ(
Expand All @@ -63,26 +63,29 @@ const std::string& OpYamlInfoParser::TensorAttrTypeName(
return it->second.type_name;
}

const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams()
const {
return vec_infer_meta_tensor_params_;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaAttrParams() const {
return vec_infer_meta_attr_params_;
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnTensorParams() const {
return vec_kernel_fn_tensor_params_;
const std::vector<std::string>& OpYamlInfoParser::TensorParams(
bool is_kernel) const {
if (is_kernel) {
return kernel_fn_tensor_params_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实可以用三目运算符来替换 if-else

} else {
return infer_meta_tensor_params_;
}
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const {
return vec_kernel_fn_attr_params_;
const std::vector<std::string>& OpYamlInfoParser::AttrParams(
bool is_kernel) const {
if (is_kernel) {
return kernel_fn_attr_params_;
} else {
return infer_meta_attr_params_;
}
}

const OpRunTimeInfo& OpYamlInfoParser::OpRuntimeInfo() const {
return std::get<3>(op_info_tuple_);
}

const std::map<std::string, int>& OpYamlInfoParser::Name2Id() const {
return map_name2id_;
return name2id_;
}

void OpYamlInfoParser::parse() {
Expand All @@ -91,43 +94,41 @@ void OpYamlInfoParser::parse() {
int start_index = 0;

for (size_t i = 0; i < input_info.size(); ++i) {
map_name2id_[input_info[i].name] = start_index++;
name2id_[input_info[i].name] = start_index++;

if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++;
}

map_input_info_[input_info[i].name] = input_info[i];
input_info_[input_info[i].name] = input_info[i];
}

auto attribute_info = std::get<1>(op_info_tuple_);
for (size_t i = 0; i < attribute_info.size(); ++i) {
map_attr_info_[attribute_info[i].name] = attribute_info[i];
attr_info_[attribute_info[i].name] = attribute_info[i];
}

auto output_info = std::get<2>(op_info_tuple_);

for (size_t i = 0; i < output_info.size(); ++i) {
map_output_info_[output_info[i].name] = output_info[i];
output_info_[output_info[i].name] = output_info[i];
}

auto runtime_info = std::get<3>(op_info_tuple_);

for (auto& name : runtime_info.infer_meta_param) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_infer_meta_tensor_params_.push_back(name);
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
infer_meta_tensor_params_.push_back(name);
} else {
vec_infer_meta_attr_params_.push_back(name);
infer_meta_attr_params_.push_back(name);
}
}

for (auto& name : runtime_info.kernel_param) {
if (map_name2id_.count(name) &&
!map_input_info_[name].is_mutable_attribute) {
vec_kernel_fn_tensor_params_.push_back(name);
if (name2id_.count(name) && !input_info_[name].is_mutable_attribute) {
kernel_fn_tensor_params_.push_back(name);
} else {
vec_kernel_fn_attr_params_.push_back(name);
kernel_fn_attr_params_.push_back(name);
}
}
}
Expand Down
22 changes: 10 additions & 12 deletions paddle/fluid/ir/interface/op_yaml_info_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ class OpYamlInfoParser {
const std::string& AttrTypeName(const std::string& name) const;
const std::string& TensorAttrTypeName(const std::string& name) const;

const std::vector<std::string>& InferMetaTensorParams() const;
const std::vector<std::string>& InferMetaAttrParams() const;
const std::vector<std::string>& KernelFnTensorParams() const;
const std::vector<std::string>& KernelFnAttrParams() const;
const std::vector<std::string>& TensorParams(bool is_kernel = false) const;
const std::vector<std::string>& AttrParams(bool is_kernel = false) const;
const OpRunTimeInfo& OpRuntimeInfo() const;
const std::map<std::string, int>& Name2Id() const;

Expand All @@ -46,16 +44,16 @@ class OpYamlInfoParser {

OpInfoTuple op_info_tuple_;

std::map<std::string, int> map_name2id_;
std::map<std::string, int> name2id_;

std::map<std::string, OpInputInfo> map_input_info_;
std::map<std::string, OpAttributeInfo> map_attr_info_;
std::map<std::string, OpOutputInfo> map_output_info_;
std::map<std::string, OpInputInfo> input_info_;
std::map<std::string, OpAttributeInfo> attr_info_;
std::map<std::string, OpOutputInfo> output_info_;

std::vector<std::string> vec_infer_meta_tensor_params_;
std::vector<std::string> vec_infer_meta_attr_params_;
std::vector<std::string> vec_kernel_fn_tensor_params_;
std::vector<std::string> vec_kernel_fn_attr_params_;
std::vector<std::string> infer_meta_tensor_params_;
std::vector<std::string> infer_meta_attr_params_;
std::vector<std::string> kernel_fn_tensor_params_;
std::vector<std::string> kernel_fn_attr_params_;

int input_tensor_number_{0};
};
Expand Down
14 changes: 11 additions & 3 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ class PhiKernelAdaptor {
phi::InferMetaContext ctx;

paddle::dialect::OpYamlInfoParser op_yaml_info_parser(yaml_info);
ir::BuildInferMetaContext(
(*it), name_map, scope_, op_yaml_info_parser, &ctx);
ir::BuildPhiContext<
phi::InferMetaContext,
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, op_yaml_info_parser, &ctx);

infer_meta_impl->infer_meta_(&ctx);

Expand All @@ -98,7 +102,11 @@ class PhiKernelAdaptor {

phi::KernelContext kernel_ctx(dev_ctx);

ir::BuildPhiKernelContext(
ir::BuildPhiContext<phi::KernelContext,
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
true>(
(*it), name_map, scope_, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx);

Expand Down
Loading