From 56cd4776d3188da62435783d503be14a1cab9c47 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 1 Nov 2023 16:57:26 +0800 Subject: [PATCH 1/7] [AutoParallel] Support vector and optional InferSPMD input and output. --- paddle/phi/api/lib/api_gen_utils.cc | 9 ++ paddle/phi/api/lib/api_gen_utils.h | 3 + paddle/phi/api/lib/data_transform.cc | 114 +++++++++++------- paddle/phi/api/lib/data_transform.h | 41 ++++--- paddle/phi/api/yaml/generator/dist_api_gen.py | 80 +++++++++--- paddle/phi/infermeta/spmd_rules/utils.h | 7 +- 6 files changed, 174 insertions(+), 80 deletions(-) diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 3f62c52eaed1c..ada755e1c9130 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -536,6 +536,15 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( return phi::distributed::DistMetaTensor(tensor); } +std::vector MakeDistMetaTensor( + const std::vector& tensors) { + std::vector out; + for (auto t : tensors) { + out.push_back(MakeDistMetaTensor(*t.impl())); + } + return out; +} + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index a57d951ce738f..6cc6c467af27e 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -139,6 +139,9 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, phi::distributed::DistMetaTensor MakeDistMetaTensor( const phi::TensorBase& tensor); +std::vector MakeDistMetaTensor( + const std::vector& tensors); + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 1a8d92c2d9040..658a623c80fa2 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -661,6 +661,39 @@ ReshardApiInputToReplicatedKernelInput( return nullptr; } +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + PADDLE_ENFORCE_EQ(tensors.size(), + dist_attrs.size(), + phi::errors::InvalidArgument( + "Tensor's size should be equal to dist_attrs' size.")); + + std::vector> out; + for (int i = 0; i < tensors.size(); i++) { + auto tensor_in = tensors[i].impl(); + auto dist_attr = dist_attrs[i]; + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { + VLOG(6) << "Vector ApiIn to Replicated KernelIn - " + << ReshardDebugInfo(*dist_tensor, dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + out.push_back(func->Eval(dev_ctx, *dist_tensor, dist_attr)); + } + out.push_back( + std::static_pointer_cast(tensor_in)); + } else { + out.push_back(nullptr); + } + } + return out; +} + paddle::optional> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, @@ -674,6 +707,20 @@ ReshardApiInputToReplicatedKernelInput( return paddle::none; } +paddle::optional>> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional>& tensors, + const std::vector& dist_attrs) { + if (tensors) { + VLOG(6) << "Optional ApiIn to Replicated KernelIn."; + return paddle::make_optional< + std::vector>>( + ReshardApiInputToReplicatedKernelInput(dev_ctx, *tensors, dist_attrs)); + } + return paddle::none; +} + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { if (out_tensor->dist_attr().is_partial()) { @@ -753,43 +800,14 @@ std::shared_ptr PrepareDataForDistTensor( return nullptr; } -paddle::optional> +std::vector> PrepareDataForDistTensor( - const paddle::optional>& - input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - if (input) { - VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " - "tensor"; - return paddle::make_optional>( - PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel)); - } - return paddle::none; -} - -std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, + const std::vector>& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { - return PrepareDataForDistTensor( - std::static_pointer_cast(input.impl()), - target_args_def, - transform_flag, - is_stride_kernel); -} - -std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { std::vector> out; - for (auto& x : input) { - const auto& tensor_in = x.impl(); + for (auto tensor_in : input) { if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); @@ -825,26 +843,38 @@ PrepareDataForDistTensor(const std::vector& input, return out; } -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional> +PrepareDataForDistTensor( + const paddle::optional>& + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { if (input) { - return {*PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel)}; + VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " + "tensor"; + return paddle::make_optional>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { +PrepareDataForDistTensor( + const paddle::optional< + std::vector>>& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { if (input) { - return PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel); + VLOG(6) << "PrepareDataForDistTensor for optional vector return " + "transformed dist " + "tensor"; + return paddle::make_optional< + std::vector>>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 712f568479d2e..7529924b003c7 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -186,12 +186,24 @@ ReshardApiInputToReplicatedKernelInput( const Tensor& tensor, const phi::distributed::TensorDistAttr& dist_attr); +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attr); + paddle::optional> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional& tensor, const phi::distributed::TensorDistAttr& dist_attr); +paddle::optional>> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const paddle::optional>& tensors, + const std::vector& dist_attr); + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); @@ -206,37 +218,28 @@ std::shared_ptr PrepareDataForDistTensor( const TransformFlag& transform_flag, bool is_stride_kernel); -paddle::optional> +std::vector> PrepareDataForDistTensor( - const paddle::optional>& - input, + const std::vector>& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, +paddle::optional> +PrepareDataForDistTensor( + const paddle::optional>& + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional>> +PrepareDataForDistTensor( + const paddle::optional< + std::vector>>& input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index d6c90584cb540..f0f0f6a076ab5 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -81,8 +81,12 @@ # 1. InferSPMD SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" +VECTOR_DIST_META_IN_TEMPLATE = """ + auto meta_dist_input_{name} = MakeDistMetaTensor({name});""" OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" +OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE = """ + auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*{name}) : std::vector(1);""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ @@ -230,9 +234,21 @@ # 5. Reshard Input SINGLE_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" + auto dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[idx++]);""" +VECTOR_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{arg}_vec = ReshardApiInputToKernelInput(dev_ctx, {arg}, + std::vector(spmd_info.first.begin() + idx, + spmd_info.first.begin() + idx + size)); + idx += size; + VLOG(4) << "After reshard dist_input_{arg}_vec";""" SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" + auto dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[idx++]);""" +VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{arg}_vec = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, + std::vector(spmd_info.first.begin() + idx, + spmd_info.first.begin() + idx + size)); + idx += size; + VLOG(4) << "After reshard dist_input_{arg}_vec";""" UNSUPPORTED_RESHARD_INPUT_COMMENT_TEMPLATE = """ // API `{}` does not need to support ReshardInput at this time """ @@ -243,11 +259,11 @@ auto input_{arg} = &dist_input_{arg}->value(); """ SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); auto input_{arg} = &dist_input_{arg}->value(); """ VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}_vec, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; for (auto tmp : dist_input_{name}_vec) {{ dense_input_{name}_vec.emplace_back(&tmp->value()); @@ -263,11 +279,11 @@ paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}_vec, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; if ({name}) {{ for (auto tmp : *dist_input_{name}_vec) {{ @@ -747,14 +763,19 @@ def generate_general_infer_spmd_code(self) -> str: elif ( self.inputs['input_info'][param] == "const std::vector&" - or self.inputs['input_info'][param] + ): + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] == "const paddle::optional>&" ): - # TODO(chenweihang): support other input type later, - # now only support single tensor input api - input_decl_code = "" - input_args_code = "" - break + input_decl_code += ( + OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -996,6 +1017,7 @@ def generate_reshard_input_code(self) -> str: else input_names ) + input_reshard_code = "\n int idx = 0, size = 0;" for i, param in enumerate(kernel_params): if param in input_names: if ( @@ -1015,10 +1037,36 @@ def generate_reshard_input_code(self) -> str: arg=param, idx=i ) ) - else: - raise ValueError( - f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." - ) + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + or self.inputs['input_info'][param] + == "const paddle::optional>&" + ): + if ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_reshard_code += ( + f"\n size = {param}.size();" + ) + else: + input_reshard_code += f"\n size = {param} ? (*{param}).size() : 1;" + + if self.generate_general_infer_spmd is True: + input_reshard_code += ( + VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE.format( + arg=param + ) + ) + else: + input_reshard_code += ( + VECTOR_INPUT_RESHARD_TEMPLATE.format(arg=param) + ) + pass + # raise ValueError( + # f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." + # ) else: # do nothing pass diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index cd16a95bceac7..88d8e1cc2a861 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -106,9 +106,10 @@ struct VariadicSpmdRuleArgumentParser // deal with inputs void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } - void operator()(const std::vector& x) { - for (auto t : x) { - inputs.emplace_back(t); + void operator()(const std::vector& x) { + for (int i = 0; i < x.size(); i++) { + std::cout << "i: " << i << std::endl; + inputs.emplace_back(&x[i]); } } From be41f748ca32ce4f7a48a66973c732673fa26cbd Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Wed, 1 Nov 2023 17:22:37 +0800 Subject: [PATCH 2/7] Fix some problems. --- paddle/phi/api/lib/data_transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 10a4318d0f6b9..7529924b003c7 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -199,7 +199,7 @@ ReshardApiInputToReplicatedKernelInput( const phi::distributed::TensorDistAttr& dist_attr); paddle::optional>> -ReshardApiInputToReplicatedKerneg'ilInput( +ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional>& tensors, const std::vector& dist_attr); From 4f08ef04934b24974486ebf7f9636c0afabca78a Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Thu, 2 Nov 2023 17:52:17 +0800 Subject: [PATCH 3/7] Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into support_vector_inferspmd --- .../framework/ir/trt_support_nhwc_pass.cc | 27 +- .../ir_adaptor/translator/op_translator.cc | 454 +++++------------ paddle/fluid/pybind/auto_parallel_py.cc | 9 +- paddle/phi/api/lib/api_custom_impl.cc | 19 +- paddle/phi/api/lib/api_gen_utils.cc | 38 +- paddle/phi/api/lib/api_gen_utils.h | 11 +- paddle/phi/api/lib/data_transform.cc | 136 ++++- paddle/phi/api/lib/data_transform.h | 41 +- paddle/phi/api/lib/tensor_method.cc | 28 +- paddle/phi/api/yaml/generator/dist_api_gen.py | 166 +++--- paddle/phi/api/yaml/ops.yaml | 1 + .../distributed/auto_parallel/dist_attr.cc | 2 +- .../distributed/auto_parallel/dist_attr.h | 2 + .../auto_parallel/inferspmd_utils.h | 22 + paddle/phi/core/distributed/type_defs.h | 8 +- paddle/phi/infermeta/spmd_rules/concat.cc | 187 +++++++ paddle/phi/infermeta/spmd_rules/concat.h | 34 ++ .../spmd_rules/default_data_parallel.cc | 6 +- paddle/phi/infermeta/spmd_rules/layer_norm.cc | 2 +- paddle/phi/infermeta/spmd_rules/matmul.cc | 15 +- paddle/phi/infermeta/spmd_rules/replicated.cc | 51 +- paddle/phi/infermeta/spmd_rules/replicated.h | 13 + paddle/phi/infermeta/spmd_rules/rules.h | 5 + paddle/phi/infermeta/spmd_rules/split.cc | 11 +- paddle/phi/infermeta/spmd_rules/utils.cc | 93 ++++ paddle/phi/infermeta/spmd_rules/utils.h | 48 +- .../executor/executor_cache.py | 3 + .../executor/function_graph.py | 5 +- .../executor/opcode_executor.py | 1 + .../paddle/jit/sot/symbolic/statement_ir.py | 20 +- python/paddle/tensor/creation.py | 18 +- .../semi_auto_parallel_for_concat.py | 62 +++ .../semi_auto_parallel_for_matmul.py | 2 +- test/auto_parallel/semi_auto_parallel_util.py | 133 +++++ test/auto_parallel/spmd_rules/CMakeLists.txt | 1 + .../spmd_rules/test_concat_rule.py | 58 +++ .../test_semi_auto_parallel_basic.py | 10 + test/cpp/auto_parallel/spmd_rule_test.cc | 481 ++++++++++-------- .../inference/test_trt_support_nhwc_pass.py | 45 +- test/legacy_test/test_complex_op.py | 7 +- test/legacy_test/test_logspace.py | 7 +- test/legacy_test/test_meshgrid_op.py | 79 +-- test/legacy_test/test_sgd_op.py | 74 ++- test/sot/test_simulate_initialize.py | 20 + 44 files changed, 1708 insertions(+), 747 deletions(-) create mode 100644 paddle/phi/infermeta/spmd_rules/concat.cc create mode 100644 paddle/phi/infermeta/spmd_rules/concat.h create mode 100644 test/auto_parallel/semi_auto_parallel_for_concat.py create mode 100644 test/auto_parallel/semi_auto_parallel_util.py create mode 100644 test/auto_parallel/spmd_rules/test_concat_rule.py diff --git a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc index 0403330f77cd1..5a086acd7cac2 100644 --- a/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc +++ b/paddle/fluid/framework/ir/trt_support_nhwc_pass.cc @@ -126,6 +126,26 @@ bool ModelLayoutIsNHWC(const std::vector &op_nodes) { return false; } +// Do additional check if OP's weight is not persistable +typedef std::string OP_NAME; +typedef std::string WEIGHT_NAME; +typedef std::unordered_map OP_WEIGHT_NAME; +bool IsWeight(ir::Node *op_node, + ir::Node *var_node, + const OP_WEIGHT_NAME &op_weight_pair) { + if (var_node->Var()->Persistable()) return true; + auto *op_desc = op_node->Op(); + std::string op_type = op_desc->Type(); + std::string var_name = var_node->Var()->Name(); + if (op_weight_pair.count(op_type)) { + if (var_name == + op_desc->Input(op_weight_pair.find(op_type)->second).front()) { + return true; + } + } + return false; +} + } // namespace void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { @@ -155,6 +175,9 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { "bilinear_interp_v2", "nearest_interp", "nearest_interp_v2"}; + // Op's weight could be temporary variable, so we save the name of OP's weight + // input + OP_WEIGHT_NAME op_weight_pair{{"conv2d", "Filter"}}; // Ops must run under the original layout even though it has // data_format/data_layout attribute, otherwise it will be very troublesome! std::unordered_set must_original_layout_ops{ @@ -193,7 +216,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { auto op_inputs = op_node->inputs; for (auto *in_var_node : op_inputs) { CHECK_EQ(in_var_node->IsVar(), true); - if (in_var_node->Var()->Persistable()) continue; + if (IsWeight(op_node, in_var_node, op_weight_pair)) continue; auto input_shape = in_var_node->Var()->GetShape(); input_shape_4 &= (input_shape.size() == 4); @@ -326,7 +349,7 @@ void TrtSupportNHWCPass::ApplyImpl(Graph *graph) const { for (auto *in_var_node : op_inputs) { CHECK_EQ(in_var_node->IsVar(), true); - if (in_var_node->Var()->Persistable()) continue; + if (IsWeight(op_node, in_var_node, op_weight_pair)) continue; if (vars_to_nchw.count(in_var_node)) continue; DoInsertTransposeOp(graph, diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index b0af5d58b3d2e..4b0407d257b1e 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1247,10 +1247,46 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { } }; +using ValueInfo = + std::tuple, dialect::DenseTensorType, pir::OpResult>; + +ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc, + const std::vector& names, + TranslationContext* param_map, + const std::string& var_name) { + IR_ENFORCE(names.size() == 1, + "Expected op[%s]'s input %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + names.size()); + const auto& name = names[0]; + IR_ENFORCE(param_map->count(name) > 0, + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + name); + const auto& defining_info = param_map->at(name); + + pir::OpResult value = defining_info.value.dyn_cast(); + IR_ENFORCE( + value, "Expected op[%s]'s input %s is not null", op_desc.Type(), name); + const pir::Type& type = value.type(); + IR_ENFORCE(type.isa(), + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + name, + type); + dialect::DenseTensorType tensor_type = + type.dyn_cast(); + + std::vector shape = phi::vectorize(tensor_type.dims()); + + return std::make_tuple(shape, tensor_type, value); +} + struct MulOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { - const std::string& target_op_name = "pd_op.matmul"; + const std::string& target_op_name = paddle::dialect::MatmulOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW("Op %d should have corresponding OpInfo %d", @@ -1280,39 +1316,16 @@ struct MulOpTranscriber : public OpTranscriber { const std::string& normalized_op_name, const OpInputInfoList& input_infos, pir::Block* block) override { - int x_num_col_dims = paddle::get(op_desc.GetAttr("x_num_col_dims")); - int y_num_col_dims = paddle::get(op_desc.GetAttr("y_num_col_dims")); + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); + + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); + + const auto& [x_shape, x_tensor_type, x_value] = x_info; - auto x_names = op_desc.Input("X", true); - IR_ENFORCE(x_names.size() == 1, - "Expected op[%s]'s input X has only 1 variable, but got %d", - op_desc.Type(), - x_names.size()); - auto x_name = x_names[0]; - IR_ENFORCE(param_map->count(x_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - x_name); - auto x_defining_info = param_map->at(x_name); - if (x_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, block, x_defining_info, x_name); - x_defining_info = param_map->at(x_name); - } - pir::OpResult x_value = x_defining_info.value.dyn_cast(); - IR_ENFORCE(x_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - x_name); - pir::Type x_type = x_value.type(); - IR_ENFORCE(x_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - x_name, - x_type); - dialect::DenseTensorType x_tensor_type = - x_type.dyn_cast(); - std::vector x_shape = phi::vectorize(x_tensor_type.dims()); IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " "dim of input X %s, but got %d", @@ -1320,36 +1333,11 @@ struct MulOpTranscriber : public OpTranscriber { x_shape.size(), x_num_col_dims); - auto y_names = op_desc.Input("Y", true); - IR_ENFORCE(y_names.size() == 1, - "Expected op[%s]'s input Y has only 1 variable, but got %d", - op_desc.Type(), - y_names.size()); - auto y_name = y_names[0]; - IR_ENFORCE(param_map->count(y_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - y_name); - auto y_defining_info = param_map->at(y_name); - if (y_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, block, y_defining_info, y_name); - y_defining_info = param_map->at(y_name); - } - pir::OpResult y_value = y_defining_info.value.dyn_cast(); - IR_ENFORCE(y_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_name); - pir::Type y_type = y_value.type(); - IR_ENFORCE(y_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_name, - y_type); - dialect::DenseTensorType y_tensor_type = - y_type.dyn_cast(); - std::vector y_shape = phi::vectorize(y_tensor_type.dims()); + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const auto& [y_shape, y_tensor_type, y_value] = y_info; + IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " "dim of input Y %s, but got %d", @@ -1406,102 +1394,29 @@ struct MulOpTranscriber : public OpTranscriber { OpTranscriber::RecordOpResultMapping( ctx, param_map, op_desc, operation, arg_to_idx); if (op_desc.HasOutput("Out")) { + ValueInfo out_info = GetTensorInfoByVarName( + op_desc, op_desc.Output("Out"), param_map, "Out"); + + const dialect::DenseTensorType& out_tensor_type = std::get<1>(out_info); + pir::OpResult& out_value = std::get<2>(out_info); + const auto& output_vars = op_desc.Output("Out"); - IR_ENFORCE(output_vars.size() == 1, - "Expected op[%s]'s Out has only 1 var but got %s", - op_desc.Type(), - output_vars.size()); - auto output_name = output_vars[0]; - auto out_defining_info = param_map->at(output_name); - - if (out_defining_info.generated_by_vector) { - InsertSliceOperationForTarget(ctx, - param_map, - operation->GetParent(), - out_defining_info, - output_name); - out_defining_info = param_map->at(output_name); - } + const auto& output_name = output_vars[0]; - pir::OpResult out_value = - out_defining_info.value.dyn_cast(); - IR_ENFORCE(out_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - output_name); - pir::Type out_type = out_value.type(); - IR_ENFORCE(out_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - output_name, - out_type); - dialect::DenseTensorType out_tensor_type = - out_type.dyn_cast(); + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); - int x_num_col_dims = paddle::get(op_desc.GetAttr("x_num_col_dims")); - int y_num_col_dims = paddle::get(op_desc.GetAttr("y_num_col_dims")); + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); - auto x_names = op_desc.Input("X", true); - IR_ENFORCE(x_names.size() == 1, - "Expected op[%s]'s input X has only 1 variable, but got %d", - op_desc.Type(), - x_names.size()); - auto x_name = x_names[0]; - IR_ENFORCE(param_map->count(x_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - x_name); - auto x_defining_info = param_map->at(x_name); - if (x_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, operation->GetParent(), x_defining_info, x_name); - x_defining_info = param_map->at(x_name); - } - pir::OpResult x_value = x_defining_info.value.dyn_cast(); - IR_ENFORCE(x_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - x_name); - pir::Type x_type = x_value.type(); - IR_ENFORCE(x_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - x_name, - x_type); - dialect::DenseTensorType x_tensor_type = - x_type.dyn_cast(); - std::vector x_shape = phi::vectorize(x_tensor_type.dims()); - - auto y_names = op_desc.Input("Y", true); - IR_ENFORCE(y_names.size() == 1, - "Expected op[%s]'s input Y has only 1 variable, but got %d", - op_desc.Type(), - y_names.size()); - auto y_name = y_names[0]; - IR_ENFORCE(param_map->count(y_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - y_name); - auto y_defining_info = param_map->at(y_name); - if (y_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, operation->GetParent(), y_defining_info, y_name); - y_defining_info = param_map->at(y_name); - } - pir::OpResult y_value = y_defining_info.value.dyn_cast(); - IR_ENFORCE(y_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_name); - pir::Type y_type = y_value.type(); - IR_ENFORCE(y_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_name, - y_type); - dialect::DenseTensorType y_tensor_type = - y_type.dyn_cast(); - std::vector y_shape = phi::vectorize(y_tensor_type.dims()); + const std::vector& x_shape = std::get<0>(x_info); + + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const std::vector& y_shape = std::get<0>(y_info); std::vector out_new_shape(x_shape.begin(), x_shape.begin() + x_num_col_dims); @@ -1525,10 +1440,10 @@ struct MulOpTranscriber : public OpTranscriber { struct MulGradOpTranscriber : public OpTranscriber { pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { - const std::string& target_op_name = "pd_op.matmul_grad"; + const std::string& target_op_name = paddle::dialect::MatmulGradOp::name(); VLOG(6) << "[op name normalizing: " << op_desc.Type() << " to " << target_op_name; - auto op_info = ctx->GetRegisteredOpInfo(target_op_name); + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { IR_THROW("Op %d should have corresponding OpInfo %d", op_desc.Type(), @@ -1557,39 +1472,16 @@ struct MulGradOpTranscriber : public OpTranscriber { const std::string& normalized_op_name, const OpInputInfoList& input_infos, pir::Block* block) override { - int x_num_col_dims = paddle::get(op_desc.GetAttr("x_num_col_dims")); - int y_num_col_dims = paddle::get(op_desc.GetAttr("y_num_col_dims")); + const int x_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")); + const int y_num_col_dims = + PADDLE_GET_CONST(int, op_desc.GetAttr("y_num_col_dims")); + + ValueInfo x_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("X", true), param_map, "X"); + + const auto& [x_shape, x_tensor_type, x_value] = x_info; - auto x_names = op_desc.Input("X", true); - IR_ENFORCE(x_names.size() == 1, - "Expected op[%s]'s input X has only 1 variable, but got %d", - op_desc.Type(), - x_names.size()); - auto x_name = x_names[0]; - IR_ENFORCE(param_map->count(x_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - x_name); - auto x_defining_info = param_map->at(x_name); - if (x_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, block, x_defining_info, x_name); - x_defining_info = param_map->at(x_name); - } - pir::OpResult x_value = x_defining_info.value.dyn_cast(); - IR_ENFORCE(x_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - x_name); - pir::Type x_type = x_value.type(); - IR_ENFORCE(x_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - x_name, - x_type); - dialect::DenseTensorType x_tensor_type = - x_type.dyn_cast(); - std::vector x_shape = phi::vectorize(x_tensor_type.dims()); IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " "dim of input X %s, but got %d", @@ -1597,36 +1489,11 @@ struct MulGradOpTranscriber : public OpTranscriber { x_shape.size(), x_num_col_dims); - auto y_names = op_desc.Input("Y", true); - IR_ENFORCE(y_names.size() == 1, - "Expected op[%s]'s input Y has only 1 variable, but got %d", - op_desc.Type(), - y_names.size()); - auto y_name = y_names[0]; - IR_ENFORCE(param_map->count(y_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - y_name); - auto y_defining_info = param_map->at(y_name); - if (y_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, block, y_defining_info, y_name); - y_defining_info = param_map->at(y_name); - } - pir::OpResult y_value = y_defining_info.value.dyn_cast(); - IR_ENFORCE(y_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_name); - pir::Type y_type = y_value.type(); - IR_ENFORCE(y_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_name, - y_type); - dialect::DenseTensorType y_tensor_type = - y_type.dyn_cast(); - std::vector y_shape = phi::vectorize(y_tensor_type.dims()); + ValueInfo y_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Y", true), param_map, "Y"); + + const auto& [y_shape, y_tensor_type, y_value] = y_info; + IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " "dim of input Y %s, but got %d", @@ -1634,38 +1501,12 @@ struct MulGradOpTranscriber : public OpTranscriber { y_shape.size(), y_num_col_dims); - auto out_grad_names = op_desc.Input("Out@GRAD", true); - IR_ENFORCE(out_grad_names.size() == 1, - "Expected op[%s]'s input X has only 1 variable, but got %d", - op_desc.Type(), - out_grad_names.size()); - auto out_grad_name = out_grad_names[0]; - IR_ENFORCE(param_map->count(out_grad_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - out_grad_name); - auto out_grad_defining_info = param_map->at(out_grad_name); - if (out_grad_defining_info.generated_by_vector) { - InsertSliceOperationForTarget( - ctx, param_map, block, out_grad_defining_info, out_grad_name); - out_grad_defining_info = param_map->at(out_grad_name); - } - pir::OpResult out_grad_value = - out_grad_defining_info.value.dyn_cast(); - IR_ENFORCE(out_grad_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - out_grad_name); - pir::Type out_grad_type = out_grad_value.type(); - IR_ENFORCE(out_grad_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - out_grad_name, - out_grad_type); - dialect::DenseTensorType out_grad_tensor_type = - out_grad_type.dyn_cast(); - std::vector out_grad_shape = - phi::vectorize(out_grad_tensor_type.dims()); + ValueInfo out_grad_info = GetTensorInfoByVarName( + op_desc, op_desc.Input("Out@GRAD", true), param_map, "Out@GRAD"); + + const dialect::DenseTensorType& out_grad_tensor_type = + std::get<1>(out_grad_info); + pir::OpResult& out_grad_value = std::get<2>(out_grad_info); pir::Builder builder(ctx, block); @@ -1734,96 +1575,63 @@ struct MulGradOpTranscriber : public OpTranscriber { pir::Builder builder(ctx, operation->GetParent()); - if (x_grad_output.size()) { - IR_ENFORCE( - x_grad_output.size() == 1, - "Expected op[%s]'s output X@GRAD has only 1 variable, but got %d", - op_desc.Type(), - x_grad_output.size()); - const auto& x_grad_var_name = x_grad_output[0]; - - auto idx_iter_x = arg_to_idx.find(x_grad_var_name); - if (idx_iter_x == arg_to_idx.end()) { - IR_THROW("op[%s] should have got its x_grad", op_desc.Type()); + auto gradReshape = [&](const std::string& var_name) { + const auto& grad_output = op_desc.Output(var_name); + IR_ENFORCE(grad_output.size() == 1, + "Expected op[%s]'s output %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + grad_output.size()); + const auto& grad_var_name = grad_output[0]; + + auto idx_iter = arg_to_idx.find(grad_var_name); + if (idx_iter == arg_to_idx.end()) { + IR_THROW("op[%s] should have got its %s", op_desc.Type(), var_name); } - auto [idx_in_op_x, idx_in_vec_x] = idx_iter_x->second; + auto [idx_in_op, idx_in_vec] = idx_iter->second; VLOG(10) << "[output recording]" - << "[" << op_desc.Type() << "]" << x_grad_var_name << " " - << idx_in_op_x << " " << idx_in_vec_x; + << "[" << op_desc.Type() << "]" << grad_var_name << " " + << idx_in_op << " " << idx_in_vec; - VarDesc* var_desc_x = op_desc.Block()->FindVarRecursive("X"); - std::vector x_shape = var_desc_x->GetShape(); - DenseTensorTypeStorage::Dim dim_x = phi::make_ddim(x_shape); + VarDesc* var_desc = + op_desc.Block()->FindVarRecursive(var_name.substr(0, 1)); + std::vector shape = var_desc->GetShape(); + DenseTensorTypeStorage::Dim dim = phi::make_ddim(shape); - pir::OpResult x_value_res = operation->result(idx_in_op_x); - auto reshape_op_x = - builder.Build(x_value_res, x_shape); + pir::OpResult value_res = operation->result(idx_in_op); + auto reshape_op = builder.Build(value_res, shape); - IR_ENFORCE(x_value_res, + IR_ENFORCE(value_res, "Expected op[%s]'s input %s is not null", op_desc.Type(), - x_grad_var_name); - pir::Type x_grad_type = x_value_res.type(); - IR_ENFORCE(x_grad_type.isa(), + grad_var_name); + pir::Type grad_type = value_res.type(); + IR_ENFORCE(grad_type.isa(), "Expected op[%s]'s input %s is DenseTensor but got %s", op_desc.Type(), - x_grad_var_name, - x_grad_type); - dialect::DenseTensorType x_grad_tensor_type = - x_grad_type.dyn_cast(); + grad_var_name, + grad_type); + dialect::DenseTensorType grad_tensor_type = + grad_type.dyn_cast(); - VLOG(10) << "[" << op_desc.Type() << "] x_grad_shape change from " - << x_grad_tensor_type.dims() << " to " << dim_x; + VLOG(10) << "[" << op_desc.Type() << "] shape of " << var_name + << " change from " << grad_tensor_type.dims() << " to " << dim; + + param_map->PushValue(grad_var_name, + VariableDefiningInfo(reshape_op.out(), false, -1)); + }; - param_map->PushValue(x_grad_var_name, - VariableDefiningInfo(reshape_op_x.out(), false, -1)); + if (x_grad_output.size()) { + gradReshape("X@GRAD"); } if (y_grad_output.size() < 1) { return; } - IR_ENFORCE( - y_grad_output.size() == 1, - "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d", - op_desc.Type(), - y_grad_output.size()); - const auto& y_grad_var_name = y_grad_output[0]; - auto idx_iter_y = arg_to_idx.find(y_grad_var_name); - if (idx_iter_y == arg_to_idx.end()) { - IR_THROW("op[%s] should have got its y_grad", op_desc.Type()); + if (y_grad_output.size()) { + gradReshape("Y@GRAD"); } - auto [idx_in_op_y, idx_in_vec_y] = idx_iter_y->second; - VLOG(10) << "[output recording]" - << "[" << op_desc.Type() << "]" << y_grad_var_name << " " - << idx_in_op_y << " " << idx_in_vec_y; - - VarDesc* var_desc_y = op_desc.Block()->FindVarRecursive("Y"); - std::vector y_shape = var_desc_y->GetShape(); - DenseTensorTypeStorage::Dim dim_y = phi::make_ddim(y_shape); - - pir::OpResult y_value_res = operation->result(idx_in_op_y); - - auto reshape_op_y = builder.Build(y_value_res, y_shape); - - IR_ENFORCE(y_value_res, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_grad_var_name); - pir::Type y_grad_type = y_value_res.type(); - IR_ENFORCE(y_grad_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_grad_var_name, - y_grad_type); - dialect::DenseTensorType y_grad_tensor_type = - y_grad_type.dyn_cast(); - - VLOG(10) << "[" << op_desc.Type() << "] y_grad_shape change from " - << y_grad_tensor_type.dims() << " to " << dim_y; - - param_map->PushValue(y_grad_var_name, - VariableDefiningInfo(reshape_op_y.out(), false, -1)); } }; diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 785e80a3abeaa..ac02cd0fc87ac 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -74,6 +74,7 @@ using paddle::distributed::auto_parallel::SPMDRuleMap; using paddle::framework::BlockDesc; using paddle::framework::OpDesc; using paddle::framework::VarDesc; +using phi::distributed::ArgDistAttr; using phi::distributed::ProcessMesh; using phi::distributed::TensorDistAttr; using phi::distributed::auto_parallel::Device; @@ -143,9 +144,9 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) { dist_attr->clear_annotated(); } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args); -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args); void BindAutoParallel(py::module *m) { @@ -703,7 +704,7 @@ static void prepare_ctx(phi::distributed::InferSpmdContext *ctx, parse_single_pyobject(obj, ctx, i); } } -static std::pair, std::vector> +static std::pair, std::vector> infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_forward "; phi::distributed::InferSpmdContext ctx; @@ -711,7 +712,7 @@ infer_forward(const phi::distributed::SpmdRule &self, const py::args &args) { return self.InferForward(ctx); } -static std::pair, std::vector> +static std::pair, std::vector> infer_backward(const phi::distributed::SpmdRule &self, const py::args &args) { VLOG(6) << "infer_backward "; phi::distributed::InferSpmdContext ctx; diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index acb45d058038e..8c82df3e83969 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -117,9 +117,13 @@ Tensor add_n_impl(const std::vector& x) { input_x[i] = x[i].impl().get(); } - auto meta_dist_input_x = MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + // auto meta_dist_input_x = MakeDistMetaTensor(input_x); + std::vector meta_dist_input_x; + for (auto& e : input_x) { + meta_dist_input_x.push_back(MakeDistMetaTensor(*e)); + } + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic( + meta_dist_input_x); auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); @@ -165,7 +169,14 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, dense_out); } - auto current_process_mesh = spmd_info.first[0].process_mesh(); + PADDLE_ENFORCE_EQ( + paddle::holds_alternative( + spmd_info.first[0]), + true, + phi::errors::PreconditionNotMet( + "Arg must be a single TensorDistAttr")); + auto current_process_mesh = + paddle::get<0>(spmd_info.first[0]).process_mesh(); SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); return api_output; } diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index ada755e1c9130..25371c3ec4ca7 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -536,14 +536,14 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( return phi::distributed::DistMetaTensor(tensor); } -std::vector MakeDistMetaTensor( - const std::vector& tensors) { - std::vector out; - for (auto t : tensors) { - out.push_back(MakeDistMetaTensor(*t.impl())); - } - return out; -} +// std::vector MakeDistMetaTensor( +// const std::vector& tensors) { +// std::vector out; +// for (auto t : tensors) { +// out.push_back(MakeDistMetaTensor(*t.impl())); +// } +// return out; +// } phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { @@ -558,6 +558,15 @@ phi::distributed::DistTensor* SetKernelDistOutput( return nullptr; } +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return SetKernelDistOutput(out, paddle::get<0>(dist_attr)); +} + std::shared_ptr CreateKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { @@ -567,6 +576,19 @@ std::shared_ptr CreateKernelDistOutput( return nullptr; } +std::shared_ptr CreateKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr) { + if (out) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + return std::make_shared( + phi::DDim(), paddle::get<0>(dist_attr)); + } + return nullptr; +} + std::vector SetKernelDistOutput( std::vector out) { std::vector result; diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 6cc6c467af27e..378be88824067 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -139,19 +140,25 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, phi::distributed::DistMetaTensor MakeDistMetaTensor( const phi::TensorBase& tensor); -std::vector MakeDistMetaTensor( - const std::vector& tensors); +// std::vector MakeDistMetaTensor( +// const std::vector& tensors); phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +phi::distributed::DistTensor* SetKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::shared_ptr CreateKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = phi::distributed::TensorDistAttr()); +std::shared_ptr CreateKernelDistOutput( + Tensor* out, const phi::distributed::ArgDistAttr& dist_attr); + std::vector SetKernelDistOutput( std::vector out); diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 658a623c80fa2..6e1a6416c8a53 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -640,21 +640,114 @@ std::shared_ptr ReshardApiInputToKernelInput( return nullptr; } +std::shared_ptr ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const Tensor& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); + return ReshardApiInputToKernelInput(dev_ctx, tensor, tensor_dist_attr); +} + +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative>( + dist_attrs), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); + return ReshardApiInputToKernelInput(dev_ctx, tensors, tensor_dist_attrs); +} + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> output; + PADDLE_ENFORCE_EQ(tensors.size(), + dist_attrs.size(), + phi::errors::PreconditionNotMet( + "tensors size and dist_attrs size not equal: %d vs %d", + tensors.size(), + dist_attrs.size())); + for (size_t i = 0; i < dist_attrs.size(); i++) { + output.push_back( + ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); + } + return output; +} + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> output; + PADDLE_ENFORCE_EQ(tensors.size(), + dist_attrs.size(), + phi::errors::PreconditionNotMet( + "tensors size and dist_attrs size not equal: %d vs %d", + tensors.size(), + dist_attrs.size())); + for (size_t i = 0; i < dist_attrs.size(); i++) { + output.push_back( + ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); + } + return output; +} + +// std::shared_ptr +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const Tensor& tensor, +// const phi::distributed::ArgDistAttr& dist_attr) { +// auto tensor_in = tensor.impl(); +// if (tensor_in) { +// phi::distributed::DistTensor* dist_tensor = +// static_cast(tensor_in.get()); +// if (ReshardIsNeeded(dist_tensor->dist_attr(), paddle::get<0>(dist_attr))) +// { +// VLOG(6) << "ApiIn to Replicated KernelIn - " +// << ReshardDebugInfo(*dist_tensor, paddle::get<0>(dist_attr)); +// auto* func = +// phi::distributed::ChooseProperReshardFunction(*dist_tensor, +// paddle::get<0>(dist_attr)); +// return func->Eval(dev_ctx, *dist_tensor, dist_attr); +// } +// return std::static_pointer_cast(tensor_in); +// } +// return nullptr; +// } + std::shared_ptr ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); + auto tensor_in = tensor.impl(); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { + if (ReshardIsNeeded(dist_tensor->dist_attr(), tensor_dist_attr)) { VLOG(6) << "ApiIn to Replicated KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); + << ReshardDebugInfo(*dist_tensor, tensor_dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction( + *dist_tensor, tensor_dist_attr); + return func->Eval(dev_ctx, *dist_tensor, tensor_dist_attr); } return std::static_pointer_cast(tensor_in); } @@ -665,16 +758,24 @@ std::vector> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const std::vector& tensors, - const std::vector& dist_attrs) { + const phi::distributed::ArgDistAttr& dist_attrs) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative>( + dist_attrs), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); + PADDLE_ENFORCE_EQ(tensors.size(), - dist_attrs.size(), + tensor_dist_attrs.size(), phi::errors::InvalidArgument( "Tensor's size should be equal to dist_attrs' size.")); std::vector> out; for (int i = 0; i < tensors.size(); i++) { auto tensor_in = tensors[i].impl(); - auto dist_attr = dist_attrs[i]; + auto dist_attr = tensor_dist_attrs[i]; if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); @@ -694,11 +795,24 @@ ReshardApiInputToReplicatedKernelInput( return out; } +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> outputs; + for (size_t i = 0; i < tensors.size(); ++i) { + outputs.push_back(ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensors[i], dist_attrs[i])); + } + return outputs; +} + paddle::optional> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { + const phi::distributed::ArgDistAttr& dist_attr) { if (tensor) { VLOG(6) << "Optional ApiIn to Replicated KernelIn."; return paddle::make_optional>( @@ -711,7 +825,7 @@ paddle::optional>> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional>& tensors, - const std::vector& dist_attrs) { + const phi::distributed::ArgDistAttr& dist_attrs) { if (tensors) { VLOG(6) << "Optional ApiIn to Replicated KernelIn."; return paddle::make_optional< diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 7529924b003c7..d795b827417e8 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" @@ -180,29 +181,57 @@ std::shared_ptr ReshardApiInputToKernelInput( const Tensor& tensor, const phi::distributed::TensorDistAttr& dist_attr); -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( +std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); + const phi::distributed::ArgDistAttr& dist_attr); + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + +std::vector> +ReshardApiInputToKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + +std::vector> +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs); + +// std::shared_ptr +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const Tensor& tensor, +// const phi::distributed::ArgDistAttr& dist_attr); + +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensor, + const phi::distributed::ArgDistAttr& dist_attr); std::vector> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const std::vector& tensors, - const std::vector& dist_attr); + const std::vector& dist_attrs); paddle::optional> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional& tensor, - const phi::distributed::TensorDistAttr& dist_attr); + const phi::distributed::ArgDistAttr& dist_attr); paddle::optional>> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional>& tensors, - const std::vector& dist_attr); + const phi::distributed::ArgDistAttr& dist_attr); void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); diff --git a/paddle/phi/api/lib/tensor_method.cc b/paddle/phi/api/lib/tensor_method.cc index 64cc4f2ae505c..61fe33fc5ad80 100644 --- a/paddle/phi/api/lib/tensor_method.cc +++ b/paddle/phi/api/lib/tensor_method.cc @@ -147,19 +147,21 @@ void Tensor::copy_(const Tensor &src, auto meta_dist_input_x = MakeDistMetaTensor(*src.impl()); - auto this_dist_attr = - std::static_pointer_cast( - this->impl())->dist_attr(); - PADDLE_ENFORCE_EQ((meta_dist_input_x.dist_attr() == this_dist_attr - || this_dist_attr.empty()), - true, - phi::errors::PreconditionNotMet( - "DistAttr is different of dst " - "tensor and args %s, which " - "current tensor holds %s " - "Copy cannot be performed!", - meta_dist_input_x.dist_attr(), - this_dist_attr)); + if (this->initialized()) { + auto this_dist_attr = + std::static_pointer_cast( + this->impl())->dist_attr(); + PADDLE_ENFORCE_EQ((meta_dist_input_x.dist_attr() == this_dist_attr + || this_dist_attr.empty()), + true, + phi::errors::PreconditionNotMet( + "DistAttr is different of dst " + "tensor and args %s, which " + "current tensor holds %s " + "Copy cannot be performed!", + meta_dist_input_x.dist_attr(), + this_dist_attr)); + } auto dist_out = SetKernelDistOutput(this, meta_dist_input_x.dist_attr()); auto dense_out = dist_out->unsafe_mutable_value(); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index f0f0f6a076ab5..1c84bd7c409e6 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -82,7 +82,10 @@ SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" VECTOR_DIST_META_IN_TEMPLATE = """ - auto meta_dist_input_{name} = MakeDistMetaTensor({name});""" + std::vector meta_dist_input_{name}; + for(auto& e: {name}){{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }}""" OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE = """ @@ -91,7 +94,7 @@ auto spmd_info = phi::distributed::{}({}); """ GENERAL_INFER_SPMD_TEMPLATE = """ - auto spmd_info = phi::distributed::VariadicReplicatedInferSpmd({}); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic({}); """ UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE = """ // API `{}` does not support InferSpmd now @@ -234,36 +237,36 @@ # 5. Reshard Input SINGLE_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[idx++]);""" -VECTOR_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg}_vec = ReshardApiInputToKernelInput(dev_ctx, {arg}, - std::vector(spmd_info.first.begin() + idx, - spmd_info.first.begin() + idx + size)); - idx += size; - VLOG(4) << "After reshard dist_input_{arg}_vec";""" + auto new_dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" +# VECTOR_INPUT_RESHARD_TEMPLATE = """ +# auto dist_input_{arg}_vec = ReshardApiInputToKernelInput(dev_ctx, {arg}, +# std::vector(spmd_info.first.begin() + idx, +# spmd_info.first.begin() + idx + size)); +# idx += size; +# VLOG(4) << "After reshard dist_input_{arg}_vec";""" SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[idx++]);""" -VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg}_vec = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, - std::vector(spmd_info.first.begin() + idx, - spmd_info.first.begin() + idx + size)); - idx += size; - VLOG(4) << "After reshard dist_input_{arg}_vec";""" + auto new_dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" +# VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE = """ +# auto dist_input_{arg}_vec = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, +# std::vector(spmd_info.first.begin() + idx, +# spmd_info.first.begin() + idx + size)); +# idx += size; +# VLOG(4) << "After reshard dist_input_{arg}_vec";""" UNSUPPORTED_RESHARD_INPUT_COMMENT_TEMPLATE = """ // API `{}` does not need to support ReshardInput at this time """ # 6. PrepareData SINGLE_PREPARE_DATA_TEMPLATE = """ - dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + auto dist_input_{arg} = PrepareDataForDistTensor(new_dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); auto input_{arg} = &dist_input_{arg}->value(); """ SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); + auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); auto input_{arg} = &dist_input_{arg}->value(); """ VECTOR_PREPARE_DATA_TEMPLATE = """ - dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}_vec, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; for (auto tmp : dist_input_{name}_vec) {{ dense_input_{name}_vec.emplace_back(&tmp->value()); @@ -275,15 +278,15 @@ }} """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ - dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name} = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name} = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}_vec, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; if ({name}) {{ for (auto tmp : *dist_input_{name}_vec) {{ @@ -373,7 +376,8 @@ # 10. Set Output DistAttr for Default impl # Dist Branch will not generated in the API that doesn't have input tensor. CURRENT_PROCESS_MESH_TEMPLATE = """ - auto current_process_mesh = spmd_info.first[0].process_mesh();""" + auto current_process_mesh = paddle::holds_alternative(spmd_info.first[0]) ? + paddle::get<0>(spmd_info.first[0]).process_mesh() : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh();""" SET_SINGLE_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ SetReplicatedDistAttrForOutput({}, current_process_mesh);""" SET_VECTOR_OUT_REPLICATED_DIST_ATTR_TEMPLATE = """ @@ -711,6 +715,15 @@ def generate_specialized_infer_spmd_code(self) -> str: name=param ) input_args_code += "meta_dist_input_" + param + ", " + elif ( + self.inputs['input_info'][param] + == "const std::vector&" + ): + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( + name=param + ) + input_args_code += "meta_dist_input_" + param + ", " + else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -787,6 +800,9 @@ def generate_general_infer_spmd_code(self) -> str: if input_decl_code == "": return UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE.format(self.api) + print( + f"kernel_name: {self.kernel['func'][0]}, input_args_code: {input_args_code}" + ) infer_spmd_code = GENERAL_INFER_SPMD_TEMPLATE.format( input_args_code[:-2] ) @@ -1017,14 +1033,15 @@ def generate_reshard_input_code(self) -> str: else input_names ) - input_reshard_code = "\n int idx = 0, size = 0;" + input_reshard_code = "" for i, param in enumerate(kernel_params): if param in input_names: - if ( - self.inputs['input_info'][param] == "const Tensor&" - or self.inputs['input_info'][param] - == "const paddle::optional&" - ): + if self.inputs['input_info'][param] in [ + "const Tensor&", + "const std::vector&", + "const paddle::optional&", + "const paddle::optional>&", + ]: if self.generate_general_infer_spmd is True: input_reshard_code += ( SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( @@ -1037,36 +1054,65 @@ def generate_reshard_input_code(self) -> str: arg=param, idx=i ) ) - elif ( - self.inputs['input_info'][param] - == "const std::vector&" - or self.inputs['input_info'][param] - == "const paddle::optional>&" - ): - if ( - self.inputs['input_info'][param] - == "const std::vector&" - ): - input_reshard_code += ( - f"\n size = {param}.size();" - ) - else: - input_reshard_code += f"\n size = {param} ? (*{param}).size() : 1;" - - if self.generate_general_infer_spmd is True: - input_reshard_code += ( - VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE.format( - arg=param - ) - ) - else: - input_reshard_code += ( - VECTOR_INPUT_RESHARD_TEMPLATE.format(arg=param) - ) - pass - # raise ValueError( - # f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." - # ) + # if ( + # self.inputs['input_info'][param] == "const Tensor&" + # or self.inputs['input_info'][param] + # == "const paddle::optional&" + # ): + # if self.generate_general_infer_spmd is True: + # input_reshard_code += ( + # SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( + # arg=param, idx=i + # ) + # ) + # else: + # input_reshard_code += ( + # SINGLE_INPUT_RESHARD_TEMPLATE.format( + # arg=param, idx=i + # ) + # ) + # elif ( + # self.inputs['input_info'][param] + # == "const std::vector&" + # or self.inputs['input_info'][param] + # == "const paddle::optional>&" + # ): + # if self.generate_general_infer_spmd is True: + # input_reshard_code += ( + # SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( + # arg=param, idx=i + # ) + # ) + # else: + # input_reshard_code += ( + # SINGLE_INPUT_RESHARD_TEMPLATE.format( + # arg=param, idx=i + # ) + # ) + # if ( + # self.inputs['input_info'][param] + # == "const std::vector&" + # ): + # input_reshard_code += ( + # f"\n size = {param}.size();" + # ) + # else: + # input_reshard_code += f"\n size = {param} ? (*{param}).size() : 1;" + + # if self.generate_general_infer_spmd is True: + # input_reshard_code += ( + # VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE.format( + # arg=param + # ) + # ) + # else: + # input_reshard_code += ( + # VECTOR_INPUT_RESHARD_TEMPLATE.format(arg=param) + # ) + else: + raise ValueError( + f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." + ) else: # do nothing pass @@ -1115,7 +1161,6 @@ def generate_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, index=kernel_param.index(input_name), @@ -1164,7 +1209,6 @@ def generate_optional_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, index=kernel_param.index(input_name), diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index e6b11884f74eb..5a0c6abc7688b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -505,6 +505,7 @@ infer_meta : func : ConcatInferMeta param : [x, axis] + spmd_rule : ConcatInferSpmdDynamic kernel : func : concat data_type : x diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc index 3c95f2c3ff66f..052a6d457ca8b 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.cc @@ -399,7 +399,7 @@ bool TensorDistAttr::is_replicated(int64_t mesh_axis) const { bool TensorDistAttr::is_shard(int64_t mesh_axis, int64_t tensor_axis) const { auto placement = to_placement(); if (mesh_axis == -1) { - return std::all_of(placement.begin(), + return std::any_of(placement.begin(), placement.end(), [tensor_axis](std::shared_ptr status) { return status->is_shard(tensor_axis); diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index f051592b7bf7e..6689750d24ad9 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -32,6 +32,8 @@ limitations under the License. */ namespace phi { namespace distributed { +constexpr int kReplicateDim = -1; + class PlacementStatus { public: virtual ~PlacementStatus() = default; diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 4781b5d872001..2d444decf640a 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -125,6 +125,28 @@ struct InferSpmdFnImpl { } }; + // direct vector + template + struct InferSpmdFnCallHelper&, Tail...> { + template + static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) { + static_assert(attr_idx == 0, + "InferSpmd's Input should appear before Attributes."); + // TODO(liuzhenhai): parse input list as vector directly + const std::pair range = ctx.InputRangeAt(in_idx); + std::vector tmp_arg = + ctx.InputsBetween(range.first, range.second); + std::vector arg; + std::transform(tmp_arg.begin(), + tmp_arg.end(), + std::back_inserter(arg), + [](const DistMetaTensor* arg_ptr) { return *arg_ptr; }); + return InferSpmdFnCallHelper::template Call( + ctx, pargs..., arg); + } + }; + #define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct InferSpmdFnCallHelper { \ diff --git a/paddle/phi/core/distributed/type_defs.h b/paddle/phi/core/distributed/type_defs.h index cd201ac5c5aaf..1b7035c1a4528 100644 --- a/paddle/phi/core/distributed/type_defs.h +++ b/paddle/phi/core/distributed/type_defs.h @@ -18,12 +18,16 @@ #include #include +#include "paddle/utils/variant.h" + namespace phi { namespace distributed { class TensorDistAttr; -using SpmdInfo = - std::pair, std::vector>; +using ArgDistAttr = + paddle::variant>; + +using SpmdInfo = std::pair, std::vector>; } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/concat.cc b/paddle/phi/infermeta/spmd_rules/concat.cc new file mode 100644 index 0000000000000..fd036cfad603a --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/concat.h" + +#include +#include + +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +static bool IsEmpty(const std::vector& shape) { + return shape.empty() || shape.at(0) == 0; +} + +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis) { + /* +# paddle.concat requires all tensors must either have the same shape (except +# in the concatenating dimension) or be "empty". "Empty" here strictly means +# tensor.shape is torch.Size([0]). When tensor.ndim > 1, it will be treated +# as a non-empty tensor and the shape must match on non-cat dimensions. + */ + + // 1、check tensors shapes + std::vector> tensor_shapes; + std::transform(x.begin(), + x.end(), + std::back_inserter(tensor_shapes), + [](const DistMetaTensor& meta) { + return phi::vectorize(meta.dims()); + }); + bool all_empty = + std::all_of(tensor_shapes.begin(), tensor_shapes.end(), IsEmpty); + if (all_empty) { + return SpmdInfo(); + } + + auto non_empty_iter = + std::find_if(tensor_shapes.begin(), tensor_shapes.end(), [](auto& shape) { + return !IsEmpty(shape); + }); + auto non_empty_index = non_empty_iter - tensor_shapes.begin(); + int64_t ndim = static_cast(tensor_shapes[non_empty_index].size()); + // normlize dim + int64_t dim = axis; + dim = dim < 0 ? dim + ndim : dim; + + std::vector input_attrs; + // 2、make sure all tensors replicated on concat dim + auto n_inputs = x.size(); + for (size_t i = 0; i < n_inputs; ++i) { + const auto& dist_attr = x[i].dist_attr(); + if ((!IsEmpty(tensor_shapes[i])) && IsDimSharded(dist_attr, dim)) { + auto sharded_dist_attr = ReplicateTensorDim(dist_attr, dim); + input_attrs.emplace_back(sharded_dist_attr); + } else { + input_attrs.emplace_back(dist_attr); + } + } + // 3、align non-concat dimensions according to cost + std::vector>> inputs_placements; + std::transform( + input_attrs.begin(), + input_attrs.end(), + std::back_inserter(inputs_placements), + [](const TensorDistAttr& attr) { return attr.to_placement(); }); + const auto& process_mess = input_attrs[non_empty_index].process_mesh(); + auto has_mismatch = [&](int32_t mesh_dim) { + bool mismatch = false; + for (size_t i = 0; i < n_inputs; i++) { + if ((!IsEmpty(tensor_shapes[i])) && + !PlacementEqual(inputs_placements[non_empty_index][mesh_dim], + inputs_placements[i][mesh_dim])) { + mismatch = true; + break; + } + } + return mismatch; + }; + bool need_reshard = false; + int32_t n_mesh_dim = process_mess.ndim(); + std::vector> best_placements( + n_mesh_dim, std::make_shared()); + // a dim can not be sharded twice along diffrent mesh_dim + std::set sharded_dims = {dim}; + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + // use the old placement + auto& best = inputs_placements[non_empty_index][mesh_dim]; + if (best->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(best); + sharded_dims.insert(shard_placement->get_axis()); + } + best_placements[mesh_dim] = best; + } + } + + for (int32_t mesh_dim = 0; mesh_dim < process_mess.ndim(); ++mesh_dim) { + if (!has_mismatch(mesh_dim)) { + continue; + } + need_reshard = true; + std::vector costs; + for (int32_t shard_dim = 0; shard_dim < ndim; shard_dim++) { + double cost = std::numeric_limits::infinity(); + if (!sharded_dims.count(shard_dim)) { + cost = 0.0; + for (size_t i = 0; i < n_inputs; i++) { + auto& tensor_shape = tensor_shapes[i]; + auto& tensor_dist_attr = input_attrs[i]; + if (IsEmpty(tensor_shape)) { + continue; + } + + if (tensor_shape[shard_dim] < process_mess.dim_size(mesh_dim)) { + // should not be selected + cost += std::numeric_limits::infinity(); + continue; + } + if (IsDimSharded(tensor_dist_attr, shard_dim)) { + continue; + } + int64_t num = std::accumulate(tensor_shape.begin(), + tensor_shape.end(), + 1, + std::multiplies()); + if (num == static_cast(0)) { + continue; + } + std::vector local_shape = + GetLocalShape(tensor_shape, process_mess, inputs_placements[i]); + cost += std::accumulate(local_shape.begin(), + local_shape.end(), + 1, + std::multiplies()) * + process_mess.dim_size(mesh_dim); + } + } + costs.push_back(cost); + } + auto min_itr = std::min_element(costs.begin(), costs.end()); + auto min_dim = min_itr - costs.begin(); + if (!sharded_dims.count(min_dim)) { + best_placements[mesh_dim] = std::make_shared(min_dim); + sharded_dims.insert(min_dim); + } + } + // set placement to the best placements + if (need_reshard) { + std::vector new_input_attrs; + for (auto& e : input_attrs) { + new_input_attrs.emplace_back(FromPlacements(e, best_placements)); + } + std::swap(input_attrs, new_input_attrs); + } + return {{input_attrs}, {input_attrs[non_empty_index]}}; +} + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis) { + // TODO(liuzhenhai): add latter + return SpmdInfo(); +} +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis) { + return ConcatInferSpmd(x, axis.to()); +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/concat.h b/paddle/phi/infermeta/spmd_rules/concat.h new file mode 100644 index 0000000000000..0f7435bec0b23 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/concat.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo ConcatInferSpmd(const std::vector& x, int axis); + +SpmdInfo ConcatInferSpmdReverse(const std::vector& x, + const DistMetaTensor& output, + int axis); + +SpmdInfo ConcatInferSpmdDynamic(const std::vector& x, + const Scalar& axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc index eb469200a7ec8..7a3639147f1ee 100644 --- a/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc +++ b/paddle/phi/infermeta/spmd_rules/default_data_parallel.cc @@ -95,7 +95,8 @@ SpmdInfo DefaultDataParallelInferSpmd( << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo DefaultDataParallelInferSpmdReverse( const std::vector& ins, @@ -157,7 +158,8 @@ SpmdInfo DefaultDataParallelInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 6befef19cfef1..1dfe8bf19c296 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -275,7 +275,7 @@ SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, } VLOG(4) << std::endl; - return {input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(input_dist_attrs), ToArgDistAttr(output_dist_attrs)}; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/matmul.cc b/paddle/phi/infermeta/spmd_rules/matmul.cc index 4893c7071f19e..60c7acacf0478 100644 --- a/paddle/phi/infermeta/spmd_rules/matmul.cc +++ b/paddle/phi/infermeta/spmd_rules/matmul.cc @@ -291,17 +291,22 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out_grad, bool trans_x, bool trans_y) { - auto confirm_dist_attr_same_fn = [&](const TensorDistAttr& x_dist_attr, + auto get_attr = [](const ArgDistAttr& attr) -> const TensorDistAttr& { + return paddle::get(attr); + }; + + auto confirm_dist_attr_same_fn = [&](const ArgDistAttr& x_dist_attr, const DistMetaTensor& y, const char* debug_msg) { + const auto& x_single_dist_attr = get_attr(x_dist_attr); PADDLE_ENFORCE_EQ( - DistAttrsAreBasicallyEqual(x_dist_attr, y.dist_attr()), + DistAttrsAreBasicallyEqual(x_single_dist_attr, y.dist_attr()), true, phi::errors::Unavailable("The matmul grad infer spmd `%s` verify " "error: left dist attr is %s, " "right dist attr is %s.", debug_msg, - x_dist_attr, + x_single_dist_attr, y.dist_attr())); }; @@ -313,8 +318,8 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x, // so it cannot be handled correctly in the backward for the time being // For this case, we uniformly transition the input to the Replicated state. auto fwd_spmd_info = MatmulInferSpmd(x, y, trans_x, trans_y); - if (x.dist_attr() != fwd_spmd_info.first[0] || - y.dist_attr() != fwd_spmd_info.first[1]) { + if (x.dist_attr() != get_attr(fwd_spmd_info.first[0]) || + y.dist_attr() != get_attr(fwd_spmd_info.first[1])) { auto x_r_dist_attr = GetReplicatedDistAttr(x.dist_attr()); auto y_r_dist_attr = GetReplicatedDistAttr(y.dist_attr()); return {{x_r_dist_attr, diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc index b2b3b019be039..d0c90f7b2d2a9 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.cc +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -86,7 +86,8 @@ SpmdInfo ReplicatedInferSpmd(const std::vector& ins, << str_join(output_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; } SpmdInfo ReplicatedInferSpmdReverse( @@ -135,7 +136,53 @@ SpmdInfo ReplicatedInferSpmdReverse( << str_join(dst_input_dist_attrs[i].dims_mapping()) << "]"; } - return {dst_input_dist_attrs, output_dist_attrs}; + return {ToArgDistAttr(dst_input_dist_attrs), + ToArgDistAttr(output_dist_attrs)}; +} + +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs) { + std::vector nonnull_inputs; + int64_t ninputs = inputs.size(); + SpmdInfo spmd_info; + + auto build_tensor_dist_attr = + [&nonnull_inputs](const DistMetaTensor& dist_meta_tensor) { + int ndim = dist_meta_tensor.dims().size(); + TensorDistAttr dist_attr_dst = + CopyTensorDistAttrForOutput(dist_meta_tensor.dist_attr()); + // `ndim == -1` means input is nullptr + if (ndim >= 0) { + std::vector dst_dims_maping = GetReplicatedDimsmapping(ndim); + dist_attr_dst.set_dims_mapping(dst_dims_maping); + nonnull_inputs.push_back(&dist_meta_tensor); + } + return dist_attr_dst; + }; + + for (int64_t i = 0; i < ninputs; i++) { + if (paddle::holds_alternative(inputs[i])) { + auto dist_meta_tensor_ptr = paddle::get<0>(inputs[i]); + auto& dist_meta_tensor = *dist_meta_tensor_ptr; + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i << ": dist attr: " << dist_attr_dst.to_string(); + spmd_info.first.emplace_back(dist_attr_dst); + } else { + std::vector list_dist_attr; + auto dist_meta_tensors_ptr = paddle::get<1>(inputs[i]); + auto& dist_meta_tensors = *dist_meta_tensors_ptr; + for (const auto& dist_meta_tensor : dist_meta_tensors) { + auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor); + VLOG(4) << "input " << i + << ": dist attr: " << dist_attr_dst.to_string(); + list_dist_attr.emplace_back(std::move(dist_attr_dst)); + } + spmd_info.first.emplace_back(std::move(list_dist_attr)); + } + } + return spmd_info; } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/replicated.h b/paddle/phi/infermeta/spmd_rules/replicated.h index a8d6c0719f2ec..1f3a26cb426d4 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.h +++ b/paddle/phi/infermeta/spmd_rules/replicated.h @@ -41,6 +41,19 @@ SpmdInfo ReplicatedInferSpmdReverse( const std::vector& ins, const std::vector& outs); +SpmdInfo ReplicatedInferDynamic( + const std::vector*>>& + inputs); + +// For phi api +template +SpmdInfo VariadicReplicatedInferSpmdDynamic(const Args&... args) { + return detail::ReplicateInferSpmdDynamicHelper() + .apply(args...) + .Infer(); +} + // For phi api template SpmdInfo VariadicReplicatedInferSpmd(const Args&... args) { diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 3eb63b5e7d0ee..eb3a97ce053c3 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/infermeta/spmd_rules/concat.h" #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" @@ -523,6 +524,10 @@ PD_REGISTER_SPMD_RULE(slice, PD_INFER_SPMD(phi::distributed::SliceInferSpmd), PD_INFER_SPMD(phi::distributed::SliceInferSpmdReverse)); +PD_REGISTER_SPMD_RULE(concat, + PD_INFER_SPMD(phi::distributed::ConcatInferSpmd), + PD_INFER_SPMD(phi::distributed::ConcatInferSpmdReverse)); + // transpose rule PD_REGISTER_SPMD_RULE( transpose, diff --git a/paddle/phi/infermeta/spmd_rules/split.cc b/paddle/phi/infermeta/spmd_rules/split.cc index 4bc2a9ce0bdb1..0856fec2e89df 100644 --- a/paddle/phi/infermeta/spmd_rules/split.cc +++ b/paddle/phi/infermeta/spmd_rules/split.cc @@ -92,8 +92,10 @@ SpmdInfo SplitWithNumInferSpmd(const DistMetaTensor& x, int num, int axis) { << str_join(out_dims_mapping) << "]"; } VLOG(4) << std::endl; - - return {{x_dist_attr_dst}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // should return list in list [] + // return {{x_dist_attr_dst}, {out_dist_attrs}}; + return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitWithNumInferSpmdReverse( @@ -193,8 +195,9 @@ SpmdInfo SplitWithNumInferSpmdReverse( } VLOG(4) << "Input shape: [" << str_join(x_shape) << "] " << "dims_mapping: [" << str_join(x_dims_mapping) << "]\n\n"; - - return {{x_dist_attr}, out_dist_attrs}; + // TODO(liuzhenhai): remedy this + // return {{x_dist_attr}, {out_dist_attrs}}; + return {{x_dist_attr}, ToArgDistAttr(out_dist_attrs)}; } SpmdInfo SplitInferSpmd(const DistMetaTensor& x, diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index dc6141f3ec0ce..42bbc659b2f2b 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -164,6 +164,99 @@ TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr) { return dst_dist_attr; } +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping = dist_attr.dims_mapping(); + dims_mapping[dim] = kReplicateDim; + dst_dist_attr.set_dims_mapping(dims_mapping); + return dst_dist_attr; +} + +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim) { + return dist_attr.is_shard(-1, dim); +} + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b) { + if (a->is_partial()) { + if (!b->is_partial()) { + return false; + } + auto a_partial = std::dynamic_pointer_cast(a); + auto b_partial = std::dynamic_pointer_cast(b); + return a_partial->get_reduce_type() == b_partial->get_reduce_type(); + } + if (a->is_replicated()) { + if (b->is_replicated()) { + return true; + } + return false; + } + if (!b->is_shard()) { + return false; + } + + auto a_shard = std::dynamic_pointer_cast(a); + auto b_shard = std::dynamic_pointer_cast(b); + return a_shard->get_axis() == b_shard->get_axis(); +} + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements) { + TensorDistAttr dst_dist_attr = CopyTensorDistAttrForOutput(dist_attr); + std::vector dims_mapping(dist_attr.dims_mapping().size(), -1); + paddle::flat_hash_map partial_status; + + for (size_t mesh_dim = 0; mesh_dim < placements.size(); mesh_dim++) { + auto& placement = placements[mesh_dim]; + if (placement->is_shard()) { + auto shard_placement = std::dynamic_pointer_cast(placement); + dims_mapping[shard_placement->get_axis()] = mesh_dim; + } + if (placement->is_partial()) { + auto partial_placement = + std::dynamic_pointer_cast(placement); + auto reduce_type = partial_placement->get_reduce_type(); + partial_status[mesh_dim] = reduce_type; + } + } + dst_dist_attr.set_dims_mapping(dims_mapping); + dst_dist_attr.set_partial_status(partial_status); + return dst_dist_attr; +} + +std::vector ToArgDistAttr( + const std::vector& dist_attrs) { + std::vector items_dist_attrs; + std::transform( + dist_attrs.begin(), + dist_attrs.end(), + std::back_inserter(items_dist_attrs), + [](const TensorDistAttr& attr) -> ArgDistAttr { return {attr}; }); + return items_dist_attrs; +} + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements) { + auto local_shape = shape; + auto n_placement = placements.size(); + for (size_t i = 0; i < n_placement; i++) { + auto& placement = placements.at(i); + if (placement->is_shard()) { + auto mesh_dim_size = mesh.dim_size(i); + auto shard_dim = + std::dynamic_pointer_cast(placement)->get_axis(); + auto split_size = + (shape.at(shard_dim) + mesh_dim_size - 1) / mesh_dim_size; + local_shape[shard_dim] = split_size; + } + } + return local_shape; +} + std::vector GetDimsMappingForAxes( const std::string& axes, const std::unordered_map& axis_to_dim_map, diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 15245f741b70b..2d52a58bdbb24 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -69,6 +69,25 @@ std::vector ResoluteOutputPartialDimension( // Repliacated state TensorDistAttr GetReplicatedDistAttr(const TensorDistAttr& dist_attr); +bool IsDimSharded(const TensorDistAttr& dist_attr, int dim); + +std::vector GetLocalShape( + const std::vector shape, + const ProcessMesh& mesh, + const std::vector>& placements); + +TensorDistAttr FromPlacements( + const TensorDistAttr& dist_attr, + const std::vector>& placements); + +std::vector ToArgDistAttr( + const std::vector& dist_attrs); + +TensorDistAttr ReplicateTensorDim(const TensorDistAttr& dist_attr, int dim); + +bool PlacementEqual(const std::shared_ptr& a, + const std::shared_ptr& b); + // Adaptor for variadic arguments template struct ArgsIterator { @@ -106,13 +125,6 @@ struct VariadicSpmdRuleArgumentParser // deal with inputs void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } - void operator()(const std::vector& x) { - for (int i = 0; i < x.size(); i++) { - std::cout << "i: " << i << std::endl; - inputs.emplace_back(&x[i]); - } - } - void operator()(const std::vector& x) { for (auto& t : x) { inputs.emplace_back(&t); @@ -132,6 +144,28 @@ struct VariadicSpmdRuleArgumentParser SpmdInfo InferBackward() { return Fn(inputs, outputs); } }; + +using DynamicSpmdFn = SpmdInfo (*)( + const std::vector*>>&); + +template +struct ReplicateInferSpmdDynamicHelper + : public ArgsIterator> { + SpmdInfo Infer() { return Fn(inputs); } + + void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } + void operator()(const std::vector& x) { + inputs.emplace_back(&x); + } + + void operator()(std::vector&& x) = delete; + void operator()(DistMetaTensor&& x) = delete; + + std::vector*>> + inputs; +}; } // namespace detail // Get dims mapping for the given axes according to sharding information of diff --git a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py index 67d656f4dcd75..c99ce0c552a8a 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py +++ b/python/paddle/jit/sot/opcode_translator/executor/executor_cache.py @@ -14,6 +14,7 @@ from __future__ import annotations +import gc import traceback import types from typing import List, Tuple @@ -228,3 +229,5 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e finally: simulator.cleanup() + del simulator + gc.collect() diff --git a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py index 1fd89009200a4..1dbe77cd48052 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/function_graph.py +++ b/python/paddle/jit/sot/opcode_translator/executor/function_graph.py @@ -26,7 +26,7 @@ from ...infer_meta import InferMetaCache, LayerInferMetaCache, MetaInfo from ...profiler import EventGuard, event_register -from ...symbolic.statement_ir import Symbol +from ...symbolic.statement_ir import Reference, Symbol from ...symbolic.symbolic_context import SymbolicTraceContext from ...utils import ( ENV_SHOW_TRACKERS, @@ -426,6 +426,7 @@ def get_opcode_executor_stack(): def call_layer( self, layer: PaddleLayerVariable, + weak_ref: bool, *args: VariableBase, **kwargs: VariableBase, ): @@ -442,7 +443,7 @@ def infer_meta_fn(layer, *metas, **kwmetas): def compute_fn(layer, inputs, outputs, stacks): self.sir_ctx.call_LAYER( - layer.value, + Reference(layer.value, weak_ref), inputs=inputs, outputs=outputs, stacks=stacks, diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index d9947579dc7d4..ea76642a671db 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1460,6 +1460,7 @@ def __init__(self, frame: types.FrameType, **kwargs): def cleanup(self): self._graph.pycode_gen = None Dispatcher.graph = None + self.call_stack[:] = [] @event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2) def _prepare_virtual_env(self): diff --git a/python/paddle/jit/sot/symbolic/statement_ir.py b/python/paddle/jit/sot/symbolic/statement_ir.py index 11a08f36acd9d..1e0ab465e0bd8 100644 --- a/python/paddle/jit/sot/symbolic/statement_ir.py +++ b/python/paddle/jit/sot/symbolic/statement_ir.py @@ -22,12 +22,26 @@ import weakref from typing import Any, Callable -import paddle from paddle.utils import is_sequence, map_structure from ..utils import NameGenerator, OrderedSet, Singleton, flatten_extend +class Reference: # to unify weak_ref and strong_ref + def __init__(self, value, is_weak): + self.is_weak = is_weak + if is_weak is True: + self.ref = weakref.ref(value) + else: + self.ref = value + + def __call__(self): + if self.is_weak is True: + return self.ref() + else: + return self.ref + + class Symbol: """ Symbol is used to distinguish a string and a `math variable`. @@ -139,7 +153,7 @@ def __init__( class LayerStatement(Statement): def __init__( self, - layer: paddle.nn.Layer, + layer: Reference, # Reference of paddle.nn.Layer inputs: list[Symbol], outputs: list[Symbol], stacks: list[str], @@ -147,7 +161,7 @@ def __init__( super().__init__( "layer", layer.__class__.__name__, inputs, outputs, stacks ) - self.layer = weakref.ref(layer) + self.layer = layer class StatementIR: diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 0f49006737bab..06932f5c9b567 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -441,23 +441,23 @@ def logspace(start, stop, num, base=10.0, dtype=None, name=None): tensor_start = start tensor_stop = stop tensor_base = base - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): check_type(num, 'num', (int), 'logspace') - if not isinstance(dtype, core.VarDesc.VarType): + if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) - if not isinstance(start, Variable): + if not isinstance(start, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_start = fill_constant([1], dtype, start) - if not isinstance(stop, Variable): + if not isinstance(stop, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_stop = fill_constant([1], dtype, stop) - if not isinstance(num, Variable): + if not isinstance(num, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_num = fill_constant([1], 'int32', num) - if not isinstance(base, Variable): + if not isinstance(base, (Variable, paddle.pir.OpResult)): with device_guard("cpu"): tensor_base = fill_constant([1], dtype, base) - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.logspace( tensor_start, tensor_stop, @@ -1648,7 +1648,7 @@ def meshgrid(*args, **kwargs): if len(args) == 1 and isinstance(args[0], (list, tuple)): args = args[0] - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.meshgrid(list(args)) else: name = kwargs.get("name", None) @@ -2608,7 +2608,7 @@ def complex(real, imag, name=None): [[0j , 1j , 2j ], [(1+0j), (1+1j), (1+2j)]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.complex(real, imag) else: check_variable_and_dtype( diff --git a/test/auto_parallel/semi_auto_parallel_for_concat.py b/test/auto_parallel/semi_auto_parallel_for_concat.py new file mode 100644 index 0000000000000..24605825d5f15 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_concat.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist + + +class TestSplitAndConcatSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + + def test_concat_forward(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [[None, None, 'x'], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def test_concat_forward_reshard(self): + shapes = [[16, 4, 4], [64, 4, 4]] + specs = [['x', None, None], [None, None, 'x']] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=paddle.concat, + with_backward=False, + axis=0, + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_concat_forward() + # all to all is not supported yet for cpu + if self._backend == "gpu": + self.test_concat_forward_reshard() + + +if __name__ == '__main__': + TestSplitAndConcatSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/semi_auto_parallel_for_matmul.py b/test/auto_parallel/semi_auto_parallel_for_matmul.py index 279062f483058..470100e9c3bc8 100644 --- a/test/auto_parallel/semi_auto_parallel_for_matmul.py +++ b/test/auto_parallel/semi_auto_parallel_for_matmul.py @@ -30,7 +30,7 @@ def __init__(self): def check_tensor_eq(self, a, b): np1 = a.numpy() np2 = b.numpy() - np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + np.testing.assert_allclose(np1, np2, rtol=1e-04, verbose=True) def test_body( self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False diff --git a/test/auto_parallel/semi_auto_parallel_util.py b/test/auto_parallel/semi_auto_parallel_util.py new file mode 100644 index 0000000000000..cfb905e8382a2 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_util.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class SemiAutoParallelTestBase: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def flatten(self, inputs, terminal_cond): + """ + inputs may be single tensor、tuple + """ + + if terminal_cond(inputs): + return [inputs], "i" + + assert isinstance(inputs, (tuple, list)) + flattened = [] + structure = [] + for i in range(len(inputs)): + tmp, tmp_structure = self.flatten(inputs[i], terminal_cond) + flattened.extend(tmp) + structure.append(tmp_structure) + + if isinstance(inputs, tuple): + structure = tuple(structure) + return flattened, structure + + def unflatten(self, inputs, structure, offset=0): + """ + inputs may be single tensor + """ + assert isinstance(inputs, list) + assert offset < len(inputs) + if structure == "i": + offset = offset + 1 + # return a list + return inputs[offset - 1], offset + assert isinstance(structure, (tuple, list)) + unflattened = [] + for i in range(len(structure)): + tmp, offset = self.unflatten(inputs, structure[i], offset) + unflattened.append(tmp) + if isinstance(structure, tuple): + unflattened = tuple(unflattened) + return unflattened, offset + + def runfunc_and_check( + self, inputs_shape, inputs_specs, op_func, with_backward, **kwargs + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + flat_inputs = [] + flat_dist_inputs = [] + + def terminal_cond(x): + return isinstance(x, list) and all( + not isinstance(e, (list, tuple)) for e in x + ) + + flat_inputs_specs, inputs_structure = self.flatten( + inputs_specs, terminal_cond + ) + flat_inputs_shape, _ = self.flatten(inputs_shape, terminal_cond) + assert len(flat_inputs_specs) == len(flat_inputs_shape) + + for shape, spec in zip(flat_inputs_shape, flat_inputs_specs): + input_np = np.random.random(size=shape).astype(self._dtype) + input = paddle.to_tensor(input_np) + input.stop_gradient = False + input_dist_attr = dist.DistAttr( + mesh=self._mesh, sharding_specs=spec + ) + dist_input = dist.shard_tensor(input, dist_attr=input_dist_attr) + dist_input.stop_gradient = False + flat_inputs.append(input) + flat_dist_inputs.append(dist_input) + inputs, _ = self.unflatten(flat_inputs, inputs_structure) + dist_inputs, _ = self.unflatten(flat_dist_inputs, inputs_structure) + + def wrap_tuple(e): + return e if isinstance(e, tuple) else (e,) + + op_inputs = wrap_tuple(inputs) + op_dist_input = wrap_tuple(dist_inputs) + + out = op_func(*op_inputs, **kwargs) + dist_out = op_func(*op_dist_input, **kwargs) + + if with_backward: + + def terminal_cond2(x): + return not isinstance(x, (list, tuple)) + + flat_out, _ = self.flatten(out, terminal_cond2) + flat_dist_out, _ = self.flatten(dist_out, terminal_cond2) + assert len(flat_out) == len(flat_dist_out) + for output, dist_output in zip(flat_out, flat_dist_out): + self.check_tensor_eq(out, dist_out) + output.backward() + dist_output.backward() + + for x, dist_x in zip(flat_inputs, flat_dist_inputs): + self.check_tensor_eq(x.grad, dist_x.grad) + + return dist_inputs, dist_out diff --git a/test/auto_parallel/spmd_rules/CMakeLists.txt b/test/auto_parallel/spmd_rules/CMakeLists.txt index f5d45ecaafc3f..5c8f78b6c6544 100644 --- a/test/auto_parallel/spmd_rules/CMakeLists.txt +++ b/test/auto_parallel/spmd_rules/CMakeLists.txt @@ -20,6 +20,7 @@ if(WITH_DISTRIBUTE) py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule) py_test_modules(test_slice_rule MODULES test_slice_rule) py_test_modules(test_flatten_rule MODULES test_flatten_rule) + py_test_modules(test_concat_rule MODULES test_concat_rule) # End of unittests WITH single card WITHOUT timeout endif() diff --git a/test/auto_parallel/spmd_rules/test_concat_rule.py b/test/auto_parallel/spmd_rules/test_concat_rule.py new file mode 100644 index 0000000000000..b1e1c11a0622e --- /dev/null +++ b/test/auto_parallel/spmd_rules/test_concat_rule.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddle.distributed.auto_parallel.static.dist_attribute import ( + DistTensorSpec, + TensorDistAttr, +) +from paddle.distributed.fleet import auto +from paddle.framework import core + + +class TestConcatSPMDRule(unittest.TestCase): + """ + Unit tests for split spmd rule. + """ + + def setUp(self): + self.process_mesh = auto.ProcessMesh(mesh=[[0, 1], [2, 3]]) + self.shapes = [[16, 16, 16], [4, 16, 16], [2, 16, 16]] + self.dim_mappings = [[-1, 0, 1], [-1, 1, 0], [-1, -1, 0]] + + def build_inputs(self): + inputs = [] + for shape, dim_mapping in zip(self.shapes, self.dim_mappings): + tensor_dist_attr = TensorDistAttr() + tensor_dist_attr.dims_mapping = dim_mapping + tensor_dist_attr.process_mesh = self.process_mesh + inputs.append(DistTensorSpec(shape, tensor_dist_attr)) + return inputs + + def test_infer_forward(self): + inputs = self.build_inputs() + rule = core.get_phi_spmd_rule("concat") + infered_dist_attrs = rule.infer_forward(inputs, 0) + infered_input_dist_attrs = infered_dist_attrs[0] + self.assertEqual(len(infered_input_dist_attrs), 1) + infered_output_dist_attrs = infered_dist_attrs[1] + self.assertEqual(len(infered_output_dist_attrs), 1) + for input_dist_attr in infered_input_dist_attrs[0]: + self.assertEqual(input_dist_attr.dims_mapping, [-1, 1, 0]) + self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, 0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 3730e019f7506..b1132a6a3a8dc 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -46,6 +46,16 @@ def test_elementwise_api(self): user_defined_envs=envs, ) + def test_concat_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_concat.py", + user_defined_envs=envs, + ) + def test_reduction_api(self): envs_list = test_base.gen_product_envs_list( self._default_envs, self._changeable_envs diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index 42476d7bb323f..eb6d08542b04a 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include +#include #include "glog/logging.h" #include "gtest/gtest.h" @@ -23,6 +24,7 @@ limitations under the License. */ #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/phi/core/distributed/type_defs.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -30,6 +32,68 @@ namespace paddle { namespace distributed { namespace auto_parallel { +auto& get_dims_mapping(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.dims_mapping(); +} + +bool is_partial(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.is_partial(); +} + +auto get_partial_dims(const phi::distributed::ArgDistAttr& dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)); + const auto& tensor_attr = paddle::get<0>(dist_attr); + return tensor_attr.partial_dims(); +} + +void check_dim_mapping(const phi::distributed::ArgDistAttr& dist_attr, + const std::vector& dim_mapping, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_dims_mapping(dist_attr), dim_mapping) << line; +} + +void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr, + const std::set& dims, + const std::string& line = "") { + EXPECT_TRUE( + paddle::holds_alternative(dist_attr)) + << line; + EXPECT_EQ(get_partial_dims(dist_attr), dims) << line; +} + +void clean_partial_status(phi::distributed::ArgDistAttr* dist_attr) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_status(); +} + +void clean_partial_dims(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.clean_partial_dims(dims); +} + +void set_partial_status(phi::distributed::ArgDistAttr* dist_attr, + std::vector dims) { + EXPECT_TRUE( + paddle::holds_alternative(*dist_attr)); + auto& tensor_attr = paddle::get<0>(*dist_attr); + tensor_attr.set_partial_status(dims); +} + TEST(MatmulSPMDRule, Ctor) { // build input data class std::vector x_shape = {64, 32}; @@ -66,14 +130,10 @@ TEST(MatmulSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[-1,0] --> mk[-1,-1],kn[-1,0] = nm[-1,0] partial[] @@ -84,15 +144,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; - // mk[1, 0],kn[-1,-1] --> mk[1, 0],kn[0, -1] = nm[1, -1] partial[0]: done x_dist_attr.set_dims_mapping({1, 0}); y_dist_attr.set_dims_mapping({-1, -1}); @@ -101,15 +157,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; // mk[-1,-1],kn[1,0] --> mk[-1, 1],kn[1, 0] = nm[-1, 0] partial[1]: done @@ -120,15 +172,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({1})); + check_dim_mapping(infered_dist_attrs.first[0], {-1, 1}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, 0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {1}); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; // abcmk[1, 0, -1, -1],kn[-1, -1] --> abcmk[1, 0, -1, -1],kn[-1, -1] = @@ -141,13 +189,10 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, 1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {0, 1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, 1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test5 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0],kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[0, -1] = abcmn[1, @@ -159,15 +204,11 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test6 done." << std::endl << std::endl << std::endl; // abcmk[1, -1, -1, 0], kn[-1, -1] --> abcmk[1, -1, -1, 0],kn[-1, -1] = @@ -179,13 +220,12 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/false}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, 0, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + VLOG(4) << "test7 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, -1, -1], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -197,17 +237,13 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/false, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, -1, 0})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({1, 0})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, -1, 1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, -1, 0}); + check_dim_mapping(infered_dist_attrs.first[1], {1, 0}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, -1, 1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); VLOG(4) << "test8 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 0, 1]+trans_x=true, kn[1, 0]+trans_y=true --> abcmk[-1, -1, @@ -219,20 +255,16 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 0, 1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector( - {-1, 0})); // confilct and should be changed to [-1, 0] - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - VLOG(4) << infered_dist_attrs.second[0].to_string(); - infered_dist_attrs.second[0].clean_partial_status(); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 0, 1}); + check_dim_mapping(infered_dist_attrs.first[1], + {-1, 0}); // confilct and should be changed to [-1, 0] + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + check_partial_dims(infered_dist_attrs.second[0], {0}); + + clean_partial_status(&infered_dist_attrs.second[0]); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); VLOG(4) << "test9 done." << std::endl << std::endl << std::endl; // abcmk[-1, -1, 1, 0], kn[1, 0] --> abcmk[-1, -1, -1, 0],kn[1, 0] = @@ -256,29 +288,21 @@ TEST(MatmulSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext( {x, y}, {/*trans_x=*/true, /*trans_x=*/true}); infered_dist_attrs = matmul_spmd_rule.InferForward(ctx); - - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); // try to clean partial on a dim which is not partial - EXPECT_ANY_THROW(infered_dist_attrs.second[0].clean_partial_dims( - std::vector({1}))); - + EXPECT_ANY_THROW(clean_partial_dims(&infered_dist_attrs.second[0], {1})); // try to clean partial on a dims which is sharded - EXPECT_ANY_THROW(infered_dist_attrs.second[0].set_partial_status( - std::vector({1}))); + EXPECT_ANY_THROW(set_partial_status(&infered_dist_attrs.second[0], {1})); // clean partial and then re-set again - infered_dist_attrs.second[0].clean_partial_dims(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), false); - infered_dist_attrs.second[0].set_partial_status(std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - EXPECT_EQ(infered_dist_attrs.second[0].partial_dims(), - std::set({0})); - + clean_partial_dims(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), false); + set_partial_status(&infered_dist_attrs.second[0], {0}); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); + check_partial_dims(infered_dist_attrs.second[0], {0}); VLOG(4) << "test11 done." << std::endl << std::endl << std::endl; } @@ -328,26 +352,18 @@ TEST(LayerNormSPMDRule, Ctor) { bias_dist_attr); phi::distributed::InferSpmdContext ctx({x, scale, bias}, {epsilon, begin_norm_axis}); - std::pair, std::vector> - infered_dist_attrs = layer_norm_rule.InferForward(ctx); + auto infered_dist_attrs = layer_norm_rule.InferForward(ctx); size_t input_size = 3; size_t output_size = 3; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test1 done."; // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1],z[1], @@ -364,18 +380,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({1})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({1})); + + check_dim_mapping(infered_dist_attrs.first[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {1}); + check_dim_mapping(infered_dist_attrs.second[2], {1}); VLOG(4) << "test2 done."; // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, 1, -1], i[0], i[0], y=jk, @@ -392,18 +403,13 @@ TEST(LayerNormSPMDRule, Ctor) { ctx = phi::distributed::InferSpmdContext({x, scale, bias}, {epsilon, begin_norm_axis}); infered_dist_attrs = layer_norm_rule.InferForward(ctx); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.first[2].dims_mapping(), - std::vector({-1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[1].dims_mapping(), - std::vector({0})); - EXPECT_EQ(infered_dist_attrs.second[2].dims_mapping(), - std::vector({0})); + + check_dim_mapping(infered_dist_attrs.first[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1}); + check_dim_mapping(infered_dist_attrs.first[2], {-1}); + check_dim_mapping(infered_dist_attrs.second[0], {0, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[1], {0}); + check_dim_mapping(infered_dist_attrs.second[2], {0}); VLOG(4) << "test3 done."; } @@ -449,24 +455,19 @@ TEST(MatmulSPMDRuleInferBackward, Ctor) { // -1] phi::distributed::InferSpmdContext ctx( {x, y, out}, {/*trans_x=*/false, /*trans_x=*/false}); - std::pair, std::vector> - infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); + auto infered_dist_attrs = matmul_spmd_rule.InferBackward(ctx); size_t input_size = 2; size_t output_size = 1; EXPECT_EQ(infered_dist_attrs.first.size(), input_size); EXPECT_EQ(infered_dist_attrs.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs.second[0].dims_mapping(), - std::vector({-1, -1, 1, -1})); - EXPECT_EQ(infered_dist_attrs.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs.second[0].is_partial(), true); - + check_dim_mapping(infered_dist_attrs.first[0], {-1, -1, 1, -1}); + check_dim_mapping(infered_dist_attrs.first[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs.second[0], {-1, -1, 1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs.second[0]), true); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; } @@ -524,18 +525,14 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({-1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); + check_dim_mapping(infered_dist_attrs_st.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {-1, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -554,15 +551,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[2].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[2], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -582,14 +574,10 @@ TEST(ReplicatedSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test3 done." << std::endl << std::endl << std::endl; @@ -649,19 +637,15 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); + check_dim_mapping(infered_dist_attrs_st.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_st.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_st.second[1], {0, -1, -1}); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.first[1]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[0]), false); + EXPECT_EQ(is_partial(infered_dist_attrs_st.second[1]), false); - EXPECT_EQ(infered_dist_attrs_st.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.second[1].dims_mapping(), - std::vector({0, -1, -1})); - EXPECT_EQ(infered_dist_attrs_st.first[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.first[1].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[0].is_partial(), false); - EXPECT_EQ(infered_dist_attrs_st.second[1].is_partial(), false); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test1 done." << std::endl << std::endl << std::endl; @@ -682,14 +666,11 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({-1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({-1, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[2].dims_mapping(), - std::vector({-1, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {-1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {-1, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[2], {-1, -1, -1}); + EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test2 done." << std::endl << std::endl << std::endl; @@ -735,19 +716,101 @@ TEST(DefaultDataParallelSPMDRule, Ctor) { EXPECT_EQ(infered_dist_attrs_st.second.size(), output_size); EXPECT_EQ(infered_dist_attrs_dy.first.size(), input_size); EXPECT_EQ(infered_dist_attrs_dy.second.size(), output_size); - - EXPECT_EQ(infered_dist_attrs_dy.first[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.first[1].dims_mapping(), - std::vector({0, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[0].dims_mapping(), - std::vector({0, -1, -1, -1})); - EXPECT_EQ(infered_dist_attrs_dy.second[1].dims_mapping(), - std::vector({0, -1, -1})); + check_dim_mapping(infered_dist_attrs_dy.first[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.first[1], {0, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[0], {0, -1, -1, -1}); + check_dim_mapping(infered_dist_attrs_dy.second[1], {0, -1, -1}); EXPECT_EQ(infered_dist_attrs_st.first, infered_dist_attrs_dy.first); EXPECT_EQ(infered_dist_attrs_st.second, infered_dist_attrs_dy.second); VLOG(4) << "test4 done." << std::endl << std::endl << std::endl; } +TEST(ConcatRule, Ctor) { + std::vector mesh_shape = {2, 2}; + std::vector process_ids = {0, 1, 2, 3}; + std::vector dim_names = {"x", "y"}; + ProcessMesh process_mesh(mesh_shape, process_ids, dim_names); + + std::vector> shapes = { + {16, 16, 16}, {4, 16, 16}, {2, 16, 16}}; + std::vector> dim_mappings = { + {-1, 0, 1}, {-1, 1, 0}, {-1, -1, 0}}; + std::vector> partial_status = {{}, {}, {1}}; + + auto build_inputs = [&] { + std::vector inputs; + for (int i = 0; i < 3; i++) { + auto t_dist_attr = TensorDistAttr(); + t_dist_attr.set_process_mesh(process_mesh); + t_dist_attr.set_dims_mapping(dim_mappings[i]); + t_dist_attr.set_dynamic_dims({false, false, false}); + auto input = phi::distributed::DistMetaTensor(phi::make_ddim(shapes[i]), + t_dist_attr); + inputs.push_back(input); + } + return inputs; + }; + + // test 1, inputs are aligned according to cost, and partial status is cleared + auto inputs = build_inputs(); + auto infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 0); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer1 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer1) { + check_dim_mapping(e, {-1, 1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {-1, 1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); + + // test 2,force replicate along concat axis + inputs = build_inputs(); + infered_dist_attrs = phi::distributed::ConcatInferSpmd(inputs, 1); + // list of tensor => sigle tensor + EXPECT_EQ(infered_dist_attrs.first.size(), static_cast(1)); + EXPECT_EQ(infered_dist_attrs.second.size(), static_cast(1)); + EXPECT_TRUE( + paddle::holds_alternative>( + infered_dist_attrs.first[0])); + EXPECT_TRUE(paddle::holds_alternative( + infered_dist_attrs.second[0])); + auto& inputs_infer2 = paddle::get<1>(infered_dist_attrs.first[0]); + for (auto e : inputs_infer2) { + check_dim_mapping(e, {1, -1, 0}); + check_partial_dims(e, {}); + } + check_dim_mapping(infered_dist_attrs.second[0], {1, -1, 0}); + check_partial_dims(infered_dist_attrs.second[0], {}); +} +TEST(Util, Ctor) { + // test equal test not equal + using phi::distributed::PartialStatus; + using phi::distributed::PlacementEqual; + using phi::distributed::ReplicatedStatus; + using phi::distributed::ShardStatus; + auto a = std::make_shared(phi::ReduceType::kRedSum); + auto b = std::make_shared(phi::ReduceType::kRedMin); + EXPECT_TRUE(PlacementEqual(a, a)); + EXPECT_TRUE(!PlacementEqual(a, b)); + auto c = std::make_shared(0); + auto d = std::make_shared(1); + EXPECT_TRUE(!PlacementEqual(a, c)); + EXPECT_TRUE(!PlacementEqual(b, c)); + EXPECT_TRUE(PlacementEqual(c, c)); + EXPECT_TRUE(!PlacementEqual(c, d)); + auto e = std::make_shared(); + EXPECT_TRUE(PlacementEqual(e, e)); + EXPECT_TRUE(!PlacementEqual(a, e)); + EXPECT_TRUE(!PlacementEqual(b, e)); + EXPECT_TRUE(!PlacementEqual(c, e)); + EXPECT_TRUE(!PlacementEqual(d, e)); +} } // namespace auto_parallel } // namespace distributed diff --git a/test/ir/inference/test_trt_support_nhwc_pass.py b/test/ir/inference/test_trt_support_nhwc_pass.py index 0648202aba30c..bd585d1b5b850 100644 --- a/test/ir/inference/test_trt_support_nhwc_pass.py +++ b/test/ir/inference/test_trt_support_nhwc_pass.py @@ -93,6 +93,10 @@ def setUp(self): self.temp_dir.name, 'inference_pass', 'nhwc_converter', '' ) self.model_prefix = self.path + 'infer_model' + self.set_args() + + def set_args(self): + self.precision_mode = inference.PrecisionType.Float32 def create_model(self): image = static.data( @@ -115,7 +119,7 @@ def create_predictor(self): workspace_size=1 << 30, max_batch_size=1, min_subgraph_size=3, - precision_mode=inference.PrecisionType.Float32, + precision_mode=self.precision_mode, use_static=False, use_calib_mode=False, ) @@ -147,5 +151,44 @@ def tearDown(self): shutil.rmtree(self.path) +class TRTNHWCConvertAMPTest(TRTNHWCConvertTest): + def set_args(self): + self.precision_mode = inference.PrecisionType.Half + + def create_model(self): + train_prog = paddle.static.Program() + with paddle.static.program_guard(train_prog): + with paddle.static.amp.fp16_guard(): + image = paddle.static.data( + name='image', shape=[None, 224, 224, 4], dtype='float32' + ) + label = paddle.static.data( + name='label', shape=[None, 1], dtype='int64' + ) + predict = SimpleNet()(image) + cost = paddle.nn.functional.loss.cross_entropy( + input=predict, label=label + ) + avg_cost = paddle.mean(x=cost) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5), + ) + optimizer = paddle.static.amp.decorate( + optimizer, + use_dynamic_loss_scaling=False, + use_pure_fp16=False, + ) + optimizer.minimize(avg_cost) + val_prog = train_prog.clone(for_test=True) + + exe = paddle.static.Executor(self.place) + exe.run(paddle.static.default_startup_program()) + paddle.static.save_inference_model( + self.model_prefix, [image], [predict], exe, program=val_prog + ) + + if __name__ == '__main__': unittest.main() diff --git a/test/legacy_test/test_complex_op.py b/test/legacy_test/test_complex_op.py index 151ecfbdb6524..e0388b0c560d3 100644 --- a/test/legacy_test/test_complex_op.py +++ b/test/legacy_test/test_complex_op.py @@ -20,6 +20,7 @@ import paddle from paddle import static from paddle.base import dygraph +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -45,12 +46,13 @@ def setUp(self): self.outputs = {'Out': out_ref} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): self.check_grad( ['X', 'Y'], 'Out', + check_pir=True, ) def test_check_grad_ignore_x(self): @@ -58,6 +60,7 @@ def test_check_grad_ignore_x(self): ['Y'], 'Out', no_grad_set=set('X'), + check_pir=True, ) def test_check_grad_ignore_y(self): @@ -65,6 +68,7 @@ def test_check_grad_ignore_y(self): ['X'], 'Out', no_grad_set=set('Y'), + check_pir=True, ) @@ -102,6 +106,7 @@ def test_dygraph(self): out_np = paddle.complex(x, y).numpy() np.testing.assert_allclose(self.out, out_np, rtol=1e-05) + @test_with_pir_api def test_static(self): mp, sp = static.Program(), static.Program() with static.program_guard(mp, sp): diff --git a/test/legacy_test/test_logspace.py b/test/legacy_test/test_logspace.py index 9edd4aef71788..857a6411b869f 100644 --- a/test/legacy_test/test_logspace.py +++ b/test/legacy_test/test_logspace.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class TestLogspaceOpCommonCase(OpTest): @@ -39,7 +40,7 @@ def init_data(self): self.outputs = {'Out': np.power(2, np.arange(0, 11)).astype(dtype)} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) class TestLogspaceFP16Op(TestLogspaceOpCommonCase): @@ -87,7 +88,7 @@ def init_data(self): self.place = core.CUDAPlace(0) def test_check_output(self): - self.check_output_with_place(self.place) + self.check_output_with_place(self.place, check_pir=True) class TestLogspaceOpReverseCase(TestLogspaceOpCommonCase): @@ -143,6 +144,7 @@ def init_data(self): class TestLogspaceAPI(unittest.TestCase): + @test_with_pir_api def test_variable_input1(self): paddle.enable_static() prog = paddle.static.Program() @@ -170,6 +172,7 @@ def test_variable_input2(self): self.assertEqual((out.numpy() == np_res).all(), True) paddle.enable_static() + @test_with_pir_api def test_dtype(self): paddle.enable_static() prog = paddle.static.Program() diff --git a/test/legacy_test/test_meshgrid_op.py b/test/legacy_test/test_meshgrid_op.py index d8324612e78e4..215424b9c9236 100644 --- a/test/legacy_test/test_meshgrid_op.py +++ b/test/legacy_test/test_meshgrid_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api def meshgrid_wrapper(x): @@ -41,10 +42,12 @@ def init_data_type(self): self.dtype = np.float64 def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_pir=True) def test_check_grad(self): - self.check_grad(['x0'], ['out0', 'out1'], check_prim=True) + self.check_grad( + ['x0'], ['out0', 'out1'], check_prim=True, check_pir=True + ) def init_inputs_and_outputs(self): self.shape = self.get_x_shape() @@ -122,19 +125,21 @@ def if_enable_cinn(self): self.enable_cinn = False def test_check_output(self): - self.check_output_with_place(place=paddle.CUDAPlace(0)) + self.check_output_with_place(place=paddle.CUDAPlace(0), check_pir=True) def test_check_grad(self): self.check_grad_with_place( - paddle.CUDAPlace(0), ['x0'], ['out0', 'out1'], check_prim=True + paddle.CUDAPlace(0), + ['x0'], + ['out0', 'out1'], + check_prim=True, + check_pir=True, ) class TestMeshgridOp3(unittest.TestCase): + @test_with_pir_api def test_api(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -155,22 +160,24 @@ def test_api(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid(x, y) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid(x, y) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) class TestMeshgridOp4(unittest.TestCase): + @test_with_pir_api def test_list_input(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -191,23 +198,24 @@ def test_list_input(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid([x, y]) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid([x, y]) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) class TestMeshgridOp5(unittest.TestCase): + @test_with_pir_api def test_tuple_input(self): - x = paddle.static.data(shape=[100], dtype='int32', name='x') - y = paddle.static.data(shape=[200], dtype='int32', name='y') - input_1 = np.random.randint( 0, 100, @@ -228,14 +236,17 @@ def test_tuple_input(self): out_2 = np.reshape(input_2, [1, 200]) out_2 = np.broadcast_to(out_2, [100, 200]) - exe = base.Executor(place=base.CPUPlace()) - grid_x, grid_y = paddle.tensor.meshgrid((x, y)) - res_1, res_2 = exe.run( - base.default_main_program(), - feed={'x': input_1, 'y': input_2}, - fetch_list=[grid_x, grid_y], - ) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data(shape=[100], dtype='int32', name='x') + y = paddle.static.data(shape=[200], dtype='int32', name='y') + exe = base.Executor(place=base.CPUPlace()) + grid_x, grid_y = paddle.tensor.meshgrid((x, y)) + res_1, res_2 = exe.run( + paddle.static.default_main_program(), + feed={'x': input_1, 'y': input_2}, + fetch_list=[grid_x, grid_y], + ) np.testing.assert_array_equal(res_1, out_1) np.testing.assert_array_equal(res_2, out_2) diff --git a/test/legacy_test/test_sgd_op.py b/test/legacy_test/test_sgd_op.py index 23511dc2a5371..d71b297185892 100644 --- a/test/legacy_test/test_sgd_op.py +++ b/test/legacy_test/test_sgd_op.py @@ -17,6 +17,7 @@ import numpy as np from op import Operator from op_test import OpTest +from utils import dygraph_guard import paddle from paddle import base @@ -428,43 +429,62 @@ def test_main(self): class TestSGDSimple(unittest.TestCase): + def setUp(self) -> None: + self.data = np.random.random(size=(2, 2)).astype('float32') + def run_static(self): - paddle.enable_static() - paddle.seed(10) - np.random.seed(10) + with paddle.pir_utils.IrGuard(): + paddle.seed(10) + np.random.seed(10) - exe = paddle.static.Executor('gpu') - train_program = paddle.static.Program() - startup_program = paddle.static.Program() - data = np.random.random(size=(2, 2)).astype('float32') + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() - with paddle.static.program_guard(train_program, startup_program): - input = paddle.static.data( - shape=[2, 2], name='input', dtype='float32' - ) - model = paddle.nn.Linear(2, 2) - output = model(input) - loss = paddle.mean(output) + with paddle.static.program_guard(train_program, startup_program): + input = paddle.static.data( + shape=[2, 2], name='input', dtype='float32' + ) + model = paddle.nn.Linear(2, 2) + output = model(input) + loss = paddle.mean(output) - optimizer = paddle.optimizer.SGD() - optimizer.minimize(loss) + optimizer = paddle.optimizer.SGD() + optimizer.minimize(loss) - exe.run(startup_program) + exe.run(startup_program) - out = [] - for _ in range(5): - (loss_data,) = exe.run( - train_program, feed={"input": data}, fetch_list=[loss] - ) - out.append(loss_data) - return out + out = [] + for _ in range(5): + (loss_data,) = exe.run( + train_program, feed={"input": self.data}, fetch_list=[loss] + ) + out.append(loss_data) + return out + + def run_dygraph(self): + with dygraph_guard(): + paddle.seed(10) + np.random.seed(10) + + out = [] + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.SGD(parameters=model.parameters()) + for _ in range(5): + output = model(paddle.to_tensor(self.data)) + loss = paddle.mean(output) + out.append(loss.numpy()) + loss.backward() + optimizer.step() + optimizer.clear_grad() + + return out def test_main(self): if not paddle.is_compiled_with_cuda(): return - out1 = self.run_static() - with paddle.pir_utils.IrGuard(): - out2 = self.run_static() + out1 = self.run_dygraph() + out2 = self.run_static() np.testing.assert_allclose(out1, out2) diff --git a/test/sot/test_simulate_initialize.py b/test/sot/test_simulate_initialize.py index 495e06ac1dbda..08a30dfc5a696 100644 --- a/test/sot/test_simulate_initialize.py +++ b/test/sot/test_simulate_initialize.py @@ -31,6 +31,18 @@ def foo(x, y): return out +def foo2(x, y): + t = nn.Softmax() + out1 = t(paddle.to_tensor([x, y], dtype="float32")) + out2 = t(paddle.to_tensor([x, y], dtype="float32")) + return out1 + out2 + + +def error_foo(x): + t = nn.Linear(10, 10) + return t(x) + + def bar(x): a = A(x) t = paddle.to_tensor(x) @@ -40,12 +52,20 @@ def bar(x): class TestInit(TestCaseBase): def test_init_paddle_layer(self): self.assert_results(foo, 1, 2) + self.assert_results(foo2, 1, 2) def test_init_python_object(self): sot_output = symbolic_translate(bar)([1.0, 2.0]) dyn_output = bar([1.0, 2.0]) self.assert_nest_match(sot_output, dyn_output) + def test_error(self): + def run(): + inputs = paddle.randn((10, 10)) + symbolic_translate(error_foo)(inputs) + + self.assertRaises(paddle.jit.sot.utils.exceptions.InnerError, run) + if __name__ == "__main__": unittest.main() From 89acd77c613191492961c86f5adc1ebbe41a670e Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Thu, 2 Nov 2023 19:43:19 +0800 Subject: [PATCH 4/7] Fix conflicts. --- paddle/phi/api/lib/api_custom_impl.cc | 10 +- paddle/phi/api/lib/api_gen_utils.cc | 17 +-- paddle/phi/api/lib/api_gen_utils.h | 4 +- paddle/phi/api/lib/data_transform.cc | 117 +++++++++--------- paddle/phi/api/lib/data_transform.h | 57 +++++---- paddle/phi/api/yaml/generator/dist_api_gen.py | 9 +- 6 files changed, 110 insertions(+), 104 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 8c82df3e83969..2f1333042fe68 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -117,13 +117,9 @@ Tensor add_n_impl(const std::vector& x) { input_x[i] = x[i].impl().get(); } - // auto meta_dist_input_x = MakeDistMetaTensor(input_x); - std::vector meta_dist_input_x; - for (auto& e : input_x) { - meta_dist_input_x.push_back(MakeDistMetaTensor(*e)); - } - auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic( - meta_dist_input_x); + auto meta_dist_input_x = MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 25371c3ec4ca7..a39010ac2f73b 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -536,14 +536,15 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( return phi::distributed::DistMetaTensor(tensor); } -// std::vector MakeDistMetaTensor( -// const std::vector& tensors) { -// std::vector out; -// for (auto t : tensors) { -// out.push_back(MakeDistMetaTensor(*t.impl())); -// } -// return out; -// } +std::vector MakeDistMetaTensor( + const std::vector& tensors) { + std::vector meta_tensors; + meta_tensors.reserve(tensors.size()); + for (const auto* t : tensors) { + meta_tensors.emplace_back(*t); + } + return meta_tensors; +} phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index 378be88824067..13f68ab7defbb 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -140,8 +140,8 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, phi::distributed::DistMetaTensor MakeDistMetaTensor( const phi::TensorBase& tensor); -// std::vector MakeDistMetaTensor( -// const std::vector& tensors); +std::vector MakeDistMetaTensor( + const std::vector& tensors); phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 1044b9d7046e9..00a9419aac7bf 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -835,62 +835,62 @@ ReshardApiInputToReplicatedKernelInput( return paddle::none; } -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensor, tensor_dist_attr); -} +// std::shared_ptr +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const Tensor& tensor, +// const phi::distributed::ArgDistAttr& dist_attr) { +// PADDLE_ENFORCE_EQ( +// paddle::holds_alternative(dist_attr), +// true, +// phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); +// const auto& tensor_dist_attr = paddle::get<0>(dist_attr); +// return ReshardApiInputToReplicatedKernelInput( +// dev_ctx, tensor, tensor_dist_attr); +// } -paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensor, tensor_dist_attr); -} +// paddle::optional> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const paddle::optional& tensor, +// const phi::distributed::ArgDistAttr& dist_attr) { +// PADDLE_ENFORCE_EQ( +// paddle::holds_alternative(dist_attr), +// true, +// phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); +// const auto& tensor_dist_attr = paddle::get<0>(dist_attr); +// return ReshardApiInputToReplicatedKernelInput( +// dev_ctx, tensor, tensor_dist_attr); +// } -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> outputs; - for (size_t i = 0; i < tensors.size(); ++i) { - outputs.push_back(ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors[i], dist_attrs[i])); - } - return outputs; -} +// std::vector> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const std::vector& tensors, +// const std::vector& dist_attrs) { +// std::vector> outputs; +// for (size_t i = 0; i < tensors.size(); ++i) { +// outputs.push_back(ReshardApiInputToReplicatedKernelInput( +// dev_ctx, tensors[i], dist_attrs[i])); +// } +// return outputs; +// } -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative>( - dist_attr), - true, - phi::errors::PreconditionNotMet( - "Arg must be a vector of TensorDistAttr")); - const auto& tensor_dist_attrs = paddle::get<1>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors, tensor_dist_attrs); -} +// std::vector> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const std::vector& tensors, +// const phi::distributed::ArgDistAttr& dist_attr) { +// PADDLE_ENFORCE_EQ( +// paddle::holds_alternative>( +// dist_attr), +// true, +// phi::errors::PreconditionNotMet( +// "Arg must be a vector of TensorDistAttr")); +// const auto& tensor_dist_attrs = paddle::get<1>(dist_attr); +// return ReshardApiInputToReplicatedKernelInput( +// dev_ctx, tensors, tensor_dist_attrs); +// } void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { @@ -937,7 +937,7 @@ void ReshardKernelOutputToApiOutput( } std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { @@ -973,7 +973,7 @@ std::shared_ptr PrepareDataForDistTensor( std::vector> PrepareDataForDistTensor( - const std::vector>& input, + std::vector> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { @@ -1016,8 +1016,7 @@ PrepareDataForDistTensor( paddle::optional> PrepareDataForDistTensor( - const paddle::optional>& - input, + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { @@ -1033,8 +1032,8 @@ PrepareDataForDistTensor( paddle::optional>> PrepareDataForDistTensor( - const paddle::optional< - std::vector>>& input, + paddle::optional>> + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index ba456776ab1ce..1a38986f4b030 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -203,22 +203,16 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, const std::vector& tensors, const phi::distributed::ArgDistAttr& dist_attrs); -// std::shared_ptr -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const Tensor& tensor, -// const phi::distributed::ArgDistAttr& dist_attr); - -std::vector> +std::shared_ptr ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, - const std::vector& tensor, + const Tensor& tensor, const phi::distributed::ArgDistAttr& dist_attr); -paddle::optional> +std::vector> ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, + const std::vector& tensor, const phi::distributed::ArgDistAttr& dist_attr); std::vector> @@ -239,17 +233,29 @@ ReshardApiInputToReplicatedKernelInput( const paddle::optional>& tensors, const phi::distributed::ArgDistAttr& dist_attr); -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); +// paddle::optional> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const paddle::optional& tensor, +// const phi::distributed::ArgDistAttr& dist_attr); -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensor, - const phi::distributed::ArgDistAttr& dist_attr); +// paddle::optional>> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const paddle::optional>& tensors, +// const phi::distributed::ArgDistAttr& dist_attr); + +// std::vector> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const std::vector& tensors, +// const std::vector& dist_attrs); + +// std::vector> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const std::vector& tensor, +// const phi::distributed::ArgDistAttr& dist_attr); void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); @@ -260,30 +266,29 @@ void ReshardKernelOutputToApiOutput( Tensor* dst_tensor); std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); std::vector> PrepareDataForDistTensor( - const std::vector>& input, + std::vector> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); paddle::optional> PrepareDataForDistTensor( - const paddle::optional>& - input, + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); paddle::optional>> PrepareDataForDistTensor( - const paddle::optional< - std::vector>>& input, + paddle::optional>> + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 790994f943068..e48a9f0d0a0b4 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -83,13 +83,18 @@ auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" VECTOR_DIST_META_IN_TEMPLATE = """ std::vector meta_dist_input_{name}; - for(auto& e: {name}){{ + for(auto& e : {name}) {{ meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); }}""" OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE = """ - auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*{name}) : std::vector(1);""" + std::vector meta_dist_input_{name}; + if ({name}) {{ + for(auto& e : *{name}) {{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }} + }}""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ From 788f3084f4c0fe224c356da78cb63f2def160257 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Thu, 2 Nov 2023 20:29:45 +0800 Subject: [PATCH 5/7] Polish code. --- paddle/phi/api/lib/data_transform.cc | 82 +----------- paddle/phi/api/lib/data_transform.h | 24 ---- paddle/phi/api/yaml/generator/dist_api_gen.py | 124 ++++-------------- 3 files changed, 28 insertions(+), 202 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 00a9419aac7bf..b9572c6a3c933 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -704,29 +704,6 @@ ReshardApiInputToKernelInput( return output; } -// std::shared_ptr -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const Tensor& tensor, -// const phi::distributed::ArgDistAttr& dist_attr) { -// auto tensor_in = tensor.impl(); -// if (tensor_in) { -// phi::distributed::DistTensor* dist_tensor = -// static_cast(tensor_in.get()); -// if (ReshardIsNeeded(dist_tensor->dist_attr(), paddle::get<0>(dist_attr))) -// { -// VLOG(6) << "ApiIn to Replicated KernelIn - " -// << ReshardDebugInfo(*dist_tensor, paddle::get<0>(dist_attr)); -// auto* func = -// phi::distributed::ChooseProperReshardFunction(*dist_tensor, -// paddle::get<0>(dist_attr)); -// return func->Eval(dev_ctx, *dist_tensor, dist_attr); -// } -// return std::static_pointer_cast(tensor_in); -// } -// return nullptr; -// } - std::shared_ptr ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, @@ -773,7 +750,7 @@ ReshardApiInputToReplicatedKernelInput( "Tensor's size should be equal to dist_attrs' size.")); std::vector> out; - for (int i = 0; i < tensors.size(); i++) { + for (size_t i = 0; i < tensors.size(); i++) { auto tensor_in = tensors[i].impl(); auto dist_attr = tensor_dist_attrs[i]; if (tensor_in) { @@ -835,63 +812,6 @@ ReshardApiInputToReplicatedKernelInput( return paddle::none; } -// std::shared_ptr -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const Tensor& tensor, -// const phi::distributed::ArgDistAttr& dist_attr) { -// PADDLE_ENFORCE_EQ( -// paddle::holds_alternative(dist_attr), -// true, -// phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); -// const auto& tensor_dist_attr = paddle::get<0>(dist_attr); -// return ReshardApiInputToReplicatedKernelInput( -// dev_ctx, tensor, tensor_dist_attr); -// } - -// paddle::optional> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const paddle::optional& tensor, -// const phi::distributed::ArgDistAttr& dist_attr) { -// PADDLE_ENFORCE_EQ( -// paddle::holds_alternative(dist_attr), -// true, -// phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); -// const auto& tensor_dist_attr = paddle::get<0>(dist_attr); -// return ReshardApiInputToReplicatedKernelInput( -// dev_ctx, tensor, tensor_dist_attr); -// } - -// std::vector> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const std::vector& tensors, -// const std::vector& dist_attrs) { -// std::vector> outputs; -// for (size_t i = 0; i < tensors.size(); ++i) { -// outputs.push_back(ReshardApiInputToReplicatedKernelInput( -// dev_ctx, tensors[i], dist_attrs[i])); -// } -// return outputs; -// } - -// std::vector> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const std::vector& tensors, -// const phi::distributed::ArgDistAttr& dist_attr) { -// PADDLE_ENFORCE_EQ( -// paddle::holds_alternative>( -// dist_attr), -// true, -// phi::errors::PreconditionNotMet( -// "Arg must be a vector of TensorDistAttr")); -// const auto& tensor_dist_attrs = paddle::get<1>(dist_attr); -// return ReshardApiInputToReplicatedKernelInput( -// dev_ctx, tensors, tensor_dist_attrs); -// } - void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { if (out_tensor->dist_attr().is_partial()) { diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 1a38986f4b030..9972f3838992d 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -233,30 +233,6 @@ ReshardApiInputToReplicatedKernelInput( const paddle::optional>& tensors, const phi::distributed::ArgDistAttr& dist_attr); -// paddle::optional> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const paddle::optional& tensor, -// const phi::distributed::ArgDistAttr& dist_attr); - -// paddle::optional>> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const paddle::optional>& tensors, -// const phi::distributed::ArgDistAttr& dist_attr); - -// std::vector> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const std::vector& tensors, -// const std::vector& dist_attrs); - -// std::vector> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const std::vector& tensor, -// const phi::distributed::ArgDistAttr& dist_attr); - void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index e48a9f0d0a0b4..9040bc703a2f0 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -48,7 +48,7 @@ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create API Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - if (rank_is_in_current_mesh){{ + if (rank_is_in_current_mesh) {{ // 4. Select Kernel{} // 5. Reshard Input{}\n // 6. PrepareData (DataTransform & Prepare Dense Input){} @@ -241,38 +241,27 @@ """ # 5. Reshard Input -SINGLE_INPUT_RESHARD_TEMPLATE = """ - auto new_dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" -# VECTOR_INPUT_RESHARD_TEMPLATE = """ -# auto dist_input_{arg}_vec = ReshardApiInputToKernelInput(dev_ctx, {arg}, -# std::vector(spmd_info.first.begin() + idx, -# spmd_info.first.begin() + idx + size)); -# idx += size; -# VLOG(4) << "After reshard dist_input_{arg}_vec";""" -SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto new_dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" -# VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE = """ -# auto dist_input_{arg}_vec = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, -# std::vector(spmd_info.first.begin() + idx, -# spmd_info.first.begin() + idx + size)); -# idx += size; -# VLOG(4) << "After reshard dist_input_{arg}_vec";""" +# Both Tensor, std::vector, paddle::optional and +# paddle::optional> use the same template +INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" +GENERAL_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" UNSUPPORTED_RESHARD_INPUT_COMMENT_TEMPLATE = """ // API `{}` does not need to support ReshardInput at this time """ # 6. PrepareData SINGLE_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{arg} = PrepareDataForDistTensor(new_dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ -# dist_input_ prefix VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; for (auto tmp : dist_input_{name}_vec) {{ dense_input_{name}_vec.emplace_back(&tmp->value()); @@ -283,19 +272,16 @@ dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i]; }} """ - OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name} = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{name} = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ - -# dist_input_ prefix OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor(new_dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; if ({name}) {{ for (auto tmp : *dist_input_{name}_vec) {{ @@ -311,6 +297,7 @@ paddle::optional> dense_input_{name}_meta_ptr_vec = {name} ? paddle::make_optional>(dense_input_{name}_meta_ptr_vec_tmp) : paddle::none; """ + INFER_META_SINGLE_INPUT_TEMPLATE = """ auto dist_input_{} = {}.impl(); auto input_{} = &(static_cast(dist_input_{}.get())->value()); @@ -1053,71 +1040,14 @@ def generate_reshard_input_code(self) -> str: ]: if self.generate_general_infer_spmd is True: input_reshard_code += ( - SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i + GENERAL_INPUT_RESHARD_TEMPLATE.format( + name=param, idx=i ) ) else: - input_reshard_code += ( - SINGLE_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) + input_reshard_code += INPUT_RESHARD_TEMPLATE.format( + name=param, idx=i ) - # if ( - # self.inputs['input_info'][param] == "const Tensor&" - # or self.inputs['input_info'][param] - # == "const paddle::optional&" - # ): - # if self.generate_general_infer_spmd is True: - # input_reshard_code += ( - # SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - # arg=param, idx=i - # ) - # ) - # else: - # input_reshard_code += ( - # SINGLE_INPUT_RESHARD_TEMPLATE.format( - # arg=param, idx=i - # ) - # ) - # elif ( - # self.inputs['input_info'][param] - # == "const std::vector&" - # or self.inputs['input_info'][param] - # == "const paddle::optional>&" - # ): - # if self.generate_general_infer_spmd is True: - # input_reshard_code += ( - # SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - # arg=param, idx=i - # ) - # ) - # else: - # input_reshard_code += ( - # SINGLE_INPUT_RESHARD_TEMPLATE.format( - # arg=param, idx=i - # ) - # ) - # if ( - # self.inputs['input_info'][param] - # == "const std::vector&" - # ): - # input_reshard_code += ( - # f"\n size = {param}.size();" - # ) - # else: - # input_reshard_code += f"\n size = {param} ? (*{param}).size() : 1;" - - # if self.generate_general_infer_spmd is True: - # input_reshard_code += ( - # VECTOR_GENERAL_INPUT_RESHARD_TEMPLATE.format( - # arg=param - # ) - # ) - # else: - # input_reshard_code += ( - # VECTOR_INPUT_RESHARD_TEMPLATE.format(arg=param) - # ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." @@ -1146,15 +1076,15 @@ def generate_single_dense_input( if self.generate_infer_spmd is True: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( - arg=input_name, + name=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) else: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( arg=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) return input_tensor_code @@ -1172,7 +1102,7 @@ def generate_vector_dense_input( kernel_param = input_names + attr_names input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) @@ -1193,14 +1123,14 @@ def generate_optional_single_dense_input( if self.generate_infer_spmd is True: input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) else: input_tensor_code += ( OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) ) @@ -1220,7 +1150,7 @@ def generate_optional_vector_dense_input( kernel_param = input_names + attr_names input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) From 2128a569eb0b936faa8a047b425d53c508a54941 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Fri, 3 Nov 2023 12:01:46 +0800 Subject: [PATCH 6/7] Polish code. --- paddle/phi/api/lib/api_custom_impl.cc | 21 +-- paddle/phi/api/lib/data_transform.cc | 129 +++--------------- paddle/phi/api/lib/data_transform.h | 48 +------ paddle/phi/api/yaml/generator/dist_api_gen.py | 27 ++-- paddle/phi/infermeta/spmd_rules/utils.h | 6 + test/auto_parallel/test_api_dist_branch.py | 44 +++--- 6 files changed, 79 insertions(+), 196 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 2f1333042fe68..2f05194e1b708 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -118,8 +118,8 @@ Tensor add_n_impl(const std::vector& x) { } auto meta_dist_input_x = MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic( + meta_dist_input_x); auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); @@ -139,7 +139,7 @@ Tensor add_n_impl(const std::vector& x) { phi::AddNInferMeta(x_metas, &meta_dist_out); if (rank_is_in_current_mesh) { auto dist_input_x = - ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first); + ReshardApiInputToKernelInput(dev_ctx, x, spmd_info.first[0]); dist_input_x = PrepareDataForDistTensor( dist_input_x, GetKernelInputArgDef(kernel.InputAt(0), kernel_backend), @@ -165,14 +165,15 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, dense_out); } - PADDLE_ENFORCE_EQ( - paddle::holds_alternative( - spmd_info.first[0]), - true, - phi::errors::PreconditionNotMet( - "Arg must be a single TensorDistAttr")); + PADDLE_ENFORCE_EQ(paddle::holds_alternative< + std::vector>( + spmd_info.first[0]), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + auto current_process_mesh = - paddle::get<0>(spmd_info.first[0]).process_mesh(); + paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); return api_output; } diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index b9572c6a3c933..32bf65c2d6b63 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -621,91 +621,6 @@ std::string ReshardDebugInfo( } std::shared_ptr ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { - auto tensor_in = tensor.impl(); - if (tensor_in) { - phi::distributed::DistTensor* dist_tensor = - static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "ApiIn to KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); - } - return std::static_pointer_cast(tensor_in); - } - return nullptr; -} - -std::shared_ptr ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToKernelInput(dev_ctx, tensor, tensor_dist_attr); -} - -std::vector> -ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attrs) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative>( - dist_attrs), - true, - phi::errors::PreconditionNotMet( - "Arg must be a vector of TensorDistAttr")); - const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); - return ReshardApiInputToKernelInput(dev_ctx, tensors, tensor_dist_attrs); -} - -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> output; - PADDLE_ENFORCE_EQ(tensors.size(), - dist_attrs.size(), - phi::errors::PreconditionNotMet( - "tensors size and dist_attrs size not equal: %d vs %d", - tensors.size(), - dist_attrs.size())); - for (size_t i = 0; i < dist_attrs.size(); i++) { - output.push_back( - ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); - } - return output; -} - -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> output; - PADDLE_ENFORCE_EQ(tensors.size(), - dist_attrs.size(), - phi::errors::PreconditionNotMet( - "tensors size and dist_attrs size not equal: %d vs %d", - tensors.size(), - dist_attrs.size())); - for (size_t i = 0; i < dist_attrs.size(); i++) { - output.push_back( - ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); - } - return output; -} - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, const phi::distributed::ArgDistAttr& dist_attr) { @@ -732,10 +647,9 @@ ReshardApiInputToReplicatedKernelInput( } std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attrs) { +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { PADDLE_ENFORCE_EQ( paddle::holds_alternative>( dist_attrs), @@ -772,34 +686,33 @@ ReshardApiInputToReplicatedKernelInput( return out; } -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> outputs; - for (size_t i = 0; i < tensors.size(); ++i) { - outputs.push_back(ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors[i], dist_attrs[i])); - } - return outputs; -} +// std::vector> +// ReshardApiInputToReplicatedKernelInput( +// phi::DeviceContext* dev_ctx, +// const std::vector& tensors, +// const std::vector& dist_attrs) { +// std::vector> outputs; +// for (size_t i = 0; i < tensors.size(); ++i) { +// outputs.push_back(ReshardApiInputToReplicatedKernelInput( +// dev_ctx, tensors[i], dist_attrs[i])); +// } +// return outputs; +// } paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { if (tensor) { VLOG(6) << "Optional ApiIn to Replicated KernelIn."; return paddle::make_optional>( - ReshardApiInputToReplicatedKernelInput(dev_ctx, *tensor, dist_attr)); + ReshardApiInputToKernelInput(dev_ctx, *tensor, dist_attr)); } return paddle::none; } paddle::optional>> -ReshardApiInputToReplicatedKernelInput( +ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional>& tensors, const phi::distributed::ArgDistAttr& dist_attrs) { @@ -807,7 +720,7 @@ ReshardApiInputToReplicatedKernelInput( VLOG(6) << "Optional ApiIn to Replicated KernelIn."; return paddle::make_optional< std::vector>>( - ReshardApiInputToReplicatedKernelInput(dev_ctx, *tensors, dist_attrs)); + ReshardApiInputToKernelInput(dev_ctx, *tensors, dist_attrs)); } return paddle::none; } diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 9972f3838992d..2eba71c7295c8 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -176,59 +176,23 @@ inline bool NeedTransformPlace(const phi::Place& src_place, /* ------------------ for auto parallel ----------------------- */ -std::shared_ptr ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); - std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, const phi::distributed::ArgDistAttr& dist_attr); -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); - -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); - std::vector> ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attrs); - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr); - -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensor, - const phi::distributed::ArgDistAttr& dist_attr); - -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); + const std::vector& tensor, + const phi::distributed::ArgDistAttr& dist_attr); paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::ArgDistAttr& dist_attr); +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr); paddle::optional>> -ReshardApiInputToReplicatedKernelInput( +ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const paddle::optional>& tensors, const phi::distributed::ArgDistAttr& dist_attr); diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 9040bc703a2f0..e62c269ccfcac 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -796,9 +796,6 @@ def generate_general_infer_spmd_code(self) -> str: if input_decl_code == "": return UNSUPPORTED_INFER_SPMD_COMMENT_TEMPLATE.format(self.api) - print( - f"kernel_name: {self.kernel['func'][0]}, input_args_code: {input_args_code}" - ) infer_spmd_code = GENERAL_INFER_SPMD_TEMPLATE.format( input_args_code[:-2] ) @@ -1029,7 +1026,6 @@ def generate_reshard_input_code(self) -> str: else input_names ) - input_reshard_code = "" for i, param in enumerate(kernel_params): if param in input_names: if self.inputs['input_info'][param] in [ @@ -1038,16 +1034,19 @@ def generate_reshard_input_code(self) -> str: "const paddle::optional&", "const paddle::optional>&", ]: - if self.generate_general_infer_spmd is True: - input_reshard_code += ( - GENERAL_INPUT_RESHARD_TEMPLATE.format( - name=param, idx=i - ) - ) - else: - input_reshard_code += INPUT_RESHARD_TEMPLATE.format( - name=param, idx=i - ) + input_reshard_code += INPUT_RESHARD_TEMPLATE.format( + name=param, idx=i + ) + # if self.generate_general_infer_spmd is True: + # input_reshard_code += ( + # GENERAL_INPUT_RESHARD_TEMPLATE.format( + # name=param, idx=i + # ) + # ) + # else: + # input_reshard_code += INPUT_RESHARD_TEMPLATE.format( + # name=param, idx=i + # ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index 2d52a58bdbb24..b5b5e207a0ee6 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -125,6 +125,12 @@ struct VariadicSpmdRuleArgumentParser // deal with inputs void operator()(const DistMetaTensor& x) { inputs.emplace_back(&x); } + void operator()(const std::vector& x) { + for (auto t : x) { + inputs.emplace_back(t); + } + } + void operator()(const std::vector& x) { for (auto& t : x) { inputs.emplace_back(&t); diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index fa2adc046422e..323010166db24 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -114,28 +114,28 @@ def test_concat_for_dist_tensor(self): self.check_tensor_eq(local_in3.grad, dist_in3.grad) # TODO(GhostScreaming): Support paddle.concat backward later. - # # input: std::vector - # # output: std::vector - # def test_broadcast_tensors_for_dist_tensor(self): - # x1 = np.random.random(size=[4, 4]).astype("float32") - # x2 = np.random.random(size=[4, 4]).astype("float32") - # local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) - # local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) - - # local_out1, local_out2 = paddle.broadcast_tensors( - # [local_in1, local_in2] - # ) - # dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2]) - # self.check_tensor_eq(local_out1, dist_out1) - # self.check_tensor_eq(local_out2, dist_out2) - - # local_out = paddle.concat([local_out1, local_out2]) - # dist_out = paddle.concat([dist_out1, dist_out2]) - - # local_out.backward() - # dist_out.backward() - # self.check_tensor_eq(local_in1.grad, dist_in1.grad) - # self.check_tensor_eq(local_in2.grad, dist_in2.grad) + # input: std::vector + # output: std::vector + def test_broadcast_tensors_for_dist_tensor(self): + x1 = np.random.random(size=[4, 4]).astype("float32") + x2 = np.random.random(size=[4, 4]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) + + local_out1, local_out2 = paddle.broadcast_tensors( + [local_in1, local_in2] + ) + dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2]) + self.check_tensor_eq(local_out1, dist_out1) + self.check_tensor_eq(local_out2, dist_out2) + + local_out = paddle.concat([local_out1, local_out2]) + dist_out = paddle.concat([dist_out1, dist_out2]) + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + self.check_tensor_eq(local_in2.grad, dist_in2.grad) # input: paddle::optional # output: phi::Tensor From 0e14c0e6f82e34245cc22f612df616ba75dab9b9 Mon Sep 17 00:00:00 2001 From: GhostScreaming Date: Fri, 3 Nov 2023 14:07:46 +0800 Subject: [PATCH 7/7] Polish code. --- paddle/phi/api/lib/data_transform.cc | 13 ------------- paddle/phi/api/yaml/generator/dist_api_gen.py | 10 ---------- 2 files changed, 23 deletions(-) diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 32bf65c2d6b63..404abb901b883 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -686,19 +686,6 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, return out; } -// std::vector> -// ReshardApiInputToReplicatedKernelInput( -// phi::DeviceContext* dev_ctx, -// const std::vector& tensors, -// const std::vector& dist_attrs) { -// std::vector> outputs; -// for (size_t i = 0; i < tensors.size(); ++i) { -// outputs.push_back(ReshardApiInputToReplicatedKernelInput( -// dev_ctx, tensors[i], dist_attrs[i])); -// } -// return outputs; -// } - paddle::optional> ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, const paddle::optional& tensor, diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index e62c269ccfcac..dd7258ef8ae98 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -1037,16 +1037,6 @@ def generate_reshard_input_code(self) -> str: input_reshard_code += INPUT_RESHARD_TEMPLATE.format( name=param, idx=i ) - # if self.generate_general_infer_spmd is True: - # input_reshard_code += ( - # GENERAL_INPUT_RESHARD_TEMPLATE.format( - # name=param, idx=i - # ) - # ) - # else: - # input_reshard_code += INPUT_RESHARD_TEMPLATE.format( - # name=param, idx=i - # ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported."