Skip to content

Commit

Permalink
[AutoParallel] Support vector and optional<vector> InferSPMD input an…
Browse files Browse the repository at this point in the history
…d output. (PaddlePaddle#58573)

* [AutoParallel] Support vector and optional<vector> InferSPMD input and output.

* Fix some problems.

* Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into support_vector_inferspmd

* Fix conflicts.

* Polish code.

* Polish code.

* Polish code.
  • Loading branch information
GhostScreaming authored and zeroRains committed Nov 8, 2023
1 parent e302c60 commit f1d36ed
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 426 deletions.
21 changes: 11 additions & 10 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
}

auto meta_dist_input_x = MakeDistMetaTensor(input_x);
auto spmd_info =
phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x);
auto spmd_info = phi::distributed::VariadicReplicatedInferSpmdDynamic(
meta_dist_input_x);

auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = dist_out->unsafe_mutable_value();
Expand All @@ -139,7 +139,7 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
phi::AddNInferMeta(x_metas, &meta_dist_out);
if (rank_is_in_current_mesh) {
auto dist_input_x =
ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first);
ReshardApiInputToKernelInput(dev_ctx, x, spmd_info.first[0]);
dist_input_x = PrepareDataForDistTensor(
dist_input_x,
GetKernelInputArgDef(kernel.InputAt(0), kernel_backend),
Expand All @@ -165,14 +165,15 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, dense_out);
}
PADDLE_ENFORCE_EQ(
paddle::holds_alternative<phi::distributed::TensorDistAttr>(
spmd_info.first[0]),
true,
phi::errors::PreconditionNotMet(
"Arg must be a single TensorDistAttr"));
PADDLE_ENFORCE_EQ(paddle::holds_alternative<
std::vector<phi::distributed::TensorDistAttr>>(
spmd_info.first[0]),
true,
phi::errors::PreconditionNotMet(
"Arg must be a vector of TensorDistAttr"));

auto current_process_mesh =
paddle::get<0>(spmd_info.first[0]).process_mesh();
paddle::get<1>(spmd_info.first[0]).at(0).process_mesh();
SetReplicatedDistAttrForOutput(dist_out, current_process_mesh);
return api_output;
}
Expand Down
Loading

0 comments on commit f1d36ed

Please sign in to comment.