Skip to content

Commit

Permalink
Avoid passing empty input buckets
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Mar 19, 2024
1 parent 77b2ad1 commit ae348b2
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,8 +664,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]):
if tensor_bytes > bucket_cap:
# Flush out previous buckets even if they don't fill up
if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap:
out_tensors.extend(
_all_gather_coalesced(tensor_bucket, output_bucket))
if len(tensor_bucket):
out_tensors.extend(
_all_gather_coalesced(tensor_bucket, output_bucket))
out_tensors.extend(
_all_gather_coalesced([tensor], [output[idx]] if output else []))
else:
Expand All @@ -682,7 +683,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]):
# Bucketize till the total spills over
total += tensor_bytes
if total > bucket_cap:
out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket))
if len(tensor_bucket):
out_tensors.extend(
_all_gather_coalesced(tensor_bucket, output_bucket))
total = tensor_bytes
tensor_bucket = []
output_bucket = []
Expand Down Expand Up @@ -933,15 +936,15 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]):
if tensor_bytes > bucket_cap:
# Flush out previous buckets even if they don't fill up
if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap:
out_tensors.extend(
_reduce_scatter_coalesced(tensor_bucket, output_bucket))
if len(tensor_bucket):
out_tensors.extend(
_reduce_scatter_coalesced(tensor_bucket, output_bucket))
out_tensors.extend(
_reduce_scatter_coalesced([tensor],
[output[idx]] if output else []))
else:
tensor_bucket.append(tensor)
if output != None:
assert (output[idx] != None)
output_bucket.append(output[idx])
out_tensors.extend(
_reduce_scatter_coalesced(tensor_bucket, output_bucket))
Expand All @@ -953,8 +956,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]):
# Bucketize till the total spills over
total += tensor_bytes
if total > bucket_cap:
out_tensors.extend(
_reduce_scatter_coalesced(tensor_bucket, output_bucket))
if len(tensor_bucket):
out_tensors.extend(
_reduce_scatter_coalesced(tensor_bucket, output_bucket))
total = tensor_bytes
tensor_bucket = []
output_bucket = []
Expand Down

0 comments on commit ae348b2

Please sign in to comment.