Skip to content

Commit

Permalink
Fix gather for SageMaker model parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger committed Sep 27, 2021
1 parent 4e0410e commit 1c96500
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,7 @@ def smp_gather(tensor):
f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
)
all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP)
all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors]
return torch.cat([t.cpu() for t in all_tensors], dim=0)

def smp_nested_concat(tensor):
Expand Down

0 comments on commit 1c96500

Please sign in to comment.