From ae348b24ffb7152145f564e0f7dd9346a4e6fa58 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Tue, 19 Mar 2024 05:46:08 +0000 Subject: [PATCH] Avoid passing empty input buckets --- torch_xla/core/xla_model.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 093fb636d52..f1f7051dc46 100755 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -664,8 +664,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]): if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend( - _all_gather_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( _all_gather_coalesced([tensor], [output[idx]] if output else [])) else: @@ -682,7 +683,9 @@ def _all_gather_coalesced(tensor_list, output_list=[]): # Bucketize till the total spills over total += tensor_bytes if total > bucket_cap: - out_tensors.extend(_all_gather_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _all_gather_coalesced(tensor_bucket, output_bucket)) total = tensor_bytes tensor_bucket = [] output_bucket = [] @@ -933,15 +936,15 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]): if tensor_bytes > bucket_cap: # Flush out previous buckets even if they don't fill up if total >= 0.5 * bucket_cap or (total + tensor_bytes) > 2 * bucket_cap: - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) out_tensors.extend( _reduce_scatter_coalesced([tensor], [output[idx]] if output else [])) else: tensor_bucket.append(tensor) if output != None: - assert (output[idx] != None) output_bucket.append(output[idx]) out_tensors.extend( _reduce_scatter_coalesced(tensor_bucket, output_bucket)) @@ -953,8 +956,9 @@ def _reduce_scatter_coalesced(tensor_list, output_list=[]): # Bucketize till the total spills over total += tensor_bytes if total > bucket_cap: - out_tensors.extend( - _reduce_scatter_coalesced(tensor_bucket, output_bucket)) + if len(tensor_bucket): + out_tensors.extend( + _reduce_scatter_coalesced(tensor_bucket, output_bucket)) total = tensor_bytes tensor_bucket = [] output_bucket = []