Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] Support view mechanism in auto parallel dygraph mode. #59401

Merged
merged 13 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ limitations under the License. */
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/placement_types.h"
Expand Down Expand Up @@ -390,12 +391,22 @@ void InitDistTensorWithTensor(TensorObject* self,
if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
// auto parallel in dygraph doesn't support strided kernel.
if (!tensor->meta().is_contiguous()) {
VLOG(4) << "Same place and not contiguous, trans it to contiguous";
*tensor = paddle::experimental::Trans2Contiguous(*tensor);
}
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
// auto parallel in dygraph doesn't support strided kernel.
if (!tensor->meta().is_contiguous()) {
VLOG(4) << "Different place and not contiguous, trans it to contiguous";
*tensor = paddle::experimental::Trans2Contiguous(*tensor);
}
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/pir.h"
#include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -138,6 +139,10 @@ void ConvertToDistTensor(Tensor* x, const phi::distributed::ProcessMesh* mesh) {
phi::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = std::static_pointer_cast<phi::DenseTensor>(x->impl());
// auto parallel in dygraph doesn't support strided kernel.
if (!dense_t->meta().is_contiguous()) {
*dense_t = paddle::experimental::Trans2Contiguous(*dense_t);
}
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr));
}
Expand Down
162 changes: 142 additions & 20 deletions paddle/phi/api/yaml/generator/dist_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,13 @@
"""

# 6. PrepareData
VIEW_OUTPUT_SHARE_MEM_WITH_INPUT_TEMPLATE = """
// {dense_out} is view output, it shares memory with input.
// If input is resharded, {dense_out} may hold
// different memory with origin input.
{dense_out}->ShareBufferWith({dense_input});
{dense_out}->ShareInplaceVersionCounterWith({dense_input});
"""
SINGLE_PREPARE_DATA_TEMPLATE = """
dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel);
auto input_{name} = &dist_input_{name}->value();
Expand Down Expand Up @@ -416,21 +423,21 @@
"""

SET_SINGLE_OR_VECTOR_INPLACE_OUT_TEMPLATE = """
// Set correct dist_attr for nplace output:
// Set correct dist_attr for inplace output:
// If no_spmd_rules, reshard it to origin dist_attr,
// Or set correct spmd output dist_attr
SetInplaceOutputCorrectDistAttr(dev_ctx, api_output, {dist_out_attr}, {need_reshard});
"""
SET_MULTI_SINGLE_OR_VECTOR_INPLACE_OUT_TEMPLATE = """
// Set correct dist_attr for nplace output:
// Set correct dist_attr for inplace output:
// If no_spmd_rules, reshard it to origin dist_attr,
// Or set correct spmd output dist_attr
auto& output_{idx} = std::get<{idx}>(api_output);
SetInplaceOutputCorrectDistAttr(dev_ctx, output_{idx}, {dist_out_attr}, {need_reshard});
"""

SET_MULTI_SINGLE_OR_VECTOR_OPTIONAL_INPLACE_OUT_TEMPLATE = """
// Set correct dist_attr for nplace output:
// Set correct dist_attr for inplace output:
// If no_spmd_rules, reshard it to origin dist_attr,
// Or set correct spmd output dist_attr
auto& output_{idx} = std::get<{idx}>(api_output);
Expand All @@ -442,6 +449,29 @@
NONEED_TO_SET_DIST_ATTR_COMMENT_TEMPLATE = """
// API `{}` does not need to set DistAttr for output."""

# TODO(GhostScreaming): Support aliquant condition.
# Specialized Code, for example, reshape needs to calculate local_shape
RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """
std::vector<int64_t> local_shape;
for (size_t i = 0; i < shape.GetData().size(); i++) {
auto out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto&

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx~

if (out_dist_attr.dims_mapping()[i] >= 0) {
int64_t mesh_dim = out_dist_attr.process_mesh().shape()[i];
// TODO: Support aliquant condition.
PADDLE_ENFORCE_EQ(shape.GetData()[i] % mesh_dim,
0,
phi::errors::InvalidArgument(
"Reshape only support local shape dim is divisible"
"by the mesh dim, however local_shape[%d] is %d",
"and shard mesh dims is %d",
i, shape.GetData()[i], mesh_dim));
local_shape.push_back(shape.GetData()[i] / mesh_dim);
} else {
local_shape.push_back(shape.GetData()[i]);
}
}
"""

# BaseAPI members:
# inputs:
# names : [], list of input names
Expand Down Expand Up @@ -517,6 +547,17 @@ def need_to_generate_code_for_view_impl(self, i):
and self.outputs['names'][i] in self.view_map
)

def need_to_generate_code_for_inplace_or_view_impl(self, i):
return self.need_to_generate_code_for_inplace_impl(
i
) or self.need_to_generate_code_for_view_impl(i)

def is_reshape_kernel(self):
return (
"reshape" in self.kernel['func'][0]
and 'grad' not in self.kernel['func'][0]
)

def is_inplace_output(self, i):
return self.outputs['names'][i] in self.inplace_map

