Skip to content

Commit

Permalink
yapf lint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffhataws committed Mar 6, 2024
1 parent ed96076 commit 6022c91
Showing 1 changed file with 20 additions and 13 deletions.
33 changes: 20 additions & 13 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,8 @@ def _all_gather_coalesced(tensor_list, output_list=None):
if output_list != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_all_gather_coalesced_out(
output_list, tensor_list, token, dim, shard_count, groups or [], pin_layout)
output_list, tensor_list, token, dim, shard_count, groups or [],
pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output_list

Expand All @@ -650,26 +651,31 @@ def _all_gather_coalesced(tensor_list, output_list=None):
divisor = xrt_world_size()
bucket_cap = bucket_cap / divisor
for idx, tensor in enumerate(value):

tensor_bytes = tensor.numel() * tensor.element_size()
output_selected = None
if output != None:
output_selected = output[idx]
if tensor.numel() != output_selected.numel():
raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: "
f"{output_selected.numel() vs tensor.numel().")
raise ValueError(
f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: "
f"{output_selected.numel() vs tensor.numel().")

# Tensor is larger than bucket_cap, don't bucketize
if tensor_bytes > bucket_cap:
# Flush out previous buckets even if they don't fill up
if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap:
out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket))
out_tensors.extend(_all_gather_coalesced([tensor], [output_selected] if output else []))
out_tensors.extend(
_all_gather_coalesced(tensor_bucket, output_bucket))
out_tensors.extend(
_all_gather_coalesced([tensor],
[output_selected] if output else []))
else:
tensor_bucket.append(tensor)
if output != None:
output_bucket.append(output[i])
out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket))
out_tensors.extend(
_all_gather_coalesced(tensor_bucket, output_bucket))
total = 0
tensor_bucket = []
output_bucket = []
Expand Down Expand Up @@ -901,14 +907,14 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None):
if output_list != None:
# Call the out of place version of the reduce_scatter
new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out(
reduce_type, output_list, tensor_list, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
reduce_type, output_list, tensor_list, token, scale, scatter_dim,
shard_count, groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token)
return output_list

result = torch_xla._XLAC._xla_reduce_scatter_coalesced(
reduce_type, tensor_list, token, scale,
scatter_dim, shard_count, groups or [], pin_layout)
reduce_type, tensor_list, token, scale, scatter_dim, shard_count,
groups or [], pin_layout)
torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1])
return result[:-1]

Expand All @@ -925,8 +931,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None):
if output != None:
output_selected = output[idx]
if tensor.numel() != output_selected.numel():
raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: "
f"{output_selected.numel() vs tensor.numel().")
raise ValueError(
f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: "
f"{output_selected.numel() vs tensor.numel().")

# Tensor is larger than bucket_cap, don't bucketize
if tensor_bytes > bucket_cap:
Expand Down

0 comments on commit 6022c91

Please sign in to comment.