diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index cbd859202ca604..2b5082f3d9a1e8 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1021,6 +1021,7 @@ def smp_gather(tensor): f"Can't gather the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." ) all_tensors = smp.allgather(tensor, smp.CommGroup.DP_GROUP) + all_tensors = [t if len(t.shape) > 0 else t[None] for t in all_tensors] return torch.cat([t.cpu() for t in all_tensors], dim=0) def smp_nested_concat(tensor):