Skip to content

Commit

Permalink
Fix OOM and type mismatch (#1884)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored Apr 25, 2022
1 parent 4575b2b commit 32d9797
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 120 deletions.
17 changes: 11 additions & 6 deletions deepspeed/runtime/swap_tensor/partitioned_param_swapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import os
import shutil
from enum import Enum
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -34,10 +35,11 @@ class PartitionedParamStatus(Enum):


class AsyncPartitionedParameterSwapper(object):
def __init__(self, ds_config):
def __init__(self, ds_config, model_dtype):

aio_op = AsyncIOBuilder().load(verbose=False)
self.aio_handle = aio_op.aio_handle
self.dtype = model_dtype

#set swap buffers, create aio handles
self._configure_aio(ds_config)
Expand Down Expand Up @@ -83,13 +85,15 @@ def available_swap_in_buffers(self):

def _configure_aio(self, ds_config):
self.swap_config = ds_config.zero_config.offload_param
torch_dtype_string = str(self.dtype).split(".")[1]
self.swap_folder = os.path.join(self.swap_config[OFFLOAD_PARAM_NVME_PATH],
'zero_stage_3',
'fp16params',
f'{torch_dtype_string}params',
f'rank{dist.get_rank()}')
shutil.rmtree(self.swap_folder, ignore_errors=True)
os.makedirs(self.swap_folder, exist_ok=True)

self.swap_element_size = torch.tensor([], dtype=torch.half).element_size()
self.swap_element_size = torch.tensor([], dtype=self.dtype).element_size()

self.aio_config = ds_config.aio_config

Expand All @@ -107,7 +111,7 @@ def _configure_aio(self, ds_config):
self.reserved_buffer_ids = []
self.buffers = torch.empty(int(self.aligned_elements_per_buffer *
self.param_buffer_count),
dtype=torch.half,
dtype=self.dtype,
pin_memory=True,
requires_grad=False)

Expand Down Expand Up @@ -293,8 +297,9 @@ def swap_in(self, params, async_op=True, swap_in_buffers=None):

if swap_in_buffers is None:
if len(self.available_buffer_ids) < len(swap_in_paths):
ids = [p.ds_id for p in params]
print_rank_0(
f'Not enough swap in buffers {len(self.available_buffer_ids)} for params {len(swap_in_paths)}',
f'Not enough swap in buffers {len(self.available_buffer_ids)} for {len(swap_in_paths)} params, ids = {ids}',
force=True)
print_rank_0(
f'Num inflight: params {len(self.inflight_params)}, buffers {len(self.inflight_swap_in_buffers)}, numel = {self.inflight_numel}',
Expand Down Expand Up @@ -392,7 +397,7 @@ def reserve_partitioned_swap_space(self, partition_num_elems):
[self._io_aligned_numel(numel) for numel in partition_num_elems])
self.partitioned_swap_buffer = torch.zeros(aligned_numel,
device='cpu',
dtype=torch.half).pin_memory()
dtype=self.dtype).pin_memory()
self.partitioned_swap_pool = SwapBufferPool([self.partitioned_swap_buffer])

def swap_out_partitioned_params(self, dst_fp16_params, src_fp32_params):
Expand Down
3 changes: 1 addition & 2 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ def get_model():

# Enable fp16 param swapping to NVMe
if self.remote_device == OFFLOAD_NVME_DEVICE:
self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config)
self.param_swapper = AsyncPartitionedParameterSwapper(_ds_config, self.dtype)
else:
self.param_swapper = None

