diff --git a/sharktank/sharktank/ops/sharded_impls.py b/sharktank/sharktank/ops/sharded_impls.py index dc66b39cd..13cf2ef4b 100644 --- a/sharktank/sharktank/ops/sharded_impls.py +++ b/sharktank/sharktank/ops/sharded_impls.py @@ -60,20 +60,22 @@ def all_gather_split( def all_reduce_split_or_unreduced( input: Union[SplitPrimitiveTensor, UnreducedTensor], ) -> ReplicatedTensor: - # For each device move the shards to it and do a reduction. - # If we don't move first, common sub-expression elimination is free to collapse all - # reductions into one and then copy to all devices, which is not what we want. + reduced = functools.reduce( + lambda x, y: elementwise(torch.add, x, y), + [ + ( + transfer_to_logical_device(shard, 0) + if i != 0 + else barrier_on_logical_device(shard, 0) + ) + for i, shard in enumerate(input.shards) + ] + ) shards = [ - functools.reduce( - lambda x, y: elementwise(torch.add, x, y), - [ - ( - barrier_on_logical_device(shard, i) - if i == j - else transfer_to_logical_device(shard, i) - ) - for j, shard in enumerate(input.shards) - ], + ( + transfer_to_logical_device(reduced, i) + if i != 0 + else barrier_on_logical_device(reduced, 0) ) for i in range(input.shard_count) ]