Skip to content

Commit

Permalink
【PIR Dist Op Reg No.5】 reg partial_allgather (#62735)
Browse files Browse the repository at this point in the history
* feat(pir): regpartial allgather

* feat(pir): regpartial allgather

* feat(pir): regpartial allgather

* feat(pir): regpartial allgather

* feat(pir): regpartial allgather
  • Loading branch information
xiaoyewww authored Mar 19, 2024
1 parent a29a754 commit 23c9830
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 0 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@
'push_sparse_v2_',
'partial_send',
'partial_recv',
'partial_allgather',
'partial_allgather_',
'nop',
'nop_',
]
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,15 @@
backward : pad_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : partial_allgather
args : (Tensor x, int nranks, int rank, int ring_id = 0, bool use_calc_stream = false)
output : Tensor(out)
infer_meta :
func: PartialAllgatherInferMeta
kernel :
func : partial_allgather
inplace : (x -> out)

- op : partial_recv
args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0)
output : Tensor(out)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2469,6 +2469,12 @@
extra :
attrs : [bool use_mkldnn = false]

- op : partial_allgather
inputs :
x : X
outputs :
out : Out

- op : partial_recv
outputs :
out : Out
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,29 @@ void Pad3dInferMeta(const MetaTensor& x,
out->share_lod(x);
}

void PartialAllgatherInferMeta(const MetaTensor& x,
int nranks,
int rank,
int ring_id,
bool use_calc_stream,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
nranks,
2,
phi::errors::InvalidArgument("The value of nranks should be >=2."));
PADDLE_ENFORCE_EQ(
(rank >= 0 && rank < nranks),
true,
phi::errors::InvalidArgument(
"The rank (%d) for partial_allgather op must >=0 and <nranks (%d)",
rank,
nranks));

auto x_dims = x.dims();
out->set_dims(x_dims);
out->set_dtype(x.dtype());
}

void PartialSendInferMeta(const MetaTensor& x,
int ring_id,
int peer,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ void Pad3dInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void PartialAllgatherInferMeta(const MetaTensor& x,
int nranks,
int rank,
int ring_id,
bool use_calc_stream,
MetaTensor* out);

void PartialSendInferMeta(const MetaTensor& x,
int ring_id,
int peer,
Expand Down
1 change: 1 addition & 0 deletions test/ir/pir/translator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
test_distributed_push_sparse_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_distributed_fused_lamb_init)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_nop_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_allgather_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_send_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST test_partial_recv_translator)
list(APPEND DISTRIBUTED_OP_TRANSLATOR_TEST
Expand Down
47 changes: 47 additions & 0 deletions test/ir/pir/translator/test_partial_allgather_translator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import test_op_translator

import paddle
from paddle.base.layer_helper import LayerHelper


class TestPartialAllgetherOpTranslator(test_op_translator.TestOpTranslator):
def append_op(self):
self.op_type = "partial_allgather"
x = paddle.ones(shape=(100, 2, 3), dtype='float32')
out = paddle.ones(shape=(100, 2, 3), dtype='float32')
attrs = {
'nranks': 2,
'rank': 0,
'ring_id': 0,
'use_calc_stream': False,
}
helper = LayerHelper(self.op_type)
helper.append_op(
type=self.op_type,
inputs={"X": x},
outputs={"Out": out},
attrs=attrs,
)

def test_translator(self):
self.check()


if __name__ == "__main__":
unittest.main()

0 comments on commit 23c9830

Please sign in to comment.