Skip to content

Commit

Permalink
split first (PaddlePaddle#57281)
Browse files Browse the repository at this point in the history
* split first

* modify test_ir_backward
  • Loading branch information
xiaoguoguo626807 authored Sep 14, 2023
1 parent 8d766d2 commit 4698b3d
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 70 deletions.
106 changes: 88 additions & 18 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,29 @@ def _gen_api_inputs(self, op_info):
ret.append(f'{self._type_map[type]} {name}')
return ', '.join(ret)

def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
def _gen_api_attrs(
self, op_info, with_default, is_mutable_attr, is_vector_mutable_sttr
):
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
default_value_list = op_info.attribute_default_value_list
mutable_name_list = op_info.mutable_attribute_name_list
mutable_type_list = op_info.mutable_attribute_type_list
assert len(name_list) == len(type_list) == len(default_value_list)
no_mutable_attr = []
mutable_attr = []
for name, type, default_value in zip(
name_list, type_list, default_value_list
):
if is_mutable_attr and name in mutable_name_list:
mutable_attr.append(f'{OP_RESULT} {name}')
if (
mutable_type_list[mutable_name_list.index(name)][0]
== "paddle::dialect::IntArrayAttribute"
and is_vector_mutable_sttr
):
mutable_attr.append(f'std::vector<{OP_RESULT}> {name}')
else:
mutable_attr.append(f'{OP_RESULT} {name}')
continue
if with_default and default_value is not None:
if type in ['float', 'double']:
Expand All @@ -158,9 +168,17 @@ def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
no_mutable_attr.append(f'{type} {name}')
return ', '.join(mutable_attr + no_mutable_attr)

def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):
def _gen_api_args(
self,
op_info,
with_default_attr,
is_mutable_attr,
is_vector_mutable_attr,
):
inputs = self._gen_api_inputs(op_info)
attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr)
attrs = self._gen_api_attrs(
op_info, with_default_attr, is_mutable_attr, is_vector_mutable_attr
)
return (inputs + ', ' + attrs).strip(', ')

def _gen_ret_type(self, op_info):
Expand All @@ -187,11 +205,15 @@ def _gen_ret_type(self, op_info):
elif output_num == 0:
return 'void'

def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
def _gen_one_declare(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, True, is_mutable_attr),
args=self._gen_api_args(
op_info, True, is_mutable_attr, is_vector_mutable_attr
),
)

def _gen_h_file(self, op_info_items, namespaces, h_file_path):
Expand All @@ -202,10 +224,19 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
# is wrong, so temporarily skip the automatic generation of these APIs
if self._need_skip(op_info, op_name):
continue
declare_str += self._gen_one_declare(op_info, op_name, False)
declare_str += self._gen_one_declare(
op_info, op_name, False, False
)
if len(op_info.mutable_attribute_name_list) > 0:
declare_str += self._gen_one_declare(op_info, op_name, True)

declare_str += self._gen_one_declare(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
type[0] for type in op_info.mutable_attribute_type_list
}:
declare_str += self._gen_one_declare(
op_info, op_name, True, True
)
body = declare_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
Expand All @@ -215,7 +246,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
# =====================================
# Gen impl functions
# =====================================
def _gen_in_combine(self, op_info):
def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
Expand All @@ -230,6 +261,24 @@ def _gen_in_combine(self, op_info):
combine_op_list.append(op_name)
else:
combine_op_list.append(None)

if is_mutable_attr:
name_list = op_info.mutable_attribute_name_list
type_list = op_info.mutable_attribute_type_list
assert len(name_list) == len(type_list)
for name, type in zip(name_list, type_list):
if (
type[0] == "paddle::dialect::IntArrayAttribute"
and is_vector_mutable_attr
):
op_name = f'{name}_combine_op'
combine_op += COMBINE_OP_TEMPLATE.format(
op_name=op_name, in_name=name
)
combine_op_list.append(op_name)
else:
combine_op_list.append(None)

return combine_op, combine_op_list

def _gen_compute_op_args(
Expand All @@ -239,15 +288,22 @@ def _gen_compute_op_args(
all_attr_list = op_info.attribute_name_list
no_mutable_attr_list = op_info.non_mutable_attribute_name_list
mutable_attr_list = op_info.mutable_attribute_name_list
assert len(input_name_list) == len(in_combine_op_list)
assert len(input_name_list) + len(mutable_attr_list) == len(
in_combine_op_list
) or len(input_name_list) == len(in_combine_op_list)
ret = []
for input_name, combine_op in zip(input_name_list, in_combine_op_list):
if is_mutable_attr:
name_list = input_name_list + mutable_attr_list
else:
name_list = input_name_list

for input_name, combine_op in zip(name_list, in_combine_op_list):
if combine_op is None:
ret.append(input_name)
else:
ret.append(f'{combine_op}.out()')
if is_mutable_attr:
ret += list(mutable_attr_list + no_mutable_attr_list)
ret += list(no_mutable_attr_list)
else:
ret += list(all_attr_list)
return ', '.join(ret)
Expand Down Expand Up @@ -299,9 +355,13 @@ def _gen_return_result(self, ret_list):
elif len(ret_list) == 0:
return 'return;'

def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
def _gen_one_impl(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
ret_type = self._gen_ret_type(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list, is_mutable_attr
)
Expand All @@ -315,7 +375,9 @@ def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
ret = API_IMPL_TEMPLATE.format(
ret_type=ret_type,
api_name=op_name,
args=self._gen_api_args(op_info, False, is_mutable_attr),
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
in_combine=in_combine,
compute_op=compute_op,
out_split=out_split,
Expand All @@ -333,9 +395,17 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
# is wrong, so temporarily skip the automatic generation of these APIs
if self._need_skip(op_info, op_name):
continue
impl_str += self._gen_one_impl(op_info, op_name, False)
impl_str += self._gen_one_impl(op_info, op_name, False, False)
if len(op_info.mutable_attribute_name_list) > 0:
impl_str += self._gen_one_impl(op_info, op_name, True)
impl_str += self._gen_one_impl(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
type[0] for type in op_info.mutable_attribute_type_list
}:
impl_str += self._gen_one_impl(
op_info, op_name, True, True
)
body = impl_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
# limitations under the License.

# generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}
_INFERMETA_NEED_META_CONFIG = {
'SplitInferMeta',
'SumInferMeta',
'SplitWithNumInferMeta',
'ConcatInferMeta',
'ReduceIntArrayAxisInferMeta',
}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {
'SplitOp',
'SumOp',
'SplitWithNumOp',
'ConcatOp',
'MeanOp',
}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}

OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
Expand Down
Loading

0 comments on commit 4698b3d

Please sign in to comment.