diff --git a/deepspeed/runtime/bf16_optimizer.py b/deepspeed/runtime/bf16_optimizer.py index 0aefd1946c36..6c1ae345ebb6 100644 --- a/deepspeed/runtime/bf16_optimizer.py +++ b/deepspeed/runtime/bf16_optimizer.py @@ -18,7 +18,7 @@ align_dense_tensors, all_gather_dp_groups, bwc_tensor_model_parallel_rank, is_model_parallel_parameter, see_memory_usage, graph_process) -from deepspeed.utils import link_hp_params, fragment_address +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.checkpoint.constants import (DS_VERSION, PARTITION_COUNT, BASE_OPTIMIZER_STATE, SINGLE_PARTITION_OF_FP32_GROUPS, CLIP_GRAD, GROUP_PADDINGS, @@ -165,6 +165,7 @@ def _setup_for_real_optimizer(self): # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() + self._hp_optimizer_states_linked = False self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -199,9 +200,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bf16_groups[i], self.fp32_groups_flat_partition[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def initialize_optimizer_states(self): """Take an optimizer step with zero-valued gradients to allocate internal optimizer state. @@ -215,8 +222,6 @@ def initialize_optimizer_states(self): param_partition.grad = grad_partition.to( param_partition.dtype) if grad_partition.dtype != param_partition.dtype else grad_partition - self.optimizer.step() - if self.grad_acc_dtype is not torch.float32: for param_partition in self.fp32_groups_flat_partition: param_partition.grad = None @@ -263,6 +268,9 @@ def step(self, closure=None): self.optimizer.step() + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + self.update_lp_params() self.clear_hp_grads() diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d971092ebd17..42008236a9ea 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1015,10 +1015,6 @@ def initialize_optimizer_states(self): else: self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow(0, 0, num_elements) - # Initialize the optimizer states with the flattened fp32 partition. - if not is_adagrad: - self._optimizer_step(i) - if swappable_param_subgroup: self._partitioned_params_swap_out(i) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 3e579422b26d..18b58403f1d7 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -28,7 +28,7 @@ from deepspeed.checkpoint.constants import (DS_VERSION, GROUP_PADDINGS, PARTITION_COUNT, LOSS_SCALER, SINGLE_PARTITION_OF_FP32_GROUPS, BASE_OPTIMIZER_STATE, BASE_OPTIMIZER_STATE_STEP, CLIP_GRAD, ZERO_STAGE, PARAM_SLICE_MAPPINGS) -from deepspeed.utils import link_hp_params +from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.checkpoint import enable_universal_checkpoint from deepspeed.utils import groups @@ -88,6 +88,12 @@ def _get_padded_tensor(src_tensor, size): return padded_tensor +def _pad_tensor_by_size(src_tensor, pad_size, dtype, device): + padded_tensor = torch.zeros(src_tensor.numel() + pad_size, dtype=dtype, device=device) + padded_tensor.data[:src_tensor.numel()].copy_(src_tensor.data) + return padded_tensor + + class DeepSpeedZeroOptimizer(ZeROOptimizer): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -536,6 +542,8 @@ def __init__(self, see_memory_usage(f"After initializing ZeRO optimizer", force=True) self._link_all_hp_params() + self._hp_optimizer_states_linked = False + self._enable_universal_checkpoint() self._param_slice_mappings = self._create_param_mapping() @@ -578,9 +586,15 @@ def _link_all_hp_params(self): param_group_index=i, partition_start=partition_id * partition_size, partition_size=partition_size, - partition_optimizer_state=self.optimizer.state[flat_hp_partition], dp_group=self.real_dp_process_group[i]) + def _lazy_init_hp_params_optimizer_state(self): + if not self._hp_optimizer_states_linked: + for i, _ in enumerate(self.optimizer.param_groups): + lazy_init_hp_params_optimizer_state(self.bit16_groups[i], self.single_partition_of_fp32_groups[i], + self.optimizer.state) + self._hp_optimizer_states_linked = True + def is_moe_group(self, group): return 'moe' in group and group['moe'] @@ -664,8 +678,6 @@ def initialize_optimizer_states(self): # which do lazy initialization of the state at the first call to step. if isinstance(self.optimizer, torch.optim.Adagrad): self.optimizer = torch.optim.Adagrad(self.single_partition_of_fp32_groups, **self.optimizer.defaults) - else: - self.optimizer.step() if not self.cpu_offload: for group in self.single_partition_of_fp32_groups: @@ -1793,6 +1805,9 @@ def _optimizer_step(self, group_no): self.optimizer.step() self.optimizer.param_groups = original_param_groups + # We need to link optimizer state after the first step() call + self._lazy_init_hp_params_optimizer_state() + def step(self, closure=None): """ Not supporting closure. @@ -2208,19 +2223,39 @@ def _partition_base_optimizer_state(self, state_key, all_partition_states, group # Assume non-tensor states are not partitioned and equal across ranks, so return first one return all_partition_states[0] - def _restore_base_optimizer_state(self, base_optimizer_group_states): + def _restore_step_from_elastic_checkpoint(self, all_state_dict): + assert BASE_OPTIMIZER_STATE_STEP in all_state_dict[0] + assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + for sd in all_state_dict), "State dicts of all partitions must have the same step value" + return all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] + + def _restore_base_optimizer_state(self, base_optimizer_group_states, base_optimizer_state_step, group_paddings): if type(base_optimizer_group_states) == dict: base_optimizer_group_states = base_optimizer_group_states['state'] + + saved_keys = base_optimizer_group_states[0].keys() + for i, group in enumerate(self.optimizer.param_groups): p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - dst_tensor = self.optimizer.state[p][key] - src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) - self.optimizer.state[p][key].data.copy_(src_tensor.data) + padding = 0 if group_paddings is None else group_paddings[i] + for key in saved_keys: + saved = base_optimizer_group_states[i][key] + + if torch.is_tensor(saved): + if key in self.optimizer.state[p]: + dst_tensor = self.optimizer.state[p][key] + src_tensor = _get_padded_tensor(saved, dst_tensor.numel()) + self.optimizer.state[p][key].data.copy_(src_tensor.data) + else: + self.optimizer.state[p][key] = _pad_tensor_by_size( + saved, padding, torch.float32, + torch.device('cpu') if self.cpu_offload else self.device) else: self.optimizer.state[p][key] = saved + for param_group in self.optimizer.param_groups: + param_group['step'] = base_optimizer_state_step + def get_ep_ranks(self, rank=0, group_name=None): from deepspeed.utils import groups expert_parallel_size_ = groups._get_expert_parallel_world_size(group_name) @@ -2248,15 +2283,8 @@ def _restore_elastic_base_optimizer_state(self, all_state_dict): partition_states[key] = self._partition_base_optimizer_state(key, all_partition_states, i) base_optimizer_group_states.append(partition_states) - self._restore_base_optimizer_state(base_optimizer_group_states) - - # Restore step - if BASE_OPTIMIZER_STATE_STEP in all_state_dict[0]: - assert all(sd[BASE_OPTIMIZER_STATE_STEP] == all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for sd in all_state_dict), "State dicts of all partitions must have the same step value" - loaded_param_groups_step = all_state_dict[0][BASE_OPTIMIZER_STATE_STEP] - for param_group in self.optimizer.param_groups: - param_group['step'] = loaded_param_groups_step + self._restore_base_optimizer_state(base_optimizer_group_states, + self._restore_step_from_elastic_checkpoint(all_state_dict), None) def load_state_dict(self, state_dict_list, @@ -2368,7 +2396,9 @@ def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, l self._restore_elastic_base_optimizer_state(state_dict_list) else: # loading an elastic checkpoint into rigid exec - self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE]) + self._restore_base_optimizer_state(current_rank_sd[BASE_OPTIMIZER_STATE], + current_rank_sd[BASE_OPTIMIZER_STATE_STEP], + current_rank_sd[GROUP_PADDINGS]) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 1f86306aefec..33ea8ba60818 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -17,6 +17,6 @@ from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_optimizer_state from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter -from .mixed_precision_linkage import link_hp_params +from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state from deepspeed.runtime.dataloader import RepeatingLoader from .numa import get_numactl_cmd diff --git a/deepspeed/utils/mixed_precision_linkage.py b/deepspeed/utils/mixed_precision_linkage.py index b1afa8f00aa3..7dea6ba322db 100644 --- a/deepspeed/utils/mixed_precision_linkage.py +++ b/deepspeed/utils/mixed_precision_linkage.py @@ -9,13 +9,19 @@ def link_hp_params(lp_param_list, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, partition_optimizer_state, dp_group): + param_group_index, partition_start, partition_size, dp_group): local_lp_param_and_offset = _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group) for lp_param, lp_start in local_lp_param_and_offset: lp_param._hp_mapping = get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, param_group_index, - partition_start, partition_size, partition_optimizer_state) + partition_start, partition_size) + + +def lazy_init_hp_params_optimizer_state(lp_param_list, flat_hp_partition, optimizer_state): + for lp in lp_param_list: + if lp._hp_mapping is not None: + lp._hp_mapping.set_optim_state_fragment(flat_hp_partition, optimizer_state[flat_hp_partition]) def _init_lp_to_hp_mapping(lp_param_list, partition_start, partition_size, dp_group): diff --git a/deepspeed/utils/tensor_fragment.py b/deepspeed/utils/tensor_fragment.py index 5f94070dc4c7..49eefafcfbcc 100644 --- a/deepspeed/utils/tensor_fragment.py +++ b/deepspeed/utils/tensor_fragment.py @@ -21,11 +21,11 @@ class tensor_fragment: lp_fragment_address: fragment_address hp_fragment: torch.Tensor hp_fragment_address: fragment_address - optim_fragment: Dict gradient_dict: Dict offload_gradient_dict: Dict use_offload: bool param_group_index: int + optim_fragment: Dict = None def update_hp(self): self.hp_fragment.data.copy_(self.lp_fragment.data) @@ -39,6 +39,13 @@ def get_optim_state_fragment(self, key): else: raise ValueError(f'{key} not found in optimizer state fragment') + def set_optim_state_fragment(self, flat_hp_partition, optim_fragment): + self.optim_fragment = { + key: value.narrow(0, self.hp_fragment_address.start, self.hp_fragment_address.numel) + for key, value in optim_fragment.items() + if torch.is_tensor(value) and value.shape == flat_hp_partition.shape + } + def get_hp_fragment_address(self): return self.hp_fragment_address @@ -255,7 +262,7 @@ def safe_set_local_fp32_param(param, value): def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict, offload_gradient_dict, use_offload, - param_group_index, partition_start, partition_size, optimizer_state_dict): + param_group_index, partition_start, partition_size): lp_end = lp_param.numel() + lp_start hp_start = partition_start hp_end = partition_start + partition_size @@ -268,11 +275,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict fragment_numel = fragment_end - fragment_start hp_frag_address = fragment_address(start=fragment_start - hp_start, numel=fragment_numel) hp_fragment_tensor = flat_hp_partition.narrow(0, hp_frag_address.start, hp_frag_address.numel) - optim_fragment = { - key: value.narrow(0, hp_frag_address.start, hp_frag_address.numel) - for key, value in optimizer_state_dict.items() - if torch.is_tensor(value) and value.shape == flat_hp_partition.shape - } lp_frag_address = fragment_address(start=fragment_start - lp_start, numel=fragment_numel) lp_fragment_tensor = lp_param.flatten().narrow(0, lp_frag_address.start, lp_frag_address.numel) @@ -281,7 +283,6 @@ def get_hp_fragment_mapping(lp_param, lp_start, flat_hp_partition, gradient_dict lp_fragment_address=lp_frag_address, hp_fragment=hp_fragment_tensor, hp_fragment_address=hp_frag_address, - optim_fragment=optim_fragment, gradient_dict=gradient_dict, offload_gradient_dict=offload_gradient_dict, use_offload=use_offload, diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index bc31e3b9a968..2594d910acff 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -1370,6 +1370,11 @@ class TestZeroAdamOptimizerStepCount(DistributedTest): world_size = 1 def test(self, zero_stage): + # We verify trhee conditions: + # 1. global_steps starts at 0 + # 2. All subgroups have the same step count + # 3. The global step count is the same as the step count of the first subgroup + # force all params to be partitioned by forcing threshold=0 config_dict = { "train_micro_batch_size_per_gpu": 2, @@ -1399,24 +1404,31 @@ def test(self, zero_stage): model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=16, hidden_dim=hidden_dim, device=model.device) - for i, batch in enumerate(data_loader): + assert model.global_steps == 0 + + for batch in data_loader: loss = model(batch[0], batch[1]) model.backward(loss) + + is_gradient_accumulation_boundary = model.is_gradient_accumulation_boundary() model.step() - step_counts = [] - if zero_stage == 3: - for sub_group_id, _ in enumerate(optimizer.fp16_groups): - fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] - state = optimizer.optimizer.state[fp32_param] - step_counts.append(state["step"]) - assert all(step == step_counts[0] for step in step_counts) - elif zero_stage == 1 or zero_stage == 2: - for param_group in optimizer.optimizer.param_groups: - for param in param_group["params"]: - state = optimizer.optimizer.state[param] + if is_gradient_accumulation_boundary: + step_counts = [] + + if zero_stage == 3: + for sub_group_id, _ in enumerate(optimizer.fp16_groups): + fp32_param = optimizer.fp32_partitioned_groups_flat[sub_group_id] + state = optimizer.optimizer.state[fp32_param] step_counts.append(state["step"]) + elif zero_stage == 1 or zero_stage == 2: + for param_group in optimizer.optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.optimizer.state[param] + step_counts.append(state["step"]) + assert all(step == step_counts[0] for step in step_counts) + assert model.global_steps == step_counts[0] @pytest.mark.parametrize("zero_stage", [1, 2, 3]) diff --git a/tests/unit/runtime/zero/test_zero_tensor_fragment.py b/tests/unit/runtime/zero/test_zero_tensor_fragment.py index c223e67af697..b3adfdf96c50 100644 --- a/tests/unit/runtime/zero/test_zero_tensor_fragment.py +++ b/tests/unit/runtime/zero/test_zero_tensor_fragment.py @@ -24,35 +24,26 @@ SECOND_ORDER_KEY = 'exp_avg_sq' -def validate_full_tensors(model): +def validate_tensor(model, api_type, opt_states): + assert api_type in ["full", "local"] for _, lp in model.named_parameters(): - hp = safe_get_full_fp32_param(lp) - exp_avg = safe_get_full_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_full_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_full_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] - if lp.requires_grad: - assert all([p is not None for p in param_list]) + param_list = [] + if opt_states: + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg')) + param_list.append( + safe_get_full_optimizer_state(lp, 'exp_avg_sq') if api_type == + "full" else safe_get_local_optimizer_state(lp, 'exp_avg_sq')) else: - assert all([p is None for p in param_list]) - - -def validate_local_tensors(model): - for _, lp in model.named_parameters(): - hp = safe_get_local_fp32_param(lp) - exp_avg = safe_get_local_optimizer_state(lp, 'exp_avg') - exp_avg_sq = safe_get_local_optimizer_state(lp, 'exp_avg_sq') - hp_grad = safe_get_local_grad(lp) - param_list = [hp, hp_grad, exp_avg, exp_avg_sq] + param_list.append(safe_get_full_fp32_param(lp) if api_type == "full" else safe_get_local_fp32_param(lp)) + param_list.append(safe_get_full_grad(lp) if api_type == "full" else safe_get_local_grad(lp)) if lp.requires_grad: assert all([p is not None for p in param_list]) else: assert all([p is None for p in param_list]) -validate_funcs_mapping = {"full": validate_full_tensors, "local": validate_local_tensors} - - class MyModel(torch.nn.Module): def __init__(self, hidden_dim, frozen_weights): @@ -71,12 +62,10 @@ def forward(self, x, y): for l in self.linears: x = l(x) x = self.act(x) - loss = self.cel(x, y) - val = (x, loss) - return val + return self.cel(x, y) -def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): +def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_after_bwd, validate_after_step): model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) data_loader = random_dataloader(model=model, total_samples=10, @@ -86,10 +75,10 @@ def run_fragmented_model(model, config_dict, hidden_dim, dtype, validate_func): dist.barrier() for n, batch in enumerate(data_loader): loss = model(batch[0], batch[1]) - loss = loss[1] model.backward(loss) - validate_func(model) + validate_after_bwd(model) model.step() + validate_after_step(model) # Needed in ZeRO 3. Not doing so can give memory leak model.destroy() @@ -147,9 +136,10 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, froz else: model = MyModel(hidden_dim, frozen_weights) - validate_func = validate_funcs_mapping[api_type] + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) - run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_func) + run_fragmented_model(model, config_dict, hidden_dim, torch.float16, validate_after_bwd, validate_after_step) def test_bf16_fragments(self, frozen_weights): if frozen_weights: @@ -178,7 +168,12 @@ def test_bf16_fragments(self, frozen_weights): hidden_dim = 128 model = MyModel(hidden_dim, frozen_weights) - run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_full_tensors) + + api_type = "full" + validate_after_bwd = lambda model: validate_tensor(model, api_type, opt_states=False) + validate_after_step = lambda model: validate_tensor(model, api_type, opt_states=True) + + run_fragmented_model(model, config_dict, hidden_dim, torch.bfloat16, validate_after_bwd, validate_after_step) def create_random_values(model, key_list, group, use_cuda=True): @@ -315,23 +310,21 @@ def test_zero_fragments(self, tmpdir, api_type, zero_stage, offload_device, dtyp if zero_stage == 3: config_dict["zero_optimization"]["param_persistence_threshold"] = hidden_dim with deepspeed.zero.Init(config_dict_or_path=config_dict): - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) else: - model = SimpleModel(hidden_dim, nlayers=4) + model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) world = dist.get_world_size() group = dist.new_group(ranks=list(range(world))) dist.barrier() - optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] - helper_funcs = helper_funcs_mapping[api_type] - optim_state_values = helper_funcs["create_random_values"](model, - optim_keys, - group, - use_cuda=offload_device == OffloadDeviceEnum.none) - helper_funcs["set_param_values_with_dict"](model, optim_state_values) - helper_funcs["validate_param_values_with_dict"](model, optim_state_values) - - # Needed in ZeRO 3. Not doing so can leak memory. - model.destroy() + + def validate_func(model): + optim_keys = [WEIGHT_KEY, FIRST_ORDER_KEY, SECOND_ORDER_KEY] + helper_funcs = helper_funcs_mapping[api_type] + optim_state_values = helper_funcs["create_random_values"]( + model, optim_keys, group, use_cuda=offload_device == OffloadDeviceEnum.none) + helper_funcs["set_param_values_with_dict"](model, optim_state_values) + helper_funcs["validate_param_values_with_dict"](model, optim_state_values) + + run_fragmented_model(model, config_dict, hidden_dim, dtype, lambda _: None, validate_func)