From 6022c917543cee06f779c0f66fce51a19a8e5067 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Wed, 6 Mar 2024 16:45:49 +0000 Subject: [PATCH] yapf lint fixes --- torch_xla/core/xla_model.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 8a629d53aa94..e18de98a16a5 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -628,7 +628,8 @@ def _all_gather_coalesced(tensor_list, output_list=None): if output_list != None: # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_all_gather_coalesced_out( - output_list, tensor_list, token, dim, shard_count, groups or [], pin_layout) + output_list, tensor_list, token, dim, shard_count, groups or [], + pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output_list @@ -650,26 +651,31 @@ def _all_gather_coalesced(tensor_list, output_list=None): divisor = xrt_world_size() bucket_cap = bucket_cap / divisor for idx, tensor in enumerate(value): - + tensor_bytes = tensor.numel() * tensor.element_size() output_selected = None if output != None: output_selected = output[idx] if tensor.numel() != output_selected.numel(): - raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + raise ValueError( + f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) - out_tensors.extend(_all_gather_coalesced([tensor], [output_selected] if output else [])) + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend( + _all_gather_coalesced([tensor], + [output_selected] if output else [])) else: tensor_bucket.append(tensor) if output != None: output_bucket.append(output[i]) - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) total = 0 tensor_bucket = [] output_bucket = [] @@ -901,14 +907,14 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): if output_list != None: # Call the out of place version of the reduce_scatter new_token = torch_xla._XLAC._xla_reduce_scatter_coalesced_out( - reduce_type, output_list, tensor_list, token, scale, scatter_dim, shard_count, - groups or [], pin_layout) + reduce_type, output_list, tensor_list, token, scale, scatter_dim, + shard_count, groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, new_token) return output_list result = torch_xla._XLAC._xla_reduce_scatter_coalesced( - reduce_type, tensor_list, token, scale, - scatter_dim, shard_count, groups or [], pin_layout) + reduce_type, tensor_list, token, scale, scatter_dim, shard_count, + groups or [], pin_layout) torch_xla._XLAC._set_all_reduce_token(devctx.device, result[-1]) return result[:-1] @@ -925,8 +931,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=None): if output != None: output_selected = output[idx] if tensor.numel() != output_selected.numel(): - raise ValueError(f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " - f"{output_selected.numel() vs tensor.numel().") + raise ValueError( + f"`output` tensor size doesn't match `input` tensor size for tensor list index {idx}: " + f"{output_selected.numel() vs tensor.numel().") # Tensor is larger than bucket_cap, don't bucketize if tensor_bytes > bucket_cap: