Skip to content

Commit

Permalink
Add all-gather and reduce-scatter coalescence support for FSDP.
Browse files Browse the repository at this point in the history
Also allow using reduce-scatter's scale param in FSDP.
(revived #4145)
  • Loading branch information
jeffhataws committed Oct 20, 2023
1 parent 8f45cae commit 2e861ff
Show file tree
Hide file tree
Showing 15 changed files with 613 additions and 211 deletions.
44 changes: 44 additions & 0 deletions test/test_torch_distributed_xla_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ def test_allgather(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)

def test_allgather_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
pg_xla = get_process_group_xla(rank=3, size=8)
output_tensors = [torch.zeros_like(tensor)] * 8
output_tensors2 = [torch.zeros_like(tensor2)] * 8
# because we set os.environ[xenv.WORLD_SIZE] = '1', here the outputs'
# shapes will be same as the inputs' shapes.
all_gather_pattern = (
r'%all-gather\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'all-gather\(s64\[2]\{0} %.+\.\d+, s64\[5]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
pg_xla.allgather_coalesced([output_tensors, output_tensors2],
[tensor, tensor2])
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_tensors)
hlo_matches(hlo, all_gather_pattern)
# purge all computations attached the device.
xm.mark_step()

def test_broadcast(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
Expand All @@ -106,6 +126,30 @@ def test_reduce_scatter(self):
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
hlo_matches(hlo, reduce_scatter_pattern)

def test_reduce_scatter_coalesced(self):
device = xm.xla_device()
tensor = torch.arange(2, device=device) + 1 + 2 * dist.get_rank()
tensor2 = torch.arange(5, device=device) + 1 + 2 * dist.get_rank()
input_tensors_list = [[tensor, tensor], [tensor2, tensor2]]
output_list = [torch.zeros_like(tensor), torch.zeros_like(tensor2)]
pg_xla = get_process_group_xla(rank=0, size=len(input_tensors_list[0]))
opts = dist.ReduceScatterOptions()
opts.reduceOp = dist.ReduceOp.SUM
reduce_scatter_pattern = (
r'%reduce\-scatter\.\d+ = \(s64\[2]\{0}, s64\[5]\{0}, s64\[]\) '
r'reduce\-scatter\(s64\[4]\{0} %.+\.\d+, s64\[10]\{0} %.+\.\d+, '
r's64\[] %.+\.\d+\)')
with self.assertRaises(RuntimeError) as cm:
pg_xla.reduce_scatter_coalesced(output_list, input_tensors_list, opts)
hlo = torch_xla._XLAC._get_xla_tensors_hlo(output_list)
hlo_matches(hlo, reduce_scatter_pattern)
# purge all computations attached the device.
xm.mark_step()
assert 'UNIMPLEMENTED: ReduceScatter is not implemented on CPU.' in str(
cm.exception), str(cm.exception)
# reset token to clean up the mess after the RuntimeError.
xm.set_replication(device, [])

@patch_world(0, 6)
def test_send(self):
device = xm.xla_device()
Expand Down
85 changes: 69 additions & 16 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,8 @@ def _all_gather_using_all_reduce(value, dim=0, groups=None, pin_layout=True):
Args:
value (torch.Tensor): The input tensor.
value (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
dim (int): The gather dimension.
Default: 0
groups (list, optional): A list of list, representing the replica groups for
Expand Down Expand Up @@ -562,6 +564,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
shard_count = xrt_world_size()

token, devctx = _get_all_reduce_token()

if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
Expand All @@ -574,6 +577,31 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
[], pin_layout)
return result

if isinstance(value, torch.Tensor):
if output != None:
# Call the out of place version of the all_gather
new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_all_gather(value, token, dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]

# Now the input should be a list of Tensors.
if not isinstance(value, list) or any(
not isinstance(v, torch.Tensor) for v in value):
raise TypeError("`value` needs to be a Tensor or a list of Tensors, but "
f"given {type(value)}.")
result = torch_xla._XLAC._xla_all_gather_coalesced(value, token, dim,
shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]


def all_to_all(value,
split_dimension,
Expand Down Expand Up @@ -718,16 +746,19 @@ def reduce_scatter(reduce_type,
reduce_type (string): One of ``xm.REDUCE_SUM``, ``xm.REDUCE_MUL``,
``xm.REDUCE_AND``, ``xm.REDUCE_OR``, ``xm.REDUCE_MIN`` and
``xm.REDUCE_MAX``.
input: A single `torch.Tensor` all reduce + scatter op to.
input: (torch.Tensor or a list of torch.Tensor): The input. If it's a list, then
it will also be the output.
scale (float): A default scaling value to be applied after the reduce.
scatter_dim (int): Dimension number to which apply scatter operation.
shard_count (int): The number of ways to split up the scatter_dim in.
groups (list): A list of list, representing the replica groups for
the `all_reduce()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
the `reduce_scatter()` operation. Example: `[[0, 1, 2, 3], [4, 5, 6, 7]]`
defines two groups, one with the `[0, 1, 2, 3]` replicas and one with
the `[4, 5, 6, 7]` replicas. If `None` there will be only one group with
all the replicas in it.
output: Optional output tensor
output: Optional output tensor if `input` is a torch.Tensor or a list of
torch.Tensor if `input` is a list of torch.Tensor.
pin_layout (bool, optional): whether to pin the layout for this communication op.
Layout pining can prevent potential data corruption when each process that
participate in the communication has slightly different program, but it might
Expand All @@ -740,21 +771,43 @@ def reduce_scatter(reduce_type,
the same as the input.
"""
token, devctx = _get_all_reduce_token()
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(reduce_type, output,
input, token, scale,
scatter_dim,
shard_count, groups or
[], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token, scale,
scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[1])
return result[0]
if isinstance(input, torch.Tensor):
if output != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_out(
reduce_type, output, input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = new_token
return output

result = torch_xla._XLAC._xla_reduce_scatter(reduce_type, input, token,
scale, scatter_dim,
shard_count, groups or [],
pin_layout)
devctx.all_reduce_token = result[1]
return result[0]

# Now the input should be a list of Tensors.
if not isinstance(input, list) or any(
not isinstance(v, torch.Tensor) for v in input):
raise TypeError("`input` needs to be a Tensor or a list of Tensors, but "
f"given {type(input)}.")
if output != None:
if not isinstance(output, list) or any(
not isinstance(v, torch.Tensor) for v in output):
raise TypeError(
f"`output` needs to be a list of Tensors, but given {type(output)}."
)
if len(output) != len(input):
raise ValueError("`output` length doesn't match `input` length: "
f"{len(output)} vs {len(input)}.")

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, output or [], input, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
devctx.all_reduce_token = result[-1]
return result[:-1]


def add_step_closure(closure, args=(), run_async=False):
Expand Down
151 changes: 95 additions & 56 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,31 +210,46 @@ AllToAllResult BuildAllToAll(xla::XlaOp input, xla::XlaOp token,
return {reduce_result, token_handler.GetNewToken(reduce_result)};
}

AllGatherResult BuildAllGather(xla::XlaOp input, xla::XlaOp token, int64_t dim,
int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
TokenHandler token_handler(token);
xla::XlaOp all_gather_result;
if (pin_layout) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(),
static_cast<XlaDeviceType>(xla_device.type()));
all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
/*layout=*/reduce_shape.layout());
} else {
all_gather_result =
xla::AllGather(token_handler.GetInput(input, &input_shape), dim,
shard_count, reduce_groups);
std::vector<xla::XlaOp> BuildAllGather(
absl::Span<const xla::XlaOp> inputs, xla::XlaOp token, int64_t dim,
int64_t shard_count, const std::vector<std::vector<int64_t>>& groups,
bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// AllGather().
xla::XlaOp chained_token = token;
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
type_ctx.second.ops.push_back(token_op);
type_ctx.second.operand_shapes.push_back(
ShapeHelper::ShapeOfXlaOp(token_op));

xla::XlaOp all_gather_result;
if (pin_layout) {
all_gather_result = xla::AllGather(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops), dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
all_gather_result =
xla::AllGather(xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
dim, shard_count, cc_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
result[op_idx] = xla::GetTupleElement(all_gather_result, i);
}
chained_token =
xla::GetTupleElement(all_gather_result, type_ctx.second.indices.size());
}
return {all_gather_result, token_handler.GetNewToken(all_gather_result)};
}
result.push_back(
MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token)));
return result;
}

CollectivePermuteResult BuildCollectivePermute(
xla::XlaOp input, xla::XlaOp token,
Expand Down Expand Up @@ -274,39 +289,63 @@ RecvResult BuildRecvWithToken(xla::XlaOp token, const xla::Shape& recv_shape,
return {result, new_token};
}

ReduceScatterResult BuildReduceScatter(
AllReduceType reduce_type, xla::XlaOp input, xla::XlaOp token, double scale,
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> reduce_groups = CreateReduceGroups(groups);
TokenHandler token_handler(token);
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp reduce_result;
if (pin_layout) {
torch::lazy::BackendDevice xla_device = bridge::GetCurrentDevice();
xla::Shape reduce_shape = MakeArrayShapeFromDimensions(
input_shape.dimensions(), input_shape.dynamic_dimensions(),
input_shape.element_type(),
static_cast<XlaDeviceType>(xla_device.type()));
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups, /*channel_id=*/absl::nullopt,
/*layout=*/reduce_shape.layout());
} else {
reduce_result = xla::ReduceScatter(
token_handler.GetInput(input, &input_shape),
GetReduceComutation(reduce_type, input_shape.element_type()),
scatter_dim, shard_count, reduce_groups);
}

if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, input_shape.element_type(), input.builder());
reduce_result = reduce_result * scaling_value;
}
std::vector<xla::XlaOp> BuildReduceScatter(
AllReduceType reduce_type, absl::Span<const xla::XlaOp> inputs,
xla::XlaOp token, double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& groups, bool pin_layout) {
std::vector<xla::ReplicaGroup> cc_groups = CreateReduceGroups(groups);
// TODO: We use pseudo-tokens ATM, which are real values. This need to be
// switched to use the real XLA Token once support has been added to XLA
// ReduceScatter().
xla::XlaOp chained_token = token;
ReduceContext cc_ctx = GetReduceContext(inputs);
std::vector<xla::XlaOp> result(inputs.size());
for (auto& type_ctx : cc_ctx.contexts) {
xla::XlaOp token_op = MaybeConvertTo(chained_token, type_ctx.first);
type_ctx.second.ops.push_back(token_op);
type_ctx.second.operand_shapes.push_back(
ShapeHelper::ShapeOfXlaOp(token_op));
xla::XlaOp reduce_result;
if (pin_layout) {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
/*layout=*/
MakeReduceShape(type_ctx.second.operand_shapes).layout());
} else {
reduce_result = xla::ReduceScatter(
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
shard_count, cc_groups);
}
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte = xla::GetTupleElement(reduce_result, i);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, type_ctx.second.operand_shapes[i].element_type(),
gte.builder());
gte = gte * scaling_value;
}
result[op_idx] = gte;
}
chained_token =
xla::GetTupleElement(reduce_result, type_ctx.second.indices.size());
}
result.push_back(
MaybeConvertTo(chained_token, XlaHelpers::TypeOfXlaOp(token)));
return result;
}

return {reduce_result, token_handler.GetNewToken(reduce_result)};
// moved from torch_xla/csrc/ops/all_reduce.cpp
std::vector<torch::lazy::Value> GetOperandList(
c10::ArrayRef<torch::lazy::Value> operands,
const torch::lazy::Value& token) {
std::vector<torch::lazy::Value> operand_list(operands.begin(),
operands.end());
operand_list.push_back(token);
return operand_list;
}

const torch::lazy::Value& GetAllReduceToken(
Expand Down
Loading

0 comments on commit 2e861ff

Please sign in to comment.