diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 2f1333042fe68e..2f05194e1b7088 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -118,8 +118,8 @@ Tensor add_n_impl(const std::vector& x) { } auto meta_dist_input_x = MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic( + meta_dist_input_x); auto dist_out = SetKernelDistOutput(&api_output); auto dense_out = dist_out->unsafe_mutable_value(); @@ -139,7 +139,7 @@ Tensor add_n_impl(const std::vector& x) { phi::AddNInferMeta(x_metas, &meta_dist_out); if (rank_is_in_current_mesh) { auto dist_input_x = - ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first); + ReshardApiInputToKernelInput(dev_ctx, x, spmd_info.first[0]); dist_input_x = PrepareDataForDistTensor( dist_input_x, GetKernelInputArgDef(kernel.InputAt(0), kernel_backend), @@ -165,14 +165,15 @@ Tensor add_n_impl(const std::vector& x) { auto* kernel_fn = kernel.GetVariadicKernelFn(); (*kernel_fn)(*dev_ctx, input_x, dense_out); } - PADDLE_ENFORCE_EQ( - paddle::holds_alternative( - spmd_info.first[0]), - true, - phi::errors::PreconditionNotMet( - "Arg must be a single TensorDistAttr")); + PADDLE_ENFORCE_EQ(paddle::holds_alternative< + std::vector>( + spmd_info.first[0]), + true, + phi::errors::PreconditionNotMet( + "Arg must be a vector of TensorDistAttr")); + auto current_process_mesh = - paddle::get<0>(spmd_info.first[0]).process_mesh(); + paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); return api_output; } diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 7c88bd3df44b0f..404abb901b883a 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -623,35 +623,29 @@ std::string ReshardDebugInfo( std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { + const phi::distributed::ArgDistAttr& dist_attr) { + PADDLE_ENFORCE_EQ( + paddle::holds_alternative(dist_attr), + true, + phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); + auto tensor_in = tensor.impl(); + const auto& tensor_dist_attr = paddle::get<0>(dist_attr); if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "ApiIn to KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); + if (ReshardIsNeeded(dist_tensor->dist_attr(), tensor_dist_attr)) { + VLOG(6) << "ApiIn to Replicated KernelIn - " + << ReshardDebugInfo(*dist_tensor, tensor_dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction( + *dist_tensor, tensor_dist_attr); + return func->Eval(dev_ctx, *dist_tensor, tensor_dist_attr); } return std::static_pointer_cast(tensor_in); } return nullptr; } -std::shared_ptr ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a single TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToKernelInput(dev_ctx, tensor, tensor_dist_attr); -} - std::vector> ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, const std::vector& tensors, @@ -661,152 +655,61 @@ ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, dist_attrs), true, phi::errors::PreconditionNotMet( - "Arg must be a vector of TensorDistAttr")); + "Arg must be a vector of TensorDistAttr")); const auto& tensor_dist_attrs = paddle::get<1>(dist_attrs); - return ReshardApiInputToKernelInput(dev_ctx, tensors, tensor_dist_attrs); -} -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> output; PADDLE_ENFORCE_EQ(tensors.size(), - dist_attrs.size(), - phi::errors::PreconditionNotMet( - "tensors size and dist_attrs size not equal: %d vs %d", - tensors.size(), - dist_attrs.size())); - for (size_t i = 0; i < dist_attrs.size(); i++) { - output.push_back( - ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); - } - return output; -} + tensor_dist_attrs.size(), + phi::errors::InvalidArgument( + "Tensor's size should be equal to dist_attrs' size.")); -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> output; - PADDLE_ENFORCE_EQ(tensors.size(), - dist_attrs.size(), - phi::errors::PreconditionNotMet( - "tensors size and dist_attrs size not equal: %d vs %d", - tensors.size(), - dist_attrs.size())); - for (size_t i = 0; i < dist_attrs.size(); i++) { - output.push_back( - ReshardApiInputToKernelInput(dev_ctx, tensors[i], dist_attrs[i])); - } - return output; -} - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { - auto tensor_in = tensor.impl(); - if (tensor_in) { - phi::distributed::DistTensor* dist_tensor = - static_cast(tensor_in.get()); - if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { - VLOG(6) << "ApiIn to Replicated KernelIn - " - << ReshardDebugInfo(*dist_tensor, dist_attr); - auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, - dist_attr); - return func->Eval(dev_ctx, *dist_tensor, dist_attr); + std::vector> out; + for (size_t i = 0; i < tensors.size(); i++) { + auto tensor_in = tensors[i].impl(); + auto dist_attr = tensor_dist_attrs[i]; + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + if (ReshardIsNeeded(dist_tensor->dist_attr(), dist_attr)) { + VLOG(6) << "Vector ApiIn to Replicated KernelIn - " + << ReshardDebugInfo(*dist_tensor, dist_attr); + auto* func = phi::distributed::ChooseProperReshardFunction(*dist_tensor, + dist_attr); + out.push_back(func->Eval(dev_ctx, *dist_tensor, dist_attr)); + } + out.push_back( + std::static_pointer_cast(tensor_in)); + } else { + out.push_back(nullptr); } - return std::static_pointer_cast(tensor_in); } - return nullptr; + return out; } paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::TensorDistAttr& dist_attr) { +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr) { if (tensor) { VLOG(6) << "Optional ApiIn to Replicated KernelIn."; return paddle::make_optional>( - ReshardApiInputToReplicatedKernelInput(dev_ctx, *tensor, dist_attr)); + ReshardApiInputToKernelInput(dev_ctx, *tensor, dist_attr)); } return paddle::none; } -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> result; - result.reserve(tensors.size()); - for (size_t i = 0; i < tensors.size(); ++i) { - result.emplace_back(ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors[i], dist_attrs[i])); - } - return result; -} - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensor, tensor_dist_attr); -} - -paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative(dist_attr), - true, - phi::errors::PreconditionNotMet("Arg must be a TensorDistAttr")); - const auto& tensor_dist_attr = paddle::get<0>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensor, tensor_dist_attr); -} - -std::vector> -ReshardApiInputToReplicatedKernelInput( +paddle::optional>> +ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs) { - std::vector> outputs; - for (size_t i = 0; i < tensors.size(); ++i) { - outputs.push_back(ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors[i], dist_attrs[i])); + const paddle::optional>& tensors, + const phi::distributed::ArgDistAttr& dist_attrs) { + if (tensors) { + VLOG(6) << "Optional ApiIn to Replicated KernelIn."; + return paddle::make_optional< + std::vector>>( + ReshardApiInputToKernelInput(dev_ctx, *tensors, dist_attrs)); } - return outputs; -} - -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attr) { - PADDLE_ENFORCE_EQ( - paddle::holds_alternative>( - dist_attr), - true, - phi::errors::PreconditionNotMet( - "Arg must be a vector of TensorDistAttr")); - const auto& tensor_dist_attrs = paddle::get<1>(dist_attr); - return ReshardApiInputToReplicatedKernelInput( - dev_ctx, tensors, tensor_dist_attrs); + return paddle::none; } void ReshardOutputPartialAxisToReplicated( @@ -854,7 +757,7 @@ void ReshardKernelOutputToApiOutput( } std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { @@ -888,87 +791,14 @@ std::shared_ptr PrepareDataForDistTensor( return nullptr; } -paddle::optional> -PrepareDataForDistTensor( - const paddle::optional>& - input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - if (input) { - VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " - "tensor"; - return paddle::make_optional>( - PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel)); - } - return paddle::none; -} - -std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - return PrepareDataForDistTensor( - std::static_pointer_cast(input.impl()), - target_args_def, - transform_flag, - is_stride_kernel); -} - -std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { - std::vector> out; - for (auto& x : input) { - const auto& tensor_in = x.impl(); - if (tensor_in) { - phi::distributed::DistTensor* dist_tensor = - static_cast(tensor_in.get()); - const phi::DenseTensor& dense_tensor = dist_tensor->value(); - if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || - (!NeedTransformPlace( - dense_tensor.place(), target_args_def.backend, transform_flag) && - !NeedTransformDataType( - dense_tensor.dtype(), target_args_def.dtype, transform_flag) && - !NeedTransformLayout(dense_tensor.layout(), - target_args_def.layout, - dense_tensor.place(), - transform_flag) && - !NeedTransform2Contiguous(is_stride_kernel, - dense_tensor.meta().is_contiguous()))) { - out.push_back( - std::static_pointer_cast(tensor_in)); - } else { - phi::DenseTensor trans_in_tensor = TransformData( - dense_tensor, target_args_def, transform_flag, is_stride_kernel); - // TODO(GhostScreaming): The global meta in DistTensor is not changed, - // but the local meta in DenseTensor maybe changed, such as layout - // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. - VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; - out.push_back(std::make_shared( - std::make_shared(trans_in_tensor), - dist_tensor->dist_attr())); - } - } else { - out.push_back(nullptr); - } - } - return out; -} - std::vector> PrepareDataForDistTensor( - const std::vector>& input, + std::vector> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { std::vector> out; - for (auto& x : input) { - const auto& tensor_in = x; + for (auto tensor_in : input) { if (tensor_in) { phi::distributed::DistTensor* dist_tensor = static_cast(tensor_in.get()); @@ -1004,26 +834,37 @@ PrepareDataForDistTensor( return out; } -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional> +PrepareDataForDistTensor( + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel) { if (input) { - return {*PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel)}; + VLOG(6) << "PrepareDataForDistTensor for optional return transformed dist " + "tensor"; + return paddle::make_optional>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel) { +PrepareDataForDistTensor( + paddle::optional>> + input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { if (input) { - return PrepareDataForDistTensor( - *input, target_args_def, transform_flag, is_stride_kernel); + VLOG(6) << "PrepareDataForDistTensor for optional vector return " + "transformed dist " + "tensor"; + return paddle::make_optional< + std::vector>>( + PrepareDataForDistTensor( + *input, target_args_def, transform_flag, is_stride_kernel)); } return paddle::none; } diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 02d86622e2aa6a..2eba71c7295c80 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -176,79 +176,25 @@ inline bool NeedTransformPlace(const phi::Place& src_place, /* ------------------ for auto parallel ----------------------- */ -std::shared_ptr ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); - std::shared_ptr ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, const Tensor& tensor, const phi::distributed::ArgDistAttr& dist_attr); -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); - -std::vector> -ReshardApiInputToKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); - std::vector> ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const phi::distributed::ArgDistAttr& dist_attrs); - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr); - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::TensorDistAttr& dist_attr); - -std::shared_ptr -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const Tensor& tensor, - const phi::distributed::ArgDistAttr& dist_attr); - -paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::TensorDistAttr& dist_attr); + const std::vector& tensor, + const phi::distributed::ArgDistAttr& dist_attr); paddle::optional> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const paddle::optional& tensor, - const phi::distributed::ArgDistAttr& dist_attr); - -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); - -std::vector> -ReshardApiInputToReplicatedKernelInput( - phi::DeviceContext* dev_ctx, - const std::vector& tensors, - const std::vector& dist_attrs); +ReshardApiInputToKernelInput(phi::DeviceContext* dev_ctx, + const paddle::optional& tensor, + const phi::distributed::ArgDistAttr& dist_attr); -std::vector> -ReshardApiInputToReplicatedKernelInput( +paddle::optional>> +ReshardApiInputToKernelInput( phi::DeviceContext* dev_ctx, - const std::vector& tensor, + const paddle::optional>& tensors, const phi::distributed::ArgDistAttr& dist_attr); void ReshardOutputPartialAxisToReplicated( @@ -260,49 +206,32 @@ void ReshardKernelOutputToApiOutput( Tensor* dst_tensor); std::shared_ptr PrepareDataForDistTensor( - const std::shared_ptr& input, + std::shared_ptr input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -paddle::optional> +std::vector> PrepareDataForDistTensor( - const paddle::optional>& - input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - -std::shared_ptr PrepareDataForDistTensor( - const Tensor& input, + std::vector> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -std::vector> -PrepareDataForDistTensor(const std::vector& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - -std::vector> +paddle::optional> PrepareDataForDistTensor( - const std::vector>& input, + paddle::optional> input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -paddle::optional PrepareDataForDistTensor( - const paddle::optional& input, +paddle::optional>> +PrepareDataForDistTensor( + paddle::optional>> + input, const phi::TensorArgDef& target_args_def, const TransformFlag& transform_flag, bool is_stride_kernel); -paddle::optional>> -PrepareDataForDistTensor(const paddle::optional>& input, - const phi::TensorArgDef& target_args_def, - const TransformFlag& transform_flag, - bool is_stride_kernel); - } // namespace experimental } // namespace paddle diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index 059320c0058edc..dd7258ef8ae981 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -48,7 +48,7 @@ // 1. InferSpmd (Infer DistAttr of Inputs&Outputs){} // 2. Create API Output & Prepare Dist and Dense Output{} // 3. Infer DistTensor's Global Shape{}\n - if (rank_is_in_current_mesh){{ + if (rank_is_in_current_mesh) {{ // 4. Select Kernel{} // 5. Reshard Input{}\n // 6. PrepareData (DataTransform & Prepare Dense Input){} @@ -81,16 +81,20 @@ # 1. InferSPMD SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = MakeDistMetaTensor(*{name}.impl());""" - -LIST_DIST_META_IN_TEMPLATE = """ +VECTOR_DIST_META_IN_TEMPLATE = """ std::vector meta_dist_input_{name}; - for(auto& e: {name}){{ + for(auto& e : {name}) {{ meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); - }} -""" - + }}""" OPTIONAL_SINGLE_DIST_META_IN_TEMPLATE = """ auto meta_dist_input_{name} = {name} ? MakeDistMetaTensor(*(*{name}).impl()) : phi::distributed::DistMetaTensor();""" +OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE = """ + std::vector meta_dist_input_{name}; + if ({name}) {{ + for(auto& e : *{name}) {{ + meta_dist_input_{name}.push_back(MakeDistMetaTensor(*e.impl())); + }} + }}""" INFER_SPMD_TEMPLATE = """ auto spmd_info = phi::distributed::{}({}); """ @@ -237,26 +241,27 @@ """ # 5. Reshard Input -SINGLE_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" -SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE = """ - auto dist_input_{arg} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {arg}, spmd_info.first[{idx}]);""" +# Both Tensor, std::vector, paddle::optional and +# paddle::optional> use the same template +INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" +GENERAL_INPUT_RESHARD_TEMPLATE = """ + auto dist_input_{name} = ReshardApiInputToReplicatedKernelInput(dev_ctx, {name}, spmd_info.first[{idx}]);""" UNSUPPORTED_RESHARD_INPUT_COMMENT_TEMPLATE = """ // API `{}` does not need to support ReshardInput at this time """ # 6. PrepareData SINGLE_PREPARE_DATA_TEMPLATE = """ - dist_input_{arg} = PrepareDataForDistTensor(dist_input_{arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{arg} = PrepareDataForDistTensor({arg}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {flag}, kernel_result.is_stride_kernel); - auto input_{arg} = &dist_input_{arg}->value(); + auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto input_{name} = &dist_input_{name}->value(); """ -# dist_input_ prefix VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({prefix}{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; for (auto tmp : dist_input_{name}_vec) {{ dense_input_{name}_vec.emplace_back(&tmp->value()); @@ -267,19 +272,16 @@ dense_input_{name}_meta_ptr_vec[i] = &dense_input_{name}_meta_vec[i]; }} """ - OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE = """ - dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional((*dist_input_{name})->value()) : paddle::none; """ OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD = """ - auto dist_input_{name} = PrepareDataForDistTensor({name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name} = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); paddle::optional input_{name} = dist_input_{name} ? paddle::make_optional(dist_input_{name}->value()) : paddle::none; """ - -# dist_input_ prefix OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE = """ - auto dist_input_{name}_vec = PrepareDataForDistTensor({prefix}{name}, GetKernelInputArgDef(kernel.InputAt({index}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); + auto dist_input_{name}_vec = PrepareDataForDistTensor(dist_input_{name}, GetKernelInputArgDef(kernel.InputAt({idx}), kernel_backend), {trans_flag}, kernel_result.is_stride_kernel); std::vector dense_input_{name}_vec; if ({name}) {{ for (auto tmp : *dist_input_{name}_vec) {{ @@ -295,6 +297,7 @@ paddle::optional> dense_input_{name}_meta_ptr_vec = {name} ? paddle::make_optional>(dense_input_{name}_meta_ptr_vec_tmp) : paddle::none; """ + INFER_META_SINGLE_INPUT_TEMPLATE = """ auto dist_input_{} = {}.impl(); auto input_{} = &(static_cast(dist_input_{}.get())->value()); @@ -712,7 +715,7 @@ def generate_specialized_infer_spmd_code(self) -> str: self.inputs['input_info'][param] == "const std::vector&" ): - input_decl_code += LIST_DIST_META_IN_TEMPLATE.format( + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( name=param ) input_args_code += "meta_dist_input_" + param + ", " @@ -770,7 +773,7 @@ def generate_general_infer_spmd_code(self) -> str: self.inputs['input_info'][param] == "const std::vector&" ): - input_decl_code += LIST_DIST_META_IN_TEMPLATE.format( + input_decl_code += VECTOR_DIST_META_IN_TEMPLATE.format( name=param ) input_args_code += "meta_dist_input_" + param + ", " @@ -778,11 +781,10 @@ def generate_general_infer_spmd_code(self) -> str: self.inputs['input_info'][param] == "const paddle::optional>&" ): - # TODO(chenweihang): support other input type later, - # now only support single tensor input api - input_decl_code = "" - input_args_code = "" - break + input_decl_code += ( + OPTIONAL_VECTOR_DIST_META_IN_TEMPLATE.format(name=param) + ) + input_args_code += "meta_dist_input_" + param + ", " else: raise ValueError( f"{self.api} : Param of infer_spmd error : {self.inputs['input_info'][param]} type is not supported." @@ -1026,39 +1028,15 @@ def generate_reshard_input_code(self) -> str: for i, param in enumerate(kernel_params): if param in input_names: - if ( - self.inputs['input_info'][param] == "const Tensor&" - or self.inputs['input_info'][param] - == "const paddle::optional&" - ): - if self.generate_general_infer_spmd is True: - input_reshard_code += ( - SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) - else: - input_reshard_code += ( - SINGLE_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) - elif ( - self.inputs['input_info'][param] - == "const std::vector&" - ): - if self.generate_general_infer_spmd is True: - input_reshard_code += ( - SINGLE_GENERAL_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) - else: - input_reshard_code += ( - SINGLE_INPUT_RESHARD_TEMPLATE.format( - arg=param, idx=i - ) - ) + if self.inputs['input_info'][param] in [ + "const Tensor&", + "const std::vector&", + "const paddle::optional&", + "const paddle::optional>&", + ]: + input_reshard_code += INPUT_RESHARD_TEMPLATE.format( + name=param, idx=i + ) else: raise ValueError( f"{self.api} : Param of reshard input error : {self.inputs['input_info'][param]} type is not supported." @@ -1087,15 +1065,15 @@ def generate_single_dense_input( if self.generate_infer_spmd is True: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE.format( - arg=input_name, + name=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) else: input_tensor_code += SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( arg=input_name, idx=kernel_param.index(input_name), - flag=trans_flag, + trans_flag=trans_flag, ) return input_tensor_code @@ -1111,11 +1089,9 @@ def generate_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - prefix = "dist_input_" if self.generate_infer_spmd else "" input_tensor_code += VECTOR_PREPARE_DATA_TEMPLATE.format( - prefix=prefix, name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) @@ -1136,14 +1112,14 @@ def generate_optional_single_dense_input( if self.generate_infer_spmd is True: input_tensor_code += OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) else: input_tensor_code += ( OPTIONAL_SINGLE_PREPARE_DATA_TEMPLATE_NO_RESHARD.format( name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) ) @@ -1161,11 +1137,9 @@ def generate_optional_vector_dense_input( kernel_param = self.kernel['param'] if kernel_param is None: kernel_param = input_names + attr_names - prefix = "dist_input_" if self.generate_infer_spmd else "" input_tensor_code += OPTIONAL_VECTOR_PREPARE_DATA_TEMPLATE.format( - prefix=prefix, name=input_name, - index=kernel_param.index(input_name), + idx=kernel_param.index(input_name), trans_flag=trans_flag, ) diff --git a/test/auto_parallel/test_api_dist_branch.py b/test/auto_parallel/test_api_dist_branch.py index fa2adc046422e7..323010166db248 100644 --- a/test/auto_parallel/test_api_dist_branch.py +++ b/test/auto_parallel/test_api_dist_branch.py @@ -114,28 +114,28 @@ def test_concat_for_dist_tensor(self): self.check_tensor_eq(local_in3.grad, dist_in3.grad) # TODO(GhostScreaming): Support paddle.concat backward later. - # # input: std::vector - # # output: std::vector - # def test_broadcast_tensors_for_dist_tensor(self): - # x1 = np.random.random(size=[4, 4]).astype("float32") - # x2 = np.random.random(size=[4, 4]).astype("float32") - # local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) - # local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) - - # local_out1, local_out2 = paddle.broadcast_tensors( - # [local_in1, local_in2] - # ) - # dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2]) - # self.check_tensor_eq(local_out1, dist_out1) - # self.check_tensor_eq(local_out2, dist_out2) - - # local_out = paddle.concat([local_out1, local_out2]) - # dist_out = paddle.concat([dist_out1, dist_out2]) - - # local_out.backward() - # dist_out.backward() - # self.check_tensor_eq(local_in1.grad, dist_in1.grad) - # self.check_tensor_eq(local_in2.grad, dist_in2.grad) + # input: std::vector + # output: std::vector + def test_broadcast_tensors_for_dist_tensor(self): + x1 = np.random.random(size=[4, 4]).astype("float32") + x2 = np.random.random(size=[4, 4]).astype("float32") + local_in1, dist_in1 = self.create_local_and_dist_tensor_pair(x1) + local_in2, dist_in2 = self.create_local_and_dist_tensor_pair(x2) + + local_out1, local_out2 = paddle.broadcast_tensors( + [local_in1, local_in2] + ) + dist_out1, dist_out2 = paddle.broadcast_tensors([dist_in1, dist_in2]) + self.check_tensor_eq(local_out1, dist_out1) + self.check_tensor_eq(local_out2, dist_out2) + + local_out = paddle.concat([local_out1, local_out2]) + dist_out = paddle.concat([dist_out1, dist_out2]) + + local_out.backward() + dist_out.backward() + self.check_tensor_eq(local_in1.grad, dist_in1.grad) + self.check_tensor_eq(local_in2.grad, dist_in2.grad) # input: paddle::optional # output: phi::Tensor