Skip to content

Commit

Permalink
Reduce buffer copying by using one device to reduce and distribute
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Vasile committed Feb 11, 2025
1 parent fbc69de commit f79044a
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ def all_gather_split(
def all_reduce_split_or_unreduced(
input: Union[SplitPrimitiveTensor, UnreducedTensor],
) -> ReplicatedTensor:
# For each device move the shards to it and do a reduction.
# If we don't move first, common sub-expression elimination is free to collapse all
# reductions into one and then copy to all devices, which is not what we want.
reduced = functools.reduce(
lambda x, y: elementwise(torch.add, x, y),
[
(
transfer_to_logical_device(shard, 0)
if i != 0
else barrier_on_logical_device(shard, 0)
)
for i, shard in enumerate(input.shards)
]
)
shards = [
functools.reduce(
lambda x, y: elementwise(torch.add, x, y),
[
(
barrier_on_logical_device(shard, i)
if i == j
else transfer_to_logical_device(shard, i)
)
for j, shard in enumerate(input.shards)
],
(
transfer_to_logical_device(reduced, i)
if i != 0
else barrier_on_logical_device(reduced, 0)
)
for i in range(input.shard_count)
]
Expand Down

0 comments on commit f79044a

Please sign in to comment.