Skip to content

Commit

Permalink
Improve z3 trace management (#1916)
Browse files Browse the repository at this point in the history
* Fix OOM and type mismatch

* Toggle prefetching

* Disable z3 prefetching for inference (temp workaround)

* Fix zero3 tracing issues

* Remove debug prints

* Enable prefetch for inference

* Code clarity

* Invalidate trace cache

* Trace cache invalidation when needed
Separate nvme prefetch from all-gather prefetch

* Track last used step id

* Use debug name in error message

* Construct param trace from module trace

Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored May 6, 2022
1 parent a3b9003 commit 673cb60
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 21 deletions.
13 changes: 10 additions & 3 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@
from ..utils import get_only_unique_item, see_memory_usage
from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks
from deepspeed.utils import init_distributed, instrument_w_nvtx, logger
from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file
from deepspeed.utils.debug import (debug_param2name_id_shape,
debug_param2name_id_shape_device,
debug_module2name,
debug_param2name,
debug_param2name_id,
debug_param2name_id_shape_status,
printflock,
log_rank_file)
from deepspeed.utils.logging import logger

from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
Expand Down Expand Up @@ -937,9 +944,9 @@ def item_override():
param.all_gather()
return param._orig_item()

def ds_summary(slf: torch.Tensor) -> dict:
def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict:
return {
"id": slf.ds_id,
"id": debug_param2name_id(slf) if use_debug_name else slf.ds_id,
"status": slf.ds_status.name,
"numel": slf.numel(),
"ds_numel": slf.ds_numel,
Expand Down
86 changes: 71 additions & 15 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.offload_constants import *
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id


def debug_rank0(message: str) -> None:
Expand All @@ -33,7 +34,7 @@ def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
return map(lambda pair: pair[1], get_all_parameters(module, recurse))


class TraceMode(Enum):
class ZeRoTraceMode(Enum):
# Record trace of the network during a single forward+backward (for training) or forward (for inference)
RECORD = 1
# Use recorded network trace to optimize current forward+backward or forward
Expand Down Expand Up @@ -75,12 +76,14 @@ def __init__(
# keeps track of the number of submodules invoked so far.
self.__step_id: int = 0
# network tracing mode
self.__trace_mode: TraceMode = TraceMode.RECORD
self.__trace_mode: ZeRoTraceMode = ZeRoTraceMode.RECORD
# sequence of submodules/parameters in forward pass + backward pass
self.__submodule_order: Iterable[Module] = []
self.__param_order: Iterable[__class__.__ParamInTrace] = []
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(
lambda: collections.deque())
# number of available params, and max number of available params
self.__n_available_params: int = 0
self.__max_n_available_params: int = max_available_parameters_in_numel
Expand Down Expand Up @@ -126,42 +129,57 @@ def _clear_trace_structures(self) -> None:
self.__param_queue = None

def is_complete_trace(self) -> bool:
return self.__trace_mode == TraceMode.COMPLETE
return self.__trace_mode == ZeRoTraceMode.COMPLETE

def is_invalid_trace(self) -> bool:
return self.__trace_mode == TraceMode.INVALID
return self.__trace_mode == ZeRoTraceMode.INVALID

def is_record_trace(self) -> bool:
return self.__trace_mode == TraceMode.RECORD
return self.__trace_mode == ZeRoTraceMode.RECORD

def _invalidate_trace(self) -> None:
if self.is_invalid_trace():
raise RuntimeError("attempted to invalidate already invalid trace")
self.__trace_mode = TraceMode.INVALID
self.__trace_mode = ZeRoTraceMode.INVALID
self._clear_trace_structures()

def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
# sub_module must match expectation else invalidate trace cache
if sub_module != self.__submodule_order[self.__step_id]:
expected_module_id = self.__submodule_order[self.__step_id].id
debug_rank0(
f"Invalidate trace cache @ step {self.__step_id}: "
f"expected module {expected_module_id}, but got module {sub_module.id}"
)
self._invalidate_trace()

def record_module(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if not self.is_record_trace():
raise RuntimeError(
f"attempted to record trace when status = {self.__trace_mode}")

self.__submodule_order.append(sub_module)
self.__step_id_module_fetched_for[sub_module.id].append(self.__step_id)

def record_parameters(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if not self.is_record_trace():
raise RuntimeError(
f"attempted to record trace when status = {self.__trace_mode}")

step_id = self.__step_id_module_fetched_for[sub_module.id].popleft()
for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
self.__param_order.append(
__class__.__ParamInTrace(param=param,
step_id_last_used_at=self.__step_id))
step_id_last_used_at=step_id))

def construct_parameter_trace_from_module_trace(self):
"""use module trace to construct parameter trace"""
self.__param_order = []
for sub_module in self.__submodule_order:
self.record_parameters(sub_module)

def reset_step(self) -> None:
"""indicate that we have completed one fwd+bwd for the model"""
Expand All @@ -180,22 +198,38 @@ def reset_step(self) -> None:

if self.is_record_trace():
# Successfully recorded a trace
self.construct_parameter_trace_from_module_trace()
self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = TraceMode.COMPLETE # self.trace_complete = True
self.__trace_mode = ZeRoTraceMode.COMPLETE
print_rank_0(
f"completed trace: {[m.id for m in self.__submodule_order]}",
f"completed record trace: {[m.id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
self.__trace_mode = TraceMode.RECORD
self.__trace_mode = ZeRoTraceMode.RECORD

self.__param_queue = collections.deque(self.__param_order) # reset fetch queue
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
lambda: int(-1e10))
self.__step_id_module_fetched_for = collections.defaultdict(
lambda: collections.deque())
self.__step_id = 0
self.__n_available_params = 0

def _dump_params(self, tag, sub_module, params, step_id=None):
if step_id is None:
step_id = self.__step_id
param_names = [debug_param2name_id(p) for p in params]
print(
f'{tag} step = {step_id} mod = {debug_module2name_id(sub_module)} p_names = {param_names}'
)

def _dump_param_ids(self, tag, mod_id, p_ids, step_id=None):
if step_id is None:
step_id = self.__step_id
print(f'{tag} mod = {mod_id}, step = {step_id}, p_ids = {p_ids}')

"""Fetch and Release
Fetching, prefetching, and releasing parameters
"""
Expand Down Expand Up @@ -264,15 +298,23 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
discarded_from_prefetch_queue.add(param_in_trace.param)

if discarded_from_prefetch_queue != params_not_already_fetched:
raise RuntimeError(
f"tracing error at step {self.__step_id}: \n"
f"module id: {current_submodule.id}, training: {current_submodule.training}\n"
f"expected the next {len(params_not_already_fetched)} parameters in the "
f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}."
f"parameter fetch queue to be {tuple(p.ds_summary(use_debug_name=True) for p in params_not_already_fetched)} \n"
f"but got \n {tuple(p.ds_summary(use_debug_name=True) for p in discarded_from_prefetch_queue)}."
)

def _is_currently_on_nvme(param):
if param.nvme_swapper is None:
return False

return param.ds_tensor.final_location == OFFLOAD_NVME_DEVICE \
and param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE

# kick off all gather for params in the next few submodules (prefetch)
if self.__prefetch_bucket_sz > 0:
max_params_to_prefetch = min(
Expand All @@ -283,11 +325,25 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
while self.__param_queue and numel_prefetching < max_params_to_prefetch:
param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft(
)
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
if param_in_trace.param not in params_to_prefetch:

if _is_currently_on_nvme(param_in_trace.param):
# nvme prefetch is handled elsewhere. Need to break here to preserve fetch order
self.__param_queue.appendleft(param_in_trace)
break

do_prefetch = param_in_trace.param.ds_status == ZeroParamStatus.NOT_AVAILABLE
if param_in_trace.param in params_to_prefetch:
# Avoid duplicates
do_prefetch = False

self.__most_recent_step_id_param_fetched_for[param_in_trace.param] = \
max(self.__most_recent_step_id_param_fetched_for[param_in_trace.param],
param_in_trace.step_id_last_used_at)

if do_prefetch:
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel

for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)
Expand Down
3 changes: 0 additions & 3 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,8 +1222,6 @@ def post_sub_module_forward_function(self, sub_module):
force=False)

param_coordinator = self._get_param_coordinator(training=sub_module.training)
if param_coordinator.is_record_trace():
param_coordinator.record_parameters(sub_module)
param_coordinator.release_sub_module(sub_module)

see_memory_usage(
Expand All @@ -1236,7 +1234,6 @@ def pre_sub_module_backward_function(self, sub_module):
param_coordinator.trace_prologue(sub_module)
if param_coordinator.is_record_trace():
param_coordinator.record_module(sub_module)
param_coordinator.record_parameters(sub_module)
param_coordinator.fetch_sub_module(sub_module)

@torch.no_grad()
Expand Down

0 comments on commit 673cb60

Please sign in to comment.