diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index ed3425167944..2bece09bffc4 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -366,13 +366,6 @@ def __init__(self, see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) - # Record padding required for alignment - if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: - padding = self.bit16_groups_flat[i].numel() - orig_group_numel - else: - padding = 0 - self.groups_padding.append(padding) - if dist.get_rank(group=self.real_dp_process_group[i]) == 0: see_memory_usage(f"After Flattening and after emptying param group {i} cache", force=False) @@ -384,6 +377,18 @@ def __init__(self, data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i) self.parallel_partitioned_bit16_groups.append(data_parallel_partitions) + # Record padding required for alignment + left_boundary = sum([t.numel() for t in data_parallel_partitions[:partition_id]]) + curr_partition_size = data_parallel_partitions[partition_id].numel() + + if orig_group_numel <= left_boundary: + padding = curr_partition_size + elif orig_group_numel < left_boundary + curr_partition_size: + padding = left_boundary + curr_partition_size - orig_group_numel + else: + padding = 0 + self.groups_padding.append(padding) + # verify that data partition start locations are 4-byte aligned for partitioned_data in data_parallel_partitions: assert (partitioned_data.data_ptr() % (2 * self.nccl_start_alignment_factor) == 0)