From 1c42fa6b2d18e4febb5a50becdec8f566163eb51 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Sun, 26 Jan 2025 10:33:28 -0600 Subject: [PATCH 1/2] fix: remove assumption that padding only occurs on last rank --- deepspeed/runtime/zero/stage_1_and_2.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index ed3425167944..65d692884a67 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -366,13 +366,6 @@ def __init__(self, see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) - # Record padding required for alignment - if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: - padding = self.bit16_groups_flat[i].numel() - orig_group_numel - else: - padding = 0 - self.groups_padding.append(padding) - if dist.get_rank(group=self.real_dp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) @@ -384,6 +377,18 @@ def __init__(self, data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i) self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) + # Record padding required for alignment + left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]]) + curr_partition_size = data_parallel_partitions[partition_id].numel() + + if orig_group_numel <= left_boundary: + padding = curr_partition_size + elif orig_group_numel < left_boundary + curr_partition_size: + padding = orig_group_numel - left_boundary + else: + padding = 0 + self.groups_padding.append(padding) + # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0) From 0be115164d06b14fd44f04b745e0d136e46e4893 Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Mon, 27 Jan 2025 01:32:41 -0600 Subject: [PATCH 2/2] fix issue --- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 65d692884a67..2bece09bffc4 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -384,7 +384,7 @@ def __init__(self, if orig_group_numel <= left_boundary: padding = curr_partition_size elif orig_group_numel < left_boundary + curr_partition_size: - padding = orig_group_numel - left_boundary + padding = left_boundary + curr_partition_size - orig_group_numel else: padding = 0 self.groups_padding.append(padding)