From 673cb60808a23f1f925bc445047720df3f0ff689 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Fri, 6 May 2022 05:28:30 -0700 Subject: [PATCH] Improve z3 trace management (#1916) * Fix OOM and type mismatch * Toggle prefetching * Disable z3 prefetching for inference (temp workaround) * Fix zero3 tracing issues * Remove debug prints * Enable prefetch for inference * Code clarity * Invalidate trace cache * Trace cache invalidation when needed Separate nvme prefetch from all-gather prefetch * Track last used step id * Use debug name in error message * Construct param trace from module trace Co-authored-by: Jeff Rasley --- .../runtime/zero/partition_parameters.py | 13 ++- .../zero/partitioned_param_coordinator.py | 86 +++++++++++++++---- deepspeed/runtime/zero/stage3.py | 3 - 3 files changed, 81 insertions(+), 21 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 80b1ee34bcec..daaf0d1dd3f1 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -27,7 +27,14 @@ from ..utils import get_only_unique_item, see_memory_usage from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks from deepspeed.utils import init_distributed, instrument_w_nvtx, logger -from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file +from deepspeed.utils.debug import (debug_param2name_id_shape, + debug_param2name_id_shape_device, + debug_module2name, + debug_param2name, + debug_param2name_id, + debug_param2name_id_shape_status, + printflock, + log_rank_file) from deepspeed.utils.logging import logger from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus @@ -937,9 +944,9 @@ def item_override(): param.all_gather() return param._orig_item() - def ds_summary(slf: torch.Tensor) -> dict: + def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict: return { - "id": slf.ds_id, + "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id, "status": slf.ds_status.name, "numel": slf.numel(), "ds_numel": slf.ds_numel, diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 4958dcf3c3a4..e4064dd03d3e 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -16,6 +16,7 @@ from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.offload_constants import * from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id def debug_rank0(message: str) -> None: @@ -33,7 +34,7 @@ def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: return map(lambda pair: pair[1], get_all_parameters(module, recurse)) -class TraceMode(Enum): +class ZeRoTraceMode(Enum): # Record trace of the network during a single forward+backward (for training) or forward (for inference) RECORD = 1 # Use recorded network trace to optimize current forward+backward or forward @@ -75,12 +76,14 @@ def __init__( # keeps track of the number of submodules invoked so far. self.__step_id: int = 0 # network tracing mode - self.__trace_mode: TraceMode = TraceMode.RECORD + self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD # sequence of submodules/parameters in forward pass + backward pass self.__submodule_order: Iterable[Module] = [] self.__param_order: Iterable[__class__.__ParamInTrace] = [] self.__most_recent_step_id_param_fetched_for = collections.defaultdict( lambda: int(-1e10)) + self.__step_id_module_fetched_for = collections.defaultdict( + lambda: collections.deque()) # number of available params, and max number of available params self.__n_available_params: int = 0 self.__max_n_available_params: int = max_available_parameters_in_numel @@ -126,24 +129,29 @@ def _clear_trace_structures(self) -> None: self.__param_queue = None def is_complete_trace(self) -> bool: - return self.__trace_mode == TraceMode.COMPLETE + return self.__trace_mode == ZeRoTraceMode.COMPLETE def is_invalid_trace(self) -> bool: - return self.__trace_mode == TraceMode.INVALID + return self.__trace_mode == ZeRoTraceMode.INVALID def is_record_trace(self) -> bool: - return self.__trace_mode == TraceMode.RECORD + return self.__trace_mode == ZeRoTraceMode.RECORD def _invalidate_trace(self) -> None: if self.is_invalid_trace(): raise RuntimeError("attempted to invalidate already invalid trace") - self.__trace_mode = TraceMode.INVALID + self.__trace_mode = ZeRoTraceMode.INVALID self._clear_trace_structures() def trace_prologue(self, sub_module: Module) -> None: if self.is_complete_trace(): # sub_module must match expectation else invalidate trace cache if sub_module != self.__submodule_order[self.__step_id]: + expected_module_id = self.__submodule_order[self.__step_id].id + debug_rank0( + f"Invalidate trace cache @ step {self.__step_id}: " + f"expected module {expected_module_id}, but got module {sub_module.id}" + ) self._invalidate_trace() def record_module(self, sub_module: Module) -> None: @@ -151,17 +159,27 @@ def record_module(self, sub_module: Module) -> None: if not self.is_record_trace(): raise RuntimeError( f"attempted to record trace when status = {self.__trace_mode}") + self.__submodule_order.append(sub_module) + self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id) def record_parameters(self, sub_module: Module) -> None: """adds sub module to trace""" if not self.is_record_trace(): raise RuntimeError( f"attempted to record trace when status = {self.__trace_mode}") + + step_id = self.__step_id_module_fetched_for[sub_module.id].popleft() for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): self.__param_order.append( __class__.__ParamInTrace(param=param, - step_id_last_used_at=self.__step_id)) + step_id_last_used_at=step_id)) + + def construct_parameter_trace_from_module_trace(self): + """use module trace to construct parameter trace""" + self.__param_order = [] + for sub_module in self.__submodule_order: + self.record_parameters(sub_module) def reset_step(self) -> None: """indicate that we have completed one fwd+bwd for the model""" @@ -180,22 +198,38 @@ def reset_step(self) -> None: if self.is_record_trace(): # Successfully recorded a trace + self.construct_parameter_trace_from_module_trace() self.__submodule_order = tuple(self.__submodule_order) # freeze self.__param_order = tuple(self.__param_order) # freeze - self.__trace_mode = TraceMode.COMPLETE # self.trace_complete = True + self.__trace_mode = ZeRoTraceMode.COMPLETE print_rank_0( - f"completed trace: {[m.id for m in self.__submodule_order]}", + f"completed record trace: {[m.id for m in self.__submodule_order]}", force=False) else: # Enable trace recording for next forward/backward pass - self.__trace_mode = TraceMode.RECORD + self.__trace_mode = ZeRoTraceMode.RECORD self.__param_queue = collections.deque(self.__param_order) # reset fetch queue self.__most_recent_step_id_param_fetched_for = collections.defaultdict( lambda: int(-1e10)) + self.__step_id_module_fetched_for = collections.defaultdict( + lambda: collections.deque()) self.__step_id = 0 self.__n_available_params = 0 + def _dump_params(self, tag, sub_module, params, step_id=None): + if step_id is None: + step_id = self.__step_id + param_names = [debug_param2name_id(p) for p in params] + print( + f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}' + ) + + def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None): + if step_id is None: + step_id = self.__step_id + print(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}') + """Fetch and Release Fetching, prefetching, and releasing parameters """ @@ -264,15 +298,23 @@ def fetch_sub_module(self, current_submodule: Module) -> None: self.__most_recent_step_id_param_fetched_for[ param_in_trace.param] = param_in_trace.step_id_last_used_at discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: raise RuntimeError( f"tracing error at step {self.__step_id}: \n" f"module id: {current_submodule.id}, training: {current_submodule.training}\n" f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} \n" - f"but got \n {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n" + f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}." ) + def _is_currently_on_nvme(param): + if param.nvme_swapper is None: + return False + + return param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE \ + and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE + # kick off all gather for params in the next few submodules (prefetch) if self.__prefetch_bucket_sz > 0: max_params_to_prefetch = min( @@ -283,11 +325,25 @@ def fetch_sub_module(self, current_submodule: Module) -> None: while self.__param_queue and numel_prefetching < max_params_to_prefetch: param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft( ) - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: + + if _is_currently_on_nvme(param_in_trace.param): + # nvme prefetch is handled elsewhere. Need to break here to preserve fetch order + self.__param_queue.appendleft(param_in_trace) + break + + do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE + if param_in_trace.param in params_to_prefetch: + # Avoid duplicates + do_prefetch = False + + self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \ + max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param], + param_in_trace.step_id_last_used_at) + + if do_prefetch: params_to_prefetch.add(param_in_trace.param) numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: debug_rank0(f"-prefetch: {param.ds_summary()}") self.__all_gather_params(params_to_prefetch) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 6796beee22d6..6212a1993db8 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1222,8 +1222,6 @@ def post_sub_module_forward_function(self, sub_module): force=False) param_coordinator = self._get_param_coordinator(training=sub_module.training) - if param_coordinator.is_record_trace(): - param_coordinator.record_parameters(sub_module) param_coordinator.release_sub_module(sub_module) see_memory_usage( @@ -1236,7 +1234,6 @@ def pre_sub_module_backward_function(self, sub_module): param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) - param_coordinator.record_parameters(sub_module) param_coordinator.fetch_sub_module(sub_module) @torch.no_grad()