Expand Down Expand Up @@ -877,7 +877,6 @@ def all_gather_coalesced(params: Iterable[Parameter],
instrument_w_nvtx(torch.cat)(
[p.ds_tensor.to(torch.cuda.current_device()) for p in params],
out=partitions[self.rank])

handle = torch_allgather_fn(partitions[self.rank],
flat_tensor,
self.ds_process_group)
Expand Down
125 changes: 86 additions & 39 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ def iter_params(module: Module, recurse=False) -> Iterable[Parameter]:
return map(lambda pair: pair[1], get_all_parameters(module, recurse))


class TraceMode(Enum):
# Record trace of the network during a single forward+backward (for training) or forward (for inference)
RECORD = 1
# Use recorded network trace to optimize current forward+backward or forward
COMPLETE = 2
# Recorded trace does not match current forward+backward or forward pass.
INVALID = 3


class PartitionedParameterCoordinator:
"""Handles partitioning and gathering of parameters."""
class __InflightParamRegistry(UserDict):
Expand Down Expand Up @@ -65,9 +74,8 @@ def __init__(
self.__inflight_param_registry = __class__.__InflightParamRegistry()
# keeps track of the number of submodules invoked so far.
self.__step_id: int = 0
# whether or not we have completed a trace of the entire network. This should
# always be true after the first forward pass + backward pass.
self.trace_complete: bool = False
# network tracing mode
self.__trace_mode: TraceMode = TraceMode.RECORD
# sequence of submodules/parameters in forward pass + backward pass
self.__submodule_order: Iterable[Module] = []
self.__param_order: Iterable[__class__.__ParamInTrace] = []
Expand Down Expand Up @@ -110,13 +118,46 @@ def __init__(
Bookkeeping operations used to track where we are in the forward/backward pass
"""

def record_trace(self, sub_module: Module) -> None:
def _clear_trace_structures(self) -> None:
self.__submodule_order = []
self.__param_order = []
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
lambda: int(-1e10))
self.__param_queue = None

def is_complete_trace(self) -> bool:
return self.__trace_mode == TraceMode.COMPLETE

def is_invalid_trace(self) -> bool:
return self.__trace_mode == TraceMode.INVALID

def is_record_trace(self) -> bool:
return self.__trace_mode == TraceMode.RECORD

def _invalidate_trace(self) -> None:
if self.is_invalid_trace():
raise RuntimeError("attempted to invalidate already invalid trace")
self.__trace_mode = TraceMode.INVALID
self._clear_trace_structures()

def trace_prologue(self, sub_module: Module) -> None:
if self.is_complete_trace():
# sub_module must match expectation else invalidate trace cache
if sub_module != self.__submodule_order[self.__step_id]:
self._invalidate_trace()

def record_module(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if self.trace_complete:
if not self.is_record_trace():
raise RuntimeError(
"attempted to record trace when trace was already complete")

f"attempted to record trace when status = {self.__trace_mode}")
self.__submodule_order.append(sub_module)

def record_parameters(self, sub_module: Module) -> None:
"""adds sub module to trace"""
if not self.is_record_trace():
raise RuntimeError(
f"attempted to record trace when status = {self.__trace_mode}")
for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id):
self.__param_order.append(
__class__.__ParamInTrace(param=param,
Expand All @@ -129,19 +170,25 @@ def reset_step(self) -> None:
f"still have inflight params "
f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")

if not self.trace_complete:
# make sure that recorded parameter and submodule orders are
if not self.is_complete_trace(): # not self.trace_complete:
# Make sure that recorded parameter and submodule orders are
# identical across ranks
assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order])
assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order])
assert_ints_same_as_other_ranks(
[p.step_id_last_used_at for p in self.__param_order])

self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.trace_complete = True
print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}",
force=False)
if self.is_record_trace():
# Successfully recorded a trace
self.__submodule_order = tuple(self.__submodule_order) # freeze
self.__param_order = tuple(self.__param_order) # freeze
self.__trace_mode = TraceMode.COMPLETE # self.trace_complete = True
print_rank_0(
f"completed trace: {[m.id for m in self.__submodule_order]}",
force=False)
else:
# Enable trace recording for next forward/backward pass
self.__trace_mode = TraceMode.RECORD

self.__param_queue = collections.deque(self.__param_order) # reset fetch queue
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(
Expand Down Expand Up @@ -199,9 +246,8 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
torch.cuda.current_stream().wait_stream(self.__allgather_stream)

# kick off parameter prefetches for upcoming modules
# don't prefetch if we dont have a completed model trace, or if we aren't
# training (throws off the tracing and don't want to prefetch modules for bwd)
if self.trace_complete and current_submodule.training:
# don't prefetch if we dont have a completed model trace
if self.is_complete_trace():
# go through the parameters we need for the current module and pop them
# off the fetch queue so that they aren't prefetched later.
# if params have already been popped off the fetch queue by earlier
Expand All @@ -228,24 +274,26 @@ def fetch_sub_module(self, current_submodule: Module) -> None:
)

# kick off all gather for params in the next few submodules (prefetch)
max_params_to_prefetch = min(
self.__max_n_available_params - self.__n_available_params,
self.__prefetch_bucket_sz)
params_to_prefetch = set()
numel_prefetching = 0
while self.__param_queue and numel_prefetching < max_params_to_prefetch:
param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft()
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
if param_in_trace.param not in params_to_prefetch:
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)

if self.__prefetch_nvme:
self.__prefetch_nvme_param_partitions()
if self.__prefetch_bucket_sz > 0:
max_params_to_prefetch = min(
self.__max_n_available_params - self.__n_available_params,
self.__prefetch_bucket_sz)
params_to_prefetch = set()
numel_prefetching = 0
while self.__param_queue and numel_prefetching < max_params_to_prefetch:
param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft(
)
self.__most_recent_step_id_param_fetched_for[
param_in_trace.param] = param_in_trace.step_id_last_used_at
if param_in_trace.param not in params_to_prefetch:
params_to_prefetch.add(param_in_trace.param)
numel_prefetching += param_in_trace.param.ds_numel
for param in params_to_prefetch:
debug_rank0(f"-prefetch: {param.ds_summary()}")
self.__all_gather_params(params_to_prefetch)

if self.__prefetch_nvme:
self.__prefetch_nvme_param_partitions()

self.__step_id += 1

Expand All @@ -256,9 +304,8 @@ def release_sub_module(self, submodule: Module) -> None:
be released."""
params_to_release = (self.__params_to_release(submodule,
self.__step_id)
if self.trace_complete else set(
if self.is_complete_trace() else set(
p.ds_id for p in iter_params(submodule)))

for param in iter_params(submodule):
param.ds_active_sub_modules.discard(submodule.id)
if param.ds_id in params_to_release and not param.is_external_param:
Expand Down Expand Up @@ -311,7 +358,7 @@ def __release_param(self, param: Parameter) -> None:
def __params_to_release(self,
submodule_to_release: Module,
step_id: int) -> Set[int]:
if not self.trace_complete:
if not self.is_complete_trace():
raise RuntimeError("expected trace to be complete")

params_to_release = set(p.ds_id for p in iter_params(submodule_to_release)
Expand All @@ -335,7 +382,7 @@ def __prefetch_nvme_param_partitions(self) -> None:
"""swap in parameter partitions from nvme for those parameters that will be used
after the ones that are already being prefetched into full parameters
"""
if not self.trace_complete:
if not self.is_complete_trace():
return

numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry)
Expand Down
Loading

0 comments on commit 32d9797

Please sign in to comment.