Expand Down Expand Up @@ -760,6 +801,11 @@ def generate_specialized_infer_spmd_code(self) -> str:
if kernel_params is None:
kernel_params = input_names + attr_names

# TODO(GhostScreaming): specialized case for reshape_grad
# xshape is not kernel params, but inferspmd needs it.
if "reshape_grad" in self.kernel['func'][0]:
kernel_params = ["xshape"] + kernel_params

input_decl_code = ""
input_args_code = ""
for param in kernel_params:
Expand Down Expand Up @@ -799,7 +845,10 @@ def generate_specialized_infer_spmd_code(self) -> str:
f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
input_args_code = input_args_code + param + ", "
if self.attrs['attr_info'][param][0] == "const IntArray&":
input_args_code = input_args_code + param + ".GetData(), "
else:
input_args_code = input_args_code + param + ", "
elif isinstance(param, str):
input_args_code = input_args_code + "\"" + param + "\", "
elif isinstance(param, bool):
Expand Down Expand Up @@ -890,9 +939,6 @@ def generate_output_creation_code(self) -> str:
return_type = self.get_return_type_with_intermediate(self.inplace_flag)
output_creation_code = ""
output_creation_code += "\n phi::DeviceContext* dev_ctx = nullptr;"
has_spmd_rules = (
self.generate_infer_spmd or self.generate_general_infer_spmd
)
if output_num == 1:
# api output generate
if self.need_to_generate_code_for_inplace_impl(0):
Expand All @@ -914,7 +960,7 @@ def generate_output_creation_code(self) -> str:
or self.outputs['types'][0] == 'const paddle::optional<Tensor>'
):
if (
self.need_to_generate_code_for_inplace_impl(0)
self.need_to_generate_code_for_inplace_or_view_impl(0)
and self.generate_general_infer_spmd
):
output_creation_code += SINGLE_INPLACE_OUT_DIST_ATTR
Expand All @@ -925,7 +971,7 @@ def generate_output_creation_code(self) -> str:
elif self.outputs['types'][0] == 'std::vector<Tensor>':
# SetKernelDistOutput arg
if (
self.need_to_generate_code_for_inplace_impl(0)
self.need_to_generate_code_for_inplace_or_view_impl(0)
and self.generate_general_infer_spmd
):
output_creation_code += VECTOR_INPLACE_OUT_DIST_ATTR
Expand All @@ -942,7 +988,7 @@ def generate_output_creation_code(self) -> str:
if self.inplace_flag:
inplace_assign_code = ""
for i, out_name in enumerate(self.outputs['names']):
if self.need_to_generate_code_for_inplace_impl(i):
if self.need_to_generate_code_for_inplace_or_view_impl(i):
inplace_assign_code += self.inplace_map[out_name] + ', '
else:
inplace_assign_code += 'Tensor(), '
Expand All @@ -968,7 +1014,9 @@ def generate_output_creation_code(self) -> str:
)
else:
if (
self.need_to_generate_code_for_inplace_impl(i)
self.need_to_generate_code_for_inplace_or_view_impl(
i
)
and self.generate_general_infer_spmd
):
output_creation_code += (
Expand Down Expand Up @@ -1002,7 +1050,9 @@ def generate_output_creation_code(self) -> str:
)
else:
if (
self.need_to_generate_code_for_inplace_impl(i)
self.need_to_generate_code_for_inplace_or_view_impl(
i
)
and self.generate_general_infer_spmd
):
output_creation_code += (
Expand Down Expand Up @@ -1306,6 +1356,30 @@ def generate_prepare_data_code(self) -> str:
)
)

for i, name in enumerate(self.outputs['names']):
if self.need_to_generate_code_for_view_impl(i):
dense_out = (
'dense_out'
if len(self.outputs['names']) == 1
else f'dense_out_{i}'
)
input_name = self.view_map[self.outputs['names'][i]]

kernel_params = self.kernel['param']
if kernel_params is None:
kernel_params = self.inputs['names'] + self.attrs['names']

if input_name in kernel_params:
dense_input = f"*input_{input_name}"
else:
dense_input = f"std::static_pointer_cast<phi::distributed::DistTensor>({input_name}.impl())->value()"
input_tensor_code += (
VIEW_OUTPUT_SHARE_MEM_WITH_INPUT_TEMPLATE.format(
dense_out=dense_out,
dense_input=dense_input,
)
)

return input_tensor_code, input_name_tensor_map

def generate_infer_meta_code(self) -> str:
Expand Down Expand Up @@ -1351,7 +1425,11 @@ def generate_infer_meta_code(self) -> str:
f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported."
)
elif param in attr_names:
input_args_code = input_args_code + param + ", "
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel() and param == "shape":
input_args_code = input_args_code + "local_shape" + ", "
else:
input_args_code = input_args_code + param + ", "
elif isinstance(param, str):
input_args_code = input_args_code + "\"" + param + "\", "
elif isinstance(param, bool):
Expand Down Expand Up @@ -1380,10 +1458,16 @@ def generate_infer_meta_code(self) -> str:
)
output_args_code = output_args_code[:-2]

return output_decl_code + INFER_META_TEMPLATE.format(
infer_meta_code = ""
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel():
infer_meta_code = RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE
infer_meta_code = infer_meta_code + INFER_META_TEMPLATE.format(
infer_meta_func_code, input_args_code, output_args_code
)

return output_decl_code + infer_meta_code

def generate_kernel_call_code(self) -> str:
dense_input_trans_map = {
'const Tensor&': 'const phi::DenseTensor&',
Expand Down Expand Up @@ -1430,7 +1514,11 @@ def generate_kernel_call_code(self) -> str:
elif arg in attr_names:
if 'IntArray' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append('const phi::IntArray&')
arg = 'phi::IntArray(' + arg + ')'
# TODO(GhostScreaming): reshape kernel need specialized process
if self.is_reshape_kernel() and arg == "shape":
arg = 'phi::IntArray(local_shape)'
else:
arg = 'phi::IntArray(' + arg + ')'
elif 'vector<phi::Scalar>' in self.attrs['attr_info'][arg][0]:
kernel_args_type_list.append(
'const std::vector<phi::Scalar>&'
Expand Down Expand Up @@ -1474,6 +1562,36 @@ def generate_kernel_call_code(self) -> str:
result += MULTI_SINGLE_SET_DIST_OUT_DIMS.format(i, i)
return result

def dist_branch_reset_view_after_fallback(
self, out_dtype_list, inplace_flag=False
):
remap_code = ''

if len(out_dtype_list) == 1:
if (
not inplace_flag
and self.view_map is not None
and self.outputs['names'][0] in self.view_map
):
remap_code += f"""
phi::DenseTensor* {self.view_map[self.outputs['names'][0]]}_remap = static_cast<phi::distributed::DistTensor*>({self.view_map[self.outputs['names'][0]]}.impl().get())->unsafe_mutable_value();
{self.view_map[self.outputs['names'][0]]}_remap->ShareBufferWith(dist_out->value());
dist_out->unsafe_mutable_value()->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][0]]}_remap);
"""
elif len(out_dtype_list) > 1:
for i in range(len(out_dtype_list)):
if (
not inplace_flag
and self.view_map is not None
and self.outputs['names'][i] in self.view_map
):
remap_code += f"""
phi::DenseTensor* {self.view_map[self.outputs['names'][i]]}_remap = static_cast<phi::distributed::DistTensor*>({self.view_map[self.outputs['names'][i]]}.impl().get())->unsafe_mutable_value();
{self.view_map[self.outputs['names'][i]]}_remap->ShareBufferWith(dist_out_{i}->value());
dist_out_{i}->unsafe_mutable_value()->ShareInplaceVersionCounterWith(*{self.view_map[self.outputs['names'][i]]}_remap);
"""
return remap_code

def generate_fallback_code(self) -> str:
fallback_code = ""
fallback_code += """
Expand All @@ -1486,8 +1604,8 @@ def generate_fallback_code(self) -> str:
if len(self.inplace_map) > 0:
inplace_flag = True

fallback_code += self.reset_view_after_fallback(
self.outputs['types'], ' ', inplace_flag
fallback_code += self.dist_branch_reset_view_after_fallback(
self.outputs['types'], inplace_flag
)

fallback_code += """
Expand Down Expand Up @@ -1518,7 +1636,13 @@ def generate_output_dist_attr_setting(self) -> str:
# Inplace output should reshard to origin state.
if self.generate_infer_spmd:
for i, out_name in enumerate(self.dist_output_args):
if self.need_to_generate_code_for_inplace_impl(i):
# TODO(GhostScreaming): for inplace view operators like reshape,
# input and output may have different shape. If they have no specified
# InferSPMD rules, just set replicated dist_attr for them.
if (
self.need_to_generate_code_for_inplace_impl(i)
and self.outputs['names'][i] not in self.view_map
):
need_reshard = (
"true" if self.generate_general_infer_spmd else "false"
)
Expand Down Expand Up @@ -1639,7 +1763,6 @@ def gene_base_api_code(self, inplace_flag=False):
'sparse' not in kernel_name
and '_sr' not in kernel_name
and len(self.inputs['names']) > 0
and len(self.view_map) == 0
and self.check_argument_whether_support_auto_parallel()
and not self.api.endswith("_double_grad")
and not self.api.endswith("_triple_grad")
Expand All @@ -1662,7 +1785,6 @@ def gene_base_api_code(self, inplace_flag=False):
dist_branch_code = ""
if (
len(self.inputs['names']) > 0
and len(self.view_map) == 0
and self.check_argument_whether_support_auto_parallel()
and not self.api.endswith("_double_grad")
and not self.api.endswith("_triple_grad")
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape, out_grad]
spmd_rule: ReshapeGradInferSpmd
kernel :
func : reshape_grad
param : [out_grad]
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,7 @@
output : Tensor(out), Tensor(xshape)
infer_meta :
func : ReshapeWithXShapeInferMeta
spmd_rule : ReshapeInferSpmdDynamic
kernel :
func : reshape
inplace : (x -> out)
Expand Down
Loading