From fe26423260afb3c2a3a38682fa808ee2b8c3b1d0 Mon Sep 17 00:00:00 2001 From: Rana Ali Amjad Date: Fri, 13 Aug 2021 20:58:23 -0700 Subject: [PATCH 01/59] Changes for bfloat16 Zero2 --- csrc/includes/type_shim.h | 10 ++++++++++ deepspeed/runtime/config.py | 15 +++++++++++++-- deepspeed/runtime/constants.py | 16 ++++++++++++++++ deepspeed/runtime/engine.py | 9 +++++++-- 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/csrc/includes/type_shim.h b/csrc/includes/type_shim.h index ba1e188f3e1c..4f4e7a539ac1 100644 --- a/csrc/includes/type_shim.h +++ b/csrc/includes/type_shim.h @@ -26,6 +26,11 @@ __VA_ARGS__; \ break; \ } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } @@ -46,6 +51,11 @@ __VA_ARGS__; \ break; \ } \ + case at::ScalarType::BFloat16: { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 933e2f02c2f7..d7a1e98aa5c0 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -113,6 +113,11 @@ def get_fp16_enabled(param_dict): else: return False +def get_bfloat16_enabled(param_dict): + if BFLOAT16 in param_dict.keys(): + return get_scalar_param(param_dict[BFLOAT16], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT) + else: + return False def get_fp16_master_weights_and_grads_enabled(param_dict): if get_fp16_enabled(param_dict): @@ -128,15 +133,18 @@ def get_loss_scale(param_dict): return get_scalar_param(param_dict[FP16], FP16_LOSS_SCALE, FP16_LOSS_SCALE_DEFAULT) + elif get_bfloat16_enabled(param_dict): + return 1.0 else: return FP16_LOSS_SCALE_DEFAULT - def get_initial_dynamic_scale(param_dict): if get_fp16_enabled(param_dict): initial_scale_power = get_scalar_param(param_dict[FP16], FP16_INITIAL_SCALE_POWER, FP16_INITIAL_SCALE_POWER_DEFAULT) + elif get_bfloat16_enabled(param_dict): + initial_scale_power = 0 else: initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT @@ -791,6 +799,9 @@ def _initialize_params(self, param_dict): self.fp16_enabled = get_fp16_enabled(param_dict) self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) + self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) self.loss_scale = get_loss_scale(param_dict) @@ -964,7 +975,7 @@ def _do_error_check(self): assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." def _do_warning_check(self): - fp16_enabled = self.fp16_enabled or self.zero_enabled + fp16_enabled = self.fp16_enabled vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index a88d1074d517..daf9ffeaadff 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -107,6 +107,22 @@ SPARSE_GRADIENTS = "sparse_gradients" SPARSE_GRADIENTS_DEFAULT = False +######################################### +# BFLOAT16 support +######################################### +# BFLOAT16 feature. By default, this feature is not enabled. +# Users can configure in ds_config.json as below example: +BFLOAT16_FORMAT = ''' +BFLOAT16 parameters should be of the format: +"bfloat16": { + "enabled": true +} +''' +BFLOAT16 = "bfloat16" + +BFLOAT16_ENABLED = "enabled" +BFLOAT16_ENABLED_DEFAULT = False + ######################################### # FP16 support ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index dcb0efd6f56b..dd0559a2e0f5 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -82,6 +82,7 @@ def split_half_float_double_csr(tensors): "torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", + "torch.cuda.BFloat16Tensor", CSRTensor.type() ] @@ -102,7 +103,6 @@ def print_configuration(args, name): dots = '.' * (29 - len(arg)) logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg))) - class DeepSpeedEngine(Module): r"""DeepSpeed engine for training. """ @@ -527,6 +527,9 @@ def fp16_enabled(self): def fp16_master_weights_and_gradients(self): return self._config.fp16_master_weights_and_gradients + def bfloat16_enabled(self): + return self._config.bfloat16_enabled + def amp_enabled(self): return self._config.amp_enabled @@ -756,6 +759,8 @@ def _configure_distributed_model(self, model): f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}" ) self.module.half() + elif self.bfloat16_enabled(): + self.module.bfloat16() else: if not all( [param.dtype == torch.float for param in self.module.parameters()]): @@ -882,7 +887,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters): ) self.optimizer = self._configure_zero_optimizer(basic_optimizer) elif self.amp_enabled(): - assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode" + assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" amp_params = self.amp_params() if self.global_rank == 0: logger.info(f"Initializing AMP with these params: {amp_params}") From 8864f91194d59e00011ca689fecdf0f31b45a2ad Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 28 Sep 2021 19:42:12 -0700 Subject: [PATCH 02/59] ZeRO stage3 optimizations, with some bug fixes optimizations for stage3: - prefetching improvements - batching allgather calls to amortize fixed overhead and improve bandwidth utilization - batching reduce_scatter calls to amortize fixed overhead and improve bandwidth utilization - using *_base variants of allgather and reduce scatter to reduce memory allocations and data movement - more fine grained synchronization for communication that allows blocking on less work - precomputation of fetching code - using a fetch queue rather than deciding what to (pre)fetch at each iteration - limiting queued coalesced communication ops to reduce memory pressure on pytorch cuda caching allocator (not elegant solution) optimizations for stage3-offload: - made some host-device tensor copies async to improve performance bug fixes and qol improvements: - fix init context method when parent modules modify child weights - speed up model initialization by moving model to GPU before weight initialization - fixed unit test imports so that unit tests can be run from any directory - change performance logging to include memory consumption - add logging w/ model size when done partitioning model new features - bfloat16 support for ZeRO 3 --- .../runtime/comm/coalesced_collectives.py | 88 + deepspeed/runtime/config.py | 6 +- deepspeed/runtime/engine.py | 42 +- deepspeed/runtime/utils.py | 9 + .../runtime/zero/partition_parameters.py | 334 +++- deepspeed/runtime/zero/stage3.py | 1445 ++++++++--------- deepspeed/runtime/zero/utils.py | 36 + deepspeed/utils/__init__.py | 1 + deepspeed/utils/nvtx.py | 11 + deepspeed/utils/timer.py | 13 +- tests/unit/__init__.py | 0 tests/unit/megatron_model.py | 3 +- tests/unit/test_activation_checkpointing.py | 2 +- tests/unit/test_adamw.py | 4 +- tests/unit/test_aio.py | 2 +- tests/unit/test_checkpointing.py | 6 +- tests/unit/test_config.py | 4 +- tests/unit/test_configurable_parallel.py | 8 +- tests/unit/test_cuda_backward.py | 11 +- tests/unit/test_cuda_forward.py | 5 +- tests/unit/test_curriculum_learning.py | 4 +- tests/unit/test_data.py | 4 +- tests/unit/test_dist.py | 2 +- tests/unit/test_ds_initialize.py | 4 +- tests/unit/test_dynamic_loss_scale.py | 4 +- tests/unit/test_elastic.py | 4 +- tests/unit/test_flops_profiler.py | 4 +- tests/unit/test_fp16.py | 6 +- tests/unit/test_ignore_unused_parameters.py | 4 +- tests/unit/test_lr_schedulers.py | 4 +- tests/unit/test_moe.py | 6 +- tests/unit/test_multi_output_model.py | 6 +- tests/unit/test_onebit.py | 6 +- tests/unit/test_partition.py | 2 +- tests/unit/test_pipe.py | 4 +- tests/unit/test_pipe_module.py | 4 +- tests/unit/test_pld.py | 4 +- tests/unit/test_runtime_utils.py | 2 +- tests/unit/test_topology.py | 2 +- tests/unit/test_zero.py | 770 ++++++++- tests/unit/test_zero_context.py | 3 +- 41 files changed, 2003 insertions(+), 876 deletions(-) create mode 100644 deepspeed/runtime/comm/coalesced_collectives.py create mode 100644 deepspeed/utils/nvtx.py create mode 100644 tests/unit/__init__.py diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py new file mode 100644 index 000000000000..f8e40326d598 --- /dev/null +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -0,0 +1,88 @@ +"""batched collective operations for overhead amortization and better +bandwidth utilization""" + +import math +from typing import List + +import torch +from torch import Tensor +import torch.distributed +from torch.distributed import ProcessGroup +import torch.nn.functional + +from deepspeed.utils import instrument_w_nvtx + + +@instrument_w_nvtx +@torch.no_grad() +def reduce_scatter_coalesced( + tensors: List[Tensor], + group: ProcessGroup = None, +) -> List[Tensor]: + """simultaneously reduce-scatter a list of tensors - this can be done more + efficiently than individual reduce scatter calls + + TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL + """ + this_rank = torch.distributed.get_rank(group) + world_sz = torch.distributed.get_world_size(group) + + partition_lst_for_each_tensor = tuple( + torch.chunk(tensor.view(-1), + world_sz) for tensor in tensors) + padded_partition_sz_for_each_tensor = tuple( + math.ceil(t.numel() / world_sz) for t in tensors) + + if len(tensors) == 1 and tensors[0].numel() % world_sz == 0: + # if there's only one tensor being reduced and we don't need to pad + # we have an opportunity to avoid a memory allocation + tensor_partition_flat_buffer = tensors[0].view(-1) + else: + # interleave tensor partitions such that the correct reduced partitions of each tensor + # end up at each rank + tensor_partitions_lst_with_padding = [] + for rank in range(world_sz): + for tensor_idx in range(len(tensors)): + # add tensor content + tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank] + tensor_partitions_lst_with_padding.append(tensor_chunk) + + # add padding if necessary + padding_sz = padded_partition_sz_for_each_tensor[ + tensor_idx] - tensor_chunk.numel() + if padding_sz > 0: + tensor_partitions_lst_with_padding.append( + torch.empty(padding_sz, + dtype=tensor_chunk.dtype, + device=tensor_chunk.device)) + + tensor_partition_flat_buffer = instrument_w_nvtx( + torch.cat)(tensor_partitions_lst_with_padding) + + tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( + tensor_partition_flat_buffer, + world_sz) + + # batched reduce-scatter call + instrument_w_nvtx(torch.distributed._reduce_scatter_base)( + tensor_partition_buffer_for_each_rank[this_rank], + tensor_partition_flat_buffer, + group=group, + ) + + # post-divide + tensor_partition_buffer_for_each_rank[this_rank].div_(world_sz) + + # reverse procedure of the interleaving done previously, done on the + # result of the batched reduce-scatter + output_lst: List[Tensor] = [None] * len(tensors) + offset = 0 + for tensor_idx in range(len(tensors)): + output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow( + 0, + offset, + partition_lst_for_each_tensor[tensor_idx][this_rank].numel()) + + offset += padded_partition_sz_for_each_tensor[tensor_idx] + + return output_lst diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index d7a1e98aa5c0..597d180b841b 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -115,7 +115,9 @@ def get_fp16_enabled(param_dict): def get_bfloat16_enabled(param_dict): if BFLOAT16 in param_dict.keys(): - return get_scalar_param(param_dict[BFLOAT16], BFLOAT16_ENABLED, BFLOAT16_ENABLED_DEFAULT) + return get_scalar_param(param_dict[BFLOAT16], + BFLOAT16_ENABLED, + BFLOAT16_ENABLED_DEFAULT) else: return False @@ -801,7 +803,7 @@ def _initialize_params(self, param_dict): param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f"bfloat16 mode is only enabled for ZeRO 2 and 3 currently, got {self.zero_optimization_stage}" self.amp_enabled = get_amp_enabled(param_dict) self.amp_params = get_amp_params(param_dict) self.loss_scale = get_loss_scale(param_dict) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index dd0559a2e0f5..57ee9ff05058 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -43,7 +43,7 @@ import deepspeed.runtime.lr_schedules as lr_schedules import deepspeed.utils.groups as groups from deepspeed.runtime.utils import get_grad_norm -from deepspeed.utils import logger, log_dist, init_distributed +from deepspeed.utils import logger, log_dist, init_distributed, instrument_w_nvtx from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer from deepspeed.utils.debug import debug_extract_module_and_param_names from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop @@ -743,34 +743,32 @@ def is_replicated(p): self.broadcast_src_rank, group=self.data_parallel_group) + @staticmethod + def __check_params(model: Module, dtype: torch.dtype) -> None: + if not all(param.dtype == dtype + for param in model.parameters()) and dist.get_rank() == 0: + raise ValueError( + f"{dtype} is enabled but the following parameters have dtype that is " + f"not {dtype}: " + f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}" + ) + def _configure_distributed_model(self, model): self.module = model if self.fp16_enabled(): if self.zero_optimization_partition_weights() and any( [hasattr(param, 'ds_id') for param in self.module.parameters()]): - if not all( - [param.dtype == torch.half for param in self.module.parameters()]): - names = [ - n for n, - p in self.module.named_parameters() if p.dtype != torch.half - ] - raise ValueError( - f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}" - ) + self.__check_params(self.module, torch.half) self.module.half() elif self.bfloat16_enabled(): + if self.zero_optimization_partition_weights() and any( + hasattr(param, + 'ds_id') for param in self.module.parameters()): + self.__check_params(self.module, torch.bfloat16) self.module.bfloat16() else: - if not all( - [param.dtype == torch.float for param in self.module.parameters()]): - names = [ - n for n, - p in self.module.named_parameters() if p.dtype != torch.float - ] - raise ValueError( - f"fp32 is enabled but the following parameters have dtype that is not fp32: {', '.join(names)}" - ) + self.__check_params(self.module, torch.float) if not self.dont_change_device: self.module.to(self.device) @@ -1283,6 +1281,7 @@ def _scale_loss_by_gas(self, prescaled_loss): return scaled_loss + @instrument_w_nvtx def forward(self, *inputs, **kwargs): r"""Execute forward propagation @@ -1323,7 +1322,8 @@ def forward(self, *inputs, **kwargs): if self.training_dataloader is None: self.tput_timer.start() - loss = self.module(*inputs, **kwargs) + with torch.cuda.nvtx.range("DeepspeedEngine.forward::module_forward"): + loss = self.module(*inputs, **kwargs) if self.zero_optimization_partition_weights(): # Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation). @@ -1352,6 +1352,7 @@ def forward(self, *inputs, **kwargs): return loss + @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary( @@ -1369,6 +1370,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): else: self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) + @instrument_w_nvtx def backward(self, loss, allreduce_gradients=True, release_loss=False): r"""Execute backward pass on the loss diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index 550d2d2e98ef..f6eb81f40e97 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -840,3 +840,12 @@ def call_to_str(base, *args, **kwargs): name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) name += ')' return name + + +def get_only_unique_item(items): + item_set = set(items) + if len(item_set) != 1: + raise RuntimeError(f"expected there to be only one unique element in {items}") + unique_item, = item_set + + return unique_item diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 2c02b96e79fa..fb4f8e2d6f6e 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -6,25 +6,33 @@ import os import time import types +from typing import Callable, Iterable from enum import Enum import functools import itertools +from typing import List import torch +from torch import Tensor +import torch.distributed from torch.distributed.distributed_c10d import _get_global_rank +from torch.nn import Module +from torch.nn import Parameter from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * -from ..utils import see_memory_usage -from deepspeed.utils import log_dist, init_distributed, logger +from ..utils import get_only_unique_item, see_memory_usage +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks +from deepspeed.utils import init_distributed, instrument_w_nvtx, logger from deepspeed.utils.debug import debug_param2name_id_shape, debug_param2name_id_shape_device, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file +from deepspeed.utils.logging import logger from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus from ..config import DeepSpeedConfig param_count = 0 -partitioned_param_data_shape = [1] +partitioned_param_data_shape = [0] def print_rank_0(message, debug=False, force=False): @@ -38,6 +46,11 @@ def print_rank_0(message, debug=False, force=False): # log_rank_file(rank, message) +def debug_rank0(msg: str) -> None: + if torch.distributed.get_rank() == 0: + logger.debug(msg) + + def is_zero_param(parameter): if not torch.is_tensor(parameter): return False @@ -159,38 +172,35 @@ class ZeroParamStatus(Enum): _orig_torch_empty = torch.empty +_orig_torch_zeros = torch.zeros +_orig_torch_ones = torch.ones +_orig_torch_full = torch.full -def empty_cuda_tensor_half(*size, **kwargs): - if not 'device' in kwargs.keys(): - kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = _orig_torch_empty(*size, **kwargs) - if tensor.is_floating_point(): - return tensor.half() - else: - return tensor +def zero_wrapper_for_fp_tensor_constructor(fn: Callable, + target_fp_dtype: torch.dtype) -> Callable: + def wrapped_fn(*args, **kwargs) -> Tensor: + if kwargs.get("device", None) is None: + kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor: Tensor = fn(*args, **kwargs) + if tensor.is_floating_point(): + tensor = tensor.to(target_fp_dtype) - -def new_cuda_tensor_half(cls, *args): - device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = torch.ones((1, 1), device=device).new_empty(*args).half() - if tensor.is_floating_point(): - return tensor.half() - else: return tensor + return wrapped_fn -def empty_cuda_tensor(*size, **kwargs): - if not 'device' in kwargs.keys(): - kwargs['device'] = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = _orig_torch_empty(*size, **kwargs) - return tensor +def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: + def new_tensor(cls, *args) -> Tensor: + device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) + tensor = _orig_torch_empty(0, device=device).new_empty(*args) + if tensor.is_floating_point(): + tensor = tensor.to(dtype) -def new_cuda_tensor(cls, *args): - device = torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])) - tensor = torch.ones((1, 1), device=device).new_empty(*args) - return tensor + return tensor + + return new_tensor # https://stackoverflow.com/a/63851681/9201239 @@ -207,6 +217,19 @@ def recurse(cl): return set(subclass_list) +@instrument_w_nvtx +def free_param(param: Parameter) -> None: + """Free underlying storage of a parameter.""" + assert not param.ds_active_sub_modules, param.ds_summary() + if param.data.is_cuda: + # need to make sure that we don't free the parameter while it is still + # being used for computation + param.data.record_stream(torch.cuda.current_stream()) + # param.data doesn't store anything meaningful in partitioned state + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + + reuse_buffers = False temp_contiguous_tensor = None empty_buffers = {} @@ -223,12 +246,79 @@ def __init__(self, self.mem_efficient_linear = mem_efficient_linear self.enabled = enabled self._set_dtype(ds_config, dtype) - assert self.dtype in [torch.half, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.float]" + assert self.dtype in [torch.half, torch.bfloat16, torch.float], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" def __enter__(self): if not self.enabled: return + def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: + """many models make use of child modules like Linear or Embedding which + perform their own weight initialization in their __init__ methods, + but will then have more weight initialization in a parent module's __init__ + method that modifies weights of child modules, which is typically done + using the Module.apply method. + + since the Init context manager partitions child modules immediately after + they are initialized, without modifying apply we would entirely skip + any initialization done by parent modules. + + to get around this issue, we wrap the function passed to Module.apply + so that the applied function is applied to child modules correctly. + """ + def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable: + if hasattr(fn_to_apply, "wrapped"): + return fn_to_apply + + @functools.wraps(fn_to_apply) + def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: + """gathers parameters before calling apply function. afterwards + parameters are broadcasted to ensure consistency across all ranks + then re-partitioned. + + takes the following steps: + 1. allgathers parameters for the current module being worked on + 2. calls the original function + 3. broadcasts root rank's parameters to the other ranks + 4. re-partitions the parameters + """ + if not all( + is_zero_param(p) + for p in module_to_apply_fn_to.parameters(recurse=False)): + raise RuntimeError( + f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + f"were zero params, is it possible that the parameters were " + f"overwritten after they were initialized? " + f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} " + ) + + params_to_apply_fn_to: Iterable[Parameter] = list( + sorted(module_to_apply_fn_to.parameters(recurse=False), + key=lambda p: p.ds_id)) + + for param in params_to_apply_fn_to: + param.all_gather() + + fn_to_apply(module_to_apply_fn_to) + + for param in params_to_apply_fn_to: + torch.distributed.broadcast(param.data, + 0, + group=param.ds_process_group) + + for param in params_to_apply_fn_to: + param.partition(has_been_updated=True) + + wrapped_fn_to_apply.wrapped = True + + return wrapped_fn_to_apply + + @functools.wraps(orig_module_apply_fn) + def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: + orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply)) + + return wrapped_apply + def partition_after(f): @functools.wraps(f) def wrapper(module, *args, **kwargs): @@ -278,18 +368,23 @@ def _init_subclass(cls, **kwargs): # print(f"subclass={subclass.__module__}.{subclass.__qualname__}") _enable_class(subclass) - # holding on to the current __init__subclass__ for exit + # holding onto some methods so we can put them back the way they were in __exit__ torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ + torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply torch.Tensor.__old_new__ = torch.Tensor.__new__ # Replace .__init__() for future subclasses of torch.nn.Module torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) - if self.dtype == torch.half: - torch.Tensor.__new__ = new_cuda_tensor_half - torch.empty = empty_cuda_tensor_half - else: - torch.Tensor.__new__ = new_cuda_tensor - torch.empty = empty_cuda_tensor + torch.nn.modules.module.Module.apply = apply_with_gather( + torch.nn.modules.module.Module._old_apply) + + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, + self.dtype) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, + self.dtype) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) if self.mem_efficient_linear: print_rank_0( @@ -309,11 +404,15 @@ def _disable_class(cls): for subclass in get_all_subclasses(torch.nn.modules.module.Module): _disable_class(subclass) - # Replace .__init__() for future subclasses of torch.nn.Module + # putting methods back the way we found them torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply torch.Tensor.__new__ = torch.Tensor.__old_new__ torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full # un doing it here will undo it during training # if self.mem_efficient_linear: @@ -321,6 +420,10 @@ def _disable_class(cls): # if self.mem_efficient_linear: # torch.nn.functional.linear = self.linear_bk + if torch.distributed.get_rank() == 0: + logger.info("finished initializing model with %.2fB parameters", + param_count / 1e9) + # Now that we cleaned up the metaclass injection, raise the exception. if exc_type is not None: return False @@ -331,11 +434,69 @@ def _post_init_method(self, module): def _set_dtype(self, ds_config, dtype): if ds_config is not None and dtype is None: - self.dtype = torch.half if ds_config.fp16_enabled else torch.float - elif dtype is None: - self.dtype = torch.half + if ds_config.bfloat16_enabled and ds_config.fp16_enabled: + raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + + if ds_config.bfloat16_enabled: + self.dtype = torch.bfloat16 + elif ds_config.fp16_enabled: + self.dtype = torch.half + else: + self.dtype = torch.float else: - self.dtype = dtype + self.dtype = dtype or torch.half + + +class AllGatherCoalescedHandle: + def __init__( + self, + allgather_handle, + params: List[Parameter], + partitions: List[Tensor], + world_size: int, + ) -> None: + self.__allgather_handle = allgather_handle + self.__params = params + self.__partitions = partitions + self.__world_size = world_size + self.__complete = False + + for param in self.__params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.__complete: + return + + instrument_w_nvtx(self.__allgather_handle.wait)() + + # split the single tensor out into individual tensors + param_offset = 0 + for param in self.__params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + partitions: List[Tensor] = [] + for rank in range(self.__world_size): + param_start = rank * param.ds_tensor.ds_numel + if param_start < param.ds_numel: + part_to_copy = self.__partitions[rank].narrow( + 0, + param_offset, + min(param.ds_numel - param_start, + param.ds_tensor.ds_numel)) + partitions.append(part_to_copy) + + param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) + param.ds_status = ZeroParamStatus.AVAILABLE + + for part_to_copy in partitions: + part_to_copy.record_stream(torch.cuda.current_stream()) + + param_offset += param.ds_tensor.ds_numel + + self.__complete = True # Replaces all parameters in module with Scattered Parameters @@ -531,6 +692,14 @@ def _post_init_method(self, module): print_rank_0( f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}" ) + + if param.is_cuda: + torch.distributed.broadcast(param, 0, self.ds_process_group) + else: + if torch.distributed.get_rank() == 0: + logger.warn(f"param in {module.__class__.__name__} " + f"not on GPU so was not broadcasted from rank 0") + param.partition() see_memory_usage( f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}", @@ -554,12 +723,14 @@ def _convert_to_deepspeed_param(self, param): param.ds_tensor = None # Keeps track of how many active sub-modules need this param at any given point in time - param.ds_active_sub_modules = 0 + param.ds_active_sub_modules = set() # If this flag is true, then the parameters are replicated throughput training # And only partitioned before the step param.ds_persist = False + param.is_external_param = False + # The group that the parameter is scattered across. param.ds_process_group = self.ds_process_group @@ -577,6 +748,63 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): param_list = [cls] return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) + @instrument_w_nvtx + def all_gather_coalesced(params: Iterable[Parameter], + safe_mode: bool = False) -> AllGatherCoalescedHandle: + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + + if safe_mode: + # ensure that same list (with same ordering) of parameters are + # being allgathered across all ranks, otherwise could mix + # data between tensors. + assert_ints_same_as_other_ranks([p.ds_id for p in params]) + # ensure that tensors from each rank agree on the same ds_numel + # otherwise could mix data between tensors. + assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + flat_tensor = torch.empty(partition_sz * self.world_size, + dtype=get_only_unique_item(p.dtype + for p in params), + device=self.local_device, + requires_grad=False) + partitions: List[Parameter] = [] + for i in range(self.world_size): + partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + + instrument_w_nvtx(torch.cat)([p.ds_tensor.data for p in params], + out=partitions[self.rank]) + + handle = instrument_w_nvtx(torch.distributed._all_gather_base)( + flat_tensor, + partitions[self.rank], + group=self.ds_process_group, + async_op=True, + ) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=self.world_size, + ) + def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param print_rank_0( @@ -621,8 +849,23 @@ def padding_size(): def partitioned_size(): return self._partitioned_size(param) + def ds_summary(slf: torch.Tensor) -> dict: + return { + "id": slf.ds_id, + "status": slf.ds_status.name, + "numel": slf.numel(), + "ds_numel": slf.ds_numel, + "shape": tuple(slf.shape), + "ds_shape": tuple(slf.ds_shape), + "requires_grad": slf.requires_grad, + "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None, + "persist": slf.ds_persist, + "active_sub_modules": slf.ds_active_sub_modules, + } + # Collectives for gathering and partitioning parameters param.all_gather = all_gather + param.all_gather_coalesced = all_gather_coalesced param.partition = partition # Collective for averaging gradients @@ -633,6 +876,7 @@ def partitioned_size(): param.aligned_size = aligned_size param.padding_size = padding_size param.partitioned_size = partitioned_size + param.ds_summary = types.MethodType(ds_summary, param) def _aligned_size(self, param): return param.ds_numel + self._padding_size(param) @@ -659,6 +903,7 @@ def _ensure_availability_of_partitioned_params(self, params): elif len(swap_in_flight) > 0: swap_in_flight[0].nvme_swapper.synchronize_reads() + @instrument_w_nvtx def _all_gather(self, param_list, async_op=False, hierarchy=None): # fetches from nvme if the partition is not available and in nvme @@ -697,6 +942,7 @@ def _partition(self, param_list, force=False, has_been_updated=False): #print_rank_0(f"After Partitioning Param {param.ds_id}") # self._param_status(param) + @instrument_w_nvtx def _partition_param(self, param, buffer=None, has_been_updated=False): assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" @@ -725,7 +971,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): f'Before partitioning param {param.ds_id} {param.shape}', force=False) # param.data does not store anything meaningful in partitioned state - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) @@ -746,7 +992,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): numel=partition_size): final_location = OFFLOAD_NVME_DEVICE buffer = self.param_swapper.get_buffer(param, partition_size) - partitioned_tensor = torch.zeros(1, + partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device) partitioned_tensor.data = buffer.data @@ -806,7 +1052,7 @@ def _partition_param(self, param, buffer=None, has_been_updated=False): see_memory_usage(f'Before partitioning param {param.ds_id} {param.shape}', force=False) - param.data = torch.ones(1, dtype=self.dtype).to(param.device) + free_param(param) see_memory_usage(f'After partitioning param {param.ds_id} {param.shape}', force=False) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 0f19c4902183..0e5aef16a0db 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -3,19 +3,27 @@ Licensed under the MIT license. """ -import sys +import gc +from dataclasses import dataclass +import functools import os -from collections import defaultdict, OrderedDict +import collections +from collections import OrderedDict, UserDict import itertools +from typing import Deque, Dict, Iterable, Set, Tuple import torch -from torch.distributed.distributed_c10d import _get_global_rank +from torch.cuda import Event, Stream +from torch.nn import Module, Parameter import torch.distributed as dist import math from torch._six import inf -from torch.autograd import Variable +from torch.nn import Module +from torch.nn.parameter import Parameter from deepspeed.utils.logging import logger from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import _init_external_params @@ -32,7 +40,6 @@ pg_correctness_test = False FWD_MODULE_STACK = list() -from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file def print_rank_0(message, debug=False, force=False): @@ -73,16 +80,34 @@ def lcm(x, y): return x * y // gcd(x, y) +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + +def get_cuda_mem_allocated_str() -> str: + # this is really slow. when enabled the python process becomes slow + # to the point where it can't keep the GPU fed with work, so only enable + # for memory debugging. + # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" + return "xGB" + + def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() +@instrument_w_nvtx def get_all_parameters(sub_module, recurse=False): return itertools.chain(sub_module.named_parameters(recurse=recurse), sub_module.ds_external_parameters()) +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + #apply torch.autograd.Function that calls a backward_function to tensors in output def _apply_to_tensors_only(module, functional, backward_function, outputs): if type(outputs) is tuple: @@ -165,384 +190,320 @@ def _inject_parameters(module, cls): module._parameters = new_param -# TODO Needs to be implemented -class PrefetchCoordinator(object): - def __init__(self): - # step_id keeps track of the number of sub-modules invoked so far - # the step_id is tracking forward and backward sequence of sub-modules - self.step_id = 0 - - # stores the sequence of sub modules in forward+backward pass - self.sub_module_trace = [] - - # maps sub_module id to submodule objects - self.id_to_sub_module_map = {} - - # stores the total number of parameters in each sub_module - self.id_to_sub_module_size_map = {} - - self.trace_completed = False - - self.most_recent_sub_module_step = {} - - # reuse distances - self.reuse_numel_for_step_id = {} - - def record_trace(self, sub_module): - if not self.trace_completed: - self.sub_module_trace.append(sub_module.id) - self.id_to_sub_module_map[sub_module.id] = sub_module - - def print_trace(self): - print_rank_0( - f"The module trace is : {[self.id_to_sub_module_map[module_id].id for module_id in self.sub_module_trace]}" - ) - - def increment_step(self, sub_module): - self.most_recent_sub_module_step[sub_module.id] = self.step_id - self.step_id += 1 - - def reset_step(self): - self.step_id = 0 - - # returns the next numel parameters that will be used next but are not available or inflight - def get_params_to_prefetch(self, sub_module, numel=2000000): - - # numel_in_sub_module = 0 - # for name, param in sub_module.named_parameters(recurse=False): - # numel_in_sub_module += param.ds_numel - - # #if numel_in_sub_module < (numel // 2): - # return [] - - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != self.sub_module_trace[self.step_id]: - print_rank_0( - f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}" - ) - return [] - - params_to_prefetch = [] - total_numel_to_prefetch = 0 - - for i in range(self.step_id, len(self.sub_module_trace)): - module_id = self.sub_module_trace[i] - for _, param in get_all_parameters(self.id_to_sub_module_map[module_id]): - if param.ds_status is ZeroParamStatus.NOT_AVAILABLE and ( - param.ds_id not in [p.ds_id for p in params_to_prefetch]): - params_to_prefetch.append(param) - total_numel_to_prefetch += param.ds_numel - #print_rank_0(f"Total numel to prefetch: {total_numel_to_prefetch}. Param: {param.ds_shape} and numel {param.ds_numel}, numel limit {numel}") - if total_numel_to_prefetch >= numel: # and total_numel_to_prefetch > (numel_in_sub_module // 2): - return params_to_prefetch - - return params_to_prefetch - - # checks if this sub_module will be used again and if so then returns the number of elements - # in the parameters used between this sub_module and the reuse of this sub_module - def get_reuse_distance_in_numel(self, sub_module, sub_module_step_id=None): - #assert is_forward is not None, "is_forward must be set to True for Forward Propagation and False for backward Propagation" - is_there_reuse = False - reuse_distance_in_numel = 1000000000000 - - # set the appropriate trace - trace = self.sub_module_trace - total_steps = len(trace) - if sub_module_step_id is None: - sub_module_step_id = self.most_recent_sub_module_step[sub_module.id] - - # tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing - if sub_module.id != trace[sub_module_step_id]: - print_rank_0( - f"Tracing failed. Cannot tell if the sub_module: {sub_module.id} is reused" - ) - return reuse_distance_in_numel - - # return cached value - if sub_module_step_id in self.reuse_numel_for_step_id: - return self.reuse_numel_for_step_id[sub_module_step_id] - - start_step = self.step_id - print_rank_0(f"Step id is {self.step_id} ") - for step_id in range(start_step, total_steps): - print_rank_0(f"Trace id {trace[step_id]} and sub_module id {sub_module.id}") - if sub_module.id == trace[step_id]: - end_step = step_id - - is_there_reuse = True - reuse_distance_in_numel = self._distance_in_numel( - start_step, - end_step, - trace) - break - - self.reuse_numel_for_step_id[sub_module_step_id] = reuse_distance_in_numel - - return reuse_distance_in_numel - - def _distance_in_numel(self, start_step, end_step, trace): - distance_in_numel = 0 - for step_id in range(start_step, end_step): - module_id = trace[step_id] - for _, param in self.id_to_sub_module_map[module_id].named_parameters(recurse=False): - distance_in_numel += param.ds_numel - for _, param in self.id_to_sub_module_map[module_id].ds_external_parameters(): - distance_in_numel += param.ds_numel - return distance_in_numel - - -class PartitionedParameterCoordinator(object): - def __init__(self, - comm_stream=None, - max_reuse_distance_in_numel=500000000, - max_available_parameters_in_numel=700000000): - - self.in_flight_handles = [] - self.params_in_flight = [] - self.comm_stream = comm_stream if comm_stream is not None else torch.cuda.current_stream( - ) - self.prefetch_coordinator = PrefetchCoordinator() - self.hierarchy = 0 - - self.total_available_parameter_numel = 0 - self.max_available_parameters_in_numel = max_available_parameters_in_numel - +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" + class __InflightParamRegistry(UserDict): + """registry for parameters in flight""" + def __setitem__(self, + param: Parameter, + handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"attempted to add non-inflight parameter to registry {param.ds_summary()}" + ) + self.data[param] = handle + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = __class__.__InflightParamRegistry() + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # whether or not we have completed a trace of the entire network. This should + # always be true after the first forward pass + backward pass. + self.trace_complete: bool = False + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel # max distance between two use of the module beyond which module is released - self.max_reuse_distance_in_numel = max_reuse_distance_in_numel - - def _increment_available_parameter_numel(self, increment): - self.total_available_parameter_numel += increment - - def _decrement_available_parameter_numel(self, decrement): - self.total_available_parameter_numel -= decrement + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: collections.deque = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + + # stream that will be used for allgather operations + self.__allgather_stream: Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesnt require any configuration by the user. + self.__ongoing_fetch_events: Deque[Event] = collections.deque() + self.__max_ongoing_fetch_events: int = 2 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ - '''-----------------------Tracing and Prefetching ---------------''' + def record_trace(self, sub_module: Module) -> None: + """adds sub module to trace""" + if self.trace_complete: + raise RuntimeError( + "attemted to record trace when trace was already complete") + + self.__submodule_order.append(sub_module) + for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + self.__param_order.append( + __class__.__ParamInTrace(param=param, + step_id_last_used_at=self.__step_id)) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + if self.__inflight_param_registry: + raise RuntimeError( + f"still have inflight params " + f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + + if not self.trace_complete: + # make sure that recorded parameter and submodule orders are + # identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.trace_complete = True + print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", + force=True) + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + self.__step_id = 0 + self.__n_available_params = 0 + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ - def record_trace(self, sub_module): - self.prefetch_coordinator.record_trace(sub_module) + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + "allocated": get_cuda_mem_allocated_str() + })) + + with torch.cuda.nvtx.range("fetch_kickoff"): + params_to_fetch = frozenset(iter_params(current_submodule)) + if self.trace_complete: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + with torch.cuda.nvtx.range("fetch"): + self.__all_gather_params(params_to_fetch) + + # wait for parameters in the immediately needed submodule to become available + for param in iter_params(current_submodule): + param.ds_active_sub_modules.add(current_submodule.id) + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + with torch.cuda.stream(self.__allgather_stream): + self.__inflight_param_registry.pop(param).wait() + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + torch.cuda.current_stream().wait_stream(self.__allgather_stream) + + with torch.cuda.nvtx.range("prefetch_kickoff"): + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + params_to_release = (self.__params_to_release(submodule, + self.__step_id) + if self.trace_complete else set( + p.ds_id for p in iter_params(submodule))) + + for param in iter_params(submodule): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self) -> None: + """release all module parameters""" + for param in map(lambda p: p.param, self.__param_order): + if param in self.__inflight_param_registry: + raise RuntimeError(f"param {param.ds_summary()} still in flight") + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param_in_trace in self.__param_order: + if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError( + f"{param_in_trace.param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter]) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + self.__n_available_params += param.ds_numel + + if partitioned_params: + with torch.cuda.stream(self.__allgather_stream): + # only allow a certain number of fetch events to be queued at once + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + handle = partitioned_params[0].all_gather_coalesced(partitioned_params) + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + + for param in partitioned_params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle + + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, + submodule_to_release: Module, + step_id: int) -> Set[int]: + if not self.trace_complete: + raise RuntimeError("expected trace to be complete") + + params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) + if not p.ds_persist) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed > self.__max_reuse_dist_in_numel: + break + for param in iter_params(module): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel - def finish_tracing(self, print_trace=False): - self.prefetch_coordinator.trace_completed = True + return params_to_release - if print_trace: - self.prefetch_coordinator.print_trace() + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) - #swap in parameter partitions from nvme for those parameters that will be used - # after the ones that are already being prefetched into full parameters - def _prefetch_nvme_param_partitions(self, sub_module, params_in_flight): - numel_in_flight = sum([param.ds_tensor.ds_numel for param in params_in_flight]) - upcoming_param_list = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=2 * numel_in_flight) + numel_considered = 0 swap_in_params = [] - for param in upcoming_param_list: - if len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers(): + for _, param in self.__param_queue: + if param.nvme_swapper is None: + raise RuntimeError( + f"expected param {param.ds_summary()} to have nvme swapper") + if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= + param.nvme_swapper.available_swap_in_buffers()): break if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: swap_in_params.append(param) - if len(swap_in_params) > 0: + if swap_in_params: swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - # Pre fetches the parameters for sub_modules that comes after - # the current sub_module. This call is asynchronous - def prefetch_next_sub_modules(self, sub_module, numel=5000000, nvme=False): - - params_to_prefetch = [] - if not self.prefetch_coordinator.trace_completed: - return params_to_prefetch - - # prefetch if there is no current prefetching in flight - if not self.in_flight_handles and self.total_available_parameter_numel < self.max_available_parameters_in_numel: - params_to_prefetch = self.prefetch_coordinator.get_params_to_prefetch( - sub_module, - numel=numel) - - self._all_gather(params_to_prefetch, async_op=True) - for param in params_to_prefetch: - param.ds_status = ZeroParamStatus.INFLIGHT - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - - if nvme: - self._prefetch_nvme_param_partitions(sub_module, params_to_prefetch) - - self._print_prefetch_elements_info(sub_module, params_to_prefetch) - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching parameters {[param.ds_id for param in params_to_prefetch]} and available {self.total_available_parameter_numel}, max limit {self.max_available_parameters_in_numel}", - force=False) - - def _print_prefetch_elements_info(self, sub_module, params_to_prefetch): - sub_module_numel = 0.0 - for name, param in sub_module.named_parameters(recurse=False): - sub_module_numel += param.ds_numel - numel_being_prefetched = 0 - for param in params_to_prefetch: - numel_being_prefetched = param.ds_numel - print_rank_0( - f"{'--' * self.hierarchy}--PreFetching {numel_being_prefetched} numels and number of numel in the next sub module is {sub_module_numel}", - force=False) - - def increment_step(self, sub_module): - self.prefetch_coordinator.increment_step(sub_module) - - def reset_step(self): - self.prefetch_coordinator.reset_step() - - '''----------------------------------------------------------------------''' - - # Fetches the parameters in the sub_module - # This call is blocking - def fetch_sub_module(self, sub_module): - partitioned_params = [] - params_in_flight = False - print_rank_0( - f"{'--' * self.hierarchy}Fetching params in module {debug_module2name_class(sub_module)}" - ) - params_to_fetch = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - # print([n for n,p in sub_module.named_parameters(recurse=False)]) - - if hasattr(sub_module, 'ds_external_parameters'): - print_rank_0( - f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}" - ) - params_to_fetch += [ - param for _, - param in sub_module.ds_external_parameters() - ] - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_fetch: - param.ds_active_sub_modules += 1 - print_rank_0( - f"{'--' * self.hierarchy}--Fetching parameters {debug_param2name_id_shape(param)} with active sub modules {param.ds_active_sub_modules}" - ) - - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is already available" - ) - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - print_rank_0( - f"{'--' * self.hierarchy}--Parameter {debug_param2name_id(param)} is being fetched" - ) - partitioned_params.append(param) - - # keeping track of number of elements consumed by available parameters - self._increment_available_parameter_numel(param.ds_numel) - print_rank_0(f"Incrementing with parameter id {param.ds_id}") - - if param.ds_status == ZeroParamStatus.INFLIGHT: - params_in_flight = True - print_rank_0( - f"{'--' * self.hierarchy}--Parameters {debug_param2name_id(param)} is already in flight (prefetched)" - ) - self.hierarchy += 1 - - # parameters are partitioned and need to be allgathered - self._all_gather(partitioned_params, async_op=True) - - # parameters are inflight and communication needs to be completed - if partitioned_params or params_in_flight: - self._synchronize_communication() - - for _, param in sub_module.named_parameters(recurse=False): - param.ds_status = ZeroParamStatus.AVAILABLE - print_rank_0( - f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}", - force=False) - #print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}") - - def release_sub_module(self, sub_module): - self.hierarchy -= 1 - print_rank_0( - f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}" - ) - params_to_release = [ - param for _, - param in sub_module.named_parameters(recurse=False) - ] - - if hasattr(sub_module, 'ds_external_parameters'): - #print_rank_0(f"Releasing external parameters {sub_module.ds_external_parameters()}") - params_to_release += [ - param for _, - param in sub_module.ds_external_parameters() - ] - - # for _, param in sub_module.named_parameters(recurse=False): - for param in params_to_release: - param.ds_active_sub_modules -= 1 - if not param.ds_active_sub_modules and not self._keep_for_later( - sub_module) and not param.ds_persist: - print_rank_0( - f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}", - force=False) - - # Keeping track of number of elements that are consumed by available parameters - self._decrement_available_parameter_numel(param.ds_numel) - see_memory_usage( - f"Before releasing param {debug_param2name_id_numel(param)}", - force=False) - param.partition(hierarchy=self.hierarchy) - see_memory_usage( - f"After releasing param {debug_param2name_id_numel(param)}", - force=False) - - param.ds_status = ZeroParamStatus.NOT_AVAILABLE - else: - - print_rank_0( - f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}", - force=False) - - def release_and_reset_parameter(self, param): - param.ds_active_sub_modules = 0 - if param.ds_status == ZeroParamStatus.AVAILABLE: - print_rank_0( - f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persistence {param.ds_persist}" - ) - self._decrement_available_parameter_numel(param.ds_numel) - param.partition() - - def _keep_for_later(self, sub_module): - if not self.prefetch_coordinator.trace_completed: - return False - if self.max_reuse_distance_in_numel == 0: - return False - reuse_distance_in_numel = self.prefetch_coordinator.get_reuse_distance_in_numel( - sub_module) - #print_rank_0(f"Reuse distance and numel for sub_module id {sub_module.id} is {reuse_distance_in_numel}") - return reuse_distance_in_numel < self.max_reuse_distance_in_numel - - def _all_gather(self, partitioned_params, async_op=False): - with torch.cuda.stream(self.comm_stream): - handles = partitioned_params[0].all_gather( - param_list=partitioned_params, - async_op=async_op, - hierarchy=self.hierarchy) if partitioned_params else None - - if handles is not None: - self.in_flight_handles.extend(handles) - self.params_in_flight.extend(partitioned_params) - - def _synchronize_communication(self, synchronize_streams=True): - assert len(self.params_in_flight) == len(self.in_flight_handles) - for handle, param in zip(self.in_flight_handles, self.params_in_flight): - if handle is not None: - with torch.cuda.stream(self.comm_stream): - handle.wait() - param.ds_status = ZeroParamStatus.AVAILABLE - self.comm_stream.synchronize() - torch.cuda.synchronize() if synchronize_streams else None - self.in_flight_handles = [] - self.params_in_flight = [] - - class PreBackwardFunction(torch.autograd.Function): @staticmethod def forward(ctx, module, pre_backward_function, outputs): @@ -589,9 +550,6 @@ def backward(ctx, *args): return (None, None) + args -INITIAL_MICRO_STEP_ID = -1 - - class FP16_DeepSpeedZeroOptimizer_Stage3(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -635,6 +593,7 @@ def __init__(self, see_memory_usage("Stage 3 initialize beginning", force=False) if dist.get_rank() == 0: + logger.info(f"initialized {__class__.__name__} with args: {locals()}") logger.info(f"Reduce bucket size {reduce_bucket_size}") logger.info(f"Allgather bucket size {prefetch_bucket_size}") # The fused optimizer does all the work. We need this layer for two reason: @@ -668,14 +627,15 @@ def __init__(self, self.module = module self.elastic_checkpoint = elastic_checkpoint - self.overlap_comm = overlap_comm # Replace ._parameters with a new class to enable auto-registration of # external parameters _inject_parameters(module, ZeROOrderedDict) - if self.overlap_comm: - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + self.gpu_sum: Tensor = torch.zeros(1, + dtype=torch.float, + device=torch.cuda.current_device(), + requires_grad=False) ###################### offload optimizer setup ################################## self.optimizer_swapper = None @@ -685,6 +645,9 @@ def __init__(self, self.offload_optimizer_pin_memory = False self.offload_optimizer_fast_init = False if offload_optimizer_config is not None: + if not contiguous_gradients: + raise ValueError( + "optimizer offload only available with contiguous gradients enabled") self.offload_optimizer = True self.offload_optimizer_pin_memory = offload_optimizer_config[ OFFLOAD_OPTIMIZER_PIN_MEMORY] @@ -715,18 +678,26 @@ def __init__(self, self.device = torch.cuda.current_device( ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ### streams used for overlapping computation with communication + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + ############################################################################ see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - - fetch_stream = torch.cuda.Stream() if self.overlap_comm else None self.param_coordinator = PartitionedParameterCoordinator( - comm_stream=fetch_stream, + prefetch_bucket_sz=int(prefetch_bucket_size), max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters)) - + max_available_parameters_in_numel=int(max_live_parameters), + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.params_in_nvme_and_cpu, + ) see_memory_usage("After Partitioned Parameter Coordinator", force=False) + self.__n_caching_allocator_flushes = 0 + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) #-------------Stage 3 Setup-------------------# # parameters smaller than the threshold will be collectively gathered at the @@ -745,8 +716,6 @@ def __init__(self, self.timers = timers - self.reduce_scatter = reduce_scatter - self.dp_process_group = dp_process_group self.partition_count = dist.get_world_size(group=self.dp_process_group) @@ -764,12 +733,7 @@ def __init__(self, self.gradient_predivide_factor = gradient_predivide_factor self.postscale_gradients = postscale_gradients self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = INITIAL_MICRO_STEP_ID - - if self.reduce_scatter: - assert not self.allreduce_always_fp32, "allreduce_always_fp32 is not yet supported with ZeRO-2 with reduce scatter enabled" - assert self.gradient_predivide_factor == 1.0, "gradient_predivide_factor != 1.0 is not yet supported with ZeRO-2 with reduce scatter enabled" - assert self.postscale_gradients, "pre-scale gradients is not yet supported with ZeRO-2 with reduce scatter enabled" + self.micro_step_id = 0 # Holds the mode parameter # The param.data may not hold any meaningful data @@ -832,12 +796,51 @@ def __init__(self, self.reduce_bucket_size = int(reduce_bucket_size) - self.reduction_event = torch.cuda.Event(enable_timing=False, blocking=False) + # IPG + if contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty( + int(reduce_bucket_size), + dtype=self.dtype, + device=torch.cuda.current_device()) + + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - self.reduction_stream = torch.cuda.Stream( - ) if self.overlap_comm else torch.cuda.current_stream() - self.callback_queued = False - self.copy_grad_stream = torch.cuda.Stream() + grad_partitions_flat_buffer: Tensor = torch.zeros( + sum(p.ds_tensor.ds_numel for p in all_params), + dtype=self.dtype, + device=self.device, + pin_memory=self.offload_optimizer_pin_memory) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[ + param.ds_id] = grad_partitions_flat_buffer.narrow( + 0, + offset, + param.ds_tensor.numel()) + offset += param.ds_tensor.numel() + + self.__params_in_ipg_bucket: List[Parameter] = [] + self.is_gradient_accumulation_boundary: bool = True + + self.__param_reduce_events: Deque[Event] = collections.deque() + self.__max_param_reduce_events: int = 2 + + # map each parameter to its group index and its offset within that group's + # flattened buffer + self.__param_id_to_param_group_and_offset_within_group_buffer = {} + for group_idx, group in enumerate(self.fp16_groups): + offset_within_group = 0 + for param in group: + self.__param_id_to_param_group_and_offset_within_group_buffer[ + param.ds_id] = (group_idx, + offset_within_group) + offset_within_group += param.ds_tensor.ds_numel + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") self.param_dict = {} @@ -848,7 +851,6 @@ def __init__(self, self.extra_large_param_to_reduce = None self.grads_in_ipg_bucket = [] self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 self.params_already_reduced = [] self.is_gradient_accumulation_boundary = True self._release_ipg_buffers() @@ -930,6 +932,60 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=False) + persistent_tensors: Set[Tensor] = set() + for param in self.module.parameters(recurse=True): + param.partition() + persistent_tensors.add(param.ds_tensor) + + FP16_DeepSpeedZeroOptimizer_Stage3.defragment(persistent_tensors) + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After defragmenting", force=True) + + @staticmethod + def defragment(tensors: Set[Tensor]): + cuda_tensors_by_device_and_dtype: Dict[tuple, + Set[Tensor]] = collections.defaultdict( + set) + for tensor in filter(lambda t: t.is_cuda, tensors): + cuda_tensors_by_device_and_dtype[(tensor.device, tensor.dtype)].add(tensor) + + cpu_buffer_and_orig_device_to_tensor_infos: Dict[ + Tuple[Tensor, + torch.device], + List[Tuple[Tensor, + int, + int]]] = collections.defaultdict(list) + for (orig_device, dtype), tensorset in cuda_tensors_by_device_and_dtype.items(): + cpu_buffer = torch.empty(sum(p.numel() for p in tensorset), + dtype=dtype, + device="cpu") + + offset = 0 + for tensor in tensorset: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + # record some data so we can restore the device tensor later + cpu_buffer_and_orig_device_to_tensor_infos[(cpu_buffer, + orig_device)].append( + (tensor, + offset, + tensor_numel)) + + offset += tensor_numel + + gc.collect() + torch.cuda.empty_cache() + + # restore device tensors + for (cpu_buffer, orig_device), tensor_offsets in cpu_buffer_and_orig_device_to_tensor_infos.items(): + device_buffer = cpu_buffer.to(orig_device) + for tensor, offset, tensor_numel in tensor_offsets: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): nvme_swap_folder = os.path.join( offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], @@ -951,6 +1007,10 @@ def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): dtype=torch.float32, timers=self.timers) + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.__params_in_ipg_bucket) + def _create_fp16_partitions(self): dist.barrier() partition_id = dist.get_rank(group=self.dp_process_group) @@ -1394,21 +1454,26 @@ def _create_fp16_sub_groups(self, params_group): def setup_zero_stage3_hooks(self): self.hierarchy = 0 - self._register_hooks_recursively(self.module) - #reset step at the beginning of forward - def _pre_forward_hook(module, *args): - self.param_coordinator.reset_step() + @instrument_w_nvtx + def _pre_forward_hook(_, *args) -> None: + """makes sure all ranks start .forward() at the same time so that we + don't accidentally mix allgathers for different steps/sets of parameters + """ + torch.cuda.synchronize() + dist.barrier() #reset step if in inference mode + @instrument_w_nvtx def _end_of_forward_hook(module, *args): if not torch._C.is_grad_enabled(): self.param_coordinator.reset_step() #likely one of them should be enough but just to be safe - self.module.register_forward_hook(_end_of_forward_hook) self.module.register_forward_pre_hook(_pre_forward_hook) + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) # Add top module to stack trace global FWD_MODULE_STACK @@ -1440,9 +1505,11 @@ def _register_hooks_recursively(self, module, count=[0]): count[0] = count[0] + 1 self._register_hooks_recursively(child, count=count) + @instrument_w_nvtx def _pre_forward_module_hook(module, *args): self.pre_sub_module_forward_function(module) + @instrument_w_nvtx def _post_forward_module_hook(module, input, output): global FWD_MODULE_STACK FWD_MODULE_STACK.pop() @@ -1463,7 +1530,7 @@ def _post_forward_module_hook(module, input, output): for item in filter(lambda item: is_zero_param(item), output): if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.ds_active_sub_modules += 1 + item.is_external_param = True module_to_register = FWD_MODULE_STACK[-1] print_rank_0( f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', @@ -1483,6 +1550,7 @@ def _post_forward_module_hook(module, input, output): self.post_sub_module_forward_function(module) def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx def _run_before_backward_function(sub_module): # some models (e.g. Albert) may run multiple forwards on the same layer in a loop # before doing backwards, so each backward will need a pre-fetch - using reference @@ -1524,6 +1592,7 @@ def _run_before_forward_function(input): def _post_backward_module_hook(module, inputs): module.ds_grads_remaining = 0 + @instrument_w_nvtx def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) @@ -1544,6 +1613,7 @@ def _run_after_backward_function(sub_module): # post backward hook module.register_forward_pre_hook(_post_backward_module_hook) + @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) @@ -1551,23 +1621,15 @@ def pre_sub_module_forward_function(self, sub_module): global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) - self.param_coordinator.record_trace(sub_module) + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) self.param_coordinator.fetch_sub_module(sub_module) see_memory_usage( f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False) - self.param_coordinator.prefetch_next_sub_modules( - sub_module, - numel=self.prefetch_elements, - nvme=self.params_in_nvme_and_cpu) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after prefetch", - force=False) - - self.param_coordinator.increment_step(sub_module) - + @torch.no_grad() def post_sub_module_forward_function(self, sub_module): see_memory_usage( f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", @@ -1579,16 +1641,13 @@ def post_sub_module_forward_function(self, sub_module): f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", force=False) + @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): - self.param_coordinator.record_trace(sub_module) - + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) self.param_coordinator.fetch_sub_module(sub_module) - self.param_coordinator.prefetch_next_sub_modules(sub_module, - numel=self.prefetch_elements) - - self.param_coordinator.increment_step(sub_module) - + @torch.no_grad() def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", @@ -1755,16 +1814,13 @@ def initialize_gradient_partitioning_data_structures(self): param_group, partition_id) + @instrument_w_nvtx def independent_gradient_partition_epilogue(self): self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.reduce_ipg_grads() + self.__reduce_and_partition_ipg_grads() self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - if self.overlap_comm: - self.reduction_stream.synchronize() - - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() + self.__reduce_and_partition_stream.synchronize() # if dist.get_rank() == 0: # logger.info("Params already reduced %s", self.params_already_reduced) @@ -1776,10 +1832,8 @@ def independent_gradient_partition_epilogue(self): if not self.offload_optimizer: for i, sub_group in enumerate(self.fp16_groups): self.averaged_gradients[i] = [ - torch.zeros_like(param.ds_tensor) if param.grad is None else - param.grad.data.narrow(0, - 0, - param.ds_tensor.numel()) + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) for param in sub_group ] # self.averaged_gradients[i] = self.get_flat_partition( @@ -1788,82 +1842,15 @@ def independent_gradient_partition_epilogue(self): # self.fp32_partitioned_groups_flat[i].numel(), # return_tensor_list=True) - self._release_ipg_buffers() - - see_memory_usage(f"End ipg_epilogue", force=False) - - # resets all partition to no reduced - # sets remaining grads to the total number of grads in each partition - # set is grad computed to false for all grads in partition - def reset_partition_gradient_structures(self): - total_partitions = dist.get_world_size(group=self.dp_process_group) - for i, _ in enumerate(self.fp16_groups): - for partition_id in range(total_partitions): - self.is_partition_reduced[i][partition_id] = False - self.remaining_grads_in_partition[i][ - partition_id] = self.total_grads_in_partition[i][partition_id] - - for param_id in self.is_grad_computed[i][partition_id]: - self.is_grad_computed[i][partition_id][param_id] = False - - def initialize_gradient_partition(self, i, param_group, partition_id): - def set_key_value_list(dictionary, key, value): - if key in dictionary: - dictionary[key].append(value) - else: - dictionary[key] = [value] - - def increment_value(dictionary, key): - if key in dictionary: - dictionary[key] += 1 - else: - dictionary[key] = 1 - - partition_size = self.partition_size[i] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for param in param_group: - - param_size = param.numel() - param_id = self.get_param_id(param) - - if (current_index >= start_index and current_index < end_index): - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][ - param_id] = current_index - start_index - self.grad_start_offset[i][partition_id][param_id] = 0 - - elif start_index > current_index and start_index < (current_index + - param_size): - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - set_key_value_list(self.param_to_partition_ids[i], - param_id, - partition_id) - increment_value(self.total_grads_in_partition[i], partition_id) - - self.is_grad_computed[i][partition_id][param_id] = False - - self.grad_partition_insertion_offset[i][partition_id][param_id] = 0 - self.grad_start_offset[i][partition_id][param_id] = first_offset - - current_index = current_index + param_size + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() - self.zero_grad() def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') @@ -1881,6 +1868,7 @@ def wrapper(param, i): param_tmp = param.expand_as(param) grad_acc = param_tmp.grad_fn.next_functions[0][0] + @instrument_w_nvtx def reduce_partition_and_remove_grads(*notneeded): self.reduce_ready_partitions_and_remove_grads(param, i) @@ -1918,13 +1906,7 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", param.ds_numel) - self.reduce_ipg_grads() - - if self.contiguous_gradients and self.overlap_comm: - # Swap ipg_index between 0 and 1 - self.ipg_index = 1 - self.ipg_index - self.report_ipg_memory_usage("In ipg_remove_grads after reduce_ipg_grads", - param.ds_numel) + self.__reduce_and_partition_ipg_grads() param_id = self.get_param_id(param) assert self.params_already_reduced[param_id] == False, \ @@ -1932,68 +1914,93 @@ def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): Gradient computed twice for this partition. \ Multiple gradient reduction is currently not supported" - # keeping the gradients contiguous to prevent memory fragmentation, and avoid flattening - if param.ds_numel > self.reduce_bucket_size: - self.extra_large_param_to_reduce = param + self.__add_grad_to_ipg_bucket(param) - elif self.contiguous_gradients: - #print_rank_0("before new grad tensor move") - new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow( - 0, - self.elements_in_ipg_bucket, - param.ds_numel) - #print_rank_0("after new grad tensor move") - new_grad_tensor.copy_(param.grad.view(-1)) - param.grad.data = new_grad_tensor.data.view_as(param.grad) + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) - self.elements_in_ipg_bucket += param.ds_numel - self.grads_in_ipg_bucket.append(param.grad) - self.params_in_ipg_bucket.append((i, param, param_id)) - self.report_ipg_memory_usage("End ipg_remove_grads", 0) + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( + ) < self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with torch.cuda.stream(self.__reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( + 0, + self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + param.grad.record_stream(torch.cuda.current_stream()) + param.grad.data = new_grad_tensor + + self.__params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.__params_in_ipg_bucket: + return - def gradient_reduction_w_predivide(self, tensor): - dp_world_size = dist.get_world_size(group=self.dp_process_group) + for param in self.__params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError( + f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") - tensor_to_allreduce = tensor + self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - if self.allreduce_always_fp32: - tensor_to_allreduce = tensor.float() + assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( + self.__params_in_ipg_bucket) - if self.postscale_gradients: - if self.gradient_predivide_factor != 1.0: - tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor) + while self.__param_reduce_events and self.__param_reduce_events[0].query(): + self.__param_reduce_events.popleft() + if len(self.__param_reduce_events) > self.__max_param_reduce_events: + self.__param_reduce_events.popleft().synchronize() - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + with torch.cuda.stream(self.__reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks( + [p.ds_id for p in self.__params_in_ipg_bucket]) - if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) - else: - tensor_to_allreduce.div_(dp_world_size) - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) + for partition in grad_partitions: + self.gpu_sum.add_(partition.float().sum()) + self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) - if self.allreduce_always_fp32 and tensor is not tensor_to_allreduce: - tensor.copy_(tensor_to_allreduce) + self.__params_in_ipg_bucket.clear() - return tensor + event = Event() + event.record() + self.__param_reduce_events.append(event) - def average_tensor(self, tensors, params_to_reduce): - with torch.cuda.stream(self.reduction_stream): - if not self.reduce_scatter: - for tensor in tensors: - self.gradient_reduction_w_predivide(tensor) - return - - for tensor in tensors: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) - - # reduction resulting with each rank only holding the gradient partition it owns - # This could either be a reduce scatter or a reduce op depending on how - # parameters are partitionied. The method is implemented by the - # DeepSpeed param extensions to the pytorch parameter, so its up to - # the extension to define what happens here - params_to_reduce[0].reduce_gradients_at_owner( - param_list=params_to_reduce, - hierarchy=self.param_coordinator.hierarchy) + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.allreduce_always_fp32: + full_grads_for_rank = [g.float() for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [ + g.div(self.gradient_predivide_factor) for g in full_grads_for_rank + ] + + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, + self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [ + g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank + ] + + if self.allreduce_always_fp32: + grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank def set_grad_positions(self): for i, group in enumerate(self.fp16_groups): @@ -2010,22 +2017,21 @@ def set_grad_positions(self): #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") current_offset += num_elements - def async_accumulate_grad_in_cpu_via_gpu(self, param, acc_grad_cpu_partition): + def async_accumulate_grad_in_cpu_via_gpu(self, grad, acc_grad_cpu_partition): # copy to a preexisiting buffer to avoid memory allocation penalty dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( 0, 0, - param.ds_tensor.ds_numel) + grad.numel()) if self.micro_step_id > 0: dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True) - param.grad.data.view(-1).add_(dest_buffer) + grad.data.view(-1).add_(dest_buffer) # at the boundary we will send 32bit directly if not self.is_gradient_accumulation_boundary: - acc_grad_cpu_partition.data.copy_(param.grad.data.view(-1), - non_blocking=True) + acc_grad_cpu_partition.data.copy_(grad.data.view(-1), non_blocking=True) def _constant_buffered_norm2(self, input, buffer_size=250000000): norm = None @@ -2086,143 +2092,78 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): return total_norm - def partition_previous_reduced_grads(self): - if not self.previous_reduced_grads: - return - - if self.offload_optimizer: - allocate_grads_in_partition = self.grads_in_partition is None\ - and self.gradient_accumulation_steps > 1 - else: - allocate_grads_in_partition = self.grads_in_partition is None - - if allocate_grads_in_partition: - self.grads_in_partition = [] - - for i, group in enumerate(self.fp16_groups): - total_size = 0 - for param_in_partition in group: - total_size += param_in_partition.ds_tensor.ds_numel - - see_memory_usage( - f"group {i} before creating {total_size} reduced gradients into partition", - force=False) - if self.offload_param_pin_memory: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device).pin_memory()) - else: - self.grads_in_partition.append( - torch.zeros(int(total_size), - dtype=self.dtype, - device=self.device)) - see_memory_usage( - f"group {i} after creating {total_size} reduced gradients into partition", - force=False) - - if self.offload_optimizer: - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - with torch.cuda.stream(self.copy_grad_stream): - self.reduction_stream.synchronize() - for param in self.previous_reduced_grads: - - [i, - dest_offset, - num_elements] = self.grad_position[self.get_param_id(param)] - - if self.offload_optimizer: - param.partition_gradients( - partition_buffers=self.temp_grad_gpu_buffer) - #with torch.cuda.stream(self.copy_grad_stream): - # self.reduction_stream.synchronize() - - if self.gradient_accumulation_steps > 1: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - self.async_accumulate_grad_in_cpu_via_gpu( - param, - fp16_grad_tensor) - - if self.is_gradient_accumulation_boundary: - - self.set_norm_for_param_grad_in_gpu(param) - - self.update_overflow_tracker_for_param_grad(param) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(param.grad.view(-1).float()) - param.grad = None - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - num_elements) - - self.async_inplace_copy_grad_to_fp32_buffer_from_gpu( - param, - fp32_grad_tensor) - else: - # The allreduce buffer will be rewritten. Copy the gradients in partition to a new buffer - fp16_grad_tensor = self.grads_in_partition[i].narrow( - 0, - dest_offset, - num_elements) - param.partition_gradients( - partition_buffers=fp16_grad_tensor, - accumulate=True if self.micro_step_id > 0 else False) - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - self.previous_reduced_grads = [] - def reduce_ipg_grads(self, extra_param=None): - if self.overlap_comm: - self.reduction_stream.synchronize() - - with torch.cuda.stream(self.reduction_stream): - self.partition_previous_reduced_grads() - - params_to_reduce = [param for i, param, param_id in self.params_in_ipg_bucket] - #print(f"Params in ipg bucket {self.params_in_ipg_bucket}") - #print(f"Reducing {[(debug_param2name_id_shape(param), param.grad) for param in params_to_reduce]}") - #exit(0) - if self.contiguous_gradients: - reduction_list = [self.ipg_buffer[self.ipg_index]] - if self.extra_large_param_to_reduce is not None: - reduction_list.append(self.extra_large_param_to_reduce.grad) - self.extra_large_param_to_reduce = None - self.average_tensor(reduction_list, params_to_reduce) - else: - self.buffered_reduce_fallback( - None, - self.grads_in_ipg_bucket, - elements_per_buffer=self.elements_in_ipg_bucket) - - for _, param, param_id in self.params_in_ipg_bucket: - self.params_already_reduced[param_id] = True + @instrument_w_nvtx + def __partition_grads(self, + params_to_release: List[Parameter], + grad_partitions: List[Tensor]) -> None: + for param, grad_partition in zip(params_to_release, grad_partitions): + if param.ds_tensor.ds_numel * dist.get_rank( + self.dp_process_group) > param.ds_numel: + # this grad partition is empty - don't need to do anything + continue - self.previous_reduced_grads = params_to_reduce + # move or accumulate gradient partition to target buffer + grad_buffer = (self.temp_grad_gpu_buffer if self.offload_optimizer else + self.__param_id_to_grad_partition[param.ds_id]).narrow( + 0, + 0, + grad_partition.numel()) + if self.micro_step_id == 0: # don't accumulate + with torch.cuda.nvtx.range("grad_copy"): + grad_buffer.copy_(grad_partition, non_blocking=True) + elif grad_buffer.is_cuda: + with torch.cuda.nvtx.range("grad_add_cuda_dst"): + grad_buffer.add_(grad_partition) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + with torch.cuda.nvtx.range("grad_add_cpu_dst"): + tmp_grad_dst = torch.empty_like(grad_partition) + tmp_grad_dst.copy_(grad_buffer, non_blocking=True) + tmp_grad_dst.add_(grad_partition) + grad_buffer.copy_(tmp_grad_dst, non_blocking=True) + + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + offload_fp32_gradients = {} + offload_fp32_offsets = {} + + if self.gradient_accumulation_steps > 1: + fp16_grad_tensor = self.__param_id_to_grad_partition[param.ds_id] + self.async_accumulate_grad_in_cpu_via_gpu(grad_buffer, + fp16_grad_tensor) + + if self.is_gradient_accumulation_boundary: + # Credit to our user David Minn + if grad_partition is not None: + self.gpu_sum.add_(grad_partition.float().sum()) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.float(), non_blocking=True) + + # free the gradient + param.grad.record_stream(torch.cuda.current_stream()) + param.grad = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.elements_in_ipg_bucket = 0 - ##################################################################### + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) def reduce_ready_partitions_and_remove_grads(self, param, i): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) @@ -2418,15 +2359,20 @@ def get_partition_info(self, tensor_list, partition_size, partition_id): return params_in_partition, params_not_in_partition, first_offset + @instrument_w_nvtx def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ + self.micro_step_id = 0 + # FP32 grad should never exist. # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: + if p.grad is not None and p.grad.is_cuda: + p.grad.record_stream(torch.cuda.current_stream()) p.grad = None else: if p.grad is not None: @@ -2443,6 +2389,7 @@ def _model_parallel_all_reduce(self, tensor, op): op=op, group=self.model_parallel_group) + @instrument_w_nvtx def get_grad_norm_direct(self, gradients, params, norm_type=2): """Clips gradient norm of an iterable of parameters. @@ -2473,15 +2420,15 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: - total_norm = 0.0 # if dist.get_rank() == 0: # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] for g, p in zip(gradients, params): if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_norm = g.data.double().norm(2) - total_norm += param_norm.item()**2 + grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) torch.distributed.all_reduce(total_norm_cuda, op=torch.distributed.ReduceOp.SUM, @@ -2490,7 +2437,7 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2): self._model_parallel_all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda.item()**(1. / norm_type) if total_norm == float( 'inf') or total_norm == -float('inf') or total_norm != total_norm: @@ -2578,19 +2525,17 @@ def stop_timers(self, timer_names): self.timers(name).stop() def _pre_step(self): - self.micro_step_id = INITIAL_MICRO_STEP_ID + self.micro_step_id = 0 print_rank_0(f"Inside Step function") see_memory_usage(f"In step before checking overflow", force=False) print_rank_0("Finished Tracing at Beginning of Step") self.param_coordinator.hierarchy = 0 - self.param_coordinator.finish_tracing(print_trace=True) - - self.param_coordinator.reset_step() print_rank_0("Finished Tracing at Beginning of Step") + @instrument_w_nvtx def _get_norm_groups(self): norm_groups = [] for i, group in enumerate(self.fp16_groups): @@ -2604,6 +2549,7 @@ def _get_norm_groups(self): self.fp16_groups[i])) return norm_groups + @instrument_w_nvtx def _prepare_fp32_grad_for_sub_group(self, sub_group_id): partition_id = dist.get_rank(group=self.dp_process_group) @@ -2619,8 +2565,12 @@ def _prepare_fp32_grad_for_sub_group(self, sub_group_id): # release all the gradient since we have already created a necessary copy in dp_grad_partition self.zero_grad() + for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): + grad.record_stream(torch.cuda.current_stream()) + self.averaged_gradients[sub_group_id] = None + @instrument_w_nvtx def _prepare_sub_group(self, sub_group_id, timer_names=set()): see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', force=False) @@ -2651,6 +2601,7 @@ def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set() see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', force=False) + @instrument_w_nvtx def _release_sub_group(self, sub_group_id, timer_names=set()): see_memory_usage(f'Before release optimizer sub group {sub_group_id}', force=False) @@ -2664,6 +2615,7 @@ def _release_sub_group(self, sub_group_id, timer_names=set()): force=False) # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx def flatten_dense_tensors_aligned(self, tensor_list, alignment): num_elements = 0 for tens in tensor_list: @@ -2735,6 +2687,7 @@ def _overflow_clean_up(self, prev_scale): prev_scale, self.loss_scale)) + @instrument_w_nvtx def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow @@ -2749,6 +2702,7 @@ def _overflow_check_and_loss_scale_update(self): return self.overflow + @instrument_w_nvtx def _post_step(self, timer_names=set()): if self.offload_optimizer: self.reset_cpu_buffers() @@ -2765,6 +2719,7 @@ def _post_step(self, timer_names=set()): see_memory_usage('After zero_optimizer step', force=False) print_rank_0(f"------------------Finishing Step-----------------------") + @instrument_w_nvtx def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): if self.fp16_partitioned_groups_flat[sub_group_id] is not None: self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( @@ -2775,6 +2730,7 @@ def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): else: self._partitioned_params_swap_out(sub_group_id) + @instrument_w_nvtx def step(self, closure=None): """ Not supporting closure. @@ -2816,7 +2772,18 @@ def step(self, closure=None): self.stop_timers(['optimizer_step']) self._post_step(timer_names) - return + + # warn user about caching allocator flushes + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + if alloc_retries > self.__n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes. this happens " + "when there is high memory pressure and is highly detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption", + alloc_retries - self.__n_caching_allocator_flushes) + self.__n_caching_allocator_flushes = alloc_retries def dump_pre_step_gradients(self, debug_fp32_grads): # Dump gradient norms for debugging @@ -2851,6 +2818,7 @@ def dump_post_step_gradients(self): norm_list = [param_norm, ds_norm] + unflat_norm print(f'Post-Step Norms {i} {param_id} = {norm_list}') + @instrument_w_nvtx def unscale_and_clip_grads(self, sub_group_id, total_norm): grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] @@ -2888,14 +2856,14 @@ def has_overflow_partitioned_grads_serial(self): return True return False + @instrument_w_nvtx def has_overflow(self, partition_gradients=True): if partition_gradients: - if self.overlap_comm: + with torch.cuda.stream(self.__reduce_and_partition_stream): self.local_overflow = self._has_inf_or_nan(self.gpu_sum) - self.gpu_sum = torch.zeros(1, dtype=torch.float).cuda() + self.gpu_sum.zero_() - overflow = self.local_overflow if self.offload_optimizer else self.has_overflow_partitioned_grads_serial( - ) + overflow = self.local_overflow #overflow = self.has_overflow_partitioned_grads_serial() overflow_gpu = torch.cuda.ByteTensor([overflow]) torch.distributed.all_reduce(overflow_gpu, @@ -2941,6 +2909,7 @@ def _has_inf_or_nan(x, j=None): return True return False + @instrument_w_nvtx def backward(self, loss, retain_graph=False): """ :attr:`backward` performs the following steps: @@ -2949,31 +2918,14 @@ def backward(self, loss, retain_graph=False): 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ - self.micro_step_id += 1 - print_rank_0( - f"Total fully available parameters {self.param_coordinator.total_available_parameter_numel}" - ) - if self.swap_optimizer: self.optimizer_swapper.pre_backward() see_memory_usage(f"Before backward", force=False) - if self.contiguous_gradients: - self.ipg_buffer = [] - buf_0 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_0) - - # Use double buffers to avoid data access conflict when overlap_comm is enabled. - if self.overlap_comm: - buf_1 = torch.empty(self.reduce_bucket_size, - dtype=self.dtype, - device=torch.cuda.current_device()) - self.ipg_buffer.append(buf_1) - self.ipg_index = 0 self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.param_coordinator.reset_step() '''Partitioning Parameters that were not partitioned Usually if parameters of modules whose input parameters do not require grad computation do not trigger post call and will therefore will remain unpartitioned ''' @@ -2982,9 +2934,34 @@ def backward(self, loss, retain_graph=False): if self.swap_optimizer: self.optimizer_swapper.post_backward() + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + self.__reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, + dest_offset, + num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + @instrument_w_nvtx def _partition_all_parameters(self): - for name, param in self.module.named_parameters(recurse=True): - self.param_coordinator.release_and_reset_parameter(param) + self.param_coordinator.release_and_reset_all() + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") def check_overflow(self, partition_gradients=True): self._check_overflow(partition_gradients) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index f90497069901..21d3291b7e80 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -1,8 +1,12 @@ +import os +from typing import List + import torch import torch.distributed as dist from deepspeed.utils import logger from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.ops.adam import FusedAdam +from deepspeed.utils.nvtx import instrument_w_nvtx def _initialize_parameter_parallel_groups(parameter_parallel_size=None): @@ -44,3 +48,35 @@ def is_zero_supported_optimizer(optimizer): f'Checking ZeRO support for optimizer={optimizer.__class__.__name__} type={type(optimizer)}' ) return type(optimizer) in ZERO_SUPPORTED_OPTIMIZERS + + +def get_lst_from_rank0(lst: List[int]) -> None: + """ + NOTE: creates both communication and synchronization overhead so should be used + sparingly + """ + lst_tensor = torch.tensor( + lst if dist.get_rank() == 0 else [-1] * len(lst), + dtype=int, + # device=torch.cuda.current_device(), + device=torch.device('cuda:{}'.format(os.environ["LOCAL_RANK"])), + requires_grad=False, + ) + dist.broadcast(lst_tensor, src=0, async_op=False) + + return list(lst_tensor.cpu().numpy()) + + +@instrument_w_nvtx +def assert_ints_same_as_other_ranks(ints: List[int]) -> None: + """ + NOTE: creates both communication and synchronization overhead so should be + used sparingly + + takes a list of ints from each rank and ensures that they are the same + across ranks, throwing an exception if they are not. + """ + rank0_ints = get_lst_from_rank0(ints) + if ints != rank0_ints: + raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " + f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") diff --git a/deepspeed/utils/__init__.py b/deepspeed/utils/__init__.py index 892eea166c0a..efa45853f0df 100644 --- a/deepspeed/utils/__init__.py +++ b/deepspeed/utils/__init__.py @@ -1,4 +1,5 @@ from .logging import logger, log_dist from .distributed import init_distributed from .groups import * +from .nvtx import instrument_w_nvtx from deepspeed.runtime.dataloader import RepeatingLoader diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py new file mode 100644 index 000000000000..0b6c98704cea --- /dev/null +++ b/deepspeed/utils/nvtx.py @@ -0,0 +1,11 @@ +import torch + + +def instrument_w_nvtx(func): + """decorator that causes an NVTX range to be recorded for the duration of the + function call.""" + def wrapped_fn(*args, **kwargs): + with torch.cuda.nvtx.range(func.__qualname__): + return func(*args, **kwargs) + + return wrapped_fn diff --git a/deepspeed/utils/timer.py b/deepspeed/utils/timer.py index 1bf5edc9c4f3..e7f197449585 100755 --- a/deepspeed/utils/timer.py +++ b/deepspeed/utils/timer.py @@ -154,10 +154,15 @@ def stop(self, report_speed=True): self.total_elapsed_time += duration if self.local_step_count % self.steps_per_output == 0: if report_speed: - self.logging("{}/{}, SamplesPerSec={}".format( - self.epoch_count, - self.local_step_count, - self.avg_samples_per_sec())) + self.logging( + "{}/{}, SamplesPerSec={}, MemAllocated={}GB, MaxMemAllocated={}GB" + .format(self.epoch_count, + self.local_step_count, + self.avg_samples_per_sec(), + round(torch.cuda.memory_allocated() / 1024**3, + 2), + round(torch.cuda.max_memory_allocated() / 1024**3, + 2))) if self.monitor_memory: virt_mem = psutil.virtual_memory() swap = psutil.swap_memory() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/unit/megatron_model.py b/tests/unit/megatron_model.py index 0957eab02ea0..fd2ef69b7259 100644 --- a/tests/unit/megatron_model.py +++ b/tests/unit/megatron_model.py @@ -1,9 +1,10 @@ +from pathlib import Path import torch import os import sys import math -from common import get_test_path +from .common import get_test_path from deepspeed.pipe import PipelineModule, LayerSpec diff --git a/tests/unit/test_activation_checkpointing.py b/tests/unit/test_activation_checkpointing.py index 73ee6a25df78..3945b4403085 100644 --- a/tests/unit/test_activation_checkpointing.py +++ b/tests/unit/test_activation_checkpointing.py @@ -9,7 +9,7 @@ import deepspeed ckpt = deepspeed.checkpointing.checkpoint -from common import distributed_test +from .common import distributed_test def _compute(module, *inputs, do_checkpoint=False): diff --git a/tests/unit/test_adamw.py b/tests/unit/test_adamw.py index 83e0b5436546..b4bfbf3c260c 100644 --- a/tests/unit/test_adamw.py +++ b/tests/unit/test_adamw.py @@ -2,10 +2,10 @@ import torch import pytest -from common import distributed_test from deepspeed.ops.adam import FusedAdam from deepspeed.ops.adam import DeepSpeedCPUAdam -from simple_model import SimpleModel, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict # yapf: disable #'optimizer, zero_offload, torch_adam, adam_w_mode, resulting_optimizer diff --git a/tests/unit/test_aio.py b/tests/unit/test_aio.py index daa633fe5e44..ac0391176f9d 100755 --- a/tests/unit/test_aio.py +++ b/tests/unit/test_aio.py @@ -4,8 +4,8 @@ import torch import deepspeed import torch.distributed as dist -from common import distributed_test from deepspeed.ops.aio import AsyncIOBuilder +from .common import distributed_test MEGA_BYTE = 1024**2 BLOCK_SIZE = MEGA_BYTE diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index fd4313d83b24..0032186975cb 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -15,15 +15,15 @@ from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3 -from util import required_torch_version +from .util import required_torch_version import argparse import pytest import json import os import numbers -from common import distributed_test -from simple_model import * +from .common import distributed_test +from .simple_model import * def compare_deepspeed_states(saved_model, loaded_model): diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index ad06a851122d..e66544833c8e 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -3,8 +3,8 @@ import pytest import json import argparse -from common import distributed_test, get_test_path -from simple_model import SimpleModel, create_config_from_dict, random_dataloader +from .common import distributed_test, get_test_path +from .simple_model import SimpleModel, create_config_from_dict, random_dataloader import torch.distributed as dist # A test on its own diff --git a/tests/unit/test_configurable_parallel.py b/tests/unit/test_configurable_parallel.py index e6933421089b..d31e89a7725e 100755 --- a/tests/unit/test_configurable_parallel.py +++ b/tests/unit/test_configurable_parallel.py @@ -7,10 +7,10 @@ import numpy as np import torch.multiprocessing as mp import torch.distributed as dist -from common import distributed_test -from simple_model import args_from_dict, create_deepspeed_args -from megatron_model import get_gpt2_model, get_megatron_version -from megatron_model import MockGPT2ModelPipe as GPT2ModelPipe +from .common import distributed_test +from .simple_model import args_from_dict, create_deepspeed_args +from .megatron_model import get_gpt2_model, get_megatron_version +from .megatron_model import MockGPT2ModelPipe as GPT2ModelPipe from deepspeed.utils import RepeatingLoader TORCH_MAJOR = int(torch.__version__.split('.')[0]) diff --git a/tests/unit/test_cuda_backward.py b/tests/unit/test_cuda_backward.py index 6db42e4fa1e7..47c30fb1b34f 100755 --- a/tests/unit/test_cuda_backward.py +++ b/tests/unit/test_cuda_backward.py @@ -1,20 +1,13 @@ -import argparse import numpy as np import torch import torch.nn.functional as F import pytest -import json import random -import time import copy from torch import nn -from modelingpreln import BertEncoder as BertEncoderPreln -from modeling import BertEncoder as BertEncoderPostln -from modeling import BertConfig, BertLayerNorm from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig -import deepspeed - -import sys +from .modeling import BertConfig, BertLayerNorm, BertEncoder as BertEncoderPostln +from .modelingpreln import BertEncoder as BertEncoderPreln #if not deepspeed.ops.__installed_ops__['transformer']: #pytest.skip( diff --git a/tests/unit/test_cuda_forward.py b/tests/unit/test_cuda_forward.py index 200fb5ea0af0..cfb4f04e8f45 100755 --- a/tests/unit/test_cuda_forward.py +++ b/tests/unit/test_cuda_forward.py @@ -8,9 +8,8 @@ import time import copy from torch import nn -from modelingpreln import BertEncoder as BertEncoderPreln -from modeling import BertEncoder as BertEncoderPostln -from modeling import BertLayerNorm, BertConfig +from .modelingpreln import BertEncoder as BertEncoderPreln +from .modeling import BertLayerNorm, BertConfig, BertEncoder as BertEncoderPostln from deepspeed import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig import deepspeed diff --git a/tests/unit/test_curriculum_learning.py b/tests/unit/test_curriculum_learning.py index cf0562ab61e1..3677b5966781 100644 --- a/tests/unit/test_curriculum_learning.py +++ b/tests/unit/test_curriculum_learning.py @@ -7,8 +7,8 @@ import os import numpy as np import time -from common import distributed_test -from simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import Curriculum_SimpleModel, random_dataloader, args_from_dict def test_curriculum_scheduler_fixed_discrete(tmpdir): diff --git a/tests/unit/test_data.py b/tests/unit/test_data.py index 290ca0c3d992..93510e557450 100644 --- a/tests/unit/test_data.py +++ b/tests/unit/test_data.py @@ -2,8 +2,8 @@ import torch import pytest import deepspeed -from common import distributed_test -from simple_model import SimpleModel, args_from_dict, random_dataset +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict, random_dataset def test_repeating_loader(): diff --git a/tests/unit/test_dist.py b/tests/unit/test_dist.py index 25a5fd22770f..d37133603ce4 100644 --- a/tests/unit/test_dist.py +++ b/tests/unit/test_dist.py @@ -1,7 +1,7 @@ import torch import torch.distributed as dist -from common import distributed_test +from .common import distributed_test import pytest diff --git a/tests/unit/test_ds_initialize.py b/tests/unit/test_ds_initialize.py index 2dfa3b481eec..27c9821c8c5c 100644 --- a/tests/unit/test_ds_initialize.py +++ b/tests/unit/test_ds_initialize.py @@ -4,8 +4,8 @@ from torch.optim import Optimizer, Adam, AdamW from torch.optim.lr_scheduler import _LRScheduler, LambdaLR -from simple_model import args_from_dict, SimpleModel -from common import distributed_test +from .simple_model import args_from_dict, SimpleModel +from .common import distributed_test import deepspeed from deepspeed.ops.adam import FusedAdam diff --git a/tests/unit/test_dynamic_loss_scale.py b/tests/unit/test_dynamic_loss_scale.py index 302de55c36a3..65a679d94de7 100755 --- a/tests/unit/test_dynamic_loss_scale.py +++ b/tests/unit/test_dynamic_loss_scale.py @@ -5,8 +5,8 @@ import json import os import numpy as np -from common import distributed_test -from simple_model import SimpleModel, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, args_from_dict def run_model_step(model, gradient_list): diff --git a/tests/unit/test_elastic.py b/tests/unit/test_elastic.py index 62d948d599b0..353d6def37ba 100644 --- a/tests/unit/test_elastic.py +++ b/tests/unit/test_elastic.py @@ -1,8 +1,8 @@ import pytest import deepspeed -from common import distributed_test +from .common import distributed_test from deepspeed.git_version_info import version as ds_version -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict base_ds_config = { "elasticity": { diff --git a/tests/unit/test_flops_profiler.py b/tests/unit/test_flops_profiler.py index f4654f93fb07..9179b8b60cd4 100644 --- a/tests/unit/test_flops_profiler.py +++ b/tests/unit/test_flops_profiler.py @@ -2,8 +2,8 @@ import deepspeed import deepspeed.runtime.utils as ds_utils from deepspeed.profiling.flops_profiler import FlopsProfiler, get_model_profile -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict -from common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .common import distributed_test def test_flops_profiler_in_ds_training(tmpdir): diff --git a/tests/unit/test_fp16.py b/tests/unit/test_fp16.py index e5b5b8efd83d..fa21e2631fac 100755 --- a/tests/unit/test_fp16.py +++ b/tests/unit/test_fp16.py @@ -8,10 +8,10 @@ import json import os from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader -from util import required_torch_version +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader +from .util import required_torch_version try: from apex import amp diff --git a/tests/unit/test_ignore_unused_parameters.py b/tests/unit/test_ignore_unused_parameters.py index 19fd50813872..eb26f46ca209 100644 --- a/tests/unit/test_ignore_unused_parameters.py +++ b/tests/unit/test_ignore_unused_parameters.py @@ -3,8 +3,8 @@ import json import argparse import os -from common import distributed_test -from simple_model import UnusedParametersModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import UnusedParametersModel, random_dataloader, args_from_dict from deepspeed.ops.op_builder import CPUAdamBuilder import deepspeed diff --git a/tests/unit/test_lr_schedulers.py b/tests/unit/test_lr_schedulers.py index d93ac6f171bb..3878ff2b08f6 100755 --- a/tests/unit/test_lr_schedulers.py +++ b/tests/unit/test_lr_schedulers.py @@ -4,8 +4,8 @@ import pytest import json import os -from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, LR_RANGE_TEST_MIN_LR, LR_RANGE_TEST_STEP_RATE, LR_RANGE_TEST_STEP_SIZE, LR_RANGE_TEST_STAIRCASE from deepspeed.runtime.lr_schedules import WARMUP_LR, WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS from deepspeed.runtime.lr_schedules import ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR, CYCLE_FIRST_STEP_SIZE, DECAY_LR_RATE, DECAY_STEP_SIZE diff --git a/tests/unit/test_moe.py b/tests/unit/test_moe.py index cc9eff01afb0..61b2fe1670fa 100644 --- a/tests/unit/test_moe.py +++ b/tests/unit/test_moe.py @@ -8,10 +8,10 @@ import json import os from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader -from util import required_torch_version +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args, SimpleMoEModel, sequence_dataloader +from .util import required_torch_version try: from apex import amp diff --git a/tests/unit/test_multi_output_model.py b/tests/unit/test_multi_output_model.py index ccbe7f484e29..478bdc8d383d 100755 --- a/tests/unit/test_multi_output_model.py +++ b/tests/unit/test_multi_output_model.py @@ -5,9 +5,9 @@ from pytest import approx import json import os -from common import distributed_test -from simple_model import args_from_dict -from multi_output_model import MultiOutputModel, multi_output_dataloader +from .common import distributed_test +from .simple_model import args_from_dict +from .multi_output_model import MultiOutputModel, multi_output_dataloader def create_config_dict(micro_batch_size, grad_accumulation_steps, world_size): diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index 9796a70953f8..c07f79ba7b2d 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -14,9 +14,9 @@ from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from common import distributed_test -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args -from test_pipe import AlexNetPipe, train_cifar +from .common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args +from .test_pipe import AlexNetPipe, train_cifar TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) diff --git a/tests/unit/test_partition.py b/tests/unit/test_partition.py index 7cd264752c6f..f766e4596509 100644 --- a/tests/unit/test_partition.py +++ b/tests/unit/test_partition.py @@ -8,7 +8,7 @@ from deepspeed.runtime.utils import prefix_sum_inc from deepspeed.runtime.utils import PartitionedTensor -from common import distributed_test +from .common import distributed_test @distributed_test(world_size=4) diff --git a/tests/unit/test_pipe.py b/tests/unit/test_pipe.py index 65ae0023b8ec..3b5d1bfc5413 100755 --- a/tests/unit/test_pipe.py +++ b/tests/unit/test_pipe.py @@ -16,7 +16,7 @@ PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from common import distributed_test +from .common import distributed_test def rel_diff(A, B): @@ -24,7 +24,7 @@ def rel_diff(A, B): # All models -from simple_model import args_from_dict +from .simple_model import args_from_dict class AlexNet(nn.Module): diff --git a/tests/unit/test_pipe_module.py b/tests/unit/test_pipe_module.py index a29d22a2a954..4fb129bdc55b 100644 --- a/tests/unit/test_pipe_module.py +++ b/tests/unit/test_pipe_module.py @@ -14,8 +14,8 @@ from deepspeed.pipe import PipelineModule, LayerSpec from deepspeed.utils import RepeatingLoader -from common import distributed_test -from simple_model import args_from_dict +from .common import distributed_test +from .simple_model import args_from_dict HIDDEN_DIM = 32 LAYERS = 8 diff --git a/tests/unit/test_pld.py b/tests/unit/test_pld.py index 784aeff0338f..0672da9177b1 100755 --- a/tests/unit/test_pld.py +++ b/tests/unit/test_pld.py @@ -2,8 +2,8 @@ import deepspeed import pytest from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from common import distributed_test -from simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict @pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) diff --git a/tests/unit/test_runtime_utils.py b/tests/unit/test_runtime_utils.py index 612fa1305181..c27f3e74e636 100644 --- a/tests/unit/test_runtime_utils.py +++ b/tests/unit/test_runtime_utils.py @@ -8,7 +8,7 @@ from deepspeed.utils.logging import log_dist import deepspeed.utils.groups as groups -from common import distributed_test +from .common import distributed_test def test_call_to_str(): diff --git a/tests/unit/test_topology.py b/tests/unit/test_topology.py index 176363688de4..89bb8ec3dc7c 100644 --- a/tests/unit/test_topology.py +++ b/tests/unit/test_topology.py @@ -7,7 +7,7 @@ from deepspeed.runtime.pipe.topology import ProcessTopology as Topo from deepspeed.runtime.pipe.topology import _prime_factors -from common import distributed_test +from .common import distributed_test def test_topology_2d(): diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 173e60e26b81..afa597aa29d8 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -1,14 +1,20 @@ -import torch +import math +from typing import Dict, List, Set import pytest -import json -import argparse -import os import torch.distributed as dist +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn.modules.container import ModuleList +from torch.nn.modules.loss import L1Loss +from torch.nn.parameter import Parameter -from common import distributed_test -from simple_model import SimpleModel, random_dataloader, args_from_dict +from .common import distributed_test +from .simple_model import SimpleModel, random_dataloader, args_from_dict import deepspeed +from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint @@ -434,3 +440,755 @@ def _test_partition_nccl_alignment(args, model, hidden_dim): (2 * nccl_start_alignment_factor) == 0) _test_partition_nccl_alignment(args=args, model=model, hidden_dim=hidden_dim) + + +def _ds_initialize_for_param_partitioning_testing(model: Module, + cfg: dict) -> DeepSpeedEngine: + ds_engine, _, _, _ = deepspeed.initialize( + config=cfg, + model=model, + model_parameters=model.parameters() + ) + + return ds_engine + + +def _print_with_rank(msg: str) -> None: + print(f"RANK{dist.get_rank()}: {msg}") + + +def _assert_partition_status(model: Module, + valid_statuses: Set[ZeroParamStatus]) -> None: + for _, param in model.named_parameters(): + assert param.ds_status in valid_statuses, param.ds_summary() + + +def _assert_fully_available(model: Module) -> None: + for _, param in model.named_parameters(): + assert param.ds_status == ZeroParamStatus.AVAILABLE + + +class EltwiseMultiplicationModule(Module): + def __init__(self, weight: Parameter) -> None: + super().__init__() + self.weight = weight + + def forward(self, x: Tensor) -> Tensor: + _assert_fully_available(self) + result = self.weight * x + + return result + + +class EltwiseMultiplicationTestNetwork(Module): + """used for testing purposes""" + def __init__( + self, + weight1: Parameter, + weight2: Parameter, + weight3: Parameter, + ) -> None: + super().__init__() + self.__layer1 = EltwiseMultiplicationModule(weight1) + self.__layer2 = EltwiseMultiplicationModule(weight2) + self.__layer3 = EltwiseMultiplicationModule(weight3) + + self.loss = L1Loss(reduction="none") + + def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status( + self.__layer1, + {ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE}) + hidden1 = self.__layer1(x) + _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer2, + { + ZeroParamStatus.AVAILABLE + if prefetching else ZeroParamStatus.NOT_AVAILABLE + }) + hidden2 = self.__layer2(hidden1) + _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer3, + { + ZeroParamStatus.AVAILABLE + if prefetching else ZeroParamStatus.NOT_AVAILABLE + }) + y_hat = self.__layer3(hidden2) + _assert_partition_status(self.__layer3, + { + ZeroParamStatus.AVAILABLE + if prefetching else ZeroParamStatus.NOT_AVAILABLE + }) + + loss = self.loss(y_hat, y) + + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + return { + "hidden1": hidden1, + "hidden2": hidden2, + "y_hat": y_hat, + "loss": loss, + } + + +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("fp16_enabled", [True, False]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base( + param_persistence_threshold: int, + fp16_enabled: bool, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": fp16_enabled, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: + return torch.as_tensor(vals, + dtype=dtype + or (torch.float16 if fp16_enabled else torch.float32), + device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.float16, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.float16, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + assert dloss_wrt_layer1.dtype == torch.float + assert dloss_wrt_layer2.dtype == torch.float + assert dloss_wrt_layer3.dtype == torch.float + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor([3 * 2 * 1] * 8, + torch.float)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor(([8] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor(([6 * 2] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], + torch.float)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + + _test_zero3_param_partitioning() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [8100]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_large_param(world_sz: int, + param_sz: int, + init_context_manager: bool) -> None: + class LargeParamModel(Module): + def __init__(self): + super().__init__() + self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32)) + + # only do weight initialization on root rank to + # make sure we are broadcasting correctly from rank 0 + if dist.get_rank() == 0: + partition_sz = math.ceil(self.param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + self.param[offset:offset + partition_sz].fill_(rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + return x * self.param + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + with deepspeed.zero.Init(mem_efficient_linear=False, + enabled=init_context_manager): + model = LargeParamModel() + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) + + for train_iter in range(3): # test multiple iterations to cover prefetching + activation: Tensor = ds_engine( + torch.ones(param_sz, + dtype=torch.float16, + device=ds_engine.device)) + + partition_sz = math.ceil(param_sz / world_sz) + for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)): + activation_from_partition = activation[start_idx:start_idx + + partition_sz] + assert torch.allclose( + activation_from_partition, + torch.full_like(activation_from_partition, + rank_idx)) + + ds_engine.backward(activation.sum()) + ds_engine.allreduce_gradients() + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradient, = avgd_gradients[0] + expected_weight_gradient = (train_iter + 1) * torch.full_like( + weight_gradient, + 1) + + assert torch.allclose(weight_gradient, expected_weight_gradient) + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000]) +@pytest.mark.parametrize("n_layers", [100, 1_000]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_many_params(world_sz: int, + param_sz: int, + n_layers: int, + init_context_manager: bool) -> None: + class ManyParamModel(Module): + def __init__(self) -> None: + super().__init__() + + self.modulelist = ModuleList( + EltwiseMultiplicationModule( + weight=Parameter(torch.empty((param_sz, + ), + dtype=torch.float32))) + for _ in range(n_layers)) + + for layer_num, module in enumerate(self.modulelist): + if dist.get_rank() == 0: + param: Parameter = module.weight + partition_sz = math.ceil(param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + param[offset:offset + partition_sz].fill_(2 * layer_num * + rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + activations = [] + + for module in self.modulelist: + print(f"{dist.get_rank()}: xval: {x.shape}") + x = module(x) + activations.append(x) + + return activations + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=init_context_manager): + model = ManyParamModel() + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) + + for _ in range(3): # test multiple iterations to cover prefetching + activations: List[Tensor] = ds_engine( + torch.ones((param_sz, + ), + dtype=torch.float16, + device=ds_engine.device)) + assert len(activations) == n_layers + + partition_sz = math.ceil(param_sz / world_sz) + expected_activations = torch.empty(param_sz, + dtype=torch.float16, + device=ds_engine.device) + for start_idx in range(0, param_sz, partition_sz): + expected_activations[start_idx:start_idx + + partition_sz] = dist.get_rank() + + for layer_num, activation in enumerate(activations): + expected_activations *= 2 * layer_num + assert torch.allclose(activation, expected_activations) + + # TODO. finish writing this test + ds_engine.backward(activations[-1].sum()) + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradients: List[Tensor] = avgd_gradients[0] + + for layer_num, activation in enumerate(weight_gradients): + pass + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +def test_zero3_init_for_parent_weight_initialization(world_sz): + class ModelWhereParentInitializesChildWeights(Module): + def __init__(self) -> None: + super().__init__() + + self.linear = Linear(12, 1) + + self.apply(self.__init_weights) + + def __init_weights(self, module): + if isinstance(module, Linear): + with torch.no_grad(): + module.weight.fill_(1 + dist.get_rank()) + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=True): + model = ModelWhereParentInitializesChildWeights() + + assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / world_sz) + assert torch.allclose(model.linear.weight.ds_tensor, + torch.full_like(model.linear.weight.ds_tensor, + 1)) + + _distributed_test() + + +@pytest.mark.skip( + reason="depends on upgraded pytorch and nccl that isnt always available") +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base_bf16( + param_persistence_threshold: int, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "bfloat16": { + "enabled": True, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals): + return torch.as_tensor(vals, dtype=torch.bfloat16, device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + expected_grad_dtype = torch.float32 if offload_optimizer else torch.bfloat16 + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * + create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * + create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + + _test_zero3_param_partitioning() diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index f0b9bf973376..830a95ced87b 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -1,5 +1,4 @@ import os -import sys from types import SimpleNamespace import torch @@ -8,7 +7,7 @@ import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape -from common import distributed_test +from .common import distributed_test def setup_serial_env(): From e66aedc2ecddbd18c8655d3095292c5acba25ca3 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 12 Oct 2021 15:03:14 -0700 Subject: [PATCH 03/59] fix import in ut --- tests/unit/test_sparse_grads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sparse_grads.py b/tests/unit/test_sparse_grads.py index 458acaf13339..08584f98de0a 100644 --- a/tests/unit/test_sparse_grads.py +++ b/tests/unit/test_sparse_grads.py @@ -2,7 +2,7 @@ import torch.distributed as dist import deepspeed import pytest -from common import distributed_test +from .common import distributed_test def test_sparse_adam(tmpdir): From 350a7a0204a73fc70301a531a5d8cba25a8f3135 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 12 Oct 2021 15:04:57 -0700 Subject: [PATCH 04/59] ran yapf --- DeepSpeedExamples | 2 +- deepspeed/runtime/config.py | 3 +++ deepspeed/runtime/engine.py | 1 + deepspeed/runtime/zero/stage3.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/DeepSpeedExamples b/DeepSpeedExamples index 174ae3bc8dbb..25d73cf73fb3 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit 174ae3bc8dbb688cfaccb4afa15d6e2cdbe19ce5 +Subproject commit 25d73cf73fb3dc66faefa141b7319526555be9fc diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 597d180b841b..43753c63de01 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -113,6 +113,7 @@ def get_fp16_enabled(param_dict): else: return False + def get_bfloat16_enabled(param_dict): if BFLOAT16 in param_dict.keys(): return get_scalar_param(param_dict[BFLOAT16], @@ -121,6 +122,7 @@ def get_bfloat16_enabled(param_dict): else: return False + def get_fp16_master_weights_and_grads_enabled(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], @@ -140,6 +142,7 @@ def get_loss_scale(param_dict): else: return FP16_LOSS_SCALE_DEFAULT + def get_initial_dynamic_scale(param_dict): if get_fp16_enabled(param_dict): initial_scale_power = get_scalar_param(param_dict[FP16], diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 57ee9ff05058..f98ad0cabb91 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -103,6 +103,7 @@ def print_configuration(args, name): dots = '.' * (29 - len(arg)) logger.info(' {} {} {}'.format(arg, dots, getattr(args, arg))) + class DeepSpeedEngine(Module): r"""DeepSpeed engine for training. """ diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 0e5aef16a0db..708ffadc6325 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -504,6 +504,7 @@ def __prefetch_nvme_param_partitions(self) -> None: if swap_in_params: swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) + class PreBackwardFunction(torch.autograd.Function): @staticmethod def forward(ctx, module, pre_backward_function, outputs): @@ -2092,7 +2093,6 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): return total_norm - @instrument_w_nvtx def __partition_grads(self, params_to_release: List[Parameter], From f38394760aee31ddc1cc0d8b475754140c3370d7 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 13 Oct 2021 08:55:06 -0700 Subject: [PATCH 05/59] improvements to cache flush warn log --- deepspeed/runtime/zero/stage3.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 708ffadc6325..a6abb08b2a5d 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2778,10 +2778,13 @@ def step(self, closure=None): if alloc_retries > self.__n_caching_allocator_flushes: if dist.get_rank() == 0: logger.warning( - "%d pytorch allocator cache flushes. this happens " - "when there is high memory pressure and is highly detrimental to " + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " "performance. if this is happening frequently consider adjusting " - "settings to reduce memory consumption", + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", alloc_retries - self.__n_caching_allocator_flushes) self.__n_caching_allocator_flushes = alloc_retries From b2a1c954eb2527dede9d68046349aea8a48a429b Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 13 Oct 2021 17:00:55 -0700 Subject: [PATCH 06/59] backwards compatibility with older versions of pytorch --- .../runtime/comm/coalesced_collectives.py | 33 +++++- deepspeed/runtime/engine.py | 3 +- .../runtime/zero/partition_parameters.py | 35 +++++- deepspeed/runtime/zero/stage3.py | 108 +++++++++--------- deepspeed/utils/nvtx.py | 12 +- 5 files changed, 117 insertions(+), 74 deletions(-) diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index f8e40326d598..a564e58ac21a 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -11,6 +11,31 @@ import torch.nn.functional from deepspeed.utils import instrument_w_nvtx +from deepspeed.utils.logging import logger + +if hasattr(torch.distributed, "_reduce_scatter_base"): + + def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): + instrument_w_nvtx(torch.distributed._reduce_scatter_base)( + output_tensor, + input_tensor, + group=group, + ) +else: + logger.warning( + "unable to find torch.distributed._reduce_scatter_base. will fall back to " + "torch.distributed.reduce_scatter which will result in suboptimal performance. " + "please consider upgrading your pytorch installation.") + + def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): + input_tensor_lst = list( + torch.chunk(input_tensor, + torch.distributed.get_world_size(group))) + instrument_w_nvtx(torch.distributed.reduce_scatter)( + output_tensor, + input_tensor_lst, + group=group, + ) @instrument_w_nvtx @@ -64,11 +89,9 @@ def reduce_scatter_coalesced( world_sz) # batched reduce-scatter call - instrument_w_nvtx(torch.distributed._reduce_scatter_base)( - tensor_partition_buffer_for_each_rank[this_rank], - tensor_partition_flat_buffer, - group=group, - ) + torch_reduce_scatter_fn(tensor_partition_flat_buffer, + tensor_partition_buffer_for_each_rank[this_rank], + group) # post-divide tensor_partition_buffer_for_each_rank[this_rank].div_(world_sz) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index f98ad0cabb91..015b52eb62a9 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1323,8 +1323,7 @@ def forward(self, *inputs, **kwargs): if self.training_dataloader is None: self.tput_timer.start() - with torch.cuda.nvtx.range("DeepspeedEngine.forward::module_forward"): - loss = self.module(*inputs, **kwargs) + loss = self.module(*inputs, **kwargs) if self.zero_optimization_partition_weights(): # Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation). diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index fb4f8e2d6f6e..18090516a459 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -34,6 +34,32 @@ param_count = 0 partitioned_param_data_shape = [0] +if hasattr(torch.distributed, "_all_gather_base"): + + def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): + return instrument_w_nvtx(torch.distributed._all_gather_base)( + output_tensor, + input_tensor, + group=group, + async_op=True, + ) +else: + logger.warning( + "unable to find torch.distributed._all_gather_base. will fall back to " + "torch.distributed.all_gather which will result in suboptimal performance. " + "please consider upgrading your pytorch installation.") + + def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): + output_tensors = list( + torch.chunk(output_tensor, + torch.distributed.get_world_size(group))) + return instrument_w_nvtx(torch.distributed.all_gather)( + output_tensors, + input_tensor, + group=group, + async_op=True, + ) + def print_rank_0(message, debug=False, force=False): rank = torch.distributed.get_rank() @@ -791,12 +817,9 @@ def all_gather_coalesced(params: Iterable[Parameter], instrument_w_nvtx(torch.cat)([p.ds_tensor.data for p in params], out=partitions[self.rank]) - handle = instrument_w_nvtx(torch.distributed._all_gather_base)( - flat_tensor, - partitions[self.rank], - group=self.ds_process_group, - async_op=True, - ) + handle = torch_allgather_fn(partitions[self.rank], + flat_tensor, + self.ds_process_group) return AllGatherCoalescedHandle( allgather_handle=handle, diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a6abb08b2a5d..d7952cc7dc97 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -326,37 +326,35 @@ def fetch_sub_module(self, current_submodule: Module) -> None: "allocated": get_cuda_mem_allocated_str() })) - with torch.cuda.nvtx.range("fetch_kickoff"): - params_to_fetch = frozenset(iter_params(current_submodule)) - if self.trace_complete: - # go through the parameters we need for the current module and pop them - # off the fetch queue so that they aren't prefetched later. - # if params have already been popped off the fetch queue by earlier - # prefetches we won't look for them here - discarded_from_prefetch_queue = set() - params_not_already_fetched = set( - filter( - lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. - __step_id, - params_to_fetch)) - while self.__param_queue and len(discarded_from_prefetch_queue) < len( - params_not_already_fetched): - param_in_trace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - discarded_from_prefetch_queue.add(param_in_trace.param) - if discarded_from_prefetch_queue != params_not_already_fetched: - raise RuntimeError( - f"tracing error at step {self.__step_id}: " - f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " - f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." - ) - # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") - with torch.cuda.nvtx.range("fetch"): - self.__all_gather_params(params_to_fetch) + params_to_fetch = frozenset(iter_params(current_submodule)) + if self.trace_complete: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch) # wait for parameters in the immediately needed submodule to become available for param in iter_params(current_submodule): @@ -368,23 +366,22 @@ def fetch_sub_module(self, current_submodule: Module) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() torch.cuda.current_stream().wait_stream(self.__allgather_stream) - with torch.cuda.nvtx.range("prefetch_kickoff"): - # kick off all gather for params in the next few submodules (prefetch) - max_params_to_prefetch = min( - self.__max_n_available_params - self.__n_available_params, - self.__prefetch_bucket_sz) - params_to_prefetch = set() - numel_prefetching = 0 - while self.__param_queue and numel_prefetching < max_params_to_prefetch: - param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: - params_to_prefetch.add(param_in_trace.param) - numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) if self.__prefetch_nvme: self.__prefetch_nvme_param_partitions() @@ -2110,19 +2107,16 @@ def __partition_grads(self, 0, grad_partition.numel()) if self.micro_step_id == 0: # don't accumulate - with torch.cuda.nvtx.range("grad_copy"): - grad_buffer.copy_(grad_partition, non_blocking=True) + grad_buffer.copy_(grad_partition, non_blocking=True) elif grad_buffer.is_cuda: - with torch.cuda.nvtx.range("grad_add_cuda_dst"): - grad_buffer.add_(grad_partition) + grad_buffer.add_(grad_partition) else: # if dst is CPU, copy first to src device, do the addition # there, then move back to dst. adding directly to cpu is very slow - with torch.cuda.nvtx.range("grad_add_cpu_dst"): - tmp_grad_dst = torch.empty_like(grad_partition) - tmp_grad_dst.copy_(grad_buffer, non_blocking=True) - tmp_grad_dst.add_(grad_partition) - grad_buffer.copy_(tmp_grad_dst, non_blocking=True) + tmp_grad_dst = torch.empty_like(grad_partition) + tmp_grad_dst.copy_(grad_buffer, non_blocking=True) + tmp_grad_dst.add_(grad_partition) + grad_buffer.copy_(tmp_grad_dst, non_blocking=True) # offload the gradient partition if applicable if self.offload_optimizer: diff --git a/deepspeed/utils/nvtx.py b/deepspeed/utils/nvtx.py index 0b6c98704cea..11579ffbb17f 100644 --- a/deepspeed/utils/nvtx.py +++ b/deepspeed/utils/nvtx.py @@ -4,8 +4,12 @@ def instrument_w_nvtx(func): """decorator that causes an NVTX range to be recorded for the duration of the function call.""" - def wrapped_fn(*args, **kwargs): - with torch.cuda.nvtx.range(func.__qualname__): - return func(*args, **kwargs) + if hasattr(torch.cuda.nvtx, "range"): - return wrapped_fn + def wrapped_fn(*args, **kwargs): + with torch.cuda.nvtx.range(func.__qualname__): + return func(*args, **kwargs) + + return wrapped_fn + else: + return func From d8678fa7df9d7264979e20d4e5c49d133475c931 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 13 Oct 2021 17:59:46 -0700 Subject: [PATCH 07/59] handle edge case where reduced tensor smaller than world size --- .../runtime/comm/coalesced_collectives.py | 13 +++- tests/unit/test_coalesced_collectives.py | 62 +++++++++++++++++++ 2 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 tests/unit/test_coalesced_collectives.py diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index a564e58ac21a..504af45d8b9a 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -52,9 +52,16 @@ def reduce_scatter_coalesced( this_rank = torch.distributed.get_rank(group) world_sz = torch.distributed.get_world_size(group) - partition_lst_for_each_tensor = tuple( - torch.chunk(tensor.view(-1), - world_sz) for tensor in tensors) + partition_lst_for_each_tensor = [None] * len(tensors) + for tensor_idx, tensor in enumerate(tensors): + flattened_tensor = tensor.view(-1) + chunk_sz = math.ceil(tensor.numel() / world_sz) + partition_lst_for_each_tensor[tensor_idx] = [ + flattened_tensor[rank * chunk_sz:rank * chunk_sz + chunk_sz] + for rank in range(0, + world_sz) + ] + padded_partition_sz_for_each_tensor = tuple( math.ceil(t.numel() / world_sz) for t in tensors) diff --git a/tests/unit/test_coalesced_collectives.py b/tests/unit/test_coalesced_collectives.py new file mode 100644 index 000000000000..b86245c8b9bb --- /dev/null +++ b/tests/unit/test_coalesced_collectives.py @@ -0,0 +1,62 @@ +"""unit tests for coalesced collectives""" + +import pytest + +import torch +import torch.distributed as dist +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced + +from .common import distributed_test + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_single_input(): + input = torch.full((6, + ), + dist.get_rank(), + dtype=torch.half, + device=torch.cuda.current_device()) + + (output, ) = reduce_scatter_coalesced([input]) + + assert output.shape == (3, ) + assert torch.allclose(output, torch.full_like(output, 0.5)) + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_two_inputs(): + tensor_kwargs = {"device": torch.cuda.current_device(), "dtype": torch.half} + inputs = [ + dist.get_rank() * torch.arange(0, + 6, + **tensor_kwargs), + dist.get_rank() * torch.arange(6, + 9, + **tensor_kwargs), + ] + + output1, output2 = reduce_scatter_coalesced(inputs) + + if dist.get_rank() == 0: + assert output1.shape == (3, ) + assert torch.allclose(output1, torch.arange(0, 3, **tensor_kwargs) / 2) + assert output2.shape == (2, ) + assert torch.allclose(output2, torch.arange(6, 8, **tensor_kwargs) / 2) + elif dist.get_rank() == 1: + assert output1.shape == (3, ) + assert torch.allclose(output1, torch.arange(3, 6, **tensor_kwargs) / 2) + assert output2.shape == (1, ) + assert torch.allclose(output2, torch.arange(8, 9, **tensor_kwargs) / 2) + + +@distributed_test(world_size=2) +def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz(): + input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) + + (output, ) = reduce_scatter_coalesced([input]) + + if dist.get_rank() == 0: + assert output.shape == (1, ) + assert torch.allclose(output, torch.zeros_like(output)) + elif dist.get_rank() == 1: + assert output.shape == (0, ) From a0faca0be61402f06919cd210b0cce4dab0f3097 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 13 Oct 2021 21:10:36 -0700 Subject: [PATCH 08/59] moved event synchronization to allgather handle wait() call --- deepspeed/runtime/zero/stage3.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d7952cc7dc97..4af272413bce 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -362,7 +362,19 @@ def fetch_sub_module(self, current_submodule: Module) -> None: debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: with torch.cuda.stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events + ) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + self.__inflight_param_registry.pop(param).wait() + + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() torch.cuda.current_stream().wait_stream(self.__allgather_stream) @@ -433,16 +445,7 @@ def __all_gather_params(self, params: Set[Parameter]) -> None: if partitioned_params: with torch.cuda.stream(self.__allgather_stream): - # only allow a certain number of fetch events to be queued at once - while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ - 0].query(): - self.__ongoing_fetch_events.popleft() - if len(self.__ongoing_fetch_events) > self.__max_ongoing_fetch_events: - self.__ongoing_fetch_events.popleft().synchronize() handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - event = Event() - event.record() - self.__ongoing_fetch_events.append(event) for param in partitioned_params: assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() From bf20c90c1435441b9db259c506cdf5dbbaf621e4 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 13 Oct 2021 21:15:18 -0700 Subject: [PATCH 09/59] removed unnecessary barrier call --- deepspeed/runtime/zero/stage3.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4af272413bce..a3e8546a93d7 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1456,14 +1456,6 @@ def _create_fp16_sub_groups(self, params_group): def setup_zero_stage3_hooks(self): self.hierarchy = 0 - @instrument_w_nvtx - def _pre_forward_hook(_, *args) -> None: - """makes sure all ranks start .forward() at the same time so that we - don't accidentally mix allgathers for different steps/sets of parameters - """ - torch.cuda.synchronize() - dist.barrier() - #reset step if in inference mode @instrument_w_nvtx def _end_of_forward_hook(module, *args): @@ -1472,7 +1464,6 @@ def _end_of_forward_hook(module, *args): self.param_coordinator.reset_step() #likely one of them should be enough but just to be safe - self.module.register_forward_pre_hook(_pre_forward_hook) self._register_hooks_recursively(self.module) self.module.register_forward_hook(_end_of_forward_hook) From c51ba461b79e53fe4ca564f57b7c31f0d6216cfb Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 14 Oct 2021 12:45:18 -0700 Subject: [PATCH 10/59] formatting fix after resolving merge conflict --- DeepSpeedExamples | 2 +- deepspeed/runtime/engine.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/DeepSpeedExamples b/DeepSpeedExamples index 25d73cf73fb3..174ae3bc8dbb 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit 25d73cf73fb3dc66faefa141b7319526555be9fc +Subproject commit 174ae3bc8dbb688cfaccb4afa15d6e2cdbe19ce5 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e45ed530f7fb..dd3ea38d1305 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1365,7 +1365,6 @@ def forward(self, *inputs, **kwargs): return loss - def print_forward_breakdown(self, fwd_time): gate_time = 0.0 moe_time = 0.0 @@ -1389,7 +1388,6 @@ def print_forward_breakdown(self, fwd_time): f"rank={torch.distributed.get_rank()} time (ms) | forward: {fwd_time:.2f} (forward_moe: {moe_time:.2f}, 1st alltoall: {falltoall:.2f}, 2nd alltoall: {salltoall:.2f}, top-k: {gate_time:.2f})", ranks=[0]) - @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): # Pass (PP) gas boundary flag to optimizer (required for zero) From ff01f5ccbbf401fbf0cf4e68a5422d35d647e202 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 14 Oct 2021 12:46:10 -0700 Subject: [PATCH 11/59] skip nvme prefetch when trace not complete --- deepspeed/runtime/zero/stage3.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a3e8546a93d7..623ddfc69522 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -487,6 +487,9 @@ def __prefetch_nvme_param_partitions(self) -> None: """swap in parameter partitions from nvme for those parameters that will be used after the ones that are already being prefetched into full parameters """ + if not self.trace_complete: + return + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) numel_considered = 0 From 13093eb80bd023a1860695fa105f08934465c413 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 14 Oct 2021 21:59:08 -0700 Subject: [PATCH 12/59] opportunistically avoid memory allocation in allgather coalesced where possible --- .../runtime/zero/partition_parameters.py | 80 +++++++++++++------ 1 file changed, 57 insertions(+), 23 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 18090516a459..be6bb2943731 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -3,6 +3,7 @@ Licensed under the MIT license. """ +import math import os import time import types @@ -473,6 +474,16 @@ def _set_dtype(self, ds_config, dtype): self.dtype = dtype or torch.half +class AllGatherHandle: + def __init__(self, handle, param: Parameter) -> None: + self.__handle = handle + self.__param = param + + def wait(self) -> None: + instrument_w_nvtx(self.__handle.wait)() + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + class AllGatherCoalescedHandle: def __init__( self, @@ -777,6 +788,7 @@ def all_gather(param_list=None, async_op=False, hierarchy=0): @instrument_w_nvtx def all_gather_coalesced(params: Iterable[Parameter], safe_mode: bool = False) -> AllGatherCoalescedHandle: + # fetches from nvme if the partition is not available and in nvme self._ensure_availability_of_partitioned_params(params) @@ -804,29 +816,51 @@ def all_gather_coalesced(params: Iterable[Parameter], # otherwise could mix data between tensors. assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) - partition_sz = sum(p.ds_tensor.ds_numel for p in params) - flat_tensor = torch.empty(partition_sz * self.world_size, - dtype=get_only_unique_item(p.dtype - for p in params), - device=self.local_device, - requires_grad=False) - partitions: List[Parameter] = [] - for i in range(self.world_size): - partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) - - instrument_w_nvtx(torch.cat)([p.ds_tensor.data for p in params], - out=partitions[self.rank]) - - handle = torch_allgather_fn(partitions[self.rank], - flat_tensor, - self.ds_process_group) - - return AllGatherCoalescedHandle( - allgather_handle=handle, - params=params, - partitions=partitions, - world_size=self.world_size, - ) + if len(params) == 1: + # have an opportunity to avoid some intermediate memory allocations + param, = params + param_buffer = torch.empty( + math.ceil(param.ds_numel / self.world_size) * self.world_size, + dtype=param.dtype, + device=param.device, + requires_grad=False, + ) + handle = torch_allgather_fn( + param.ds_tensor, + param_buffer, + self.ds_process_group, + ) + param.data = param_buffer.narrow(0, + 0, + param.ds_numel).view(param.ds_shape) + return AllGatherHandle(handle, param) + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + flat_tensor = torch.empty(partition_sz * self.world_size, + dtype=get_only_unique_item(p.dtype + for p in params), + device=self.local_device, + requires_grad=False) + partitions: List[Parameter] = [] + for i in range(self.world_size): + partitions.append( + flat_tensor.narrow(0, + partition_sz * i, + partition_sz)) + + instrument_w_nvtx(torch.cat)([p.ds_tensor.data for p in params], + out=partitions[self.rank]) + + handle = torch_allgather_fn(partitions[self.rank], + flat_tensor, + self.ds_process_group) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=self.world_size, + ) def partition(param_list=None, hierarchy=0, has_been_updated=False): cls = param From f19593d6f9c5be2ae6a4f05ae899e0c527db3cdd Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 21 Oct 2021 19:46:27 -0700 Subject: [PATCH 13/59] fix indentation after merge --- deepspeed/runtime/zero/partition_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 54eb7b284ca8..86de8d5bb416 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -923,7 +923,7 @@ def ds_summary(slf: torch.Tensor) -> dict: "active_sub_modules": slf.ds_active_sub_modules, } - def convert_to_zero_parameters(param_list): + def convert_to_zero_parameters(param_list): self._convert_to_zero_parameters(param_list) # Collectives for gathering and partitioning parameters From f72bc78a191a056c844222180a4cd91bb95f25ff Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 21 Oct 2021 19:41:40 -0700 Subject: [PATCH 14/59] fixes to account for parameter offload --- .../runtime/zero/partition_parameters.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 86de8d5bb416..673208223762 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -38,12 +38,17 @@ if hasattr(torch.distributed, "_all_gather_base"): def torch_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group): - return instrument_w_nvtx(torch.distributed._all_gather_base)( - output_tensor, - input_tensor, - group=group, - async_op=True, - ) + try: + return instrument_w_nvtx(torch.distributed._all_gather_base)( + output_tensor, + input_tensor, + group=group, + async_op=True, + ) + except RuntimeError as e: + raise RuntimeError( + f"output_tensor: {output_tensor.device}, input_tensor: {input_tensor.device}" + ) from e else: logger.warning( "unable to find torch.distributed._all_gather_base. will fall back to " @@ -825,24 +830,25 @@ def all_gather_coalesced(params: Iterable[Parameter], param_buffer = torch.empty( math.ceil(param.ds_numel / self.world_size) * self.world_size, dtype=param.dtype, - device=param.device, + device=torch.cuda.current_device(), requires_grad=False, ) handle = torch_allgather_fn( - param.ds_tensor, + param.ds_tensor.to(torch.cuda.current_device()), param_buffer, self.ds_process_group, ) param.data = param_buffer.narrow(0, 0, - param.ds_numel).view(param.ds_shape) + param.ds_numel).view(param.ds_shape).to( + param.device) return AllGatherHandle(handle, param) else: partition_sz = sum(p.ds_tensor.ds_numel for p in params) flat_tensor = torch.empty(partition_sz * self.world_size, dtype=get_only_unique_item(p.dtype for p in params), - device=self.local_device, + device=torch.cuda.current_device(), requires_grad=False) partitions: List[Parameter] = [] for i in range(self.world_size): @@ -851,8 +857,9 @@ def all_gather_coalesced(params: Iterable[Parameter], partition_sz * i, partition_sz)) - instrument_w_nvtx(torch.cat)([p.ds_tensor.data for p in params], - out=partitions[self.rank]) + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(torch.cuda.current_device()) for p in params], + out=partitions[self.rank]) handle = torch_allgather_fn(partitions[self.rank], flat_tensor, From 660df05b46845d46987e4907e2292abe0e0aba02 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 21 Oct 2021 19:51:16 -0700 Subject: [PATCH 15/59] accounting for torch.cuda.memory_stats not being available --- deepspeed/runtime/zero/stage3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5b29327b3bcc..7c7533b6fa59 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2773,7 +2773,9 @@ def step(self, closure=None): self._post_step(timer_names) # warn user about caching allocator flushes - alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( + torch.cuda, + "memory_stats") else 0 if alloc_retries > self.__n_caching_allocator_flushes: if dist.get_rank() == 0: logger.warning( From 4f9477f80207400d21c1d4c621a79faf99ede913 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 21 Oct 2021 20:19:09 -0700 Subject: [PATCH 16/59] moved partition_all_params to optimizer step --- deepspeed/runtime/zero/stage3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7c7533b6fa59..428d4ad7183a 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2735,6 +2735,7 @@ def step(self, closure=None): Not supporting closure. """ self._pre_step() + self._partition_all_parameters() #checks for overflow, adjust the loss scale accordingly if self._overflow_check_and_loss_scale_update(): @@ -2930,10 +2931,6 @@ def backward(self, loss, retain_graph=False): self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) self.param_coordinator.reset_step() - '''Partitioning Parameters that were not partitioned - Usually if parameters of modules whose input parameters do not require - grad computation do not trigger post call and will therefore will remain unpartitioned ''' - self._partition_all_parameters() if self.swap_optimizer: self.optimizer_swapper.post_backward() @@ -2962,6 +2959,9 @@ def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: @instrument_w_nvtx def _partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" self.param_coordinator.release_and_reset_all() for param in iter_params(self.module, recurse=True): if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: From bb34f9013a274d0b4d2aa066f8f3ada55fda2558 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Mon, 25 Oct 2021 13:29:55 -0700 Subject: [PATCH 17/59] allgathering on params before item gets called --- deepspeed/runtime/zero/partition_parameters.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 673208223762..c7d964093617 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -916,6 +916,10 @@ def padding_size(): def partitioned_size(): return self._partitioned_size(param) + def item_override(): + param.all_gather() + return param._orig_item() + def ds_summary(slf: torch.Tensor) -> dict: return { "id": slf.ds_id, @@ -933,6 +937,13 @@ def ds_summary(slf: torch.Tensor) -> dict: def convert_to_zero_parameters(param_list): self._convert_to_zero_parameters(param_list) + def allgather_before(func: Callable) -> Callable: + def wrapped(*args, **kwargs): + param.all_gather() + return func(*args, **kwargs) + + return wrapped + # Collectives for gathering and partitioning parameters param.all_gather = all_gather param.all_gather_coalesced = all_gather_coalesced @@ -948,6 +959,8 @@ def convert_to_zero_parameters(param_list): param.partitioned_size = partitioned_size param.ds_summary = types.MethodType(ds_summary, param) + param.item = allgather_before(param.item) + param.convert_to_zero_parameters = convert_to_zero_parameters def _aligned_size(self, param): From 9f3b50434ea702980ecfa0439564dcb3f1535314 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Mon, 25 Oct 2021 13:54:45 -0700 Subject: [PATCH 18/59] fix param status checks needed after moving partition_all_parameters call to optimizer step --- tests/unit/test_zero.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index afa597aa29d8..de8d281cfe65 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -504,25 +504,20 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: ZeroParamStatus.AVAILABLE } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) - _assert_partition_status( - self.__layer1, - {ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE}) + layerwise_expected_states = { + ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.AVAILABLE, + } + + _assert_partition_status(self.__layer1, layerwise_expected_states) hidden1 = self.__layer1(x) _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE}) - _assert_partition_status(self.__layer2, - { - ZeroParamStatus.AVAILABLE - if prefetching else ZeroParamStatus.NOT_AVAILABLE - }) + _assert_partition_status(self.__layer2, layerwise_expected_states) hidden2 = self.__layer2(hidden1) _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE}) - _assert_partition_status(self.__layer3, - { - ZeroParamStatus.AVAILABLE - if prefetching else ZeroParamStatus.NOT_AVAILABLE - }) + _assert_partition_status(self.__layer3, layerwise_expected_states) y_hat = self.__layer3(hidden2) _assert_partition_status(self.__layer3, { @@ -677,7 +672,6 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: ]) for train_iter in range(3): - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) activations = ds_engine( x=torch.ones((m, n), @@ -695,7 +689,6 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: assert torch.allclose(activations["loss"], expected_loss) ds_engine.backward(activations["loss"].sum()) - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) # check the gradients grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() @@ -754,6 +747,9 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: # TODO. add testing for this - for now we just call it to make sure it # doesnt throw ds_engine.optimizer.step() + # taking an optimizer step invalidates all parameters, make sure everything + # has been partitioned afterwards + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) _test_zero3_param_partitioning() @@ -929,7 +925,6 @@ def _distributed_test(): # TODO. finish writing this test ds_engine.backward(activations[-1].sum()) - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) avgd_gradients = ds_engine.optimizer.averaged_gradients assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" @@ -1190,5 +1185,6 @@ def create_tensor(vals): # TODO. add testing for this - for now we just call it to make sure it # doesnt throw ds_engine.optimizer.step() + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) _test_zero3_param_partitioning() From 1772d410cf8555b3329ce379425da44412308676 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Mon, 25 Oct 2021 14:38:39 -0700 Subject: [PATCH 19/59] fix grad accumulation with optimizer offload --- deepspeed/runtime/zero/stage3.py | 53 +++++++++----------------------- tests/unit/test_zero.py | 2 +- 2 files changed, 15 insertions(+), 40 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 428d4ad7183a..7cf604ff17fa 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -886,16 +886,9 @@ def __init__(self, self.grads_in_partition = None if self.offload_optimizer: - self.accumulated_grads_in_cpu = {} self.norm_for_param_grads = {} self.local_overflow = False - self.temp_grad_buffer_for_gpu_offload = torch.zeros( - largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) - self.temp_grad_gpu_buffer = torch.zeros(largest_partitioned_param_numel, - device=torch.cuda.current_device(), - dtype=self.dtype) + see_memory_usage(f"After CPU Offload initialization", force=False) # stores if a partition has been reduced in this step @@ -2020,22 +2013,6 @@ def set_grad_positions(self): #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") current_offset += num_elements - def async_accumulate_grad_in_cpu_via_gpu(self, grad, acc_grad_cpu_partition): - - # copy to a preexisiting buffer to avoid memory allocation penalty - dest_buffer = self.temp_grad_buffer_for_gpu_offload.view(-1).narrow( - 0, - 0, - grad.numel()) - - if self.micro_step_id > 0: - dest_buffer.copy_(acc_grad_cpu_partition.view(-1), non_blocking=True) - grad.data.view(-1).add_(dest_buffer) - - # at the boundary we will send 32bit directly - if not self.is_gradient_accumulation_boundary: - acc_grad_cpu_partition.data.copy_(grad.data.view(-1), non_blocking=True) - def _constant_buffered_norm2(self, input, buffer_size=250000000): norm = None for part in input.view(-1).split(buffer_size): @@ -2106,11 +2083,10 @@ def __partition_grads(self, continue # move or accumulate gradient partition to target buffer - grad_buffer = (self.temp_grad_gpu_buffer if self.offload_optimizer else - self.__param_id_to_grad_partition[param.ds_id]).narrow( - 0, - 0, - grad_partition.numel()) + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( + 0, + 0, + grad_partition.numel()) if self.micro_step_id == 0: # don't accumulate grad_buffer.copy_(grad_partition, non_blocking=True) elif grad_buffer.is_cuda: @@ -2118,22 +2094,21 @@ def __partition_grads(self, else: # if dst is CPU, copy first to src device, do the addition # there, then move back to dst. adding directly to cpu is very slow - tmp_grad_dst = torch.empty_like(grad_partition) - tmp_grad_dst.copy_(grad_buffer, non_blocking=True) - tmp_grad_dst.add_(grad_partition) - grad_buffer.copy_(tmp_grad_dst, non_blocking=True) + cuda_grad_buffer = grad_buffer.to(grad_partition.device, + non_blocking=True) + cuda_grad_buffer.add_(grad_partition) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + + # use the CUDA buffer from now on so this sequence + later copy to + # fp32 buffer can be all be done async + grad_buffer = cuda_grad_buffer # offload the gradient partition if applicable if self.offload_optimizer: - i, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] offload_fp32_gradients = {} offload_fp32_offsets = {} - if self.gradient_accumulation_steps > 1: - fp16_grad_tensor = self.__param_id_to_grad_partition[param.ds_id] - self.async_accumulate_grad_in_cpu_via_gpu(grad_buffer, - fp16_grad_tensor) - if self.is_gradient_accumulation_boundary: # Credit to our user David Minn if grad_partition is not None: diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index de8d281cfe65..642e0db716cf 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -547,7 +547,7 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: @pytest.mark.parametrize("fp16_enabled", [True, False]) @pytest.mark.parametrize("contiguous_gradients", [True, False]) @pytest.mark.parametrize("offload_optimizer", [True, False]) -@pytest.mark.parametrize("zero_grad", [True]) +@pytest.mark.parametrize("zero_grad", [True, False]) @pytest.mark.parametrize("iteration", list(range(1))) def test_zero3_param_partitioning_base( param_persistence_threshold: int, From 5f213d8c1d88f37a1edf3e0b40e05fa20a6fa7f5 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Mon, 25 Oct 2021 17:32:22 -0700 Subject: [PATCH 20/59] grad norm computation fix for optimizer offload --- deepspeed/runtime/zero/stage3.py | 3 +++ tests/unit/test_zero.py | 1 + 2 files changed, 4 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 7cf604ff17fa..a752c797964a 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2114,6 +2114,9 @@ def __partition_grads(self, if grad_partition is not None: self.gpu_sum.add_(grad_partition.float().sum()) + self.norm_for_param_grads[self.get_param_id( + param)] = self._constant_buffered_norm2(grad_buffer) + if self._swappable_optimizer_subgroup(i): if not i in offload_fp32_gradients.keys(): offload_fp32_gradients[i] = [] diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 642e0db716cf..0389d3c366e5 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -750,6 +750,7 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: # taking an optimizer step invalidates all parameters, make sure everything # has been partitioned afterwards _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) _test_zero3_param_partitioning() From 319880543a17bbcd23fb45b561b5e84413994672 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Mon, 25 Oct 2021 20:35:55 -0700 Subject: [PATCH 21/59] change post divide in reduce-scatter to pre divide --- deepspeed/runtime/comm/coalesced_collectives.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index 504af45d8b9a..1ac438734813 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -91,6 +91,7 @@ def reduce_scatter_coalesced( tensor_partition_flat_buffer = instrument_w_nvtx( torch.cat)(tensor_partitions_lst_with_padding) + tensor_partition_flat_buffer.div_(world_sz) # pre-divide tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk( tensor_partition_flat_buffer, world_sz) @@ -100,9 +101,6 @@ def reduce_scatter_coalesced( tensor_partition_buffer_for_each_rank[this_rank], group) - # post-divide - tensor_partition_buffer_for_each_rank[this_rank].div_(world_sz) - # reverse procedure of the interleaving done previously, done on the # result of the batched reduce-scatter output_lst: List[Tensor] = [None] * len(tensors) From 2225659ce516255bc8ca992b9d295ac87ded51a1 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 26 Oct 2021 10:41:03 -0700 Subject: [PATCH 22/59] fix gradient race condition w/ optimizer offload --- deepspeed/runtime/zero/stage3.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index a752c797964a..af7ad0c88c6e 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2089,6 +2089,9 @@ def __partition_grads(self, grad_partition.numel()) if self.micro_step_id == 0: # don't accumulate grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) elif grad_buffer.is_cuda: grad_buffer.add_(grad_partition) else: @@ -2098,9 +2101,8 @@ def __partition_grads(self, non_blocking=True) cuda_grad_buffer.add_(grad_partition) grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) - - # use the CUDA buffer from now on so this sequence + later copy to - # fp32 buffer can be all be done async + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer # offload the gradient partition if applicable @@ -2129,7 +2131,7 @@ def __partition_grads(self, i].grad.narrow(0, dest_offset, grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer.float(), non_blocking=True) + fp32_grad_tensor.copy_(grad_buffer) # free the gradient param.grad.record_stream(torch.cuda.current_stream()) From 5aa9bd5030f22975c79b7287b3627f4bcbd532e7 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 26 Oct 2021 12:20:30 -0700 Subject: [PATCH 23/59] improve inf/nan gradient tracking --- deepspeed/runtime/zero/stage3.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index af7ad0c88c6e..ce2abd8ae8cf 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -632,10 +632,11 @@ def __init__(self, # external parameters _inject_parameters(module, ZeROOrderedDict) - self.gpu_sum: Tensor = torch.zeros(1, - dtype=torch.float, - device=torch.cuda.current_device(), - requires_grad=False) + self.__inf_or_nan_tracker: Tensor = torch.zeros( + 1, + dtype=torch.bool, + device=torch.cuda.current_device(), + requires_grad=False) ###################### offload optimizer setup ################################## self.optimizer_swapper = None @@ -1960,8 +1961,6 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: [p.ds_id for p in self.__params_in_ipg_bucket]) grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) - for partition in grad_partitions: - self.gpu_sum.add_(partition.float().sum()) self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) self.__params_in_ipg_bucket.clear() @@ -2028,14 +2027,6 @@ def set_norm_for_param_grad_in_gpu(self, param): #Using a more memory efficient version self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - def update_overflow_tracker_for_param_grad(self, param): - #Credit to our user David Minn - if param.grad is not None: - if self.overlap_comm: - self.gpu_sum = self.gpu_sum + param.grad.data.float().sum() - elif self._has_inf_or_nan(param.grad.data): - self.local_overflow = True - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): with torch.cuda.stream(self.copy_grad_stream): param_id = self.get_param_id(param) @@ -2105,6 +2096,9 @@ def __partition_grads(self, # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + # offload the gradient partition if applicable if self.offload_optimizer: i, dest_offset, _ = self.grad_position[self.get_param_id(param)] @@ -2112,10 +2106,6 @@ def __partition_grads(self, offload_fp32_offsets = {} if self.is_gradient_accumulation_boundary: - # Credit to our user David Minn - if grad_partition is not None: - self.gpu_sum.add_(grad_partition.float().sum()) - self.norm_for_param_grads[self.get_param_id( param)] = self._constant_buffered_norm2(grad_buffer) @@ -2814,6 +2804,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.loss_scale + # to maintain behavior of averaging over accumulation steps + combined_scale *= self.micro_step_id + 1 for grad in grad_groups_flat: if isinstance(grad, list): @@ -2845,8 +2837,8 @@ def has_overflow_partitioned_grads_serial(self): def has_overflow(self, partition_gradients=True): if partition_gradients: with torch.cuda.stream(self.__reduce_and_partition_stream): - self.local_overflow = self._has_inf_or_nan(self.gpu_sum) - self.gpu_sum.zero_() + self.local_overflow = bool(self.__inf_or_nan_tracker.item()) + self.__inf_or_nan_tracker.zero_() overflow = self.local_overflow #overflow = self.has_overflow_partitioned_grads_serial() From a1a60ed44683d88db5fbf42915654a69d0501ef9 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 26 Oct 2021 15:48:47 -0700 Subject: [PATCH 24/59] don't prefetch when not in training mode --- deepspeed/runtime/zero/stage3.py | 91 +++++++++++++++++--------------- 1 file changed, 48 insertions(+), 43 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ce2abd8ae8cf..401210091e4b 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -327,30 +327,7 @@ def fetch_sub_module(self, current_submodule: Module) -> None: })) params_to_fetch = frozenset(iter_params(current_submodule)) - if self.trace_complete: - # go through the parameters we need for the current module and pop them - # off the fetch queue so that they aren't prefetched later. - # if params have already been popped off the fetch queue by earlier - # prefetches we won't look for them here - discarded_from_prefetch_queue = set() - params_not_already_fetched = set( - filter( - lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. - __step_id, - params_to_fetch)) - while self.__param_queue and len(discarded_from_prefetch_queue) < len( - params_not_already_fetched): - param_in_trace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - discarded_from_prefetch_queue.add(param_in_trace.param) - if discarded_from_prefetch_queue != params_not_already_fetched: - raise RuntimeError( - f"tracing error at step {self.__step_id}: " - f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " - f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." - ) + # kick off all gather for params in the immediately required submodule for param in params_to_fetch: debug_rank0(f"-fetch: {param.ds_summary()}") @@ -378,25 +355,53 @@ def fetch_sub_module(self, current_submodule: Module) -> None: assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() torch.cuda.current_stream().wait_stream(self.__allgather_stream) - # kick off all gather for params in the next few submodules (prefetch) - max_params_to_prefetch = min( - self.__max_n_available_params - self.__n_available_params, - self.__prefetch_bucket_sz) - params_to_prefetch = set() - numel_prefetching = 0 - while self.__param_queue and numel_prefetching < max_params_to_prefetch: - param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: - params_to_prefetch.add(param_in_trace.param) - numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) - - if self.__prefetch_nvme: - self.__prefetch_nvme_param_partitions() + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace, or if we aren't + # training (throws off the tracing and don't want to prefetch modules for bwd) + if self.trace_complete and current_submodule.training: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() self.__step_id += 1 From df41659349814d812e0b9656aa0c836282805c4f Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 26 Oct 2021 15:57:27 -0700 Subject: [PATCH 25/59] format fix after merging --- tests/unit/test_sparse_grads.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_sparse_grads.py b/tests/unit/test_sparse_grads.py index b5df2aa65caa..c0e7272192f1 100644 --- a/tests/unit/test_sparse_grads.py +++ b/tests/unit/test_sparse_grads.py @@ -7,7 +7,6 @@ import deepspeed.utils.groups as groups - def test_sparse_adam(tmpdir): config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True} From ab3a82af5a5f62b553056a6bd9fbe464a6c8e1ba Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 27 Oct 2021 10:09:17 -0700 Subject: [PATCH 26/59] fix prefetching issue when using NVME offload --- deepspeed/runtime/zero/stage3.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 401210091e4b..af3d5096e4db 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -237,7 +237,7 @@ def __init__( self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel # queue for parameters to fetch. parameters will be popped off the left # side of the dequeue as they are fetched - self.__param_queue: collections.deque = None + self.__param_queue: Deque[__class__.__ParamInTrace] = None self.__prefetch_bucket_sz: int = prefetch_bucket_sz self.__prefetch_nvme: bool = prefetch_nvme self.hierarchy: int = 0 @@ -499,15 +499,16 @@ def __prefetch_nvme_param_partitions(self) -> None: numel_considered = 0 swap_in_params = [] - for _, param in self.__param_queue: + for param_in_trace in self.__param_queue: + param = param_in_trace.param if param.nvme_swapper is None: - raise RuntimeError( - f"expected param {param.ds_summary()} to have nvme swapper") + continue if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= param.nvme_swapper.available_swap_in_buffers()): break if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: swap_in_params.append(param) + numel_considered += param.ds_numel if swap_in_params: swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) From a26d1fb9638a4483c057d9db6fbaf476a5b38813 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Sat, 30 Oct 2021 17:53:22 -0700 Subject: [PATCH 27/59] improved defragmentation for fp16 parameters --- deepspeed/runtime/zero/stage3.py | 241 ++++++++++++++----------------- 1 file changed, 106 insertions(+), 135 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index af3d5096e4db..15aa7daffc10 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -932,59 +932,42 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=False) - persistent_tensors: Set[Tensor] = set() - for param in self.module.parameters(recurse=True): - param.partition() - persistent_tensors.add(param.ds_tensor) - - FP16_DeepSpeedZeroOptimizer_Stage3.defragment(persistent_tensors) - - if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After defragmenting", force=True) - @staticmethod - def defragment(tensors: Set[Tensor]): - cuda_tensors_by_device_and_dtype: Dict[tuple, - Set[Tensor]] = collections.defaultdict( - set) - for tensor in filter(lambda t: t.is_cuda, tensors): - cuda_tensors_by_device_and_dtype[(tensor.device, tensor.dtype)].add(tensor) - - cpu_buffer_and_orig_device_to_tensor_infos: Dict[ - Tuple[Tensor, - torch.device], - List[Tuple[Tensor, - int, - int]]] = collections.defaultdict(list) - for (orig_device, dtype), tensorset in cuda_tensors_by_device_and_dtype.items(): - cpu_buffer = torch.empty(sum(p.numel() for p in tensorset), - dtype=dtype, - device="cpu") + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = [] + orig_device = get_only_unique_item(t.device for t in tensors) - offset = 0 - for tensor in tensorset: - tensor_numel = tensor.numel() - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - # record some data so we can restore the device tensor later - cpu_buffer_and_orig_device_to_tensor_infos[(cpu_buffer, - orig_device)].append( - (tensor, - offset, - tensor_numel)) + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) - offset += tensor_numel + offset += tensor_numel gc.collect() torch.cuda.empty_cache() + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + # restore device tensors - for (cpu_buffer, orig_device), tensor_offsets in cpu_buffer_and_orig_device_to_tensor_infos.items(): - device_buffer = cpu_buffer.to(orig_device) - for tensor, offset, tensor_numel in tensor_offsets: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer def _convert_to_zero_parameters(self, module, mpu): non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] @@ -1161,91 +1144,70 @@ def _create_param_groups_fp16_flat_cpu_memory(self): def _create_fp16_partitions_with_defragmentation(self): dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - create_fp16_flat_reuse_buffer = False - largest_partition_numel = [] - max_partition_numel = 0 - - #create a flat CPU memory allocation for each param group - if self.offload_param: - self._create_param_groups_fp16_flat_cpu_memory() - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - print_rank_0(f'fp16 group {j} has {len(sub_groups)} subgroups', force=False) - - flat_offset = 0 - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) + for param_group in self.optimizer.param_groups) + + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group_idx, sub_group in enumerate(param_group): + # record sub group and partitions self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - # comment out for zero_to_fp32 debug - # if torch.distributed.get_rank() == 0: - # for param in self.fp16_groups[i]: - # print(f"{debug_param2name_id_shape(param)} {param.ds_shape}") - - #These are the list of the partitioned parameters self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - total_elements = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[i]]) - self.fp16_partitioned_groups_flat_numel.append(total_elements) - - if total_elements > max_partition_numel: - largest_partition_numel = [ - t.ds_numel for t in self.fp16_partitioned_groups[i] - ] - max_partition_numel = total_elements - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param subgroup {i}", force=False) - - #all partitioned parameters remain in GPU during training - if not self.offload_param: - see_memory_usage(f"Before moving param subgroup group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param subgroup {i} to CPU", - force=False) - - #create flat buffer in CPU and move to GPU + [param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append( + sum(p.ds_tensor.ds_numel for p in sub_group)) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size( + self.dp_process_group) - 1 + self.groups_padding.append([ + p.padding_size() if rank_requires_padding else 0 for p in sub_group + ]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for param_group_idx, param_group in enumerate(param_groups): + for sub_group_idx, sub_group in enumerate(param_group): + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + device_buffer = __class__.defragment(parameter_partitions) + + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for param_group_idx, param_group in enumerate(param_groups): + for sub_group_idx, sub_group in enumerate(param_group): + sub_group_numel = sum(param.ds_tensor.ds_numel + for param in sub_group) self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - 1).cuda(torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param subgroup {i} to GPU", - force=False) - - #all partitioned parameters are in CPU during training - else: + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") #Flat buffer may not be available for parameters that reside in NVME if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - j].numel(): + param_group_idx].numel(): fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - j].narrow(0, - flat_offset, - total_elements) + param_group_idx].narrow(0, + flat_offset, + total_elements) print_rank_0( f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", force=False) @@ -1261,20 +1223,29 @@ def _create_fp16_partitions_with_defragmentation(self): self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) flat_offset += total_elements - # move param to flat buffer for both param offload on/off - self._move_to_flat_buffer(self.fp16_groups[i], - self.fp16_partitioned_groups_flat[i], - avoid_copy=not self.offload_param) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #create a pinned memory to be used for swapping out params to NVME after optimizer step - if self.fp16_partitioned_groups_flat[-1] is None: - create_fp16_flat_reuse_buffer = True - - see_memory_usage(f"After Flattening param subgroup {i}", force=False) + self._move_to_flat_buffer( + self.fp16_groups[sub_group_idx], + self.fp16_partitioned_groups_flat[sub_group_idx], + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any( + flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel = 0 + largest_partition_numel = None + for param_group in param_groups: + for sub_group in param_group: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [ + t.ds_numel + for t in self.fp16_partitioned_groups[sub_group_idx] + ] + max_partition_numel = total_elements - if create_fp16_flat_reuse_buffer: assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( largest_partition_numel) From 937f04e16f2c9fe8d8a2d8de967b601b4cb9a824 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 2 Nov 2021 12:10:27 -0700 Subject: [PATCH 28/59] relative imports for bf16 tests --- tests/unit/test_bf16.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_bf16.py b/tests/unit/test_bf16.py index 9220ce7e98c6..99d4a8c514ae 100644 --- a/tests/unit/test_bf16.py +++ b/tests/unit/test_bf16.py @@ -3,10 +3,10 @@ import deepspeed import pytest from deepspeed.ops.adam import FusedAdam -from common import distributed_test +from .common import distributed_test from deepspeed.ops.op_builder import CPUAdamBuilder -from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict -from util import bf16_required_version_check +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from .util import bf16_required_version_check @pytest.mark.parametrize('zero_stage, use_cpu_offload', [(2, False)]) From e74f50999a0ac59bbfe82ebe2820f200a066a0d3 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 2 Nov 2021 16:20:19 -0700 Subject: [PATCH 29/59] changes for bwd compatibility with pytorch 1.2 --- deepspeed/runtime/zero/stage3.py | 10 ++++++++-- tests/unit/test_coalesced_collectives.py | 6 +++--- tests/unit/test_zero.py | 9 ++------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 15aa7daffc10..228f254de196 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2073,8 +2073,14 @@ def __partition_grads(self, # operations and so it can be used asynchronously grad_buffer = cuda_grad_buffer - self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + if hasattr(self.__inf_or_nan_tracker, "logical_or_"): + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() + self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() + self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 # offload the gradient partition if applicable if self.offload_optimizer: diff --git a/tests/unit/test_coalesced_collectives.py b/tests/unit/test_coalesced_collectives.py index b86245c8b9bb..fb6b5354a158 100644 --- a/tests/unit/test_coalesced_collectives.py +++ b/tests/unit/test_coalesced_collectives.py @@ -17,7 +17,7 @@ def test_reduce_scatter_coalesced_single_input(): dtype=torch.half, device=torch.cuda.current_device()) - (output, ) = reduce_scatter_coalesced([input]) + (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD) assert output.shape == (3, ) assert torch.allclose(output, torch.full_like(output, 0.5)) @@ -35,7 +35,7 @@ def test_reduce_scatter_coalesced_two_inputs(): **tensor_kwargs), ] - output1, output2 = reduce_scatter_coalesced(inputs) + output1, output2 = reduce_scatter_coalesced(inputs, dist.group.WORLD) if dist.get_rank() == 0: assert output1.shape == (3, ) @@ -53,7 +53,7 @@ def test_reduce_scatter_coalesced_two_inputs(): def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz(): input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) - (output, ) = reduce_scatter_coalesced([input]) + (output, ) = reduce_scatter_coalesced([input], dist.group.WORLD) if dist.get_rank() == 0: assert output.shape == (1, ) diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 7c62137d91d8..0d4053acc9ab 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -454,10 +454,6 @@ def _ds_initialize_for_param_partitioning_testing(model: Module, return ds_engine -def _print_with_rank(msg: str) -> None: - print(f"RANK{dist.get_rank()}: {msg}") - - def _assert_partition_status(model: Module, valid_statuses: Set[ZeroParamStatus]) -> None: for _, param in model.named_parameters(): @@ -676,11 +672,11 @@ def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: activations = ds_engine( x=torch.ones((m, n), - dtype=torch.float16, + dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device), y=torch.ones((m, n), - dtype=torch.float16, + dtype=torch.float16 if fp16_enabled else torch.float32, device=ds_engine.device), prefetching=train_iter > 0, ) @@ -870,7 +866,6 @@ def forward(self, x: Tensor) -> Tensor: activations = [] for module in self.modulelist: - print(f"{dist.get_rank()}: xval: {x.shape}") x = module(x) activations.append(x) From 6ee558d10e02590da3cac2fb811d99cfa5f80763 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 2 Nov 2021 16:22:50 -0700 Subject: [PATCH 30/59] remove buffered_reduce_fallback --- deepspeed/runtime/zero/stage3.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 228f254de196..748a23f2b28c 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -57,20 +57,6 @@ def input(msg): return -def split_half_float_double(tensors): - dtypes = [ - "torch.cuda.HalfTensor", - "torch.cuda.FloatTensor", - "torch.cuda.DoubleTensor" - ] - buckets = [] - for i, dtype in enumerate(dtypes): - bucket = [t for t in tensors if t.type() == dtype] - if bucket: - buckets.append(bucket) - return buckets - - def isclose(a, b, rtol=1e-09, atol=0.0): return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) @@ -2240,20 +2226,6 @@ def allreduce_no_retain(self, if len(small_bucket) > 0: self.allreduce_and_copy(small_bucket, rank=rank, log=log) - # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, - rank, - grads, - elements_per_buffer=500000000, - log=None): - split_buckets = split_half_float_double(grads) - - for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, - numel_per_bucket=elements_per_buffer, - rank=rank, - log=log) - ############################################################################# ############################################################################# ############################################################################# From 14e22a25c75878c844922084abace730d3193202 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 2 Nov 2021 17:55:28 -0700 Subject: [PATCH 31/59] removed unused parameter offset bookkeeping --- deepspeed/runtime/zero/stage3.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 748a23f2b28c..5dda0af818fa 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -821,17 +821,6 @@ def __init__(self, self.__param_reduce_events: Deque[Event] = collections.deque() self.__max_param_reduce_events: int = 2 - # map each parameter to its group index and its offset within that group's - # flattened buffer - self.__param_id_to_param_group_and_offset_within_group_buffer = {} - for group_idx, group in enumerate(self.fp16_groups): - offset_within_group = 0 - for param in group: - self.__param_id_to_param_group_and_offset_within_group_buffer[ - param.ds_id] = (group_idx, - offset_within_group) - offset_within_group += param.ds_tensor.ds_numel - if dist.get_rank() == 0: logger.info(f"optimizer state initialized") From 16281df2be8c38a3ffbd22949bf187d9b1e63dec Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 2 Nov 2021 18:32:08 -0700 Subject: [PATCH 32/59] fixed tracking for multiple param groups --- deepspeed/runtime/zero/stage3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 5dda0af818fa..18c92a3c985d 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1124,8 +1124,9 @@ def _create_fp16_partitions_with_defragmentation(self): for param_group in self.optimizer.param_groups) # bookkeeping related to param groups + sub_group_idx = 0 for param_group_idx, param_group in enumerate(param_groups): - for sub_group_idx, sub_group in enumerate(param_group): + for sub_group in param_group: # record sub group and partitions self.fp16_groups.append(sub_group) self.fp16_partitioned_groups.append( @@ -1145,6 +1146,7 @@ def _create_fp16_partitions_with_defragmentation(self): self.groups_padding.append([ p.padding_size() if rank_requires_padding else 0 for p in sub_group ]) + sub_group_idx += 1 # move parameters to flattened buffer if not self.offload_param: # partitioned params remain in GPU during training From cc7011ec335d13a753607a830c400bc18d4c8675 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 3 Nov 2021 10:07:24 -0700 Subject: [PATCH 33/59] unbroke bfloat16 config after merge conflict --- deepspeed/runtime/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 4a2d9a1b6671..c96b547e1f93 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -805,8 +805,6 @@ def _initialize_params(self, param_dict): self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) - assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) From 806b072686d3f28c29526cd13af14a4c72b22c48 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 3 Nov 2021 13:42:05 -0700 Subject: [PATCH 34/59] using base allgather params when only 1 param --- deepspeed/runtime/zero/partition_parameters.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 3840460b0801..6632869abbea 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -1016,8 +1016,10 @@ def _all_gather(self, param_list, async_op=False, hierarchy=None): all_gather_list.append(param) if not async_op: - # ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) - ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) + if len(param_list) == 1: + ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + else: + ret_value = self._allgather_params_coalesced(all_gather_list, hierarchy) for param in all_gather_list: param.ds_status = ZeroParamStatus.AVAILABLE From bf0dd663cab9080a24c1972a40634aec7606f45e Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 3 Nov 2021 14:37:53 -0700 Subject: [PATCH 35/59] cleanup/fixes for fp16 partition defragmentation --- deepspeed/runtime/zero/stage3.py | 52 +++++++++++++------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 18c92a3c985d..151d1ae74b88 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1124,9 +1124,10 @@ def _create_fp16_partitions_with_defragmentation(self): for param_group in self.optimizer.param_groups) # bookkeeping related to param groups - sub_group_idx = 0 for param_group_idx, param_group in enumerate(param_groups): for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) + # record sub group and partitions self.fp16_groups.append(sub_group) self.fp16_partitioned_groups.append( @@ -1146,30 +1147,26 @@ def _create_fp16_partitions_with_defragmentation(self): self.groups_padding.append([ p.padding_size() if rank_requires_padding else 0 for p in sub_group ]) - sub_group_idx += 1 # move parameters to flattened buffer if not self.offload_param: # partitioned params remain in GPU during training # move parameter partitions into a single contiguous flat buffer parameter_partitions: List[Tensor] = [] - for param_group_idx, param_group in enumerate(param_groups): - for sub_group_idx, sub_group in enumerate(param_group): - for param in sub_group: - parameter_partitions.append(param.ds_tensor) + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) device_buffer = __class__.defragment(parameter_partitions) # setup flat buffers per subgroup, these are each just sections of the # contiguous flat buffer for all parameters that we created earlier offset = 0 - for param_group_idx, param_group in enumerate(param_groups): - for sub_group_idx, sub_group in enumerate(param_group): - sub_group_numel = sum(param.ds_tensor.ds_numel - for param in sub_group) - self.fp16_partitioned_groups_flat.append( - device_buffer.narrow(0, - offset, - sub_group_numel)) - offset += sub_group_numel + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + self.fp16_partitioned_groups_flat.append( + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel else: # partitioned params offloaded to CPU when not in use # create a flat CPU memory allocation for each param group self._create_param_groups_fp16_flat_cpu_memory() @@ -1188,7 +1185,6 @@ def _create_fp16_partitions_with_defragmentation(self): print_rank_0( f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", force=False) - #these parameters reside in NVME and elif self.params_in_nvme_and_cpu: fp16_partitioned_group_flat = None print_rank_0( @@ -1200,10 +1196,9 @@ def _create_fp16_partitions_with_defragmentation(self): self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) flat_offset += total_elements - self._move_to_flat_buffer( - self.fp16_groups[sub_group_idx], - self.fp16_partitioned_groups_flat[sub_group_idx], - avoid_copy=not self.offload_param) + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) # if necessary, create a pinned memory buffer to be used for swapping out # params to NVME after optimizer step @@ -1211,17 +1206,12 @@ def _create_fp16_partitions_with_defragmentation(self): flattened_partition_group is None for flattened_partition_group in self.fp16_partitioned_groups_flat) if should_create_fp16_flat_reuse_buffer: - max_partition_numel = 0 - largest_partition_numel = None - for param_group in param_groups: - for sub_group in param_group: - total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) - if total_elements > max_partition_numel: - largest_partition_numel = [ - t.ds_numel - for t in self.fp16_partitioned_groups[sub_group_idx] - ] - max_partition_numel = total_elements + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( From 6dc21a60e6ba1e920ed6e21061a5055fa92a9505 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 18 Nov 2021 10:41:28 -0800 Subject: [PATCH 36/59] switch to CRLF --- deepspeed/runtime/zero/stage3.py | 6658 +++++++++++++++--------------- 1 file changed, 3329 insertions(+), 3329 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 151d1ae74b88..f3ecafb6a2a8 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1,3329 +1,3329 @@ -""" -"Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -""" - -import gc -from dataclasses import dataclass -import functools -import os -import collections -from collections import OrderedDict, UserDict -import itertools -from typing import Deque, Dict, Iterable, Set, Tuple -import torch -from torch.cuda import Event, Stream -from torch.nn import Module, Parameter -import torch.distributed as dist -import math -from torch._six import inf -from torch.nn import Module -from torch.nn.parameter import Parameter - -from deepspeed.utils.logging import logger -from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter -from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partition_parameters import _init_external_params -from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS -from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.zero.offload_constants import * -from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus -from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper -from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper - -# Toggle this to true to enable correctness test -# with gradient partitioning and without -pg_correctness_test = False - -FWD_MODULE_STACK = list() - - -def print_rank_0(message, debug=False, force=False): - rank = torch.distributed.get_rank() - if rank == 0 and (debug or force): - print(message) - # other variations - # - print for all ranks w/o interleaving - # printflock(f"[{rank}] {message}") - # - print to log file per rank - # log_rank_file(rank, message) - - -def input(msg): - return - - -def isclose(a, b, rtol=1e-09, atol=0.0): - return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) - - -def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 - return x * y // gcd(x, y) - - -def debug_rank0(message: str) -> None: - if dist.get_rank() == 0: - logger.debug(message) - - -def get_cuda_mem_allocated_str() -> str: - # this is really slow. when enabled the python process becomes slow - # to the point where it can't keep the GPU fed with work, so only enable - # for memory debugging. - # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" - return "xGB" - - -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - -@instrument_w_nvtx -def get_all_parameters(sub_module, recurse=False): - return itertools.chain(sub_module.named_parameters(recurse=recurse), - sub_module.ds_external_parameters()) - - -def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: - return map(lambda pair: pair[1], get_all_parameters(module, recurse)) - - -#apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, - functional, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -#for each tensor in outputs run the forward_function and register backward_function as hook -def _apply_forward_and_backward_to_tensors_only(module, - forward_function, - backward_function, - outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_forward_and_backward_to_tensors_only( - module, - forward_function, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - forward_function(outputs) - if outputs.requires_grad: - outputs.register_hook(backward_function) - return outputs - else: - return outputs - - -class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module, *args, **kwargs): - """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. - - Args: - parent_module (``collections.OrderedDict``): the collection to replace - """ - - super().__init__(*args, **kwargs) - self._parent_module = parent_module - self._in_forward = False - - def __getitem__(self, key): - param = super().__getitem__(key) - - # Params can be registered as None (e.g., bias) - if param is None: - return param - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: - print_rank_0(f'Registering external parameter from getter {key}', - force=False) - register_external_parameter(FWD_MODULE_STACK[-1], param) - param.all_gather() - - return param - - -def _inject_parameters(module, cls): - for module in module.modules(): - if cls == ZeROOrderedDict: - new_param = cls(parent_module=module) - else: - new_param = cls() - - for key, param in module._parameters.items(): - new_param[key] = param - module._parameters = new_param - - -class PartitionedParameterCoordinator: - """Handles partitioning and gathering of parameters.""" - class __InflightParamRegistry(UserDict): - """registry for parameters in flight""" - def __setitem__(self, - param: Parameter, - handle: AllGatherCoalescedHandle) -> None: - if param in self.data: - raise RuntimeError(f"{param.ds_summary()} already in registry") - if param.ds_status != ZeroParamStatus.INFLIGHT: - raise RuntimeError( - f"attempted to add non-inflight parameter to registry {param.ds_summary()}" - ) - self.data[param] = handle - - @dataclass - class __ParamInTrace: - param: Parameter - step_id_last_used_at: int - - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: Stream, - prefetch_nvme: bool = False, - ) -> None: - # mapping of param -> handle for each param that is currently in flight - self.__inflight_param_registry = __class__.__InflightParamRegistry() - # keeps track of the number of submodules invoked so far. - self.__step_id: int = 0 - # whether or not we have completed a trace of the entire network. This should - # always be true after the first forward pass + backward pass. - self.trace_complete: bool = False - # sequence of submodules/parameters in forward pass + backward pass - self.__submodule_order: Iterable[Module] = [] - self.__param_order: Iterable[__class__.__ParamInTrace] = [] - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - # number of available params, and max number of available params - self.__n_available_params: int = 0 - self.__max_n_available_params: int = max_available_parameters_in_numel - # max distance between two use of the module beyond which module is released - self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel - # queue for parameters to fetch. parameters will be popped off the left - # side of the dequeue as they are fetched - self.__param_queue: Deque[__class__.__ParamInTrace] = None - self.__prefetch_bucket_sz: int = prefetch_bucket_sz - self.__prefetch_nvme: bool = prefetch_nvme - self.hierarchy: int = 0 - - # stream that will be used for allgather operations - self.__allgather_stream: Stream = allgather_stream - - # limit the number of fetch events that can be queued at once - # otherwise, what happens is memory is allocated by the host thread at the - # time of the call, but not used until later by the asynchronous cuda stream. - # allowing an infinite number of these to queue up causes a lot of memory - # pressure that then becomes detrimental to performance. - # this is a much less elegant way of fixing this vs something like using - # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now - # because ideally in the future its replaced by an async allocation - # mechanism which doesnt require any configuration by the user. - self.__ongoing_fetch_events: Deque[Event] = collections.deque() - self.__max_ongoing_fetch_events: int = 2 - - """Tracing and Tracking - TODO. consider performing trace before initializing PartitionedParameterCoordinator - and passing trace results into constructor. This way all the code in here can - just assume that the trace is complete and the results can be entirely - immutable. - - Bookkeeping operations used to track where we are in the forward/backward pass - """ - - def record_trace(self, sub_module: Module) -> None: - """adds sub module to trace""" - if self.trace_complete: - raise RuntimeError( - "attemted to record trace when trace was already complete") - - self.__submodule_order.append(sub_module) - for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): - self.__param_order.append( - __class__.__ParamInTrace(param=param, - step_id_last_used_at=self.__step_id)) - - def reset_step(self) -> None: - """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError( - f"still have inflight params " - f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") - - if not self.trace_complete: - # make sure that recorded parameter and submodule orders are - # identical across ranks - assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) - assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) - assert_ints_same_as_other_ranks( - [p.step_id_last_used_at for p in self.__param_order]) - - self.__submodule_order = tuple(self.__submodule_order) # freeze - self.__param_order = tuple(self.__param_order) # freeze - self.trace_complete = True - print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", - force=True) - - self.__param_queue = collections.deque(self.__param_order) # reset fetch queue - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - self.__step_id = 0 - self.__n_available_params = 0 - - """Fetch and Release - Fetching, prefetching, and releasing parameters - """ - - @instrument_w_nvtx - @torch.no_grad() - def fetch_sub_module(self, current_submodule: Module) -> None: - """This method does the following (in order): - 1. kick off fetch for parameters in immediately required sub module - 2. kick off fetch for next few parameters we will need later (prefetch) - 3. block on parameters in immediately required sub module - """ - debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " - + str({ - "avail": f"{self.__n_available_params:.1e}", - "queue_sz": f"{len(self.__param_queue or [])}", - "inflight": [p.ds_id for p in self.__inflight_param_registry], - "allocated": get_cuda_mem_allocated_str() - })) - - params_to_fetch = frozenset(iter_params(current_submodule)) - - # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") - self.__all_gather_params(params_to_fetch) - - # wait for parameters in the immediately needed submodule to become available - for param in iter_params(current_submodule): - param.ds_active_sub_modules.add(current_submodule.id) - debug_rank0(f"-wait: {param.ds_summary()}") - if param in self.__inflight_param_registry: - with torch.cuda.stream(self.__allgather_stream): - while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ - 0].query(): - self.__ongoing_fetch_events.popleft() - if len(self.__ongoing_fetch_events - ) > self.__max_ongoing_fetch_events: - self.__ongoing_fetch_events.popleft().synchronize() - - self.__inflight_param_registry.pop(param).wait() - - event = Event() - event.record() - self.__ongoing_fetch_events.append(event) - - assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - torch.cuda.current_stream().wait_stream(self.__allgather_stream) - - # kick off parameter prefetches for upcoming modules - # don't prefetch if we dont have a completed model trace, or if we aren't - # training (throws off the tracing and don't want to prefetch modules for bwd) - if self.trace_complete and current_submodule.training: - # go through the parameters we need for the current module and pop them - # off the fetch queue so that they aren't prefetched later. - # if params have already been popped off the fetch queue by earlier - # prefetches we won't look for them here - discarded_from_prefetch_queue = set() - params_not_already_fetched = set( - filter( - lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. - __step_id, - params_to_fetch)) - while self.__param_queue and len(discarded_from_prefetch_queue) < len( - params_not_already_fetched): - param_in_trace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - discarded_from_prefetch_queue.add(param_in_trace.param) - if discarded_from_prefetch_queue != params_not_already_fetched: - raise RuntimeError( - f"tracing error at step {self.__step_id}: " - f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " - f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." - ) - - # kick off all gather for params in the next few submodules (prefetch) - max_params_to_prefetch = min( - self.__max_n_available_params - self.__n_available_params, - self.__prefetch_bucket_sz) - params_to_prefetch = set() - numel_prefetching = 0 - while self.__param_queue and numel_prefetching < max_params_to_prefetch: - param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: - params_to_prefetch.add(param_in_trace.param) - numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) - - if self.__prefetch_nvme: - self.__prefetch_nvme_param_partitions() - - self.__step_id += 1 - - @instrument_w_nvtx - @torch.no_grad() - def release_sub_module(self, submodule: Module) -> None: - """release the parameters of a sub module, assuming they meet conditions to - be released.""" - params_to_release = (self.__params_to_release(submodule, - self.__step_id) - if self.trace_complete else set( - p.ds_id for p in iter_params(submodule))) - - for param in iter_params(submodule): - param.ds_active_sub_modules.discard(submodule.id) - if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) - - @instrument_w_nvtx - @torch.no_grad() - def release_and_reset_all(self) -> None: - """release all module parameters""" - for param in map(lambda p: p.param, self.__param_order): - if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") - - # TODO. make this throw if if there are still active submodules. currently - # there's a hook execution issue - param.ds_active_sub_modules.clear() - self.__release_param(param) - - for param_in_trace in self.__param_order: - if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError( - f"{param_in_trace.param.ds_summary()} expected to be released") - - @instrument_w_nvtx - def __all_gather_params(self, params: Set[Parameter]) -> None: - """for each partitioned parameter, kick off an async allgather and store - the work handle for the in flight parameters.""" - partitioned_params = [] - for param in params: - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - partitioned_params.append(param) - self.__n_available_params += param.ds_numel - - if partitioned_params: - with torch.cuda.stream(self.__allgather_stream): - handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - - for param in partitioned_params: - assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() - self.__inflight_param_registry[param] = handle - - @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: - if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: - debug_rank0(f"-release: {param.ds_summary()}") - param.partition() - self.__n_available_params -= param.ds_numel - - @instrument_w_nvtx - @functools.lru_cache(maxsize=None) - def __params_to_release(self, - submodule_to_release: Module, - step_id: int) -> Set[int]: - if not self.trace_complete: - raise RuntimeError("expected trace to be complete") - - params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) - if not p.ds_persist) - - # examine all modules within `max_reuse_dist_in_numel` of the current step, - # if we see any of the candidate parameters to be released reoccur while - # doing this, remove them from the set of parameters to release. - params_traversed = 0 - for module in self.__submodule_order[step_id:]: - if params_traversed > self.__max_reuse_dist_in_numel: - break - for param in iter_params(module): - params_to_release.discard(param.ds_id) - params_traversed += param.ds_numel - - return params_to_release - - @instrument_w_nvtx - def __prefetch_nvme_param_partitions(self) -> None: - """swap in parameter partitions from nvme for those parameters that will be used - after the ones that are already being prefetched into full parameters - """ - if not self.trace_complete: - return - - numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) - - numel_considered = 0 - swap_in_params = [] - for param_in_trace in self.__param_queue: - param = param_in_trace.param - if param.nvme_swapper is None: - continue - if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= - param.nvme_swapper.available_swap_in_buffers()): - break - if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_in_params.append(param) - numel_considered += param.ds_numel - - if swap_in_params: - swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - - -class PreBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - -class FP16_DeepSpeedZeroOptimizer_Stage3(object): - """ - DeepSpeedZeroOptimizer designed to reduce the memory footprint - required for training large deep learning models. - - For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models - https://arxiv.org/abs/1910.02054 - - For usage examples, refer to TODO: DeepSpeed Tutorial - - """ - def __init__(self, - module, - init_optimizer, - timers, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True, - contiguous_gradients=True, - reduce_bucket_size=500000000, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - dp_process_group=None, - reduce_scatter=True, - overlap_comm=False, - offload_optimizer_config=None, - offload_param_config=None, - sub_group_size=1000000000000, - mpu=None, - clip_grad=0.0, - allreduce_always_fp32=False, - postscale_gradients=True, - gradient_predivide_factor=1.0, - gradient_accumulation_steps=1, - elastic_checkpoint=False, - aio_config=None): - - see_memory_usage("Stage 3 initialize beginning", force=False) - - if dist.get_rank() == 0: - logger.info(f"initialized {__class__.__name__} with args: {locals()}") - logger.info(f"Reduce bucket size {reduce_bucket_size}") - logger.info(f"Allgather bucket size {prefetch_bucket_size}") - # The fused optimizer does all the work. We need this layer for two reason: - # 1. maintain same user API from apex.fp16_utils - # 2. keep common stuff here in case we need to add ne552w fused optimizer later - - # differences from apex.fp16_utils: - # - assume all model params in fp16 - # - assume all params requires grad - # - flat by groups, not keeping state. TODO: remove state explicitly? - # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - self.optimizer = init_optimizer - - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten - self.dtype = self.optimizer.param_groups[0]['params'][0].dtype - self._global_grad_norm = 0. - - self._convert_to_zero_parameters(module, mpu) - - for m in module.modules(): - _init_external_params(m) - - self.module = module - self.elastic_checkpoint = elastic_checkpoint - - # Replace ._parameters with a new class to enable auto-registration of - # external parameters - _inject_parameters(module, ZeROOrderedDict) - - self.__inf_or_nan_tracker: Tensor = torch.zeros( - 1, - dtype=torch.bool, - device=torch.cuda.current_device(), - requires_grad=False) - - ###################### offload optimizer setup ################################## - self.optimizer_swapper = None - self.swap_optimizer = False - - self.offload_optimizer = False - self.offload_optimizer_pin_memory = False - self.offload_optimizer_fast_init = False - if offload_optimizer_config is not None: - if not contiguous_gradients: - raise ValueError( - "optimizer offload only available with contiguous gradients enabled") - self.offload_optimizer = True - self.offload_optimizer_pin_memory = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIN_MEMORY] - self.swap_optimizer = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE - self.offload_optimizer_fast_init = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_FAST_INIT] - - ###################### offload param setup ################################## - self.offload_param = False - self.offload_param_pin_memory = False - self.params_in_nvme_and_cpu = False - self.max_params_in_cpu = 0 - if offload_param_config is not None: - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" - self.offload_param = True - self.offload_param_pin_memory = offload_param_config[ - OFFLOAD_PARAM_PIN_MEMORY] - self.params_in_nvme_and_cpu = offload_param_config[ - OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE - self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] - print_rank_0( - f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", - force=False) - - self.deepspeed_adam_offload = (self.offload_optimizer - and type(init_optimizer) == DeepSpeedCPUAdam) - - self.device = torch.cuda.current_device( - ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE - ### streams used for overlapping computation with communication - self.__allgather_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - self.__reduce_and_partition_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - - ############################################################################ - - see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - self.param_coordinator = PartitionedParameterCoordinator( - prefetch_bucket_sz=int(prefetch_bucket_size), - max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters), - allgather_stream=self.__allgather_stream, - prefetch_nvme=self.params_in_nvme_and_cpu, - ) - see_memory_usage("After Partitioned Parameter Coordinator", force=False) - - self.__n_caching_allocator_flushes = 0 - - #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) - #-------------Stage 3 Setup-------------------# - # parameters smaller than the threshold will be collectively gathered at the - # end of the optimizer step and will be kept till the end of the backward pass - # TODO maybe worth just replicating these parameters and doing all reduce for them - self.persistence_threshold = int(param_persistence_threshold) - - self.persistent_parameters = self.persistent_parameters() - - self.setup_zero_stage3_hooks() - - #resetting ds_tensor just in case parameters have been changed after initialization - #example .half() or .to() - #self.reset_ds_tensor() - #---------------------------------------------# - - self.timers = timers - - self.dp_process_group = dp_process_group - - self.partition_count = dist.get_world_size(group=self.dp_process_group) - - if mpu is None: - self.model_parallel_group = None - self.model_parallel_rank = 0 - else: - self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() - - self.overflow = False - self.clip_grad = clip_grad - self.allreduce_always_fp32 = allreduce_always_fp32 - self.gradient_predivide_factor = gradient_predivide_factor - self.postscale_gradients = postscale_gradients - self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 - - # Holds the mode parameter - # The param.data may not hold any meaningful data - # when param's status is NOT_AVAILABLE or IN_FLGHT - self.fp16_groups = [] - - # Hold partitioned parameters - self.fp16_partitioned_groups = [] - - # Holds a fused and flattened copy of the parameters - self.fp16_partitioned_groups_flat = [] - self.fp16_partitioned_groups_flat_numel = [] - - #defragmented pinned memory - self.param_groups_fp16_flat_cpu_memory = [] - - #a single 32-bit partition of the parallel partitioned parameters - #that this process will update - self.fp32_partitioned_groups_flat = [] - self.next_swappable_fp32_partitioned_groups = [] - - # number of elements per partition in each group - self.partition_size = [] - - self.all_reduce_print = False - - self.prefetch_elements = int(prefetch_bucket_size) - - # padding on each partition for alignment purposes - self.groups_padding = [] - - self.sub_group_size = sub_group_size - - self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=False) - self._create_fp16_partitions_with_defragmentation() - num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) - see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=False) - - # Optimizer tensor swapping - if self.swap_optimizer: - self._configure_tensor_swapping(offload_optimizer_config, aio_config) - - see_memory_usage("Before creating fp32 partitions", force=False) - self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=False) - dist.barrier() - - # To support pipelined optimizer swapping - self._create_next_swappable_fp32_groups() - - see_memory_usage("Before initializing optimizer states", force=False) - self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=False) - dist.barrier() - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.reduce_bucket_size = int(reduce_bucket_size) - - # IPG - if contiguous_gradients: - self.__ipg_bucket_flat_buffer: Tensor = torch.empty( - int(reduce_bucket_size), - dtype=self.dtype, - device=torch.cuda.current_device()) - - self.__param_id_to_grad_partition: Dict[int, Tensor] = {} - - all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - - grad_partitions_flat_buffer: Tensor = torch.zeros( - sum(p.ds_tensor.ds_numel for p in all_params), - dtype=self.dtype, - device=self.device, - pin_memory=self.offload_optimizer_pin_memory) - - offset = 0 - for param in all_params: - self.__param_id_to_grad_partition[ - param.ds_id] = grad_partitions_flat_buffer.narrow( - 0, - offset, - param.ds_tensor.numel()) - offset += param.ds_tensor.numel() - - self.__params_in_ipg_bucket: List[Parameter] = [] - self.is_gradient_accumulation_boundary: bool = True - - self.__param_reduce_events: Deque[Event] = collections.deque() - self.__max_param_reduce_events: int = 2 - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.param_dict = {} - - # map between param_id and bool to specify if a param is in this partition - self.is_param_in_current_partition = {} - - self.contiguous_gradients = contiguous_gradients - self.extra_large_param_to_reduce = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.params_already_reduced = [] - self.is_gradient_accumulation_boundary = True - self._release_ipg_buffers() - self.previous_reduced_grads = None - - # simplified param id - self.param_id = {} - - count = 0 - for i, params_group in enumerate(self.fp16_groups): - for param in params_group: - unique_id = id(param) - self.param_id[unique_id] = count - self.param_dict[count] = param - self.params_already_reduced.append(False) - count = count + 1 - - #Largest partitioned param - largest_partitioned_param_numel = max([ - max([tensor.numel() for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups - ]) - print_rank_0( - f'Largest partitioned param numel = {largest_partitioned_param_numel}', - force=False) - - see_memory_usage(f"Before Set Grad positions", force=False) - - self.grad_position = {} - self.set_grad_positions() - see_memory_usage(f"Before CPU Offload initialization", force=False) - - self.grads_in_partition = None - - if self.offload_optimizer: - self.norm_for_param_grads = {} - self.local_overflow = False - - see_memory_usage(f"After CPU Offload initialization", force=False) - - # stores if a partition has been reduced in this step - self.is_partition_reduced = {} - - # stores if a grad in a partition has been computed or not - self.is_grad_computed = {} - - # will store the averaged gradients required by this paritition - self.averaged_gradients = {} - - #creates backward hooks for gradient partitioning - self.create_reduce_and_remove_grad_hooks() - - #exit(0) - - # we may have a way of fusing dynamic scale. Do not support for now - if self.dtype == torch.float or not dynamic_loss_scale: - loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale - - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=loss_scale_value) - cur_iter = 0 - else: - if dynamic_loss_args is None: - self.loss_scaler = DynamicLossScaler() - else: - self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) - - self.dynamic_loss_scale = True - - self.debug_fp16_grads = [{} for _ in self.fp16_groups] - - if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=False) - - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - - gc.collect() - torch.cuda.empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - - def _convert_to_zero_parameters(self, module, mpu): - non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] - if non_zero_params: - zero_params = [p for p in module.parameters() if is_zero_param(p)] - if zero_params: - zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) - else: - group = None - if mpu: - group = mpu.get_data_parallel_group() - Init(module=module, data_parallel_group=group, dtype=self.dtype) - - def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): - nvme_swap_folder = os.path.join( - offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], - 'zero_stage_3') - os.makedirs(nvme_swap_folder, exist_ok=True) - if torch.distributed.get_rank() == 0: - logger.info(f'Tensor Swapping: Adding optimizer tensors') - - swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper - - self.optimizer_swapper = swapper_type( - swap_config=offload_optimizer_config, - aio_config=aio_config, - base_folder=nvme_swap_folder, - optimizer=self.optimizer, - largest_numel=max(self.fp16_partitioned_groups_flat_numel), - device=self.device, - dtype=torch.float32, - timers=self.timers) - - @property - def elements_in_ipg_bucket(self): - return sum(p.ds_numel for p in self.__params_in_ipg_bucket) - - def _create_fp16_partitions(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param group {i}", force=False) - - if not self.offload_param: - see_memory_usage(f"Before moving param group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param group {i} to CPU", force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param group {i} to GPU", - force=False) - else: - #Without the detach, seems like the flattening becomes part of the - #model graph causing errors downstream - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size( - group=self.dp_process_group)).detach().pin_memory()) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #set model fp16 weight to slices of flattened buffer - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): - partitioned_param.data = q.data - - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): - '''If flat buffer is None then the parameters in the param_list are - not copied to the flat buffer. This is because they excede the number of max_params_in_cpu - Some of these parameters may aready be in CPU in unflattened buffers - or they maybe in GPU, or they maybe in NVME. If they are in NVME, then - they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are - needed during training.''' - if flat_buffer is None: - # this dst buffer is on NVMe, so skip this - return - - start = 0 - for param in param_list: - src = param.ds_tensor - dest = flat_buffer.narrow(0, start, src.ds_numel) - start = start + src.ds_numel - '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' - if src.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" - ) - param.nvme_swapper.swap_into_buffer(param, dest) - src.data = dest.data - src.status = PartitionedParamStatus.AVAILABLE - else: - assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" - if not avoid_copy: - dest.data.copy_(src.data) - src.data = dest.data - - # Final location must be gpu/cpu in this case - param.ds_tensor.final_location = 'not-nvme' - - def _create_param_groups_fp16_flat_cpu_memory(self): - - aggregate_params_count = 0 - - for j, param_group in enumerate(self.optimizer.param_groups): - params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) - - flat_buffer_size = params_in_group - - if self.params_in_nvme_and_cpu and \ - aggregate_params_count + params_in_group > self.max_params_in_cpu: - - flat_buffer_size = max(0, - self.max_params_in_cpu - aggregate_params_count) - - aggregate_params_count += params_in_group - - if flat_buffer_size > 0: - print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", - force=False) - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(int(flat_buffer_size), - dtype=self.dtype, - pin_memory=True)) - else: - print_rank_0( - f"No flat buffer size. Param group size was {params_in_group}", - force=False) - - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(1, - dtype=self.dtype)) - - def _create_fp16_partitions_with_defragmentation(self): - dist.barrier() - param_groups: List[List[Parameter]] = tuple( - self._create_fp16_sub_groups(param_group["params"]) - for param_group in self.optimizer.param_groups) - - # bookkeeping related to param groups - for param_group_idx, param_group in enumerate(param_groups): - for sub_group in param_group: - sub_group_idx = len(self.fp16_groups) - - # record sub group and partitions - self.fp16_groups.append(sub_group) - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in sub_group]) - - # record sub group -> group mapping - self.sub_group_to_group_id[sub_group_idx] = param_group_idx - - # record total elements of parameter partitions in sub group - self.fp16_partitioned_groups_flat_numel.append( - sum(p.ds_tensor.ds_numel for p in sub_group)) - - # record padding required to align group to world size (only applies to last rank) - rank_requires_padding = dist.get_rank( - self.dp_process_group) == dist.get_world_size( - self.dp_process_group) - 1 - self.groups_padding.append([ - p.padding_size() if rank_requires_padding else 0 for p in sub_group - ]) - - # move parameters to flattened buffer - if not self.offload_param: # partitioned params remain in GPU during training - # move parameter partitions into a single contiguous flat buffer - parameter_partitions: List[Tensor] = [] - for sub_group in self.fp16_groups: - for param in sub_group: - parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) - - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) - self.fp16_partitioned_groups_flat.append( - device_buffer.narrow(0, - offset, - sub_group_numel)) - offset += sub_group_numel - else: # partitioned params offloaded to CPU when not in use - # create a flat CPU memory allocation for each param group - self._create_param_groups_fp16_flat_cpu_memory() - for param_group_idx, param_group in enumerate(param_groups): - flat_offset = 0 - for i, sub_group in enumerate(param_group): - total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) - print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") - #Flat buffer may not be available for parameters that reside in NVME - if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].numel(): - fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].narrow(0, - flat_offset, - total_elements) - print_rank_0( - f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", - force=False) - elif self.params_in_nvme_and_cpu: - fp16_partitioned_group_flat = None - print_rank_0( - f"No flat buffer for sub group {i} of {total_elements} elements", - force=False) - else: - assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" - - self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - flat_offset += total_elements - - self._move_to_flat_buffer(sub_group, - fp16_partitioned_group_flat, - avoid_copy=not self.offload_param) - - # if necessary, create a pinned memory buffer to be used for swapping out - # params to NVME after optimizer step - should_create_fp16_flat_reuse_buffer = any( - flattened_partition_group is None - for flattened_partition_group in self.fp16_partitioned_groups_flat) - if should_create_fp16_flat_reuse_buffer: - max_partition_numel, largest_partition_numel = 0, None - for sub_group in self.fp16_groups: - total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) - if total_elements > max_partition_numel: - largest_partition_numel = [t.ds_numel for t in sub_group] - max_partition_numel = total_elements - - assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' - self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( - largest_partition_numel) - - def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): - offset = 0 - elements_in_sub_group = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) - assert (flat_buffer.numel() == elements_in_sub_group) - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" - ) - param.nvme_swapper.swap_in([param], async_op=False) - dest.data.copy_(partitioned_param.data) - param.nvme_swapper.remove_partition_and_release_buffers([param]) - print_rank_0(f"Swapping in {param.ds_id} done") - else: - dest.data.copy_(partitioned_param.data) - offset += partitioned_param.ds_numel - - def _create_next_swappable_fp32_groups(self): - reverse_order_indices = [ - i for i in range(len(self.fp32_partitioned_groups_flat)) - ] - reverse_order_indices.reverse() - - next_group = None - for i in reverse_order_indices: - self.next_swappable_fp32_partitioned_groups.append(next_group) - if self._swappable_optimizer_subgroup(i): - next_group = self.fp32_partitioned_groups_flat[i] - - self.next_swappable_fp32_partitioned_groups.reverse() - - def _get_sub_group_partitions(self, sub_group_id): - sub_group_partitions = [] - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_path = param.nvme_swapper.get_path(param, True) - sub_group_partitions.append((partitioned_param, - param.ds_tensor.ds_numel, - swap_path)) - else: - sub_group_partitions.append((partitioned_param, - partitioned_param.ds_numel, - None)) - - return sub_group_partitions - - def _create_fp32_partitions(self): - cpu_memory_usage = 0 - cpu_memory_sub_groups = 0 - nvme_memory_usage = 0 - num_swappable_partitions = 0 - num_swap_from_nvme_partitions = 0 - num_swap_from_cpu_partitions = 0 - swap_from_nvme_memory_usage = 0 - swap_from_cpu_memory_usage = 0 - GIGA_BYTES = (1024**3) - - swappable_fp32_tensors = [] - swappable_fp16_src_tensors = [] - nvme_fp16_partitions_info = [] - nvme_fp16_num_elems = [] - nvme_fp32_dest_tensors = [] - fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() - - for i, tensor in enumerate(self.fp16_partitioned_groups_flat): - num_elements = self.fp16_partitioned_groups_flat_numel[i] - - # a partition of the fp32 master weights that will be updated by this process - if self._swappable_optimizer_subgroup(i): - self.fp32_partitioned_groups_flat.append(torch.Tensor()) - nvme_memory_usage += (fp32_element_size * num_elements) - num_swappable_partitions += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - num_swap_from_nvme_partitions += 1 - swap_from_nvme_memory_usage += (fp32_element_size * num_elements) - if self.offload_optimizer_fast_init: - sub_group_partitions = self._get_sub_group_partitions(i) - nvme_fp16_partitions_info.append(sub_group_partitions) - nvme_fp16_num_elems.append(num_elements) - nvme_fp32_dest_tensors.append( - self.fp32_partitioned_groups_flat[i]) - else: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.optimizer_swapper.initialize_parameters( - parameters=[self.fp32_partitioned_groups_flat[i]], - src_tensors=[unpinned_fp32_buffer]) - else: - num_swap_from_cpu_partitions += 1 - swap_from_cpu_memory_usage += (fp32_element_size * num_elements) - swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) - swappable_fp16_src_tensors.append( - self.fp16_partitioned_groups_flat[i]) - else: - cpu_memory_usage += (fp32_element_size * num_elements) - cpu_memory_sub_groups += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) - else: - self.fp32_partitioned_groups_flat.append( - self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) - - self.fp32_partitioned_groups_flat[ - i].requires_grad = True # keep this in case internal optimizer uses it - - if len(swappable_fp32_tensors) > 0: - self.optimizer_swapper.initialize_parameters( - parameters=swappable_fp32_tensors, - src_tensors=swappable_fp16_src_tensors) - - if len(nvme_fp32_dest_tensors) > 0: - fp16_pinned_buffers = self.fp16_groups[0][ - 0].nvme_swapper.reserve_available_buffers() - assert len(fp16_pinned_buffers) > 0 - self.optimizer_swapper.initialize_from_swapped_fp16_params( - fp16_partitions_info=nvme_fp16_partitions_info, - fp16_num_elems=nvme_fp16_num_elems, - fp16_pinned_buffers=fp16_pinned_buffers, - fp32_parameters=nvme_fp32_dest_tensors) - self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() - - nvme_gigabytes = nvme_memory_usage / GIGA_BYTES - print_rank_0( - f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', - force=False) - if self.params_in_nvme_and_cpu: - print_rank_0( - f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - print_rank_0( - f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - - cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES - print_rank_0( - f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', - force=False) - - # Clear for on-the-fly population before the optimizer step - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _create_fp16_sub_groups(self, params_group): - - params_group_numel = sum([param.partitioned_size() for param in params_group]) - sub_group_size = self.sub_group_size - - if sub_group_size is None or sub_group_size >= params_group_numel: - return [params_group] - - sub_groups = [] - sub_group = [] - local_sub_group_size = 0 - for param in params_group: - - sub_group.append(param) - local_sub_group_size += param.partitioned_size() - - if local_sub_group_size >= sub_group_size or id(param) == id( - params_group[-1]): - - sub_groups.append(sub_group) - - sub_group = [] - local_sub_group_size = 0 - - return sub_groups - - # def reset_ds_tensor(self): - # for name, param in self.module.named_parameters(recurse=True): - # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" - # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" - # param.ds_tensor.data = param.data - - def setup_zero_stage3_hooks(self): - self.hierarchy = 0 - - #reset step if in inference mode - @instrument_w_nvtx - def _end_of_forward_hook(module, *args): - - if not torch._C.is_grad_enabled(): - self.param_coordinator.reset_step() - - #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) - - # Add top module to stack trace - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(self.module) - - def persistent_parameters(self): - persistent_params = [] - total_persistent_parameters = 0 - params_count = 0 - for _, param in self.module.named_parameters(recurse=True): - if param.ds_numel < self.persistence_threshold: - params_count += 1 - param.ds_persist = True - persistent_params.append(param) - total_persistent_parameters += param.ds_numel - - print_rank_0( - f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=False) - return persistent_params - - def _register_hooks_recursively(self, module, count=[0]): - my_count = count[0] - module.id = my_count - - #print(f"{module.__class__} : {module.id}") - - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) - - @instrument_w_nvtx - def _pre_forward_module_hook(module, *args): - self.pre_sub_module_forward_function(module) - - @instrument_w_nvtx - def _post_forward_module_hook(module, input, output): - global FWD_MODULE_STACK - FWD_MODULE_STACK.pop() - if output is None: - output = [] - elif not isinstance(output, (list, tuple)): - if torch.is_tensor(output): - output = [output] - else: - #print(f'got UNKNOWN type {type(output)}') - outputs = [] - output = output if isinstance(output, dict) else vars(output) - for name, val in output.items(): - if not name.startswith('__') and torch.is_tensor(val): - outputs.append(val) - output = outputs - #print(f'convert output to {output}') - - for item in filter(lambda item: is_zero_param(item), output): - if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.is_external_param = True - module_to_register = FWD_MODULE_STACK[-1] - print_rank_0( - f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', - force=False) - register_external_parameter(module_to_register, item) - - # It's possible that the parameter was already external to the completed module. If so, remove it the - # registration as it will be covered by the outer module instead. - if id(item) in module._external_params: - print_rank_0( - f' Unregistering nested dangling parameter from module {module.__class__.__name__}', - force=False) - unregister_external_parameter(module, item) - - item.all_gather() - - self.post_sub_module_forward_function(module) - - def _pre_backward_module_hook(module, inputs, output): - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - return _apply_to_tensors_only(module, - PreBackwardFunction, - _run_before_backward_function, - output) - - #This is an alternate to doing _post_backward_module_hook - #it uses tensor.register_hook instead of using torch.autograd.Function - def _alternate_post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - #print(f"Before Forward {module.__class__.__name__}") - - def _run_after_backward_hook(*unused): - module.ds_grads_remaining = module.ds_grads_remaining - 1 - if module.ds_grads_remaining == 0: - #print(f"After backward {module.__class__.__name__}") - self.post_sub_module_backward_function(module) - - def _run_before_forward_function(input): - if input.requires_grad: - module.ds_grads_remaining += 1 - - return _apply_forward_and_backward_to_tensors_only( - module, - _run_before_forward_function, - _run_after_backward_hook, - inputs) - - def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, - PostBackwardFunction, - _run_after_backward_function, - inputs) - - # Pre forward hook - module.register_forward_pre_hook(_pre_forward_module_hook) - # Post forward hook - module.register_forward_hook(_post_forward_module_hook) - - # Pre backward hook - module.register_forward_hook(_pre_backward_module_hook) - - # post backward hook - module.register_forward_pre_hook(_post_backward_module_hook) - - @torch.no_grad() - def pre_sub_module_forward_function(self, sub_module): - see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", - force=False) - - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(sub_module) - - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - - self.param_coordinator.fetch_sub_module(sub_module) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after fetch", - force=False) - - @torch.no_grad() - def post_sub_module_forward_function(self, sub_module): - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - self.param_coordinator.release_sub_module(sub_module) - - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - @torch.no_grad() - def pre_sub_module_backward_function(self, sub_module): - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - self.param_coordinator.fetch_sub_module(sub_module) - - @torch.no_grad() - def post_sub_module_backward_function(self, sub_module): - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - self.param_coordinator.release_sub_module(sub_module) - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - def _release_ipg_buffers(self): - if self.contiguous_gradients: - self.ipg_buffer = None - if not self.offload_optimizer and self.is_gradient_accumulation_boundary: - self.grads_in_partition = None - - self.grads_in_partition_offset = 0 - - def _optimizer_step(self, sub_group_id): - param_group_id = self.sub_group_to_group_id[sub_group_id] - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] - - def _swappable_optimizer_subgroup(self, sub_group_id): - if not self.swap_optimizer: - return False - - return self.optimizer_swapper.swappable_tensor( - None, - numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) - - def _partitioned_params_swap_out(self, i): - offset = 0 - fp32_param = self.fp32_partitioned_groups_flat[i] - assert fp32_param is not None, \ - f'fp32 parameters of sub_group {i} is None' - - swap_fp16_params = [] - swap_fp32_params = [] - for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): - src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.AVAILABLE: - partitioned_param.data.copy_(src.data) - else: - swap_fp32_params.append(src) - swap_fp16_params.append(param) - offset += partitioned_param.ds_numel - - if len(swap_fp16_params): - swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( - dst_fp16_params=swap_fp16_params, - src_fp32_params=swap_fp32_params) - - def initialize_optimizer_states(self): - num_subgroups = len(self.fp16_groups) - - largest_numel = max( - [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) - gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype - gradient_buffer = torch.zeros(int(largest_numel), - dtype=gradient_dtype, - device=self.device) - - timers = self.timers - timer_names = set() - - if self.swap_optimizer: - self.optimizer_swapper.init_timers() - - INIT_OPTIMIZER_TIMER = 'init_optimizer_state' - timer_names.add(INIT_OPTIMIZER_TIMER) - self.start_timers([INIT_OPTIMIZER_TIMER]) - - for i, group in enumerate(self.fp16_groups): - swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) - swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None - - num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) - - see_memory_usage( - f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_in(i, timer_names) - - if self.offload_optimizer and not swappable_optimizer_subgroup: - subgroup_gradient_buffer = torch.zeros(num_elements, - dtype=gradient_dtype, - device=self.device) - if self.offload_optimizer_pin_memory: - subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() - - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer - else: - self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( - 0, - 0, - num_elements) - - self._optimizer_step(i) - - if swappable_param_subgroup: - self._partitioned_params_swap_out(i) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_out(i, timer_names) - - see_memory_usage( - f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - self.stop_timers([INIT_OPTIMIZER_TIMER]) - self.log_timers(timer_names) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - if not self.offload_optimizer: - for group in self.fp32_partitioned_groups_flat: - group.grad = None - - # Reset steps - return - - ######################################################################### - #########################ZeRO Partition Gradients######################## - ######################################################################### - - def get_first_param_index(self, group_id, param_group, partition_id): - for index, param in enumerate(param_group): - param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index - return None - - def initialize_gradient_partitioning_data_structures(self): - - total_partitions = dist.get_world_size(group=self.dp_process_group) - - for i, param_group in enumerate(self.fp16_groups): - - self.param_to_partition_ids[i] = {} - self.is_partition_reduced[i] = {} - self.total_grads_in_partition[i] = {} - self.remaining_grads_in_partition[i] = {} - self.is_grad_computed[i] = {} - self.grad_partition_insertion_offset[i] = {} - self.grad_start_offset[i] = {} - self.first_param_index_in_partition[i] = {} - - for partition_id in range(total_partitions): - self.is_grad_computed[i][partition_id] = {} - self.grad_partition_insertion_offset[i][partition_id] = {} - self.grad_start_offset[i][partition_id] = {} - self.initialize_gradient_partition(i, param_group, partition_id) - self.is_partition_reduced[i][partition_id] = False - self.first_param_index_in_partition[i][ - partition_id] = self.get_first_param_index( - i, - param_group, - partition_id) - - @instrument_w_nvtx - def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.__reduce_and_partition_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - - self.__reduce_and_partition_stream.synchronize() - - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False - - #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad - #TODO: use a similar code path for both cpu_offload and non-cpu offload - if not self.offload_optimizer: - for i, sub_group in enumerate(self.fp16_groups): - self.averaged_gradients[i] = [ - self.__param_id_to_grad_partition[param.ds_id] - if param.requires_grad else torch.zeros_like(param.ds_tensor) - for param in sub_group - ] - # self.averaged_gradients[i] = self.get_flat_partition( - # self.fp16_groups[i], - # 0, - # self.fp32_partitioned_groups_flat[i].numel(), - # return_tensor_list=True) - - # this method gets called after every backward. need to increment - # here because if it gets incremented in backward() the micro step - # id will be off by one when we do the reduce and partition at the. - # start of this method. - # TODO. make this less error prone - self.micro_step_id += 1 - - def overlapping_partition_gradients_reduce_epilogue(self): - self.independent_gradient_partition_epilogue() - - def create_reduce_and_remove_grad_hooks(self): - print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] - for i, param_group in enumerate(self.fp16_groups): - for param in param_group: - if param.requires_grad: - #print_rank_0(f" Before all gather {param.device}, {param.shape}") - - # The hook must be created in un-partitioned parameter - param.all_gather() - - #print(f"After all gather {param.device}, {param.shape}") - def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - @instrument_w_nvtx - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) - - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) - - #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param, i) - - # Partition the parameter after creating the hook - param.partition() - print_rank_0(f'[End] Create gradient reduction hooks') - - def get_param_id(self, param): - unique_id = id(param) - return self.param_id[unique_id] - - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", - force=False) - - ###############Idependent Partition Gradient ######################## - def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) - - # Because the ipg bucket is initialized with a random place holder tensor, we must - # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > - # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a - # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be - # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: - self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", - param.ds_numel) - - self.__reduce_and_partition_ipg_grads() - - param_id = self.get_param_id(param) - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.__add_grad_to_ipg_bucket(param) - - @instrument_w_nvtx - @torch.no_grad() - def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) - - if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( - ) < self.reduce_bucket_size: - # move the gradient to a contiguous buffer - with torch.cuda.stream(self.__reduce_and_partition_stream): - # move the parameter's gradient to the contiguous flat buffer - new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( - 0, - self.elements_in_ipg_bucket, - param.grad.numel()).view_as(param.grad) - new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(torch.cuda.current_stream()) - param.grad.data = new_grad_tensor - - self.__params_in_ipg_bucket.append(param) - - @instrument_w_nvtx - @torch.no_grad() - def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: - if not self.__params_in_ipg_bucket: - return - - for param in self.__params_in_ipg_bucket: - if param.grad.numel() != param.ds_numel: - raise RuntimeError( - f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " - f"gradients whose size is not same as the params") - - self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - - assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( - self.__params_in_ipg_bucket) - - while self.__param_reduce_events and self.__param_reduce_events[0].query(): - self.__param_reduce_events.popleft() - if len(self.__param_reduce_events) > self.__max_param_reduce_events: - self.__param_reduce_events.popleft().synchronize() - - with torch.cuda.stream(self.__reduce_and_partition_stream): - if safe_mode: - assert_ints_same_as_other_ranks( - [p.ds_id for p in self.__params_in_ipg_bucket]) - - grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) - self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) - - self.__params_in_ipg_bucket.clear() - - event = Event() - event.record() - self.__param_reduce_events.append(event) - - @instrument_w_nvtx - def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: - """average gradients and scatter partitions across ranks""" - dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) - - full_grads_for_rank = [p.grad for p in params_to_reduce] - if self.allreduce_always_fp32: - full_grads_for_rank = [g.float() for g in full_grads_for_rank] - - if self.postscale_gradients and self.gradient_predivide_factor != 1.0: - full_grads_for_rank = [ - g.div(self.gradient_predivide_factor) for g in full_grads_for_rank - ] - - grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, - self.dp_process_group) - - if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( - self.dp_process_group): - grad_partitions_for_rank = [ - g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank - ] - - if self.allreduce_always_fp32: - grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] - - return grad_partitions_for_rank - - def set_grad_positions(self): - for i, group in enumerate(self.fp16_groups): - current_offset = 0 - for param in group: - param_id = self.get_param_id(param) - num_elements = param.ds_tensor.ds_numel - - self.grad_position[param_id] = [ - int(i), - int(current_offset), - int(num_elements) - ] - #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") - current_offset += num_elements - - def _constant_buffered_norm2(self, input, buffer_size=250000000): - norm = None - for part in input.view(-1).split(buffer_size): - if norm is None: - norm = part.data.double().norm(2)**2.0 - else: - norm += part.data.double().norm(2)**2.0 - return norm**0.5 - - def set_norm_for_param_grad_in_gpu(self, param): - param_id = self.get_param_id(param) - #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) - #Using a more memory efficient version - self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): - with torch.cuda.stream(self.copy_grad_stream): - param_id = self.get_param_id(param) - src_tensor = param.grad.view(-1).float() - #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") - fp32_grad_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None - - def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 - for p in params: - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_id = self.get_param_id(p) - if param_id in self.norm_for_param_grads.keys(): - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - @instrument_w_nvtx - def __partition_grads(self, - params_to_release: List[Parameter], - grad_partitions: List[Tensor]) -> None: - for param, grad_partition in zip(params_to_release, grad_partitions): - if param.ds_tensor.ds_numel * dist.get_rank( - self.dp_process_group) > param.ds_numel: - # this grad partition is empty - don't need to do anything - continue - - # move or accumulate gradient partition to target buffer - grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( - 0, - 0, - grad_partition.numel()) - if self.micro_step_id == 0: # don't accumulate - grad_buffer.copy_(grad_partition, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) - elif grad_buffer.is_cuda: - grad_buffer.add_(grad_partition) - else: - # if dst is CPU, copy first to src device, do the addition - # there, then move back to dst. adding directly to cpu is very slow - cuda_grad_buffer = grad_buffer.to(grad_partition.device, - non_blocking=True) - cuda_grad_buffer.add_(grad_partition) - grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = cuda_grad_buffer - - if hasattr(self.__inf_or_nan_tracker, "logical_or_"): - self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 - - # offload the gradient partition if applicable - if self.offload_optimizer: - i, dest_offset, _ = self.grad_position[self.get_param_id(param)] - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - if self.is_gradient_accumulation_boundary: - self.norm_for_param_grads[self.get_param_id( - param)] = self._constant_buffered_norm2(grad_buffer) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(grad_buffer.float()) - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) - - # free the gradient - param.grad.record_stream(torch.cuda.current_stream()) - param.grad = None - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - def reduce_ready_partitions_and_remove_grads(self, param, i): - #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) - self.reduce_independent_p_g_buckets_and_remove_grads(param, i) - - def zero_reduced_gradients(self, partition_id, i): - def are_all_related_partitions_reduced(params_id): - for partition_id in self.param_to_partition_ids[i][params_id]: - if not self.is_partition_reduced[i][partition_id]: - return False - return True - - for params_id in self.is_grad_computed[i][partition_id]: - if are_all_related_partitions_reduced(params_id): - self.param_dict[params_id].grad = None - - def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = self.flatten(tensors) - - def print_func(): - logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) - - self.sequential_execution(print_func, message) - - def get_grads_to_reduce(self, i, partition_id): - def get_reducible_portion(key): - grad = self.param_dict[key].grad - total_elements = grad.numel() - start = self.grad_start_offset[i][partition_id][key] - num_elements = min( - total_elements - start, - self.partition_size[i] - - self.grad_partition_insertion_offset[i][partition_id][key]) - if not pg_correctness_test: - if num_elements == total_elements: - return grad - else: - return grad.contiguous().view(-1).narrow(0, - int(start), - int(num_elements)) - else: - if num_elements == total_elements: - return grad.clone() - else: - return grad.clone().contiguous().view(-1).narrow( - 0, - int(start), - int(num_elements)) - - grads_to_reduce = [] - for key in self.is_grad_computed[i][partition_id]: - grad = get_reducible_portion(key) - grads_to_reduce.append(grad) - return grads_to_reduce - - def sequential_execution(self, function, message, group=None): - if group is None: - group = self.dp_process_group - if dist.get_rank(group=group) == 0: - logger.info(message) - for id in range(dist.get_world_size(group=group)): - if id == dist.get_rank(group=group): - function() - dist.barrier(group=group) - - def set_none_gradients_to_zero(self, i, partition_id): - for param_id in self.is_grad_computed[i][partition_id]: - param = self.param_dict[param_id] - if param.grad is None: - param.grad = torch.zero_like(param) - - ######################Reduction Related Methods############################## - - def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): - rank = None - tensor = self.flatten(bucket) - - tensor_to_allreduce = tensor - - if pg_correctness_test: - allreduce_always_fp32 = True - - if allreduce_always_fp32: - tensor_to_allreduce = tensor.float() - - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) - - if rank is None: - # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - else: - global_rank = _get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) - - if allreduce_always_fp32 and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - tensor.copy_(tensor_to_allreduce) - - return tensor - - # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): - with torch.cuda.stream(self.reduction_stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): - buf.copy_(synced) - - def allreduce_no_retain(self, - bucket, - numel_per_bucket=500000000, - rank=None, - log=None): - small_bucket = [] - numel = 0 - for tensor in bucket: - small_bucket.append(tensor) - numel = numel + tensor.numel() - if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) - small_bucket = [] - if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) - - ############################################################################# - ############################################################################# - ############################################################################# - - # views the tensor as multiple partitions and returns - # those partitions - def get_data_parallel_partitions(self, tensor): - partitions = [] - - dp = dist.get_world_size(group=self.dp_process_group) - dp_id = dist.get_rank(group=self.dp_process_group) - - total_num_elements = tensor.numel() - - base_size = total_num_elements // dp - remaining = total_num_elements % dp - - start = 0 - for id in range(dp): - partition_size = base_size - if id < remaining: - partition_size = partition_size + 1 - partitions.append(tensor.narrow(0, start, partition_size)) - start = start + partition_size - return partitions - - def get_partition_info(self, tensor_list, partition_size, partition_id): - params_in_partition = [] - params_not_in_partition = [] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for tensor in tensor_list: - - tensor_size = tensor.numel() - - if (current_index >= start_index and current_index < end_index): - params_in_partition.append(tensor) - - elif start_index > current_index and start_index < (current_index + - tensor_size): - params_in_partition.append(tensor) - - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - else: - params_not_in_partition.append(tensor) - - current_index = current_index + tensor_size - - return params_in_partition, params_not_in_partition, first_offset - - @instrument_w_nvtx - def zero_grad(self, set_grads_to_None=True): - """ - Zero FP16 parameter grads. - """ - self.micro_step_id = 0 - - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_grads_to_None: - if p.grad is not None and p.grad.is_cuda: - p.grad.record_stream(torch.cuda.current_stream()) - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def _model_parallel_all_reduce(self, tensor, op): - """ Perform all reduce within model parallel group, if any. - """ - if self.model_parallel_group is None: - pass - else: - torch.distributed.all_reduce(tensor=tensor, - op=op, - group=self.model_parallel_group) - - @instrument_w_nvtx - def get_grad_norm_direct(self, gradients, params, norm_type=2): - """Clips gradient norm of an iterable of parameters. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - grad_norms = [] - for g, p in zip(gradients, params): - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda.item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - # creates a flat fused tensor from the tensor list starting at the first_offset - # in the first tensor of the list. If there are not enough elements in the tensor - # list then the flat tensor will be padded with zeros - def get_flat_partition(self, - tensor_list, - first_offset, - partition_size, - return_tensor_list=False): - flat_tensor_list = [] - current_size = 0 - for i, tensor in enumerate(tensor_list): - if tensor.grad is None: - tensor.grad = torch.zeros_like(tensor) - - tensor = tensor.grad - num_elements = tensor.numel() - tensor_offset = 0 - - # we need to offset to get to the right element - if i == 0 and first_offset > 0: - tensor_offset = first_offset - num_elements = num_elements - tensor_offset - - # we dont need all elements of the tensor - if num_elements > (partition_size - current_size): - num_elements = partition_size - current_size - - # we need a narrow view of the tensor based on the tensor offset and number of elements that - # we need from this tensor - if tensor_offset > 0 or num_elements < tensor.numel(): - flat_tensor_list.append(tensor.contiguous().view(-1).narrow( - 0, - int(tensor_offset), - int(num_elements))) - else: - flat_tensor_list.append(tensor) - - current_size = current_size + num_elements - - # this means its the last partition and does not align with the dp boundary. We need to pad before flattening - if current_size < partition_size: - flat_tensor_list.append( - torch.zeros(int(partition_size - current_size), - dtype=tensor_list[0].dtype, - device=tensor_list[0].device)) - - if return_tensor_list: - return flat_tensor_list - - return self.flatten(flat_tensor_list) - - def free_grad_in_param_list(self, param_list): - for p in param_list: - p.grad = None - - def reset_cpu_buffers(self): - self.norm_for_param_grads = {} - self.local_overflow = False - - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() - - def _pre_step(self): - self.micro_step_id = 0 - - print_rank_0(f"Inside Step function") - see_memory_usage(f"In step before checking overflow", force=False) - - print_rank_0("Finished Tracing at Beginning of Step") - self.param_coordinator.hierarchy = 0 - - print_rank_0("Finished Tracing at Beginning of Step") - - @instrument_w_nvtx - def _get_norm_groups(self): - norm_groups = [] - for i, group in enumerate(self.fp16_groups): - if self.offload_optimizer: - norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload( - self.fp16_groups[i])) - else: - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.fp16_groups[i])) - return norm_groups - - @instrument_w_nvtx - def _prepare_fp32_grad_for_sub_group(self, sub_group_id): - partition_id = dist.get_rank(group=self.dp_process_group) - - single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( - self.fp32_partitioned_groups_flat[sub_group_id].dtype) - - assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format( - single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) - - self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition - - # release all the gradient since we have already created a necessary copy in dp_grad_partition - self.zero_grad() - - for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): - grad.record_stream(torch.cuda.current_stream()) - - self.averaged_gradients[sub_group_id] = None - - @instrument_w_nvtx - def _prepare_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', - force=False) - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) - elif not self.offload_optimizer: - self._prepare_fp32_grad_for_sub_group(sub_group_id) - see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', - force=False) - - def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' - see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_IN_STATE]) - - self.optimizer_swapper.swap_in_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) - - self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) - timer_names.add(OPTIMIZER_SWAP_IN_STATE) - see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', - force=False) - - @instrument_w_nvtx - def _release_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before release optimizer sub group {sub_group_id}', - force=False) - # get rid of the fp32 gradients. Not needed anymore - if not self.offload_optimizer: - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) - see_memory_usage(f'After release optimizer sub group {sub_group_id}', - force=False) - - # create a flat tensor aligned at the alignment boundary - @instrument_w_nvtx - def flatten_dense_tensors_aligned(self, tensor_list, alignment): - num_elements = 0 - for tens in tensor_list: - num_elements = num_elements + tens.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return self.flatten(padded_tensor_list) - - def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' - see_memory_usage( - f'post-step Before swapping out optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) - - self.optimizer_swapper.swap_out_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is - not None) - - self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) - see_memory_usage( - f'post-step After swapping out optimizer tensors {sub_group_id}', - force=False) - timer_names.add(OPTIMIZER_SWAP_OUT_STATE) - - # get rid of the fp32 gradients. Not needed anymore - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - def _unflatten_partitioned_parameters(self, sub_group_id): - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - def _overflow_clean_up(self, prev_scale): - see_memory_usage('After overflow before clearing gradients', force=False) - self.zero_grad() - - if self.offload_optimizer: - self.reset_cpu_buffers() - else: - self.averaged_gradients = {} - - see_memory_usage('After overflow after clearing gradients', force=False) - - if torch.distributed.get_rank() == 0: - logger.info( - "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) - - @instrument_w_nvtx - def _overflow_check_and_loss_scale_update(self): - - # First compute norm for all group so we know if there is overflow - self.check_overflow() - - #loss scaling related computation - prev_scale = self.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self._overflow_clean_up(prev_scale) - - return self.overflow - - @instrument_w_nvtx - def _post_step(self, timer_names=set()): - if self.offload_optimizer: - self.reset_cpu_buffers() - - #Gathering persisting parameters - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - self.log_timers(timer_names) - - see_memory_usage('After zero_optimizer step', force=False) - print_rank_0(f"------------------Finishing Step-----------------------") - - @instrument_w_nvtx - def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): - if self.fp16_partitioned_groups_flat[sub_group_id] is not None: - self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( - self.fp32_partitioned_groups_flat[sub_group_id].data) - - #unflatten fp16 parameter subgroup - self._unflatten_partitioned_parameters(sub_group_id) - else: - self._partitioned_params_swap_out(sub_group_id) - - @instrument_w_nvtx - def step(self, closure=None): - """ - Not supporting closure. - """ - self._pre_step() - self._partition_all_parameters() - - #checks for overflow, adjust the loss scale accordingly - if self._overflow_check_and_loss_scale_update(): - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - return - - norm_groups = self._get_norm_groups() - self._global_grad_norm = get_global_norm(norm_list=norm_groups) - - timer_names = set() - - timer_names.add('optimizer_step') - self.start_timers(['optimizer_step']) - - #update parameters one sub group at a time - for sub_group_id, group in enumerate(self.fp16_groups): - - #prepare optimizer states, gradients and fp32 parameters for update - self._prepare_sub_group(sub_group_id, timer_names) - - #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) - - #apply the optimizer step on the sub group and copy fp32 parameters to fp16 - self._optimizer_step(sub_group_id) - - #put fp16 parameters in appropriate location - self._reassign_or_swap_out_partitioned_parameters(sub_group_id) - - #release memory or swap out optimizer states of fp32 parameters - self._release_sub_group(sub_group_id, timer_names) - - self.stop_timers(['optimizer_step']) - - self._post_step(timer_names) - - # warn user about caching allocator flushes - alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( - torch.cuda, - "memory_stats") else 0 - if alloc_retries > self.__n_caching_allocator_flushes: - if dist.get_rank() == 0: - logger.warning( - "%d pytorch allocator cache flushes since last step. this happens " - "when there is high memory pressure and is detrimental to " - "performance. if this is happening frequently consider adjusting " - "settings to reduce memory consumption. If you are unable to " - "make the cache flushes go away consider adding " - "torch.cuda.empty_cache() calls in your training loop to ensure " - "that all ranks flush their caches at the same time", - alloc_retries - self.__n_caching_allocator_flushes) - self.__n_caching_allocator_flushes = alloc_retries - - def dump_pre_step_gradients(self, debug_fp32_grads): - # Dump gradient norms for debugging - for i, _ in enumerate(self.fp16_groups): - print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') - for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): - param_id = self.get_param_id(fp16_param) - fp16_grad_norm = self.debug_fp16_grads[i][param_id] - - fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] - norm_list = [fp16_grad_norm, fp32_grad_norm] - print(f'Pre-Step Norms {i} {param_id} = {norm_list}') - - def dump_post_step_gradients(self): - # Dump gradient norms for debugging - for i, group in enumerate(self.fp16_groups): - print( - f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') - unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) - unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], - self.fp16_groups[i]) - for j, p in enumerate(self.fp16_groups[i]): - param_id = self.get_param_id(p) - param_norm = float(p.data.float().norm(2)) - ds_norm = float(p.ds_tensor.data.float().norm(2)) - - unflat_norm = [ - float(t.data.float().norm(2)) - for t in [unflat_fp16[j], - unflat_fp32[j]] - ] - norm_list = [param_norm, ds_norm] + unflat_norm - print(f'Post-Step Norms {i} {param_id} = {norm_list}') - - @instrument_w_nvtx - def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - - # compute combined scale factor for this group - combined_scale = self.loss_scale - if self.clip_grad > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale - # to maintain behavior of averaging over accumulation steps - combined_scale *= self.micro_step_id + 1 - - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) - - def _check_overflow(self, partition_gradients=True): - self.overflow = self.has_overflow(partition_gradients) - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params, is_grad_list=False): - for p in params: - if p.grad is not None and self._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow_partitioned_grads_serial(self): - for i in range(len(self.fp16_groups)): - for j, grad in enumerate(self.averaged_gradients[i]): - if grad is not None and self._has_inf_or_nan(grad.data, j): - return True - return False - - @instrument_w_nvtx - def has_overflow(self, partition_gradients=True): - if partition_gradients: - with torch.cuda.stream(self.__reduce_and_partition_stream): - self.local_overflow = bool(self.__inf_or_nan_tracker.item()) - self.__inf_or_nan_tracker.zero_() - - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = torch.cuda.ByteTensor([overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - else: - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - - overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) - - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - self._model_parallel_all_reduce(tensor=overflow_gpu, - op=torch.distributed.ReduceOp.MAX) - - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - @staticmethod - def _has_inf_or_nan(x, j=None): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - @instrument_w_nvtx - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - if self.swap_optimizer: - self.optimizer_swapper.pre_backward() - - see_memory_usage(f"Before backward", force=False) - - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - self.param_coordinator.reset_step() - - if self.swap_optimizer: - self.optimizer_swapper.post_backward() - - def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: - """get fp32 gradient partition dictionary - accessed as grad_dict[parameter_group_index][parameter_index] - """ - self.__reduce_and_partition_stream.synchronize() - grad_dict = collections.defaultdict(dict) - if self.offload_optimizer: - for group in self.fp16_groups: - for param_idx, param in enumerate(group): - group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( - 0, - dest_offset, - num_elements) - grad_dict[group_idx][param_idx] = fp32_grad - else: - for group_idx, group in self.averaged_gradients.items(): - for param_idx, gradient in enumerate(group): - grad_dict[group_idx][param_idx] = gradient.float() - - return grad_dict - - @instrument_w_nvtx - def _partition_all_parameters(self): - """Partitioning Parameters that were not partitioned usually if parameters - of modules whose input parameters do not require grad computation do not - trigger post call and will therefore will remain unpartitioned""" - self.param_coordinator.release_and_reset_all() - for param in iter_params(self.module, recurse=True): - if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(f"{param.ds_summary()} expected to be released") - - def check_overflow(self, partition_gradients=True): - self._check_overflow(partition_gradients) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - cur_scale = property(_get_loss_scale, _set_loss_scale) - - def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): - # Remove paddings from flattened tensor - individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) - lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] - lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] - #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') - return lean_tensors - - #TODO REVISIT this for stage 3 - def get_lean_optimizer_state(self): - # Return optimizer states after removing paddings. - # This method assumes that each param group contains a single flattened tensor. - optimizer_groups_state = [] - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - lean_state = {} - for key, value in self.optimizer.state[p].items(): - if torch.is_tensor(value): - padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] - lean_state[key] = self._get_lean_tensors( - value, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - lean_flat_len = sum([t.numel() for t in lean_state[key]]) - else: - lean_state[key] = value - - optimizer_groups_state.append(lean_state) - - return optimizer_groups_state - - def get_groups_without_padding(self, groups_with_padding): - # Return group tensor after removing paddings added for alignment to DP world size. - groups_without_padding = [] - for i, group in enumerate(groups_with_padding): - lean_group = self._get_lean_tensors(group, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - groups_without_padding.append(lean_group) - - return groups_without_padding - - def _set_fp32_optimizer_param_groups(self): - for sub_group_id, _ in enumerate(self.fp16_groups): - param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'].append( - self.fp32_partitioned_groups_flat[sub_group_id]) - - def _clear_fp32_optimizer_param_groups(self): - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _rigid_state_dict(self): - state_dict = {} - state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS - state_dict['loss_scaler'] = self.loss_scaler - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['overflow'] = self.overflow - state_dict['partition_count'] = self.partition_count - - self._set_fp32_optimizer_param_groups() - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat - self._clear_fp32_optimizer_param_groups() - - return state_dict - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - return self._rigid_state_dict() - - -# Restore base optimizer fp32 weights from checkpoint by: -# 1) Merging fp32 weights from checkpoints of all partitions -# 2) Extracting fp32 weights for current partition from merged weights -# 3) Using extracted weights to update base optimizer weights directly. - - def _restore_from_fp32_weights(self, all_state_dict): - - flat_local_partition = [] - for i in range(len(self.fp32_partitioned_groups_flat)): - merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] - flat_local_partition.append(self._get_flattened_partition(merged_partitions)) - - for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): - current.data.copy_(saved.data) - - # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): - for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): - fp32_partition.data.copy_(fp16_partitions.data) - - # Refresh the fp32 master params from the fp16 copies. - def refresh_fp32_params(self): - self._restore_from_fp16_weights() - - # Extract flattened partition for current rank from all partitions - def _get_flattened_partition(self, all_partition_states): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) - - param_partitions = [[] for _ in range(len(all_partition_states[0]))] - for i, partition in enumerate(all_partition_states): - for j, param in enumerate(partition): - param_partitions[j].append(param) - - local_state_partitions = [] - for param_index, param_slices in enumerate(param_partitions): - flattened_merged_tensor = self.flatten_dense_tensors_aligned( - param_slices, - alignment) - new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) - local_state_partitions.append(new_partitions[partition_id]) - - if torch.is_tensor(local_state_partitions[0]): - return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) - - # Assume non-tensor states are not partitioned and equal across ranks, so return first one - return local_state_partitions[0] - - # Restore base optimizer state from checkpoint by - # 1) Merging optimizer state from checkpoints of all partitions - # 2) Extracting optimizer state for current partition from the merged state - # 3) Using the extracted value to directly update the base optimizer. - def _restore_base_optimizer_state(self, all_state_dict): - base_optimizer_group_states = [] - for i in range(len(self.optimizer.param_groups)): - partition_states = {} - all_partition_group_states = [ - sd['base_optimizer_state'][i] for sd in all_state_dict - ] - for key in all_partition_group_states[0].keys(): - all_partition_states = [ - all_states[key] for all_states in all_partition_group_states - ] - partition_states[key] = self._get_flattened_partition( - all_partition_states) - base_optimizer_group_states.append(partition_states) - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - self.optimizer.state[p][key].data.copy_(saved.data) - else: - self.optimizer.state[p][key] = saved - - def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.overflow = state_dict['overflow'] - - if load_optimizer_states: - self._set_fp32_optimizer_param_groups() - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - self._clear_fp32_optimizer_param_groups() - - # restore fp32 partitions - for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): - curr_param.data.copy_(saved_param.data) - - # restore fp16 partitions from fp32 - for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - fp16_param.data.copy_(fp32_param.data) - - # update fp16 unflattened params - for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): - updated_params = self.unflatten( - self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - # TODO: Support different/changing load/save DP degree. - def load_state_dict(self, - state_dict_list, - load_optimizer_states=True, - load_from_fp32_weights=False): - r"""Loading a ZeRO checkpoint - Arguments: - state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. - Note that the number of saved partitions may differ from number of loading partitions to support - changing GPU count, specifically DP world size, between saving and loading checkpoints. - load_optimizer_states: Boolean indicating whether or not to load base optimizer states - load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 - copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). - """ - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - self._rigid_load_state_dict( - state_dict_list[dist.get_rank(group=self.dp_process_group)], - load_optimizer_states=load_optimizer_states) - - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].partition(self.persistent_parameters) - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - def save_checkpoint_prologue(self): - self._partition_all_parameters() - - def save_checkpoint_epilogue(self): - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - -def _handle_overflow(cpu_sum, x, i): - import math - rank = torch.distributed.get_rank() - if rank == 0: - t_i = -1 - for v_i, v in enumerate(x.data.contiguous().view(-1)): - if not math.isfinite(float(v)): - t_i = v_i - break - logger.info( - f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" - ) - - -def estimate_zero3_model_states_mem_needs(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - cpu_offload=True, - cpu_offload_params=True, - zero_init=True, - additional_buffer_factor=1.5): - - total_gpus = num_nodes * num_gpus_per_node - gpus_factor = 1 / num_nodes - largest_layer_memory = (4 * largest_layer_params) - - if cpu_offload: - if cpu_offload_params: - gpu_mem = largest_layer_memory - - if zero_init: - cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 18 * gpus_factor) * additional_buffer_factor - - else: - gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) - - if zero_init: - cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 16 * gpus_factor) * additional_buffer_factor - else: - gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) - if zero_init: - cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor - else: - cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor - - return int(cpu_mem), int(gpu_mem), largest_layer_memory - - -def model_to_params(model): - # shared params calculated only once - total_params = sum( - dict((p.data_ptr(), - p.numel()) for p in model.parameters()).values()) - - largest_layer_params = 0 - for m in model.modules(): - # assuming no shared params within a single layer - layer_params = sum(p.numel() for p in m.parameters(recurse=False)) - largest_layer_params = max(largest_layer_params, layer_params) - - return total_params, largest_layer_params - - -import math - - -def estimate_zero3_model_states_mem_needs_all_live(model, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If you have an actual model object, use this function and everything will be derived - automatically. - - If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - Args: - - ``model``: ``nn.Module`` object - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - - total_params, largest_layer_params = model_to_params(model) - - estimate_zero3_model_states_mem_needs_all_cold( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - additional_buffer_factor=additional_buffer_factor) - - -def estimate_zero3_model_states_mem_needs_all_cold(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If it's a hypothetical model, use this function where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything - will be derived automatically. - - Args: - - ``total_params``: total model params - - ``largest_layer_params``: largest layer's params - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - def format_options(cpu_offload, cpu_offload_params, zero_init): - enabled = [] - enabled.append(f"cpu_offload={1 if cpu_offload else 0}") - enabled.append(f"cpu_offload_params={1 if cpu_offload_params else 0}") - enabled.append(f"zero_init={1 if zero_init else 0}") - return ", ".join(enabled) - - nodes_str = "nodes" if num_nodes > 1 else "node" - gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" - print( - "Estimated memory needed for params, optim states and gradients for a:\n" - f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" - f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." - ) - print(" per CPU | per GPU | Options") - for cpu_offload in [True, False]: - for cpu_offload_params in [True, False]: - if not cpu_offload and cpu_offload_params: - continue - for zero_init in [True, False]: - cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init, - additional_buffer_factor=additional_buffer_factor - ) - - options_str = format_options(cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init) - print( - f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") +""" +"Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +""" + +import gc +from dataclasses import dataclass +import functools +import os +import collections +from collections import OrderedDict, UserDict +import itertools +from typing import Deque, Dict, Iterable, Set, Tuple +import torch +from torch.cuda import Event, Stream +from torch.nn import Module, Parameter +import torch.distributed as dist +import math +from torch._six import inf +from torch.nn import Module +from torch.nn.parameter import Parameter + +from deepspeed.utils.logging import logger +from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.runtime.zero.offload_constants import * +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper +from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +FWD_MODULE_STACK = list() + + +def print_rank_0(message, debug=False, force=False): + rank = torch.distributed.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def input(msg): + return + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + +def get_cuda_mem_allocated_str() -> str: + # this is really slow. when enabled the python process becomes slow + # to the point where it can't keep the GPU fed with work, so only enable + # for memory debugging. + # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" + return "xGB" + + +def move_to_cpu(tensor_list): + for tensor in tensor_list: + tensor.data = tensor.data.cpu() + + +@instrument_w_nvtx +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), + sub_module.ds_external_parameters()) + + +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + +#apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, + functional, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, + forward_function, + backward_function, + outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only( + module, + forward_function, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + def __init__(self, parent_module, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + print_rank_0(f'Registering external parameter from getter {key}', + force=False) + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" + class __InflightParamRegistry(UserDict): + """registry for parameters in flight""" + def __setitem__(self, + param: Parameter, + handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"attempted to add non-inflight parameter to registry {param.ds_summary()}" + ) + self.data[param] = handle + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = __class__.__InflightParamRegistry() + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # whether or not we have completed a trace of the entire network. This should + # always be true after the first forward pass + backward pass. + self.trace_complete: bool = False + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel + # max distance between two use of the module beyond which module is released + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: Deque[__class__.__ParamInTrace] = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + + # stream that will be used for allgather operations + self.__allgather_stream: Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesnt require any configuration by the user. + self.__ongoing_fetch_events: Deque[Event] = collections.deque() + self.__max_ongoing_fetch_events: int = 2 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ + + def record_trace(self, sub_module: Module) -> None: + """adds sub module to trace""" + if self.trace_complete: + raise RuntimeError( + "attemted to record trace when trace was already complete") + + self.__submodule_order.append(sub_module) + for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + self.__param_order.append( + __class__.__ParamInTrace(param=param, + step_id_last_used_at=self.__step_id)) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + if self.__inflight_param_registry: + raise RuntimeError( + f"still have inflight params " + f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + + if not self.trace_complete: + # make sure that recorded parameter and submodule orders are + # identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.trace_complete = True + print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", + force=True) + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + self.__step_id = 0 + self.__n_available_params = 0 + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ + + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + "allocated": get_cuda_mem_allocated_str() + })) + + params_to_fetch = frozenset(iter_params(current_submodule)) + + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch) + + # wait for parameters in the immediately needed submodule to become available + for param in iter_params(current_submodule): + param.ds_active_sub_modules.add(current_submodule.id) + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + with torch.cuda.stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events + ) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + + self.__inflight_param_registry.pop(param).wait() + + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + torch.cuda.current_stream().wait_stream(self.__allgather_stream) + + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace, or if we aren't + # training (throws off the tracing and don't want to prefetch modules for bwd) + if self.trace_complete and current_submodule.training: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + params_to_release = (self.__params_to_release(submodule, + self.__step_id) + if self.trace_complete else set( + p.ds_id for p in iter_params(submodule))) + + for param in iter_params(submodule): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self) -> None: + """release all module parameters""" + for param in map(lambda p: p.param, self.__param_order): + if param in self.__inflight_param_registry: + raise RuntimeError(f"param {param.ds_summary()} still in flight") + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param_in_trace in self.__param_order: + if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError( + f"{param_in_trace.param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter]) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + self.__n_available_params += param.ds_numel + + if partitioned_params: + with torch.cuda.stream(self.__allgather_stream): + handle = partitioned_params[0].all_gather_coalesced(partitioned_params) + + for param in partitioned_params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle + + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, + submodule_to_release: Module, + step_id: int) -> Set[int]: + if not self.trace_complete: + raise RuntimeError("expected trace to be complete") + + params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) + if not p.ds_persist) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed > self.__max_reuse_dist_in_numel: + break + for param in iter_params(module): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel + + return params_to_release + + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + if not self.trace_complete: + return + + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) + + numel_considered = 0 + swap_in_params = [] + for param_in_trace in self.__param_queue: + param = param_in_trace.param + if param.nvme_swapper is None: + continue + if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= + param.nvme_swapper.available_swap_in_buffers()): + break + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_in_params.append(param) + numel_considered += param.ds_numel + + if swap_in_params: + swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + if not hasattr(module, "applied_pre_backward_ref_cnt"): + module.applied_pre_backward_ref_cnt = 0 + module.applied_pre_backward_ref_cnt += 1 + #print(f"After Forward: {ctx.module.__class__.__name__}") + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + #print(f"Before Backward: {ctx.module.__class__.__name__}") + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.pre_backward_function = pre_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.pre_backward_function(ctx.module) + #print(f"After Backward: {ctx.module.__class__.__name__}") + return (None, None) + args + + +class FP16_DeepSpeedZeroOptimizer_Stage3(object): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + def __init__(self, + module, + init_optimizer, + timers, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + mpu=None, + clip_grad=0.0, + allreduce_always_fp32=False, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None): + + see_memory_usage("Stage 3 initialize beginning", force=False) + + if dist.get_rank() == 0: + logger.info(f"initialized {__class__.__name__} with args: {locals()}") + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Allgather bucket size {prefetch_bucket_size}") + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + self.optimizer = init_optimizer + + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self._global_grad_norm = 0. + + self._convert_to_zero_parameters(module, mpu) + + for m in module.modules(): + _init_external_params(m) + + self.module = module + self.elastic_checkpoint = elastic_checkpoint + + # Replace ._parameters with a new class to enable auto-registration of + # external parameters + _inject_parameters(module, ZeROOrderedDict) + + self.__inf_or_nan_tracker: Tensor = torch.zeros( + 1, + dtype=torch.bool, + device=torch.cuda.current_device(), + requires_grad=False) + + ###################### offload optimizer setup ################################## + self.optimizer_swapper = None + self.swap_optimizer = False + + self.offload_optimizer = False + self.offload_optimizer_pin_memory = False + self.offload_optimizer_fast_init = False + if offload_optimizer_config is not None: + if not contiguous_gradients: + raise ValueError( + "optimizer offload only available with contiguous gradients enabled") + self.offload_optimizer = True + self.offload_optimizer_pin_memory = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIN_MEMORY] + self.swap_optimizer = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE + self.offload_optimizer_fast_init = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_FAST_INIT] + + ###################### offload param setup ################################## + self.offload_param = False + self.offload_param_pin_memory = False + self.params_in_nvme_and_cpu = False + self.max_params_in_cpu = 0 + if offload_param_config is not None: + assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" + self.offload_param = True + self.offload_param_pin_memory = offload_param_config[ + OFFLOAD_PARAM_PIN_MEMORY] + self.params_in_nvme_and_cpu = offload_param_config[ + OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE + self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] + print_rank_0( + f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", + force=False) + + self.deepspeed_adam_offload = (self.offload_optimizer + and type(init_optimizer) == DeepSpeedCPUAdam) + + self.device = torch.cuda.current_device( + ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ### streams used for overlapping computation with communication + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + + ############################################################################ + + see_memory_usage("Before Partitioned Parameter Coordinator", force=False) + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=int(prefetch_bucket_size), + max_reuse_distance_in_numel=int(max_reuse_distance), + max_available_parameters_in_numel=int(max_live_parameters), + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.params_in_nvme_and_cpu, + ) + see_memory_usage("After Partitioned Parameter Coordinator", force=False) + + self.__n_caching_allocator_flushes = 0 + + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) + #-------------Stage 3 Setup-------------------# + # parameters smaller than the threshold will be collectively gathered at the + # end of the optimizer step and will be kept till the end of the backward pass + # TODO maybe worth just replicating these parameters and doing all reduce for them + self.persistence_threshold = int(param_persistence_threshold) + + self.persistent_parameters = self.persistent_parameters() + + self.setup_zero_stage3_hooks() + + #resetting ds_tensor just in case parameters have been changed after initialization + #example .half() or .to() + #self.reset_ds_tensor() + #---------------------------------------------# + + self.timers = timers + + self.dp_process_group = dp_process_group + + self.partition_count = dist.get_world_size(group=self.dp_process_group) + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_rank = mpu.get_model_parallel_rank() + + self.overflow = False + self.clip_grad = clip_grad + self.allreduce_always_fp32 = allreduce_always_fp32 + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = 0 + + # Holds the mode parameter + # The param.data may not hold any meaningful data + # when param's status is NOT_AVAILABLE or IN_FLGHT + self.fp16_groups = [] + + # Hold partitioned parameters + self.fp16_partitioned_groups = [] + + # Holds a fused and flattened copy of the parameters + self.fp16_partitioned_groups_flat = [] + self.fp16_partitioned_groups_flat_numel = [] + + #defragmented pinned memory + self.param_groups_fp16_flat_cpu_memory = [] + + #a single 32-bit partition of the parallel partitioned parameters + #that this process will update + self.fp32_partitioned_groups_flat = [] + self.next_swappable_fp32_partitioned_groups = [] + + # number of elements per partition in each group + self.partition_size = [] + + self.all_reduce_print = False + + self.prefetch_elements = int(prefetch_bucket_size) + + # padding on each partition for alignment purposes + self.groups_padding = [] + + self.sub_group_size = sub_group_size + + self.sub_group_to_group_id = {} + see_memory_usage("Before creating fp16 partitions", force=False) + self._create_fp16_partitions_with_defragmentation() + num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", + force=False) + + # Optimizer tensor swapping + if self.swap_optimizer: + self._configure_tensor_swapping(offload_optimizer_config, aio_config) + + see_memory_usage("Before creating fp32 partitions", force=False) + self._create_fp32_partitions() + see_memory_usage("After creating fp32 partitions", force=False) + dist.barrier() + + # To support pipelined optimizer swapping + self._create_next_swappable_fp32_groups() + + see_memory_usage("Before initializing optimizer states", force=False) + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=False) + dist.barrier() + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.reduce_bucket_size = int(reduce_bucket_size) + + # IPG + if contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty( + int(reduce_bucket_size), + dtype=self.dtype, + device=torch.cuda.current_device()) + + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + + grad_partitions_flat_buffer: Tensor = torch.zeros( + sum(p.ds_tensor.ds_numel for p in all_params), + dtype=self.dtype, + device=self.device, + pin_memory=self.offload_optimizer_pin_memory) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[ + param.ds_id] = grad_partitions_flat_buffer.narrow( + 0, + offset, + param.ds_tensor.numel()) + offset += param.ds_tensor.numel() + + self.__params_in_ipg_bucket: List[Parameter] = [] + self.is_gradient_accumulation_boundary: bool = True + + self.__param_reduce_events: Deque[Event] = collections.deque() + self.__max_param_reduce_events: int = 2 + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.contiguous_gradients = contiguous_gradients + self.extra_large_param_to_reduce = None + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.params_already_reduced = [] + self.is_gradient_accumulation_boundary = True + self._release_ipg_buffers() + self.previous_reduced_grads = None + + # simplified param id + self.param_id = {} + + count = 0 + for i, params_group in enumerate(self.fp16_groups): + for param in params_group: + unique_id = id(param) + self.param_id[unique_id] = count + self.param_dict[count] = param + self.params_already_reduced.append(False) + count = count + 1 + + #Largest partitioned param + largest_partitioned_param_numel = max([ + max([tensor.numel() for tensor in fp16_partitioned_group]) + for fp16_partitioned_group in self.fp16_partitioned_groups + ]) + print_rank_0( + f'Largest partitioned param numel = {largest_partitioned_param_numel}', + force=False) + + see_memory_usage(f"Before Set Grad positions", force=False) + + self.grad_position = {} + self.set_grad_positions() + see_memory_usage(f"Before CPU Offload initialization", force=False) + + self.grads_in_partition = None + + if self.offload_optimizer: + self.norm_for_param_grads = {} + self.local_overflow = False + + see_memory_usage(f"After CPU Offload initialization", force=False) + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # will store the averaged gradients required by this paritition + self.averaged_gradients = {} + + #creates backward hooks for gradient partitioning + self.create_reduce_and_remove_grad_hooks() + + #exit(0) + + # we may have a way of fusing dynamic scale. Do not support for now + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: + if dynamic_loss_args is None: + self.loss_scaler = DynamicLossScaler() + else: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + + self.dynamic_loss_scale = True + + self.debug_fp16_grads = [{} for _ in self.fp16_groups] + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=False) + + @staticmethod + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = [] + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + + offset += tensor_numel + + gc.collect() + torch.cuda.empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer + + def _convert_to_zero_parameters(self, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + Init(module=module, data_parallel_group=group, dtype=self.dtype) + + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): + nvme_swap_folder = os.path.join( + offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], + 'zero_stage_3') + os.makedirs(nvme_swap_folder, exist_ok=True) + if torch.distributed.get_rank() == 0: + logger.info(f'Tensor Swapping: Adding optimizer tensors') + + swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper + + self.optimizer_swapper = swapper_type( + swap_config=offload_optimizer_config, + aio_config=aio_config, + base_folder=nvme_swap_folder, + optimizer=self.optimizer, + largest_numel=max(self.fp16_partitioned_groups_flat_numel), + device=self.device, + dtype=torch.float32, + timers=self.timers) + + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.__params_in_ipg_bucket) + + def _create_fp16_partitions(self): + dist.barrier() + partition_id = dist.get_rank(group=self.dp_process_group) + + # loop to deal with groups + for j, param_group in enumerate(self.optimizer.param_groups): + + sub_groups = self._create_fp16_sub_groups(param_group['params']) + for sub_group in sub_groups: + i = len(self.fp16_groups) + + # push this group to list before modify + self.fp16_groups.append(sub_group) + self.sub_group_to_group_id[i] = j + + #These are the list of the partitioned parameters + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in self.fp16_groups[i]]) + + print_rank_0( + f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" + ) + + # Record padding required to align group to world size (only applies to last rank) + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + padding = [p.padding_size() for p in self.fp16_groups[i]] + else: + padding = [0] * len(self.fp16_groups[i]) + self.groups_padding.append(padding) + + #not sure why apex was cloning the weights before flattening + #removing cloning here + see_memory_usage(f"Before Flattening param group {i}", force=False) + + if not self.offload_param: + see_memory_usage(f"Before moving param group {i} to CPU", + force=False) + #move all the parameters to cpu to free up GPU space for creating flat buffer + move_to_cpu(self.fp16_partitioned_groups[i]) + see_memory_usage(f"After moving param group {i} to CPU", force=False) + + #create flat buffer in CPU and move to GPU + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size(group=self.dp_process_group)).cuda( + torch.cuda.current_device())) + see_memory_usage( + f"After flattening and moving param group {i} to GPU", + force=False) + else: + #Without the detach, seems like the flattening becomes part of the + #model graph causing errors downstream + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size( + group=self.dp_process_group)).detach().pin_memory()) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + #set model fp16 weight to slices of flattened buffer + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): + partitioned_param.data = q.data + + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): + '''If flat buffer is None then the parameters in the param_list are + not copied to the flat buffer. This is because they excede the number of max_params_in_cpu + Some of these parameters may aready be in CPU in unflattened buffers + or they maybe in GPU, or they maybe in NVME. If they are in NVME, then + they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are + needed during training.''' + if flat_buffer is None: + # this dst buffer is on NVMe, so skip this + return + + start = 0 + for param in param_list: + src = param.ds_tensor + dest = flat_buffer.narrow(0, start, src.ds_numel) + start = start + src.ds_numel + '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' + if src.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" + ) + param.nvme_swapper.swap_into_buffer(param, dest) + src.data = dest.data + src.status = PartitionedParamStatus.AVAILABLE + else: + assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" + if not avoid_copy: + dest.data.copy_(src.data) + src.data = dest.data + + # Final location must be gpu/cpu in this case + param.ds_tensor.final_location = 'not-nvme' + + def _create_param_groups_fp16_flat_cpu_memory(self): + + aggregate_params_count = 0 + + for j, param_group in enumerate(self.optimizer.param_groups): + params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) + + flat_buffer_size = params_in_group + + if self.params_in_nvme_and_cpu and \ + aggregate_params_count + params_in_group > self.max_params_in_cpu: + + flat_buffer_size = max(0, + self.max_params_in_cpu - aggregate_params_count) + + aggregate_params_count += params_in_group + + if flat_buffer_size > 0: + print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", + force=False) + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(int(flat_buffer_size), + dtype=self.dtype, + pin_memory=True)) + else: + print_rank_0( + f"No flat buffer size. Param group size was {params_in_group}", + force=False) + + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(1, + dtype=self.dtype)) + + def _create_fp16_partitions_with_defragmentation(self): + dist.barrier() + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) + for param_group in self.optimizer.param_groups) + + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) + + # record sub group and partitions + self.fp16_groups.append(sub_group) + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append( + sum(p.ds_tensor.ds_numel for p in sub_group)) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size( + self.dp_process_group) - 1 + self.groups_padding.append([ + p.padding_size() if rank_requires_padding else 0 for p in sub_group + ]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + device_buffer = __class__.defragment(parameter_partitions) + + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + self.fp16_partitioned_groups_flat.append( + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) + print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") + #Flat buffer may not be available for parameters that reside in NVME + if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].numel(): + fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].narrow(0, + flat_offset, + total_elements) + print_rank_0( + f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", + force=False) + elif self.params_in_nvme_and_cpu: + fp16_partitioned_group_flat = None + print_rank_0( + f"No flat buffer for sub group {i} of {total_elements} elements", + force=False) + else: + assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" + + self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) + flat_offset += total_elements + + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any( + flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements + + assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( + largest_partition_numel) + + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): + offset = 0 + elements_in_sub_group = sum( + [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) + assert (flat_buffer.numel() == elements_in_sub_group) + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" + ) + param.nvme_swapper.swap_in([param], async_op=False) + dest.data.copy_(partitioned_param.data) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0(f"Swapping in {param.ds_id} done") + else: + dest.data.copy_(partitioned_param.data) + offset += partitioned_param.ds_numel + + def _create_next_swappable_fp32_groups(self): + reverse_order_indices = [ + i for i in range(len(self.fp32_partitioned_groups_flat)) + ] + reverse_order_indices.reverse() + + next_group = None + for i in reverse_order_indices: + self.next_swappable_fp32_partitioned_groups.append(next_group) + if self._swappable_optimizer_subgroup(i): + next_group = self.fp32_partitioned_groups_flat[i] + + self.next_swappable_fp32_partitioned_groups.reverse() + + def _get_sub_group_partitions(self, sub_group_id): + sub_group_partitions = [] + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_path = param.nvme_swapper.get_path(param, True) + sub_group_partitions.append((partitioned_param, + param.ds_tensor.ds_numel, + swap_path)) + else: + sub_group_partitions.append((partitioned_param, + partitioned_param.ds_numel, + None)) + + return sub_group_partitions + + def _create_fp32_partitions(self): + cpu_memory_usage = 0 + cpu_memory_sub_groups = 0 + nvme_memory_usage = 0 + num_swappable_partitions = 0 + num_swap_from_nvme_partitions = 0 + num_swap_from_cpu_partitions = 0 + swap_from_nvme_memory_usage = 0 + swap_from_cpu_memory_usage = 0 + GIGA_BYTES = (1024**3) + + swappable_fp32_tensors = [] + swappable_fp16_src_tensors = [] + nvme_fp16_partitions_info = [] + nvme_fp16_num_elems = [] + nvme_fp32_dest_tensors = [] + fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): + num_elements = self.fp16_partitioned_groups_flat_numel[i] + + # a partition of the fp32 master weights that will be updated by this process + if self._swappable_optimizer_subgroup(i): + self.fp32_partitioned_groups_flat.append(torch.Tensor()) + nvme_memory_usage += (fp32_element_size * num_elements) + num_swappable_partitions += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + num_swap_from_nvme_partitions += 1 + swap_from_nvme_memory_usage += (fp32_element_size * num_elements) + if self.offload_optimizer_fast_init: + sub_group_partitions = self._get_sub_group_partitions(i) + nvme_fp16_partitions_info.append(sub_group_partitions) + nvme_fp16_num_elems.append(num_elements) + nvme_fp32_dest_tensors.append( + self.fp32_partitioned_groups_flat[i]) + else: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.optimizer_swapper.initialize_parameters( + parameters=[self.fp32_partitioned_groups_flat[i]], + src_tensors=[unpinned_fp32_buffer]) + else: + num_swap_from_cpu_partitions += 1 + swap_from_cpu_memory_usage += (fp32_element_size * num_elements) + swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) + swappable_fp16_src_tensors.append( + self.fp16_partitioned_groups_flat[i]) + else: + cpu_memory_usage += (fp32_element_size * num_elements) + cpu_memory_sub_groups += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + else: + self.fp32_partitioned_groups_flat.append( + self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) + + self.fp32_partitioned_groups_flat[ + i].requires_grad = True # keep this in case internal optimizer uses it + + if len(swappable_fp32_tensors) > 0: + self.optimizer_swapper.initialize_parameters( + parameters=swappable_fp32_tensors, + src_tensors=swappable_fp16_src_tensors) + + if len(nvme_fp32_dest_tensors) > 0: + fp16_pinned_buffers = self.fp16_groups[0][ + 0].nvme_swapper.reserve_available_buffers() + assert len(fp16_pinned_buffers) > 0 + self.optimizer_swapper.initialize_from_swapped_fp16_params( + fp16_partitions_info=nvme_fp16_partitions_info, + fp16_num_elems=nvme_fp16_num_elems, + fp16_pinned_buffers=fp16_pinned_buffers, + fp32_parameters=nvme_fp32_dest_tensors) + self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() + + nvme_gigabytes = nvme_memory_usage / GIGA_BYTES + print_rank_0( + f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', + force=False) + if self.params_in_nvme_and_cpu: + print_rank_0( + f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + print_rank_0( + f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + + cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES + print_rank_0( + f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', + force=False) + + # Clear for on-the-fly population before the optimizer step + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partitioned_size() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + for param in params_group: + + sub_group.append(param) + local_sub_group_size += param.partitioned_size() + + if local_sub_group_size >= sub_group_size or id(param) == id( + params_group[-1]): + + sub_groups.append(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + # def reset_ds_tensor(self): + # for name, param in self.module.named_parameters(recurse=True): + # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" + # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" + # param.ds_tensor.data = param.data + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + + #reset step if in inference mode + @instrument_w_nvtx + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def persistent_parameters(self): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for _, param in self.module.named_parameters(recurse=True): + if param.ds_numel < self.persistence_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=False) + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + @instrument_w_nvtx + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module) + + @instrument_w_nvtx + def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + #print(f'convert output to {output}') + + for item in filter(lambda item: is_zero_param(item), output): + if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): + item.is_external_param = True + module_to_register = FWD_MODULE_STACK[-1] + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', + force=False) + register_external_parameter(module_to_register, item) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if id(item) in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}', + force=False) + unregister_external_parameter(module, item) + + item.all_gather() + + self.post_sub_module_forward_function(module) + + def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + return _apply_to_tensors_only(module, + PreBackwardFunction, + _run_before_backward_function, + output) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only( + module, + _run_before_forward_function, + _run_after_backward_hook, + inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + return _apply_to_tensors_only(module, + PostBackwardFunction, + _run_after_backward_function, + inputs) + + # Pre forward hook + module.register_forward_pre_hook(_pre_forward_module_hook) + # Post forward hook + module.register_forward_hook(_post_forward_module_hook) + + # Pre backward hook + module.register_forward_hook(_pre_backward_module_hook) + + # post backward hook + module.register_forward_pre_hook(_post_backward_module_hook) + + @torch.no_grad() + def pre_sub_module_forward_function(self, sub_module): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", + force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + + self.param_coordinator.fetch_sub_module(sub_module) + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after fetch", + force=False) + + @torch.no_grad() + def post_sub_module_forward_function(self, sub_module): + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.param_coordinator.release_sub_module(sub_module) + + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + @torch.no_grad() + def pre_sub_module_backward_function(self, sub_module): + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + self.param_coordinator.fetch_sub_module(sub_module) + + @torch.no_grad() + def post_sub_module_backward_function(self, sub_module): + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + self.param_coordinator.release_sub_module(sub_module) + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + if not self.offload_optimizer and self.is_gradient_accumulation_boundary: + self.grads_in_partition = None + + self.grads_in_partition_offset = 0 + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + + def _swappable_optimizer_subgroup(self, sub_group_id): + if not self.swap_optimizer: + return False + + return self.optimizer_swapper.swappable_tensor( + None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + + def _partitioned_params_swap_out(self, i): + offset = 0 + fp32_param = self.fp32_partitioned_groups_flat[i] + assert fp32_param is not None, \ + f'fp32 parameters of sub_group {i} is None' + + swap_fp16_params = [] + swap_fp32_params = [] + for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): + src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.AVAILABLE: + partitioned_param.data.copy_(src.data) + else: + swap_fp32_params.append(src) + swap_fp16_params.append(param) + offset += partitioned_param.ds_numel + + if len(swap_fp16_params): + swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( + dst_fp16_params=swap_fp16_params, + src_fp32_params=swap_fp32_params) + + def initialize_optimizer_states(self): + num_subgroups = len(self.fp16_groups) + + largest_numel = max( + [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) + gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), + dtype=gradient_dtype, + device=self.device) + + timers = self.timers + timer_names = set() + + if self.swap_optimizer: + self.optimizer_swapper.init_timers() + + INIT_OPTIMIZER_TIMER = 'init_optimizer_state' + timer_names.add(INIT_OPTIMIZER_TIMER) + self.start_timers([INIT_OPTIMIZER_TIMER]) + + for i, group in enumerate(self.fp16_groups): + swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) + swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None + + num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_in(i, timer_names) + + if self.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, + dtype=gradient_dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() + + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + else: + self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( + 0, + 0, + num_elements) + + self._optimizer_step(i) + + if swappable_param_subgroup: + self._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + self.stop_timers([INIT_OPTIMIZER_TIMER]) + self.log_timers(timer_names) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + if not self.offload_optimizer: + for group in self.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + total_partitions = dist.get_world_size(group=self.dp_process_group) + + for i, param_group in enumerate(self.fp16_groups): + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][ + partition_id] = self.get_first_param_index( + i, + param_group, + partition_id) + + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.__reduce_and_partition_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + self.__reduce_and_partition_stream.synchronize() + + # if dist.get_rank() == 0: + # logger.info("Params already reduced %s", self.params_already_reduced) + for i in range(len(self.params_already_reduced)): + self.params_already_reduced[i] = False + + #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad + #TODO: use a similar code path for both cpu_offload and non-cpu offload + if not self.offload_optimizer: + for i, sub_group in enumerate(self.fp16_groups): + self.averaged_gradients[i] = [ + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) + for param in sub_group + ] + # self.averaged_gradients[i] = self.get_flat_partition( + # self.fp16_groups[i], + # 0, + # self.fp32_partitioned_groups_flat[i].numel(), + # return_tensor_list=True) + + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + + def create_reduce_and_remove_grad_hooks(self): + print_rank_0(f'[Begin] Create gradient reduction hooks') + self.grad_accs = [] + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: + #print_rank_0(f" Before all gather {param.device}, {param.shape}") + + # The hook must be created in un-partitioned parameter + param.all_gather() + + #print(f"After all gather {param.device}, {param.shape}") + def wrapper(param, i): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + @instrument_w_nvtx + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param, i) + + grad_acc.register_hook(reduce_partition_and_remove_grads) + self.grad_accs.append(grad_acc) + + #print(f"param grad fn {param.expand_as(param).grad_fn}") + wrapper(param, i) + + # Partition the parameter after creating the hook + param.partition() + print_rank_0(f'[End] Create gradient reduction hooks') + + def get_param_id(self, param): + unique_id = id(param) + return self.param_id[unique_id] + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", + force=False) + + ###############Idependent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) + + # Because the ipg bucket is initialized with a random place holder tensor, we must + # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > + # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a + # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be + # empty, while reduction_list will have that garbage data. + if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", + param.ds_numel) + + self.__reduce_and_partition_ipg_grads() + + param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.__add_grad_to_ipg_bucket(param) + + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) + + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( + ) < self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with torch.cuda.stream(self.__reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( + 0, + self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + param.grad.record_stream(torch.cuda.current_stream()) + param.grad.data = new_grad_tensor + + self.__params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.__params_in_ipg_bucket: + return + + for param in self.__params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError( + f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") + + self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + + assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( + self.__params_in_ipg_bucket) + + while self.__param_reduce_events and self.__param_reduce_events[0].query(): + self.__param_reduce_events.popleft() + if len(self.__param_reduce_events) > self.__max_param_reduce_events: + self.__param_reduce_events.popleft().synchronize() + + with torch.cuda.stream(self.__reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks( + [p.ds_id for p in self.__params_in_ipg_bucket]) + + grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) + self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) + + self.__params_in_ipg_bucket.clear() + + event = Event() + event.record() + self.__param_reduce_events.append(event) + + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.allreduce_always_fp32: + full_grads_for_rank = [g.float() for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [ + g.div(self.gradient_predivide_factor) for g in full_grads_for_rank + ] + + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, + self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [ + g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank + ] + + if self.allreduce_always_fp32: + grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank + + def set_grad_positions(self): + for i, group in enumerate(self.fp16_groups): + current_offset = 0 + for param in group: + param_id = self.get_param_id(param) + num_elements = param.ds_tensor.ds_numel + + self.grad_position[param_id] = [ + int(i), + int(current_offset), + int(num_elements) + ] + #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") + current_offset += num_elements + + def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.double().norm(2)**2.0 + else: + norm += part.data.double().norm(2)**2.0 + return norm**0.5 + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) + #Using a more memory efficient version + self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): + with torch.cuda.stream(self.copy_grad_stream): + param_id = self.get_param_id(param) + src_tensor = param.grad.view(-1).float() + #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") + fp32_grad_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + if param_id in self.norm_for_param_grads.keys(): + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + @instrument_w_nvtx + def __partition_grads(self, + params_to_release: List[Parameter], + grad_partitions: List[Tensor]) -> None: + for param, grad_partition in zip(params_to_release, grad_partitions): + if param.ds_tensor.ds_numel * dist.get_rank( + self.dp_process_group) > param.ds_numel: + # this grad partition is empty - don't need to do anything + continue + + # move or accumulate gradient partition to target buffer + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( + 0, + 0, + grad_partition.numel()) + if self.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif grad_buffer.is_cuda: + grad_buffer.add_(grad_partition) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, + non_blocking=True) + cuda_grad_buffer.add_(grad_partition) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + if hasattr(self.__inf_or_nan_tracker, "logical_or_"): + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() + self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() + self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 + + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + offload_fp32_gradients = {} + offload_fp32_offsets = {} + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id( + param)] = self._constant_buffered_norm2(grad_buffer) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer) + + # free the gradient + param.grad.record_stream(torch.cuda.current_stream()) + param.grad = None + + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + + def reduce_ready_partitions_and_remove_grads(self, param, i): + #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) + self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + + def zero_reduced_gradients(self, partition_id, i): + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min( + total_elements - start, + self.partition_size[i] - + self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, + int(start), + int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow( + 0, + int(start), + int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zero_like(param) + + ######################Reduction Related Methods############################## + + def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): + rank = None + tensor = self.flatten(bucket) + + tensor_to_allreduce = tensor + + if pg_correctness_test: + allreduce_always_fp32 = True + + if allreduce_always_fp32: + tensor_to_allreduce = tensor.float() + + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + else: + global_rank = _get_global_rank(self.dp_process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + + if allreduce_always_fp32 and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None): + with torch.cuda.stream(self.reduction_stream): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain(self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None) + small_bucket = [] + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor): + partitions = [] + + dp = dist.get_world_size(group=self.dp_process_group) + dp_id = dist.get_rank(group=self.dp_process_group) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if (current_index >= start_index and current_index < end_index): + params_in_partition.append(tensor) + + elif start_index > current_index and start_index < (current_index + + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + @instrument_w_nvtx + def zero_grad(self, set_grads_to_None=True): + """ + Zero FP16 parameter grads. + """ + self.micro_step_id = 0 + + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if set_grads_to_None: + if p.grad is not None and p.grad.is_cuda: + p.grad.record_stream(torch.cuda.current_stream()) + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None: + pass + else: + torch.distributed.all_reduce(tensor=tensor, + op=op, + group=self.model_parallel_group) + + @instrument_w_nvtx + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] + for g, p in zip(gradients, params): + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda.item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, + tensor_list, + first_offset, + partition_size, + return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + for i, tensor in enumerate(tensor_list): + if tensor.grad is None: + tensor.grad = torch.zeros_like(tensor) + + tensor = tensor.grad + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow( + 0, + int(tensor_offset), + int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append( + torch.zeros(int(partition_size - current_size), + dtype=tensor_list[0].dtype, + device=tensor_list[0].device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + self.local_overflow = False + + def log_timers(self, timer_names): + if self.timers is None: + return + + self.timers.log(names=list(timer_names)) + + def start_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).start() + + def stop_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).stop() + + def _pre_step(self): + self.micro_step_id = 0 + + print_rank_0(f"Inside Step function") + see_memory_usage(f"In step before checking overflow", force=False) + + print_rank_0("Finished Tracing at Beginning of Step") + self.param_coordinator.hierarchy = 0 + + print_rank_0("Finished Tracing at Beginning of Step") + + @instrument_w_nvtx + def _get_norm_groups(self): + norm_groups = [] + for i, group in enumerate(self.fp16_groups): + if self.offload_optimizer: + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload( + self.fp16_groups[i])) + else: + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.fp16_groups[i])) + return norm_groups + + @instrument_w_nvtx + def _prepare_fp32_grad_for_sub_group(self, sub_group_id): + partition_id = dist.get_rank(group=self.dp_process_group) + + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) + + assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) + + self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition + + # release all the gradient since we have already created a necessary copy in dp_grad_partition + self.zero_grad() + + for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): + grad.record_stream(torch.cuda.current_stream()) + + self.averaged_gradients[sub_group_id] = None + + @instrument_w_nvtx + def _prepare_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', + force=False) + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) + elif not self.offload_optimizer: + self._prepare_fp32_grad_for_sub_group(sub_group_id) + see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', + force=False) + + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' + see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_IN_STATE]) + + self.optimizer_swapper.swap_in_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) + + self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) + timer_names.add(OPTIMIZER_SWAP_IN_STATE) + see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', + force=False) + + @instrument_w_nvtx + def _release_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before release optimizer sub group {sub_group_id}', + force=False) + # get rid of the fp32 gradients. Not needed anymore + if not self.offload_optimizer: + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) + see_memory_usage(f'After release optimizer sub group {sub_group_id}', + force=False) + + # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' + see_memory_usage( + f'post-step Before swapping out optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) + + self.optimizer_swapper.swap_out_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is + not None) + + self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) + see_memory_usage( + f'post-step After swapping out optimizer tensors {sub_group_id}', + force=False) + timer_names.add(OPTIMIZER_SWAP_OUT_STATE) + + # get rid of the fp32 gradients. Not needed anymore + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _unflatten_partitioned_parameters(self, sub_group_id): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _overflow_clean_up(self, prev_scale): + see_memory_usage('After overflow before clearing gradients', force=False) + self.zero_grad() + + if self.offload_optimizer: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients', force=False) + + if torch.distributed.get_rank() == 0: + logger.info( + "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " + "reducing to {}".format(dist.get_rank(), + prev_scale, + self.loss_scale)) + + @instrument_w_nvtx + def _overflow_check_and_loss_scale_update(self): + + # First compute norm for all group so we know if there is overflow + self.check_overflow() + + #loss scaling related computation + prev_scale = self.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self._overflow_clean_up(prev_scale) + + return self.overflow + + @instrument_w_nvtx + def _post_step(self, timer_names=set()): + if self.offload_optimizer: + self.reset_cpu_buffers() + + #Gathering persisting parameters + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + self.log_timers(timer_names) + + see_memory_usage('After zero_optimizer step', force=False) + print_rank_0(f"------------------Finishing Step-----------------------") + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + + #unflatten fp16 parameter subgroup + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + self._pre_step() + self._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if self._overflow_check_and_loss_scale_update(): + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + return + + norm_groups = self._get_norm_groups() + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + + timer_names = set() + + timer_names.add('optimizer_step') + self.start_timers(['optimizer_step']) + + #update parameters one sub group at a time + for sub_group_id, group in enumerate(self.fp16_groups): + + #prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + #scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) + + #apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + #put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + #release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.stop_timers(['optimizer_step']) + + self._post_step(timer_names) + + # warn user about caching allocator flushes + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( + torch.cuda, + "memory_stats") else 0 + if alloc_retries > self.__n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - self.__n_caching_allocator_flushes) + self.__n_caching_allocator_flushes = alloc_retries + + def dump_pre_step_gradients(self, debug_fp32_grads): + # Dump gradient norms for debugging + for i, _ in enumerate(self.fp16_groups): + print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') + for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): + param_id = self.get_param_id(fp16_param) + fp16_grad_norm = self.debug_fp16_grads[i][param_id] + + fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] + norm_list = [fp16_grad_norm, fp32_grad_norm] + print(f'Pre-Step Norms {i} {param_id} = {norm_list}') + + def dump_post_step_gradients(self): + # Dump gradient norms for debugging + for i, group in enumerate(self.fp16_groups): + print( + f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], + self.fp16_groups[i]) + for j, p in enumerate(self.fp16_groups[i]): + param_id = self.get_param_id(p) + param_norm = float(p.data.float().norm(2)) + ds_norm = float(p.ds_tensor.data.float().norm(2)) + + unflat_norm = [ + float(t.data.float().norm(2)) + for t in [unflat_fp16[j], + unflat_fp32[j]] + ] + norm_list = [param_norm, ds_norm] + unflat_norm + print(f'Post-Step Norms {i} {param_id} = {norm_list}') + + @instrument_w_nvtx + def unscale_and_clip_grads(self, sub_group_id, total_norm): + grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] + + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + if clip > 1: + combined_scale = clip * self.loss_scale + # to maintain behavior of averaging over accumulation steps + combined_scale *= self.micro_step_id + 1 + + for grad in grad_groups_flat: + if isinstance(grad, list): + sub_partitions = grad + for g in sub_partitions: + g.data.mul_(1. / combined_scale) + else: + grad.data.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.fp16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False + + @instrument_w_nvtx + def has_overflow(self, partition_gradients=True): + if partition_gradients: + with torch.cuda.stream(self.__reduce_and_partition_stream): + self.local_overflow = bool(self.__inf_or_nan_tracker.item()) + self.__inf_or_nan_tracker.zero_() + + overflow = self.local_overflow + #overflow = self.has_overflow_partitioned_grads_serial() + overflow_gpu = torch.cuda.ByteTensor([overflow]) + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + else: + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + + overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, + op=torch.distributed.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + @instrument_w_nvtx + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + if self.swap_optimizer: + self.optimizer_swapper.pre_backward() + + see_memory_usage(f"Before backward", force=False) + + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.param_coordinator.reset_step() + + if self.swap_optimizer: + self.optimizer_swapper.post_backward() + + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + self.__reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, + dest_offset, + num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + @instrument_w_nvtx + def _partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.param_coordinator.release_and_reset_all() + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): + # Remove paddings from flattened tensor + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) + lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] + lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] + #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') + return lean_tensors + + #TODO REVISIT this for stage 3 + def get_lean_optimizer_state(self): + # Return optimizer states after removing paddings. + # This method assumes that each param group contains a single flattened tensor. + optimizer_groups_state = [] + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_state = {} + for key, value in self.optimizer.state[p].items(): + if torch.is_tensor(value): + padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] + lean_state[key] = self._get_lean_tensors( + value, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + lean_flat_len = sum([t.numel() for t in lean_state[key]]) + else: + lean_state[key] = value + + optimizer_groups_state.append(lean_state) + + return optimizer_groups_state + + def get_groups_without_padding(self, groups_with_padding): + # Return group tensor after removing paddings added for alignment to DP world size. + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_group = self._get_lean_tensors(group, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + groups_without_padding.append(lean_group) + + return groups_without_padding + + def _set_fp32_optimizer_param_groups(self): + for sub_group_id, _ in enumerate(self.fp16_groups): + param_group_id = self.sub_group_to_group_id[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) + + def _clear_fp32_optimizer_param_groups(self): + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _rigid_state_dict(self): + state_dict = {} + state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['partition_count'] = self.partition_count + + self._set_fp32_optimizer_param_groups() + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat + self._clear_fp32_optimizer_param_groups() + + return state_dict + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + return self._rigid_state_dict() + + +# Restore base optimizer fp32 weights from checkpoint by: +# 1) Merging fp32 weights from checkpoints of all partitions +# 2) Extracting fp32 weights for current partition from merged weights +# 3) Using extracted weights to update base optimizer weights directly. + + def _restore_from_fp32_weights(self, all_state_dict): + + flat_local_partition = [] + for i in range(len(self.fp32_partitioned_groups_flat)): + merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] + flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + + for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 weights + def _restore_from_fp16_weights(self): + for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + # Refresh the fp32 master params from the fp16 copies. + def refresh_fp32_params(self): + self._restore_from_fp16_weights() + + # Extract flattened partition for current rank from all partitions + def _get_flattened_partition(self, all_partition_states): + partition_id = dist.get_rank(group=self.dp_process_group) + alignment = dist.get_world_size(group=self.dp_process_group) + + param_partitions = [[] for _ in range(len(all_partition_states[0]))] + for i, partition in enumerate(all_partition_states): + for j, param in enumerate(partition): + param_partitions[j].append(param) + + local_state_partitions = [] + for param_index, param_slices in enumerate(param_partitions): + flattened_merged_tensor = self.flatten_dense_tensors_aligned( + param_slices, + alignment) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + local_state_partitions.append(new_partitions[partition_id]) + + if torch.is_tensor(local_state_partitions[0]): + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) + + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return local_state_partitions[0] + + # Restore base optimizer state from checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [ + sd['base_optimizer_state'][i] for sd in all_state_dict + ] + for key in all_partition_group_states[0].keys(): + all_partition_states = [ + all_states[key] for all_states in all_partition_group_states + ] + partition_states[key] = self._get_flattened_partition( + all_partition_states) + base_optimizer_group_states.append(partition_states) + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + for key, saved in base_optimizer_group_states[i].items(): + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved + + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + + if load_optimizer_states: + self._set_fp32_optimizer_param_groups() + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + self._clear_fp32_optimizer_param_groups() + + # restore fp32 partitions + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): + curr_param.data.copy_(saved_param.data) + + # restore fp16 partitions from fp32 + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + # update fp16 unflattened params + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten( + self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + # TODO: Support different/changing load/save DP degree. + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + r"""Loading a ZeRO checkpoint + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + self._rigid_load_state_dict( + state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + def save_checkpoint_prologue(self): + self._partition_all_parameters() + + def save_checkpoint_epilogue(self): + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = torch.distributed.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info( + f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" + ) + + +def estimate_zero3_model_states_mem_needs(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + cpu_offload_params=True, + zero_init=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + gpus_factor = 1 / num_nodes + largest_layer_memory = (4 * largest_layer_params) + + if cpu_offload: + if cpu_offload_params: + gpu_mem = largest_layer_memory + + if zero_init: + cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 18 * gpus_factor) * additional_buffer_factor + + else: + gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) + + if zero_init: + cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 16 * gpus_factor) * additional_buffer_factor + else: + gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) + if zero_init: + cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor + else: + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem), largest_layer_memory + + +def model_to_params(model): + # shared params calculated only once + total_params = sum( + dict((p.data_ptr(), + p.numel()) for p in model.parameters()).values()) + + largest_layer_params = 0 + for m in model.modules(): + # assuming no shared params within a single layer + layer_params = sum(p.numel() for p in m.parameters(recurse=False)) + largest_layer_params = max(largest_layer_params, layer_params) + + return total_params, largest_layer_params + + +import math + + +def estimate_zero3_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params, largest_layer_params = model_to_params(model) + + estimate_zero3_model_states_mem_needs_all_cold( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero3_model_states_mem_needs_all_cold(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``largest_layer_params``: largest layer's params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + def format_options(cpu_offload, cpu_offload_params, zero_init): + enabled = [] + enabled.append(f"cpu_offload={1 if cpu_offload else 0}") + enabled.append(f"cpu_offload_params={1 if cpu_offload_params else 0}") + enabled.append(f"zero_init={1 if zero_init else 0}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print( + "Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." + ) + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + for cpu_offload_params in [True, False]: + if not cpu_offload and cpu_offload_params: + continue + for zero_init in [True, False]: + cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init, + additional_buffer_factor=additional_buffer_factor + ) + + options_str = format_options(cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init) + print( + f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") From 2a38302040fc1daa38473af034a86b98db9991ab Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 18 Nov 2021 15:32:24 -0800 Subject: [PATCH 37/59] convert to same new-line style as master --- deepspeed/runtime/config.py | 2018 +++++++++++++++++------------------ 1 file changed, 1009 insertions(+), 1009 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index c96b547e1f93..4f8b25802033 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -1,1009 +1,1009 @@ -""" -Copyright (c) Microsoft Corporation -Licensed under the MIT license. -""" -import os -from typing import Union - -import torch -import json -import copy - -from .constants import * -from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE -from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder -from .zero.config import DeepSpeedZeroConfig -from .zero.constants import * -from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig - -from ..git_version_info import version as __version__ -from ..utils import logger - -from ..elasticity import elasticity_enabled, compute_elastic_config, ensure_immutable_elastic_config -from ..elasticity.config import ElasticityConfigError -from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \ - IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT - -from ..profiling.config import DeepSpeedFlopsProfilerConfig - -from .swap_tensor.aio_config import get_aio_config - -TENSOR_CORE_ALIGN_SIZE = 8 - -ADAGRAD_OPTIMIZER = 'adagrad' -ADAM_OPTIMIZER = 'adam' -ADAMW_OPTIMIZER = 'adamw' -LAMB_OPTIMIZER = 'lamb' -ONEBIT_ADAM_OPTIMIZER = 'onebitadam' -ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' -DEEPSPEED_OPTIMIZERS = [ - ADAGRAD_OPTIMIZER, - ADAM_OPTIMIZER, - ADAMW_OPTIMIZER, - LAMB_OPTIMIZER, - ONEBIT_ADAM_OPTIMIZER, - ONEBIT_LAMB_OPTIMIZER, -] - -# extra optimizer parameters for adam/adamw -TORCH_ADAM_PARAM = "torch_adam" - -# default to adamw logic for adam/adamw optimizers unless user explicitly opts out -ADAM_W_MODE = "adam_w_mode" -ADAM_W_MODE_DEFAULT = True - - -class DeepSpeedConfigError(Exception): - pass - - -def get_curriculum_enabled(param_dict): - if CURRICULUM_LEARNING in param_dict.keys(): - return get_scalar_param(param_dict[CURRICULUM_LEARNING], - CURRICULUM_ENABLED, - CURRICULUM_ENABLED_DEFAULT) - else: - return False - - -def get_curriculum_params(param_dict): - if CURRICULUM_LEARNING in param_dict.keys(): - curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING]) - curriculum_params.pop(CURRICULUM_ENABLED) - return curriculum_params - else: - return False - - -def get_pld_enabled(param_dict): - if PROGRESSIVE_LAYER_DROP in param_dict.keys(): - return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], - PLD_ENABLED, - PLD_ENABLED_DEFAULT) - else: - return False - - -def get_pld_params(param_dict): - if PROGRESSIVE_LAYER_DROP in param_dict.keys(): - pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP]) - pld_params.pop(PLD_ENABLED) - return pld_params - else: - return False - - -def get_amp_enabled(param_dict): - if AMP in param_dict.keys(): - return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT) - else: - return False - - -def get_amp_params(param_dict): - if AMP in param_dict.keys(): - amp_params = copy.copy(param_dict[AMP]) - amp_params.pop(AMP_ENABLED) - return amp_params - else: - return False - - -def get_fp16_enabled(param_dict): - if FP16 in param_dict.keys(): - return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) - else: - return False - - -def get_bfloat16_enabled(param_dict): - if BFLOAT16 in param_dict.keys(): - return get_scalar_param(param_dict[BFLOAT16], - BFLOAT16_ENABLED, - BFLOAT16_ENABLED_DEFAULT) - else: - return False - - -def get_fp16_master_weights_and_grads_enabled(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], - FP16_MASTER_WEIGHTS_AND_GRADS, - FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) - else: - return False - - -def get_loss_scale(param_dict): - if get_fp16_enabled(param_dict): - return get_scalar_param(param_dict[FP16], - FP16_LOSS_SCALE, - FP16_LOSS_SCALE_DEFAULT) - elif get_bfloat16_enabled(param_dict): - return 1.0 - else: - return FP16_LOSS_SCALE_DEFAULT - - -def get_initial_dynamic_scale(param_dict): - if get_fp16_enabled(param_dict): - initial_scale_power = get_scalar_param(param_dict[FP16], - FP16_INITIAL_SCALE_POWER, - FP16_INITIAL_SCALE_POWER_DEFAULT) - elif get_bfloat16_enabled(param_dict): - initial_scale_power = 0 - else: - initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT - - return 2**initial_scale_power - - -def get_dynamic_loss_scale_args(param_dict): - loss_scale_args = None - if get_fp16_enabled(param_dict): - fp16_dict = param_dict[FP16] - dynamic_loss_args = [ - FP16_INITIAL_SCALE_POWER, - FP16_LOSS_SCALE_WINDOW, - FP16_MIN_LOSS_SCALE, - FP16_HYSTERESIS - ] - if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): - init_scale = get_scalar_param(fp16_dict, - FP16_INITIAL_SCALE_POWER, - FP16_INITIAL_SCALE_POWER_DEFAULT) - scale_window = get_scalar_param(fp16_dict, - FP16_LOSS_SCALE_WINDOW, - FP16_LOSS_SCALE_WINDOW_DEFAULT) - delayed_shift = get_scalar_param(fp16_dict, - FP16_HYSTERESIS, - FP16_HYSTERESIS_DEFAULT) - min_loss_scale = get_scalar_param(fp16_dict, - FP16_MIN_LOSS_SCALE, - FP16_MIN_LOSS_SCALE_DEFAULT) - loss_scale_args = { - INITIAL_LOSS_SCALE: 2**init_scale, - SCALE_WINDOW: scale_window, - DELAYED_SHIFT: delayed_shift, - MIN_LOSS_SCALE: min_loss_scale - } - - return loss_scale_args - - -def get_gradient_accumulation_steps(param_dict): - return get_scalar_param(param_dict, - GRADIENT_ACCUMULATION_STEPS, - GRADIENT_ACCUMULATION_STEPS_DEFAULT) - - -def get_sparse_gradients_enabled(param_dict): - return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT) - - -def get_zero_optimization(param_dict): - return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT) - - -def get_zero_reduce_scatter(param_dict): - return get_scalar_param(param_dict, - ZERO_OPTIMIZATION_REDUCE_SCATTER, - ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) - - -def get_allreduce_always_fp32(param_dict): - return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT) - - -def get_prescale_gradients(param_dict): - return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT) - - -def get_gradient_predivide_factor(param_dict): - return get_scalar_param(param_dict, - GRADIENT_PREDIVIDE_FACTOR, - GRADIENT_PREDIVIDE_FACTOR_DEFAULT) - - -def get_quantize_enabled(param_dict): - if QUANTIZE_TRAINING in param_dict.keys(): - return get_scalar_param(param_dict[QUANTIZE_TRAINING], - QUANTIZE_TRAINING_ENABLED, - QUANTIZE_TRAINING_ENABLED_DEFAULT) - else: - return False - - -def get_quantize_training(param_dict): - if QUANTIZE_TRAINING in param_dict.keys(): - return ((param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][TARGET_BITS]), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][START_BITS] if START_BITS in param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS].keys() else QUANTIZE_START_BITS_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][QUANTIZE_PERIOD] if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_PERIOD_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][SCHEDULE_OFFSET] if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() and SCHEDULE_OFFSET in param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE].keys() else QUANTIZE_OFFSET_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZE_GROUPS] if QUANTIZE_GROUPS in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_GROUPS_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][FP16_MIXED_QUANTIZE_ENABLED] if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys() and FP16_MIXED_QUANTIZE_ENABLED in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else FP16_MIXED_QUANTIZE_ENABLED_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][QUANTIZE_CHANGE_RATIO] if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys() and QUANTIZE_CHANGE_RATIO in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else QUANTIZE_CHANGE_RATIO_DEFAULT), \ - (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_TYPE in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys() and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_TYPE] == QUANTIZE_ASYMMETRIC else QUANTIZE_TYPE_DEFAULT), \ - (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_ROUNDING in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys() and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_ROUNDING] == STOCHASTIC_ROUNDING else QUANTIZE_ROUNDING_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZE_VERBOSE] if QUANTIZE_VERBOSE in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_VERBOSE_DEFAULT), \ - (param_dict[QUANTIZE_TRAINING][QUANTIZER_KERNEL] if QUANTIZER_KERNEL in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZER_KERNEL_DEFAULT)) - else: - return (QUANTIZE_TARGET_BITS_DEFAULT, \ - QUANTIZE_START_BITS_DEFAULT, \ - QUANTIZE_PERIOD_DEFAULT, \ - QUANTIZE_OFFSET_DEFAULT, \ - QUANTIZE_GROUPS_DEFAULT, \ - FP16_MIXED_QUANTIZE_ENABLED_DEFAULT, \ - QUANTIZE_CHANGE_RATIO_DEFAULT, \ - QUANTIZE_TYPE_DEFAULT, \ - QUANTIZE_ROUNDING_DEFAULT, \ - QUANTIZE_VERBOSE_DEFAULT, \ - QUANTIZER_KERNEL_DEFAULT) - - -def get_steps_per_print(param_dict): - return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT) - - -def get_disable_allgather(param_dict): - return get_scalar_param(param_dict, DISABLE_ALLGATHER, DISABLE_ALLGATHER_DEFAULT) - - -def get_dump_state(param_dict): - return get_scalar_param(param_dict, DUMP_STATE, DUMP_STATE_DEFAULT) - - -def get_gradient_clipping(param_dict): - return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) - - -def get_sparse_attention(param_dict): - if SPARSE_ATTENTION in param_dict.keys(): - sparsity = param_dict[SPARSE_ATTENTION] - mode = get_sparse_attention_mode(sparsity) - - if (mode == SPARSE_DENSE_MODE): - return get_sparse_dense_config(sparsity) - elif (mode == SPARSE_FIXED_MODE): - return get_sparse_fixed_config(sparsity) - elif (mode == SPARSE_VARIABLE_MODE): - return get_sparse_variable_config(sparsity) - elif (mode == SPARSE_BIGBIRD_MODE): - return get_sparse_bigbird_config(sparsity) - elif (mode == SPARSE_BSLONGFORMER_MODE): - return get_sparse_bslongformer_config(sparsity) - else: - raise NotImplementedError( - f'Given sparsity mode, {mode}, has not been implemented yet!') - - else: - return None - - -def get_sparse_dense_config(sparsity): - block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) - return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block} - - -def get_sparse_fixed_config(sparsity): - block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) - different_layout_per_head = get_scalar_param( - sparsity, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) - num_local_blocks = get_scalar_param(sparsity, - SPARSE_NUM_LOCAL_BLOCKS, - SPARSE_NUM_LOCAL_BLOCKS_DEFAULT) - num_global_blocks = get_scalar_param(sparsity, - SPARSE_NUM_GLOBAL_BLOCKS, - SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) - attention = get_scalar_param(sparsity, - SPARSE_ATTENTION_TYPE, - SPARSE_ATTENTION_TYPE_DEFAULT) - horizontal_global_attention = get_scalar_param( - sparsity, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) - num_different_global_patterns = get_scalar_param( - sparsity, - SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS, - SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT) - - return { - SPARSE_MODE: SPARSE_FIXED_MODE, - SPARSE_BLOCK: block, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, - SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks, - SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks, - SPARSE_ATTENTION_TYPE: attention, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention, - SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns - } - - -def get_sparse_variable_config(sparsity): - block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) - different_layout_per_head = get_scalar_param( - sparsity, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) - num_random_blocks = get_scalar_param(sparsity, - SPARSE_NUM_RANDOM_BLOCKS, - SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) - local_window_blocks = get_scalar_param(sparsity, - SPARSE_LOCAL_WINDOW_BLOCKS, - SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT) - global_block_indices = get_scalar_param(sparsity, - SPARSE_GLOBAL_BLOCK_INDICES, - SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) - global_block_end_indices = get_scalar_param(sparsity, - SPARSE_GLOBAL_BLOCK_END_INDICES, - SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) - attention = get_scalar_param(sparsity, - SPARSE_ATTENTION_TYPE, - SPARSE_ATTENTION_TYPE_DEFAULT) - horizontal_global_attention = get_scalar_param( - sparsity, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) - - return { - SPARSE_MODE: SPARSE_VARIABLE_MODE, - SPARSE_BLOCK: block, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, - SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, - SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks, - SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, - SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices, - SPARSE_ATTENTION_TYPE: attention, - SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention - } - - -def get_sparse_bigbird_config(sparsity): - block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) - different_layout_per_head = get_scalar_param( - sparsity, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) - num_random_blocks = get_scalar_param(sparsity, - SPARSE_NUM_RANDOM_BLOCKS, - SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) - num_sliding_window_blocks = get_scalar_param( - sparsity, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) - num_global_blocks = get_scalar_param(sparsity, - SPARSE_NUM_GLOBAL_BLOCKS, - SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) - - return { - SPARSE_MODE: SPARSE_BIGBIRD_MODE, - SPARSE_BLOCK: block, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, - SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, - SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks - } - - -def get_sparse_bslongformer_config(sparsity): - block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) - different_layout_per_head = get_scalar_param( - sparsity, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) - num_sliding_window_blocks = get_scalar_param( - sparsity, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) - global_block_indices = get_scalar_param(sparsity, - SPARSE_GLOBAL_BLOCK_INDICES, - SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) - global_block_end_indices = get_scalar_param(sparsity, - SPARSE_GLOBAL_BLOCK_END_INDICES, - SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) - - return { - SPARSE_MODE: SPARSE_BSLONGFORMER_MODE, - SPARSE_BLOCK: block, - SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, - SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, - SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, - SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices - } - - -def get_sparse_attention_mode(param_dict): - if SPARSE_MODE in param_dict.keys(): - return param_dict[SPARSE_MODE] - else: - return SPARSE_MODE_DEFAULT - - -def get_sparse_attention_type(param_dict): - if SPARSE_ATTENTION_TYPE in param_dict.keys(): - return param_dict[SPARSE_ATTENTION_TYPE] - else: - return SPARSE_ATTENTION_TYPE_DEFAULT - - -def get_pipeline_config(param_dict): - '''Parses pipeline engine configuration. ''' - default_pipeline = { - 'stages': 'auto', - 'partition': 'best', - 'seed_layers': False, - 'activation_checkpoint_interval': 0 - } - config = default_pipeline - for key, val in param_dict.get('pipeline', {}).items(): - config[key] = val - return config - - -def get_optimizer_name(param_dict): - if OPTIMIZER in param_dict.keys() and \ - TYPE in param_dict[OPTIMIZER].keys(): - return param_dict[OPTIMIZER][TYPE] - else: - return OPTIMIZER_TYPE_DEFAULT - - -def get_optimizer_params(param_dict): - if get_optimizer_name(param_dict) is not None and \ - OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys(): - return param_dict[OPTIMIZER][OPTIMIZER_PARAMS] - else: - return None - - -def get_optimizer_gradient_clipping(param_dict): - optimizer_params = get_optimizer_params(param_dict) - if optimizer_params is not None and \ - MAX_GRAD_NORM in optimizer_params.keys(): - return optimizer_params[MAX_GRAD_NORM] - else: - return None - - -def get_optimizer_legacy_fusion(param_dict): - if OPTIMIZER in param_dict.keys() and \ - LEGACY_FUSION in param_dict[OPTIMIZER].keys(): - return param_dict[OPTIMIZER][LEGACY_FUSION] - else: - return LEGACY_FUSION_DEFAULT - - -def get_zero_allow_untested_optimizer(param_dict): - return get_scalar_param(param_dict, - ZERO_ALLOW_UNTESTED_OPTIMIZER, - ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT) - - -def get_scheduler_name(param_dict): - if SCHEDULER in param_dict.keys() and \ - TYPE in param_dict[SCHEDULER].keys(): - return param_dict[SCHEDULER][TYPE] - else: - return SCHEDULER_TYPE_DEFAULT - - -def get_scheduler_params(param_dict): - if get_scheduler_name(param_dict) is not None and \ - SCHEDULER_PARAMS in param_dict[SCHEDULER].keys(): - return param_dict[SCHEDULER][SCHEDULER_PARAMS] - else: - return None - - -def get_train_batch_size(param_dict): - return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT) - - -def get_train_micro_batch_size_per_gpu(param_dict): - return get_scalar_param(param_dict, - TRAIN_MICRO_BATCH_SIZE_PER_GPU, - TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT) - - -def get_wall_clock_breakdown(param_dict): - return get_scalar_param(param_dict, - WALL_CLOCK_BREAKDOWN, - WALL_CLOCK_BREAKDOWN_DEFAULT) - - -def get_memory_breakdown(param_dict): - return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) - - -def get_tensorboard_enabled(param_dict): - if TENSORBOARD in param_dict.keys(): - return get_scalar_param(param_dict[TENSORBOARD], - TENSORBOARD_ENABLED, - TENSORBOARD_ENABLED_DEFAULT) - else: - return False - - -def get_eigenvalue_config(param_dict): - if get_quantize_enabled(param_dict): - param_dict = param_dict[QUANTIZE_TRAINING] - return (get_eigenvalue_enabled(param_dict), \ - get_eigenvalue_verbose(param_dict), \ - get_eigenvalue_max_iter(param_dict), \ - get_eigenvalue_tol(param_dict), \ - get_eigenvalue_stability(param_dict), \ - get_eigenvalue_gas_boundary_resolution(param_dict), \ - get_eigenvalue_layer_name(param_dict), \ - get_eigenvalue_layer_num(param_dict)) - else: - return (EIGENVALUE_ENABLED_DEFAULT, \ - EIGENVALUE_VERBOSE_DEFAULT, \ - EIGENVALUE_MAX_ITER_DEFAULT, \ - EIGENVALUE_TOL_DEFAULT, \ - EIGENVALUE_STABILITY_DEFAULT, \ - EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT, \ - EIGENVALUE_LAYER_NAME_DEFAULT, \ - EIGENVALUE_LAYER_NUM_DEFAULT) - - -def get_eigenvalue_enabled(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_ENABLED, - EIGENVALUE_ENABLED_DEFAULT) - else: - return EIGENVALUE_ENABLED_DEFAULT - - -def get_eigenvalue_verbose(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_VERBOSE, - EIGENVALUE_VERBOSE_DEFAULT) - else: - return EIGENVALUE_VERBOSE_DEFAULT - - -def get_eigenvalue_max_iter(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_MAX_ITER, - EIGENVALUE_MAX_ITER_DEFAULT) - else: - return EIGENVALUE_MAX_ITER_DEFAULT - - -def get_eigenvalue_tol(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_TOL, - EIGENVALUE_TOL_DEFAULT) - else: - return EIGENVALUE_TOL_DEFAULT - - -def get_eigenvalue_stability(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_STABILITY, - EIGENVALUE_STABILITY_DEFAULT) - else: - return EIGENVALUE_STABILITY_DEFAULT - - -def get_eigenvalue_gas_boundary_resolution(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_GAS_BOUNDARY_RESOLUTION, - EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT) - else: - return EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT - - -def get_eigenvalue_layer_name(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_LAYER_NAME, - EIGENVALUE_LAYER_NAME_DEFAULT) - else: - return EIGENVALUE_LAYER_NAME_DEFAULT - - -def get_eigenvalue_layer_num(param_dict): - if EIGENVALUE in param_dict.keys(): - return get_scalar_param(param_dict[EIGENVALUE], - EIGENVALUE_LAYER_NUM, - EIGENVALUE_LAYER_NUM_DEFAULT) - else: - return EIGENVALUE_LAYER_NUM_DEFAULT - - -def get_tensorboard_output_path(param_dict): - if get_tensorboard_enabled(param_dict): - return get_scalar_param(param_dict[TENSORBOARD], - TENSORBOARD_OUTPUT_PATH, - TENSORBOARD_OUTPUT_PATH_DEFAULT) - else: - return TENSORBOARD_OUTPUT_PATH_DEFAULT - - -def get_tensorboard_job_name(param_dict): - if get_tensorboard_enabled(param_dict): - return get_scalar_param(param_dict[TENSORBOARD], - TENSORBOARD_JOB_NAME, - TENSORBOARD_JOB_NAME_DEFAULT) - else: - return TENSORBOARD_JOB_NAME_DEFAULT - - -def get_checkpoint_params(param_dict): - return param_dict.get(CHECKPOINT, {}) - - -def get_checkpoint_tag_validation_mode(checkpoint_params): - tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, - CHECKPOINT_TAG_VALIDATION_DEFAULT) - tag_validation_mode = tag_validation_mode.upper() - if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES: - return tag_validation_mode - else: - raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \ - f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}") - - -def get_dataloader_drop_last(param_dict): - return get_scalar_param(param_dict, - DATALOADER_DROP_LAST, - DATALOADER_DROP_LAST_DEFAULT) - - -'''Write deepspeed config files by modifying basic templates. -Can be used for quickly changing parameters via command line parameters.''' - - -class DeepSpeedConfigWriter: - def __init__(self, data=None): - self.data = data if data is not None else {} - - def add_config(self, key, value): - self.data[key] = value - - def load_config(self, filename): - self.data = json.load(open(filename, - 'r'), - object_pairs_hook=dict_raise_error_on_duplicate_keys) - - def write_config(self, filename): - with open(filename, 'w') as outfile: - json.dump(self.data, outfile) - - -class DeepSpeedConfig(object): - def __init__(self, config: Union[str, dict], mpu=None): - super(DeepSpeedConfig, self).__init__() - if isinstance(config, dict): - self._param_dict = config - elif os.path.exists(config): - self._param_dict = json.load( - open(config, - 'r'), - object_pairs_hook=dict_raise_error_on_duplicate_keys) - else: - raise ValueError( - f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}" - ) - try: - self.global_rank = torch.distributed.get_rank() - if mpu is None: - self.world_size = torch.distributed.get_world_size() - else: - self.world_size = mpu.get_data_parallel_world_size() - except: - self.global_rank = 0 - self.world_size = 1 - - # If elastic-mode enabled, update compute + update _param_dict - self.elasticity_enabled = elasticity_enabled(self._param_dict) - if self.elasticity_enabled: - logger.info("DeepSpeed elasticity support enabled") - final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config( - ds_config=self._param_dict, - target_deepspeed_version=__version__, - world_size=self.world_size) - - elastic_dict = self._param_dict[ELASTICITY] - - # Ensure the resource scheduler saw the same elastic config we are using at runtime - ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) - - ignore_non_elastic_batch_info = elastic_dict.get( - IGNORE_NON_ELASTIC_BATCH_INFO, - IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) - - if not ignore_non_elastic_batch_info: - batch_params = [ - TRAIN_BATCH_SIZE, - TRAIN_MICRO_BATCH_SIZE_PER_GPU, - GRADIENT_ACCUMULATION_STEPS - ] - if any(map(lambda t: t in self._param_dict, batch_params)): - raise ElasticityConfigError("One or more batch related parameters were found in your " \ - f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ - f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ - "elastic training is enabled, which takes control of these parameters. " \ - "If you want to suppress this error (the parameters will be silently ignored) " \ - f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") - - # micro_bsz * world_size * gas = total_batch_size - # gas = total_batch_size // (micro_bsz * world_size) - gradient_accu_steps = final_batch_size // (micro_batch_size * - self.world_size) - - if TRAIN_BATCH_SIZE in self._param_dict: - logger.warning("[Elasticity] overriding training_batch_size: " \ - f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") - if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: - logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \ - f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") - if GRADIENT_ACCUMULATION_STEPS in self._param_dict: - logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\ - f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") - - logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") - - self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size - self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size - self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps - - self._initialize_params(self._param_dict) - self._configure_train_batch_size() - self._do_sanity_check() - - def _initialize_params(self, param_dict): - self.train_batch_size = get_train_batch_size(param_dict) - self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu( - param_dict) - self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict) - self.steps_per_print = get_steps_per_print(param_dict) - self.dump_state = get_dump_state(param_dict) - - self.disable_allgather = get_disable_allgather(param_dict) - self.allreduce_always_fp32 = get_allreduce_always_fp32(param_dict) - self.prescale_gradients = get_prescale_gradients(param_dict) - self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) - self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) - - self.zero_config = DeepSpeedZeroConfig(param_dict) - self.zero_optimization_stage = self.zero_config.stage - self.zero_enabled = self.zero_optimization_stage > 0 - - self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig( - param_dict) - - self.gradient_clipping = get_gradient_clipping(param_dict) - self.fp16_enabled = get_fp16_enabled(param_dict) - self.bfloat16_enabled = get_bfloat16_enabled(param_dict) - self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( - param_dict) - self.bfloat16_enabled = get_bfloat16_enabled(param_dict) - assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f"bfloat16 mode is only enabled for ZeRO 2 and 3 currently, got {self.zero_optimization_stage}" - self.amp_enabled = get_amp_enabled(param_dict) - self.amp_params = get_amp_params(param_dict) - self.loss_scale = get_loss_scale(param_dict) - self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) - self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) - - self.quantize_training_enabled = get_quantize_enabled(param_dict) - self.quantize_target_bits, \ - self.quantize_start_bits, \ - self.quantize_period, \ - self.quantize_offset, \ - self.quantize_groups, \ - self.fp16_mixed_quantize, \ - self.quantize_change_rate, \ - self.quantize_type, \ - self.quantize_rounding, \ - self.quantize_verbose, \ - self.use_quantizer_kernel = get_quantize_training(param_dict) - - self.optimizer_name = get_optimizer_name(param_dict) - if self.optimizer_name is not None and \ - self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: - self.optimizer_name = self.optimizer_name.lower() - - self.optimizer_params = get_optimizer_params(param_dict) - self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict) - - self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( - param_dict) - - self.scheduler_name = get_scheduler_name(param_dict) - self.scheduler_params = get_scheduler_params(param_dict) - - self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict) - self.wall_clock_breakdown = get_wall_clock_breakdown( - param_dict) | self.flops_profiler_config.enabled - self.memory_breakdown = get_memory_breakdown(param_dict) - self.tensorboard_enabled = get_tensorboard_enabled(param_dict) - self.tensorboard_output_path = get_tensorboard_output_path(param_dict) - self.tensorboard_job_name = get_tensorboard_job_name(param_dict) - - self.eigenvalue_enabled, \ - self.eigenvalue_verbose, \ - self.eigenvalue_max_iter, \ - self.eigenvalue_tol, \ - self.eigenvalue_stability, \ - self.eigenvalue_gas_boundary_resolution, \ - self.eigenvalue_layer_name, \ - self.eigenvalue_layer_num = get_eigenvalue_config(param_dict) - - self.sparse_attention = get_sparse_attention(param_dict) - self.pipeline = get_pipeline_config(param_dict) - - self.pld_enabled = get_pld_enabled(param_dict) - self.pld_params = get_pld_params(param_dict) - - self.curriculum_enabled = get_curriculum_enabled(param_dict) - self.curriculum_params = get_curriculum_params(param_dict) - - checkpoint_params = get_checkpoint_params(param_dict) - validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params) - self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE - self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL - - self.aio_config = get_aio_config(param_dict) - - self.dataloader_drop_last = get_dataloader_drop_last(param_dict) - - def _batch_assertion(self): - - train_batch = self.train_batch_size - micro_batch = self.train_micro_batch_size_per_gpu - grad_acc = self.gradient_accumulation_steps - - assert train_batch > 0, \ - f'Train batch size: {train_batch} has to be greater than 0' - - assert micro_batch > 0, \ - f'Micro batch size per gpu: {micro_batch} has to be greater than 0' - - assert grad_acc > 0, \ - f'Gradient accumulation steps: {grad_acc} has to be greater than 0' - - assert train_batch == micro_batch * grad_acc * self.world_size, \ - (f'Check batch related parameters. train_batch_size is not equal' - ' to micro_batch_per_gpu * gradient_acc_step * world_size' - f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') - - def _set_batch_related_parameters(self): - - train_batch = self.train_batch_size - micro_batch = self.train_micro_batch_size_per_gpu - grad_acc = self.gradient_accumulation_steps - - #all values are provided nothing needs to be set - if train_batch is not None and \ - micro_batch is not None and \ - grad_acc is not None: - return - - #global_accumulation_steps needs to be set - elif train_batch is not None and \ - micro_batch is not None: - grad_acc = train_batch // micro_batch - grad_acc //= self.world_size - self.gradient_accumulation_steps = grad_acc - - #micro_batch_per_gpu needs to be set - elif train_batch is not None and \ - grad_acc is not None: - micro_batch = train_batch // self.world_size - micro_batch //= grad_acc - self.train_micro_batch_size_per_gpu = micro_batch - - #train_batch_size needs to be set - elif micro_batch is not None and \ - grad_acc is not None: - train_batch_size = micro_batch * grad_acc - train_batch_size *= self.world_size - self.train_batch_size = train_batch_size - - #gradient_accumulation_steps and micro_batch_per_gpus is set - elif train_batch is not None: - self.gradient_accumulation_steps = 1 - self.train_micro_batch_size_per_gpu = train_batch // self.world_size - - #train_batch_size and gradient_accumulation_step is set - elif micro_batch is not None: - self.train_batch_size = micro_batch * self.world_size - self.gradient_accumulation_steps = 1 - - #either none of the three parameters are provided or just gradient_accumulation_step is provided - else: - assert False, \ - 'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided' - - def _configure_train_batch_size(self): - self._set_batch_related_parameters() - self._batch_assertion() - - def _do_sanity_check(self): - self._do_error_check() - - self._do_warning_check() - - def print(self, name): - logger.info('{}:'.format(name)) - for arg in sorted(vars(self)): - if arg != '_param_dict': - dots = '.' * (29 - len(arg)) - logger.info(' {} {} {}'.format(arg, dots, getattr(self, arg))) - - logger.info(' json = {}'.format( - json.dumps(self._param_dict, - sort_keys=True, - indent=4, - cls=ScientificNotationEncoder, - separators=(',', - ':')))) - - def _do_error_check(self): - assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) - - assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format( - GRADIENT_ACCUMULATION_STEPS) - - if self.zero_enabled: - assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) - - if self.fp16_master_weights_and_gradients: - assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." - - def _do_warning_check(self): - fp16_enabled = self.fp16_enabled - - vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) - if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: - logger.warning( - "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization." - .format(vocabulary_size, - TENSOR_CORE_ALIGN_SIZE)) - - if self.optimizer_params is not None and \ - MAX_GRAD_NORM in self.optimizer_params.keys() and \ - self.optimizer_params[MAX_GRAD_NORM] > 0: - if fp16_enabled: - if self.global_rank == 0: - logger.warning( - 'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper' - .format(MAX_GRAD_NORM, - self.optimizer_params[MAX_GRAD_NORM])) - else: - if self.global_rank == 0: - logger.warning( - 'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero' - .format(self.optimizer_params[MAX_GRAD_NORM])) - self.optimizer_params[MAX_GRAD_NORM] = 0.0 +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" +import os +from typing import Union + +import torch +import json +import copy + +from .constants import * +from .fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, DELAYED_SHIFT, MIN_LOSS_SCALE +from .config_utils import get_scalar_param, dict_raise_error_on_duplicate_keys, ScientificNotationEncoder +from .zero.config import DeepSpeedZeroConfig +from .zero.constants import * +from .activation_checkpointing.config import DeepSpeedActivationCheckpointingConfig + +from ..git_version_info import version as __version__ +from ..utils import logger + +from ..elasticity import elasticity_enabled, compute_elastic_config, ensure_immutable_elastic_config +from ..elasticity.config import ElasticityConfigError +from ..elasticity.constants import ELASTICITY, IGNORE_NON_ELASTIC_BATCH_INFO, \ + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT + +from ..profiling.config import DeepSpeedFlopsProfilerConfig + +from .swap_tensor.aio_config import get_aio_config + +TENSOR_CORE_ALIGN_SIZE = 8 + +ADAGRAD_OPTIMIZER = 'adagrad' +ADAM_OPTIMIZER = 'adam' +ADAMW_OPTIMIZER = 'adamw' +LAMB_OPTIMIZER = 'lamb' +ONEBIT_ADAM_OPTIMIZER = 'onebitadam' +ONEBIT_LAMB_OPTIMIZER = 'onebitlamb' +DEEPSPEED_OPTIMIZERS = [ + ADAGRAD_OPTIMIZER, + ADAM_OPTIMIZER, + ADAMW_OPTIMIZER, + LAMB_OPTIMIZER, + ONEBIT_ADAM_OPTIMIZER, + ONEBIT_LAMB_OPTIMIZER, +] + +# extra optimizer parameters for adam/adamw +TORCH_ADAM_PARAM = "torch_adam" + +# default to adamw logic for adam/adamw optimizers unless user explicitly opts out +ADAM_W_MODE = "adam_w_mode" +ADAM_W_MODE_DEFAULT = True + + +class DeepSpeedConfigError(Exception): + pass + + +def get_curriculum_enabled(param_dict): + if CURRICULUM_LEARNING in param_dict.keys(): + return get_scalar_param(param_dict[CURRICULUM_LEARNING], + CURRICULUM_ENABLED, + CURRICULUM_ENABLED_DEFAULT) + else: + return False + + +def get_curriculum_params(param_dict): + if CURRICULUM_LEARNING in param_dict.keys(): + curriculum_params = copy.copy(param_dict[CURRICULUM_LEARNING]) + curriculum_params.pop(CURRICULUM_ENABLED) + return curriculum_params + else: + return False + + +def get_pld_enabled(param_dict): + if PROGRESSIVE_LAYER_DROP in param_dict.keys(): + return get_scalar_param(param_dict[PROGRESSIVE_LAYER_DROP], + PLD_ENABLED, + PLD_ENABLED_DEFAULT) + else: + return False + + +def get_pld_params(param_dict): + if PROGRESSIVE_LAYER_DROP in param_dict.keys(): + pld_params = copy.copy(param_dict[PROGRESSIVE_LAYER_DROP]) + pld_params.pop(PLD_ENABLED) + return pld_params + else: + return False + + +def get_amp_enabled(param_dict): + if AMP in param_dict.keys(): + return get_scalar_param(param_dict[AMP], AMP_ENABLED, AMP_ENABLED_DEFAULT) + else: + return False + + +def get_amp_params(param_dict): + if AMP in param_dict.keys(): + amp_params = copy.copy(param_dict[AMP]) + amp_params.pop(AMP_ENABLED) + return amp_params + else: + return False + + +def get_fp16_enabled(param_dict): + if FP16 in param_dict.keys(): + return get_scalar_param(param_dict[FP16], FP16_ENABLED, FP16_ENABLED_DEFAULT) + else: + return False + + +def get_bfloat16_enabled(param_dict): + if BFLOAT16 in param_dict.keys(): + return get_scalar_param(param_dict[BFLOAT16], + BFLOAT16_ENABLED, + BFLOAT16_ENABLED_DEFAULT) + else: + return False + + +def get_fp16_master_weights_and_grads_enabled(param_dict): + if get_fp16_enabled(param_dict): + return get_scalar_param(param_dict[FP16], + FP16_MASTER_WEIGHTS_AND_GRADS, + FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT) + else: + return False + + +def get_loss_scale(param_dict): + if get_fp16_enabled(param_dict): + return get_scalar_param(param_dict[FP16], + FP16_LOSS_SCALE, + FP16_LOSS_SCALE_DEFAULT) + elif get_bfloat16_enabled(param_dict): + return 1.0 + else: + return FP16_LOSS_SCALE_DEFAULT + + +def get_initial_dynamic_scale(param_dict): + if get_fp16_enabled(param_dict): + initial_scale_power = get_scalar_param(param_dict[FP16], + FP16_INITIAL_SCALE_POWER, + FP16_INITIAL_SCALE_POWER_DEFAULT) + elif get_bfloat16_enabled(param_dict): + initial_scale_power = 0 + else: + initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT + + return 2**initial_scale_power + + +def get_dynamic_loss_scale_args(param_dict): + loss_scale_args = None + if get_fp16_enabled(param_dict): + fp16_dict = param_dict[FP16] + dynamic_loss_args = [ + FP16_INITIAL_SCALE_POWER, + FP16_LOSS_SCALE_WINDOW, + FP16_MIN_LOSS_SCALE, + FP16_HYSTERESIS + ] + if any(arg in list(fp16_dict.keys()) for arg in dynamic_loss_args): + init_scale = get_scalar_param(fp16_dict, + FP16_INITIAL_SCALE_POWER, + FP16_INITIAL_SCALE_POWER_DEFAULT) + scale_window = get_scalar_param(fp16_dict, + FP16_LOSS_SCALE_WINDOW, + FP16_LOSS_SCALE_WINDOW_DEFAULT) + delayed_shift = get_scalar_param(fp16_dict, + FP16_HYSTERESIS, + FP16_HYSTERESIS_DEFAULT) + min_loss_scale = get_scalar_param(fp16_dict, + FP16_MIN_LOSS_SCALE, + FP16_MIN_LOSS_SCALE_DEFAULT) + loss_scale_args = { + INITIAL_LOSS_SCALE: 2**init_scale, + SCALE_WINDOW: scale_window, + DELAYED_SHIFT: delayed_shift, + MIN_LOSS_SCALE: min_loss_scale + } + + return loss_scale_args + + +def get_gradient_accumulation_steps(param_dict): + return get_scalar_param(param_dict, + GRADIENT_ACCUMULATION_STEPS, + GRADIENT_ACCUMULATION_STEPS_DEFAULT) + + +def get_sparse_gradients_enabled(param_dict): + return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT) + + +def get_zero_optimization(param_dict): + return get_scalar_param(param_dict, ZERO_OPTIMIZATION, ZERO_OPTIMIZATION_DEFAULT) + + +def get_zero_reduce_scatter(param_dict): + return get_scalar_param(param_dict, + ZERO_OPTIMIZATION_REDUCE_SCATTER, + ZERO_OPTIMIZATION_REDUCE_SCATTER_DEFAULT) + + +def get_allreduce_always_fp32(param_dict): + return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT) + + +def get_prescale_gradients(param_dict): + return get_scalar_param(param_dict, PRESCALE_GRADIENTS, PRESCALE_GRADIENTS_DEFAULT) + + +def get_gradient_predivide_factor(param_dict): + return get_scalar_param(param_dict, + GRADIENT_PREDIVIDE_FACTOR, + GRADIENT_PREDIVIDE_FACTOR_DEFAULT) + + +def get_quantize_enabled(param_dict): + if QUANTIZE_TRAINING in param_dict.keys(): + return get_scalar_param(param_dict[QUANTIZE_TRAINING], + QUANTIZE_TRAINING_ENABLED, + QUANTIZE_TRAINING_ENABLED_DEFAULT) + else: + return False + + +def get_quantize_training(param_dict): + if QUANTIZE_TRAINING in param_dict.keys(): + return ((param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][TARGET_BITS]), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS][START_BITS] if START_BITS in param_dict[QUANTIZE_TRAINING][QUANTIZE_BITS].keys() else QUANTIZE_START_BITS_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][QUANTIZE_PERIOD] if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_PERIOD_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE][SCHEDULE_OFFSET] if QUANTIZE_SCHEDULE in param_dict[QUANTIZE_TRAINING].keys() and SCHEDULE_OFFSET in param_dict[QUANTIZE_TRAINING][QUANTIZE_SCHEDULE].keys() else QUANTIZE_OFFSET_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZE_GROUPS] if QUANTIZE_GROUPS in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_GROUPS_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][FP16_MIXED_QUANTIZE_ENABLED] if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys() and FP16_MIXED_QUANTIZE_ENABLED in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else FP16_MIXED_QUANTIZE_ENABLED_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE][QUANTIZE_CHANGE_RATIO] if FP16_MIXED_QUANTIZE in param_dict[QUANTIZE_TRAINING].keys() and QUANTIZE_CHANGE_RATIO in param_dict[QUANTIZE_TRAINING][FP16_MIXED_QUANTIZE].keys() else QUANTIZE_CHANGE_RATIO_DEFAULT), \ + (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_TYPE in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys() and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_TYPE] == QUANTIZE_ASYMMETRIC else QUANTIZE_TYPE_DEFAULT), \ + (1 if QUANTIZE_ALGO in param_dict[QUANTIZE_TRAINING] and QUANTIZE_ROUNDING in param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO].keys() and param_dict[QUANTIZE_TRAINING][QUANTIZE_ALGO][QUANTIZE_ROUNDING] == STOCHASTIC_ROUNDING else QUANTIZE_ROUNDING_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZE_VERBOSE] if QUANTIZE_VERBOSE in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZE_VERBOSE_DEFAULT), \ + (param_dict[QUANTIZE_TRAINING][QUANTIZER_KERNEL] if QUANTIZER_KERNEL in param_dict[QUANTIZE_TRAINING].keys() else QUANTIZER_KERNEL_DEFAULT)) + else: + return (QUANTIZE_TARGET_BITS_DEFAULT, \ + QUANTIZE_START_BITS_DEFAULT, \ + QUANTIZE_PERIOD_DEFAULT, \ + QUANTIZE_OFFSET_DEFAULT, \ + QUANTIZE_GROUPS_DEFAULT, \ + FP16_MIXED_QUANTIZE_ENABLED_DEFAULT, \ + QUANTIZE_CHANGE_RATIO_DEFAULT, \ + QUANTIZE_TYPE_DEFAULT, \ + QUANTIZE_ROUNDING_DEFAULT, \ + QUANTIZE_VERBOSE_DEFAULT, \ + QUANTIZER_KERNEL_DEFAULT) + + +def get_steps_per_print(param_dict): + return get_scalar_param(param_dict, STEPS_PER_PRINT, STEPS_PER_PRINT_DEFAULT) + + +def get_disable_allgather(param_dict): + return get_scalar_param(param_dict, DISABLE_ALLGATHER, DISABLE_ALLGATHER_DEFAULT) + + +def get_dump_state(param_dict): + return get_scalar_param(param_dict, DUMP_STATE, DUMP_STATE_DEFAULT) + + +def get_gradient_clipping(param_dict): + return get_scalar_param(param_dict, GRADIENT_CLIPPING, GRADIENT_CLIPPING_DEFAULT) + + +def get_sparse_attention(param_dict): + if SPARSE_ATTENTION in param_dict.keys(): + sparsity = param_dict[SPARSE_ATTENTION] + mode = get_sparse_attention_mode(sparsity) + + if (mode == SPARSE_DENSE_MODE): + return get_sparse_dense_config(sparsity) + elif (mode == SPARSE_FIXED_MODE): + return get_sparse_fixed_config(sparsity) + elif (mode == SPARSE_VARIABLE_MODE): + return get_sparse_variable_config(sparsity) + elif (mode == SPARSE_BIGBIRD_MODE): + return get_sparse_bigbird_config(sparsity) + elif (mode == SPARSE_BSLONGFORMER_MODE): + return get_sparse_bslongformer_config(sparsity) + else: + raise NotImplementedError( + f'Given sparsity mode, {mode}, has not been implemented yet!') + + else: + return None + + +def get_sparse_dense_config(sparsity): + block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) + return {SPARSE_MODE: SPARSE_DENSE_MODE, SPARSE_BLOCK: block} + + +def get_sparse_fixed_config(sparsity): + block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) + different_layout_per_head = get_scalar_param( + sparsity, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) + num_local_blocks = get_scalar_param(sparsity, + SPARSE_NUM_LOCAL_BLOCKS, + SPARSE_NUM_LOCAL_BLOCKS_DEFAULT) + num_global_blocks = get_scalar_param(sparsity, + SPARSE_NUM_GLOBAL_BLOCKS, + SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) + attention = get_scalar_param(sparsity, + SPARSE_ATTENTION_TYPE, + SPARSE_ATTENTION_TYPE_DEFAULT) + horizontal_global_attention = get_scalar_param( + sparsity, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) + num_different_global_patterns = get_scalar_param( + sparsity, + SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS, + SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS_DEFAULT) + + return { + SPARSE_MODE: SPARSE_FIXED_MODE, + SPARSE_BLOCK: block, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, + SPARSE_NUM_LOCAL_BLOCKS: num_local_blocks, + SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks, + SPARSE_ATTENTION_TYPE: attention, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention, + SPARSE_NUM_DIFFERENT_GLOBAL_PATTERNS: num_different_global_patterns + } + + +def get_sparse_variable_config(sparsity): + block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) + different_layout_per_head = get_scalar_param( + sparsity, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) + num_random_blocks = get_scalar_param(sparsity, + SPARSE_NUM_RANDOM_BLOCKS, + SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) + local_window_blocks = get_scalar_param(sparsity, + SPARSE_LOCAL_WINDOW_BLOCKS, + SPARSE_LOCAL_WINDOW_BLOCKS_DEFAULT) + global_block_indices = get_scalar_param(sparsity, + SPARSE_GLOBAL_BLOCK_INDICES, + SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) + global_block_end_indices = get_scalar_param(sparsity, + SPARSE_GLOBAL_BLOCK_END_INDICES, + SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) + attention = get_scalar_param(sparsity, + SPARSE_ATTENTION_TYPE, + SPARSE_ATTENTION_TYPE_DEFAULT) + horizontal_global_attention = get_scalar_param( + sparsity, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION_DEFAULT) + + return { + SPARSE_MODE: SPARSE_VARIABLE_MODE, + SPARSE_BLOCK: block, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, + SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, + SPARSE_LOCAL_WINDOW_BLOCKS: local_window_blocks, + SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, + SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices, + SPARSE_ATTENTION_TYPE: attention, + SPARSE_HORIZONTAL_GLOBAL_ATTENTION: horizontal_global_attention + } + + +def get_sparse_bigbird_config(sparsity): + block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) + different_layout_per_head = get_scalar_param( + sparsity, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) + num_random_blocks = get_scalar_param(sparsity, + SPARSE_NUM_RANDOM_BLOCKS, + SPARSE_NUM_RANDOM_BLOCKS_DEFAULT) + num_sliding_window_blocks = get_scalar_param( + sparsity, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) + num_global_blocks = get_scalar_param(sparsity, + SPARSE_NUM_GLOBAL_BLOCKS, + SPARSE_NUM_GLOBAL_BLOCKS_DEFAULT) + + return { + SPARSE_MODE: SPARSE_BIGBIRD_MODE, + SPARSE_BLOCK: block, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, + SPARSE_NUM_RANDOM_BLOCKS: num_random_blocks, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, + SPARSE_NUM_GLOBAL_BLOCKS: num_global_blocks + } + + +def get_sparse_bslongformer_config(sparsity): + block = get_scalar_param(sparsity, SPARSE_BLOCK, SPARSE_BLOCK_DEFAULT) + different_layout_per_head = get_scalar_param( + sparsity, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD_DEFAULT) + num_sliding_window_blocks = get_scalar_param( + sparsity, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS_DEFAULT) + global_block_indices = get_scalar_param(sparsity, + SPARSE_GLOBAL_BLOCK_INDICES, + SPARSE_GLOBAL_BLOCK_INDICES_DEFAULT) + global_block_end_indices = get_scalar_param(sparsity, + SPARSE_GLOBAL_BLOCK_END_INDICES, + SPARSE_GLOBAL_BLOCK_END_INDICES_DEFAULT) + + return { + SPARSE_MODE: SPARSE_BSLONGFORMER_MODE, + SPARSE_BLOCK: block, + SPARSE_DIFFERENT_LAYOUT_PER_HEAD: different_layout_per_head, + SPARSE_NUM_SLIDING_WINDOW_BLOCKS: num_sliding_window_blocks, + SPARSE_GLOBAL_BLOCK_INDICES: global_block_indices, + SPARSE_GLOBAL_BLOCK_END_INDICES: global_block_end_indices + } + + +def get_sparse_attention_mode(param_dict): + if SPARSE_MODE in param_dict.keys(): + return param_dict[SPARSE_MODE] + else: + return SPARSE_MODE_DEFAULT + + +def get_sparse_attention_type(param_dict): + if SPARSE_ATTENTION_TYPE in param_dict.keys(): + return param_dict[SPARSE_ATTENTION_TYPE] + else: + return SPARSE_ATTENTION_TYPE_DEFAULT + + +def get_pipeline_config(param_dict): + '''Parses pipeline engine configuration. ''' + default_pipeline = { + 'stages': 'auto', + 'partition': 'best', + 'seed_layers': False, + 'activation_checkpoint_interval': 0 + } + config = default_pipeline + for key, val in param_dict.get('pipeline', {}).items(): + config[key] = val + return config + + +def get_optimizer_name(param_dict): + if OPTIMIZER in param_dict.keys() and \ + TYPE in param_dict[OPTIMIZER].keys(): + return param_dict[OPTIMIZER][TYPE] + else: + return OPTIMIZER_TYPE_DEFAULT + + +def get_optimizer_params(param_dict): + if get_optimizer_name(param_dict) is not None and \ + OPTIMIZER_PARAMS in param_dict[OPTIMIZER].keys(): + return param_dict[OPTIMIZER][OPTIMIZER_PARAMS] + else: + return None + + +def get_optimizer_gradient_clipping(param_dict): + optimizer_params = get_optimizer_params(param_dict) + if optimizer_params is not None and \ + MAX_GRAD_NORM in optimizer_params.keys(): + return optimizer_params[MAX_GRAD_NORM] + else: + return None + + +def get_optimizer_legacy_fusion(param_dict): + if OPTIMIZER in param_dict.keys() and \ + LEGACY_FUSION in param_dict[OPTIMIZER].keys(): + return param_dict[OPTIMIZER][LEGACY_FUSION] + else: + return LEGACY_FUSION_DEFAULT + + +def get_zero_allow_untested_optimizer(param_dict): + return get_scalar_param(param_dict, + ZERO_ALLOW_UNTESTED_OPTIMIZER, + ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT) + + +def get_scheduler_name(param_dict): + if SCHEDULER in param_dict.keys() and \ + TYPE in param_dict[SCHEDULER].keys(): + return param_dict[SCHEDULER][TYPE] + else: + return SCHEDULER_TYPE_DEFAULT + + +def get_scheduler_params(param_dict): + if get_scheduler_name(param_dict) is not None and \ + SCHEDULER_PARAMS in param_dict[SCHEDULER].keys(): + return param_dict[SCHEDULER][SCHEDULER_PARAMS] + else: + return None + + +def get_train_batch_size(param_dict): + return get_scalar_param(param_dict, TRAIN_BATCH_SIZE, TRAIN_BATCH_SIZE_DEFAULT) + + +def get_train_micro_batch_size_per_gpu(param_dict): + return get_scalar_param(param_dict, + TRAIN_MICRO_BATCH_SIZE_PER_GPU, + TRAIN_MICRO_BATCH_SIZE_PER_GPU_DEFAULT) + + +def get_wall_clock_breakdown(param_dict): + return get_scalar_param(param_dict, + WALL_CLOCK_BREAKDOWN, + WALL_CLOCK_BREAKDOWN_DEFAULT) + + +def get_memory_breakdown(param_dict): + return get_scalar_param(param_dict, MEMORY_BREAKDOWN, MEMORY_BREAKDOWN_DEFAULT) + + +def get_tensorboard_enabled(param_dict): + if TENSORBOARD in param_dict.keys(): + return get_scalar_param(param_dict[TENSORBOARD], + TENSORBOARD_ENABLED, + TENSORBOARD_ENABLED_DEFAULT) + else: + return False + + +def get_eigenvalue_config(param_dict): + if get_quantize_enabled(param_dict): + param_dict = param_dict[QUANTIZE_TRAINING] + return (get_eigenvalue_enabled(param_dict), \ + get_eigenvalue_verbose(param_dict), \ + get_eigenvalue_max_iter(param_dict), \ + get_eigenvalue_tol(param_dict), \ + get_eigenvalue_stability(param_dict), \ + get_eigenvalue_gas_boundary_resolution(param_dict), \ + get_eigenvalue_layer_name(param_dict), \ + get_eigenvalue_layer_num(param_dict)) + else: + return (EIGENVALUE_ENABLED_DEFAULT, \ + EIGENVALUE_VERBOSE_DEFAULT, \ + EIGENVALUE_MAX_ITER_DEFAULT, \ + EIGENVALUE_TOL_DEFAULT, \ + EIGENVALUE_STABILITY_DEFAULT, \ + EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT, \ + EIGENVALUE_LAYER_NAME_DEFAULT, \ + EIGENVALUE_LAYER_NUM_DEFAULT) + + +def get_eigenvalue_enabled(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_ENABLED, + EIGENVALUE_ENABLED_DEFAULT) + else: + return EIGENVALUE_ENABLED_DEFAULT + + +def get_eigenvalue_verbose(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_VERBOSE, + EIGENVALUE_VERBOSE_DEFAULT) + else: + return EIGENVALUE_VERBOSE_DEFAULT + + +def get_eigenvalue_max_iter(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_MAX_ITER, + EIGENVALUE_MAX_ITER_DEFAULT) + else: + return EIGENVALUE_MAX_ITER_DEFAULT + + +def get_eigenvalue_tol(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_TOL, + EIGENVALUE_TOL_DEFAULT) + else: + return EIGENVALUE_TOL_DEFAULT + + +def get_eigenvalue_stability(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_STABILITY, + EIGENVALUE_STABILITY_DEFAULT) + else: + return EIGENVALUE_STABILITY_DEFAULT + + +def get_eigenvalue_gas_boundary_resolution(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_GAS_BOUNDARY_RESOLUTION, + EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT) + else: + return EIGENVALUE_GAS_BOUNDARY_RESOLUTION_DEFAULT + + +def get_eigenvalue_layer_name(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_LAYER_NAME, + EIGENVALUE_LAYER_NAME_DEFAULT) + else: + return EIGENVALUE_LAYER_NAME_DEFAULT + + +def get_eigenvalue_layer_num(param_dict): + if EIGENVALUE in param_dict.keys(): + return get_scalar_param(param_dict[EIGENVALUE], + EIGENVALUE_LAYER_NUM, + EIGENVALUE_LAYER_NUM_DEFAULT) + else: + return EIGENVALUE_LAYER_NUM_DEFAULT + + +def get_tensorboard_output_path(param_dict): + if get_tensorboard_enabled(param_dict): + return get_scalar_param(param_dict[TENSORBOARD], + TENSORBOARD_OUTPUT_PATH, + TENSORBOARD_OUTPUT_PATH_DEFAULT) + else: + return TENSORBOARD_OUTPUT_PATH_DEFAULT + + +def get_tensorboard_job_name(param_dict): + if get_tensorboard_enabled(param_dict): + return get_scalar_param(param_dict[TENSORBOARD], + TENSORBOARD_JOB_NAME, + TENSORBOARD_JOB_NAME_DEFAULT) + else: + return TENSORBOARD_JOB_NAME_DEFAULT + + +def get_checkpoint_params(param_dict): + return param_dict.get(CHECKPOINT, {}) + + +def get_checkpoint_tag_validation_mode(checkpoint_params): + tag_validation_mode = checkpoint_params.get(CHECKPOINT_TAG_VALIDATION, + CHECKPOINT_TAG_VALIDATION_DEFAULT) + tag_validation_mode = tag_validation_mode.upper() + if tag_validation_mode in CHECKPOINT_TAG_VALIDATION_MODES: + return tag_validation_mode + else: + raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \ + f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}") + + +def get_dataloader_drop_last(param_dict): + return get_scalar_param(param_dict, + DATALOADER_DROP_LAST, + DATALOADER_DROP_LAST_DEFAULT) + + +'''Write deepspeed config files by modifying basic templates. +Can be used for quickly changing parameters via command line parameters.''' + + +class DeepSpeedConfigWriter: + def __init__(self, data=None): + self.data = data if data is not None else {} + + def add_config(self, key, value): + self.data[key] = value + + def load_config(self, filename): + self.data = json.load(open(filename, + 'r'), + object_pairs_hook=dict_raise_error_on_duplicate_keys) + + def write_config(self, filename): + with open(filename, 'w') as outfile: + json.dump(self.data, outfile) + + +class DeepSpeedConfig(object): + def __init__(self, config: Union[str, dict], mpu=None): + super(DeepSpeedConfig, self).__init__() + if isinstance(config, dict): + self._param_dict = config + elif os.path.exists(config): + self._param_dict = json.load( + open(config, + 'r'), + object_pairs_hook=dict_raise_error_on_duplicate_keys) + else: + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary. Received: {config}" + ) + try: + self.global_rank = torch.distributed.get_rank() + if mpu is None: + self.world_size = torch.distributed.get_world_size() + else: + self.world_size = mpu.get_data_parallel_world_size() + except: + self.global_rank = 0 + self.world_size = 1 + + # If elastic-mode enabled, update compute + update _param_dict + self.elasticity_enabled = elasticity_enabled(self._param_dict) + if self.elasticity_enabled: + logger.info("DeepSpeed elasticity support enabled") + final_batch_size, valid_gpus, micro_batch_size = compute_elastic_config( + ds_config=self._param_dict, + target_deepspeed_version=__version__, + world_size=self.world_size) + + elastic_dict = self._param_dict[ELASTICITY] + + # Ensure the resource scheduler saw the same elastic config we are using at runtime + ensure_immutable_elastic_config(runtime_elastic_config_dict=elastic_dict) + + ignore_non_elastic_batch_info = elastic_dict.get( + IGNORE_NON_ELASTIC_BATCH_INFO, + IGNORE_NON_ELASTIC_BATCH_INFO_DEFAULT) + + if not ignore_non_elastic_batch_info: + batch_params = [ + TRAIN_BATCH_SIZE, + TRAIN_MICRO_BATCH_SIZE_PER_GPU, + GRADIENT_ACCUMULATION_STEPS + ] + if any(map(lambda t: t in self._param_dict, batch_params)): + raise ElasticityConfigError("One or more batch related parameters were found in your " \ + f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \ + f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \ + "elastic training is enabled, which takes control of these parameters. " \ + "If you want to suppress this error (the parameters will be silently ignored) " \ + f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.") + + # micro_bsz * world_size * gas = total_batch_size + # gas = total_batch_size // (micro_bsz * world_size) + gradient_accu_steps = final_batch_size // (micro_batch_size * + self.world_size) + + if TRAIN_BATCH_SIZE in self._param_dict: + logger.warning("[Elasticity] overriding training_batch_size: " \ + f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}") + if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict: + logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \ + f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}") + if GRADIENT_ACCUMULATION_STEPS in self._param_dict: + logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\ + f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}") + + logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}") + + self._param_dict[TRAIN_BATCH_SIZE] = final_batch_size + self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU] = micro_batch_size + self._param_dict[GRADIENT_ACCUMULATION_STEPS] = gradient_accu_steps + + self._initialize_params(self._param_dict) + self._configure_train_batch_size() + self._do_sanity_check() + + def _initialize_params(self, param_dict): + self.train_batch_size = get_train_batch_size(param_dict) + self.train_micro_batch_size_per_gpu = get_train_micro_batch_size_per_gpu( + param_dict) + self.gradient_accumulation_steps = get_gradient_accumulation_steps(param_dict) + self.steps_per_print = get_steps_per_print(param_dict) + self.dump_state = get_dump_state(param_dict) + + self.disable_allgather = get_disable_allgather(param_dict) + self.allreduce_always_fp32 = get_allreduce_always_fp32(param_dict) + self.prescale_gradients = get_prescale_gradients(param_dict) + self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) + self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) + + self.zero_config = DeepSpeedZeroConfig(param_dict) + self.zero_optimization_stage = self.zero_config.stage + self.zero_enabled = self.zero_optimization_stage > 0 + + self.activation_checkpointing_config = DeepSpeedActivationCheckpointingConfig( + param_dict) + + self.gradient_clipping = get_gradient_clipping(param_dict) + self.fp16_enabled = get_fp16_enabled(param_dict) + self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( + param_dict) + self.bfloat16_enabled = get_bfloat16_enabled(param_dict) + assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f"bfloat16 mode is only enabled for ZeRO 2 and 3 currently, got {self.zero_optimization_stage}" + self.amp_enabled = get_amp_enabled(param_dict) + self.amp_params = get_amp_params(param_dict) + self.loss_scale = get_loss_scale(param_dict) + self.initial_dynamic_scale = get_initial_dynamic_scale(param_dict) + self.dynamic_loss_scale_args = get_dynamic_loss_scale_args(param_dict) + + self.quantize_training_enabled = get_quantize_enabled(param_dict) + self.quantize_target_bits, \ + self.quantize_start_bits, \ + self.quantize_period, \ + self.quantize_offset, \ + self.quantize_groups, \ + self.fp16_mixed_quantize, \ + self.quantize_change_rate, \ + self.quantize_type, \ + self.quantize_rounding, \ + self.quantize_verbose, \ + self.use_quantizer_kernel = get_quantize_training(param_dict) + + self.optimizer_name = get_optimizer_name(param_dict) + if self.optimizer_name is not None and \ + self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS: + self.optimizer_name = self.optimizer_name.lower() + + self.optimizer_params = get_optimizer_params(param_dict) + self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict) + + self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( + param_dict) + + self.scheduler_name = get_scheduler_name(param_dict) + self.scheduler_params = get_scheduler_params(param_dict) + + self.flops_profiler_config = DeepSpeedFlopsProfilerConfig(param_dict) + self.wall_clock_breakdown = get_wall_clock_breakdown( + param_dict) | self.flops_profiler_config.enabled + self.memory_breakdown = get_memory_breakdown(param_dict) + self.tensorboard_enabled = get_tensorboard_enabled(param_dict) + self.tensorboard_output_path = get_tensorboard_output_path(param_dict) + self.tensorboard_job_name = get_tensorboard_job_name(param_dict) + + self.eigenvalue_enabled, \ + self.eigenvalue_verbose, \ + self.eigenvalue_max_iter, \ + self.eigenvalue_tol, \ + self.eigenvalue_stability, \ + self.eigenvalue_gas_boundary_resolution, \ + self.eigenvalue_layer_name, \ + self.eigenvalue_layer_num = get_eigenvalue_config(param_dict) + + self.sparse_attention = get_sparse_attention(param_dict) + self.pipeline = get_pipeline_config(param_dict) + + self.pld_enabled = get_pld_enabled(param_dict) + self.pld_params = get_pld_params(param_dict) + + self.curriculum_enabled = get_curriculum_enabled(param_dict) + self.curriculum_params = get_curriculum_params(param_dict) + + checkpoint_params = get_checkpoint_params(param_dict) + validation_mode = get_checkpoint_tag_validation_mode(checkpoint_params) + self.checkpoint_tag_validation_enabled = validation_mode != ValidationMode.IGNORE + self.checkpoint_tag_validation_fail = validation_mode == ValidationMode.FAIL + + self.aio_config = get_aio_config(param_dict) + + self.dataloader_drop_last = get_dataloader_drop_last(param_dict) + + def _batch_assertion(self): + + train_batch = self.train_batch_size + micro_batch = self.train_micro_batch_size_per_gpu + grad_acc = self.gradient_accumulation_steps + + assert train_batch > 0, \ + f'Train batch size: {train_batch} has to be greater than 0' + + assert micro_batch > 0, \ + f'Micro batch size per gpu: {micro_batch} has to be greater than 0' + + assert grad_acc > 0, \ + f'Gradient accumulation steps: {grad_acc} has to be greater than 0' + + assert train_batch == micro_batch * grad_acc * self.world_size, \ + (f'Check batch related parameters. train_batch_size is not equal' + ' to micro_batch_per_gpu * gradient_acc_step * world_size' + f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}') + + def _set_batch_related_parameters(self): + + train_batch = self.train_batch_size + micro_batch = self.train_micro_batch_size_per_gpu + grad_acc = self.gradient_accumulation_steps + + #all values are provided nothing needs to be set + if train_batch is not None and \ + micro_batch is not None and \ + grad_acc is not None: + return + + #global_accumulation_steps needs to be set + elif train_batch is not None and \ + micro_batch is not None: + grad_acc = train_batch // micro_batch + grad_acc //= self.world_size + self.gradient_accumulation_steps = grad_acc + + #micro_batch_per_gpu needs to be set + elif train_batch is not None and \ + grad_acc is not None: + micro_batch = train_batch // self.world_size + micro_batch //= grad_acc + self.train_micro_batch_size_per_gpu = micro_batch + + #train_batch_size needs to be set + elif micro_batch is not None and \ + grad_acc is not None: + train_batch_size = micro_batch * grad_acc + train_batch_size *= self.world_size + self.train_batch_size = train_batch_size + + #gradient_accumulation_steps and micro_batch_per_gpus is set + elif train_batch is not None: + self.gradient_accumulation_steps = 1 + self.train_micro_batch_size_per_gpu = train_batch // self.world_size + + #train_batch_size and gradient_accumulation_step is set + elif micro_batch is not None: + self.train_batch_size = micro_batch * self.world_size + self.gradient_accumulation_steps = 1 + + #either none of the three parameters are provided or just gradient_accumulation_step is provided + else: + assert False, \ + 'Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided' + + def _configure_train_batch_size(self): + self._set_batch_related_parameters() + self._batch_assertion() + + def _do_sanity_check(self): + self._do_error_check() + + self._do_warning_check() + + def print(self, name): + logger.info('{}:'.format(name)) + for arg in sorted(vars(self)): + if arg != '_param_dict': + dots = '.' * (29 - len(arg)) + logger.info(' {} {} {}'.format(arg, dots, getattr(self, arg))) + + logger.info(' json = {}'.format( + json.dumps(self._param_dict, + sort_keys=True, + indent=4, + cls=ScientificNotationEncoder, + separators=(',', + ':')))) + + def _do_error_check(self): + assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU) + + assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format( + GRADIENT_ACCUMULATION_STEPS) + + if self.zero_enabled: + assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION) + + if self.fp16_master_weights_and_gradients: + assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now." + + def _do_warning_check(self): + fp16_enabled = self.fp16_enabled + + vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT) + if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0: + logger.warning( + "DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization." + .format(vocabulary_size, + TENSOR_CORE_ALIGN_SIZE)) + + if self.optimizer_params is not None and \ + MAX_GRAD_NORM in self.optimizer_params.keys() and \ + self.optimizer_params[MAX_GRAD_NORM] > 0: + if fp16_enabled: + if self.global_rank == 0: + logger.warning( + 'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper' + .format(MAX_GRAD_NORM, + self.optimizer_params[MAX_GRAD_NORM])) + else: + if self.global_rank == 0: + logger.warning( + 'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero' + .format(self.optimizer_params[MAX_GRAD_NORM])) + self.optimizer_params[MAX_GRAD_NORM] = 0.0 From 16f1d21d96979c4e669ec1bf83b8c1e1905615e8 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Thu, 18 Nov 2021 15:35:15 -0800 Subject: [PATCH 38/59] align new line with master --- tests/unit/test_zero.py | 2374 +++++++++++++++++++-------------------- 1 file changed, 1187 insertions(+), 1187 deletions(-) diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 0d4053acc9ab..132ec4ab23d2 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -1,1187 +1,1187 @@ -import math -from typing import Dict, List, Set -import pytest -import torch.distributed as dist -import torch -from torch import Tensor -from torch.nn import Linear, Module -from torch.nn.modules.container import ModuleList -from torch.nn.modules.loss import L1Loss -from torch.nn.parameter import Parameter - -from .common import distributed_test -from .simple_model import SimpleModel, random_dataloader, args_from_dict - -import deepspeed -from deepspeed.runtime.engine import DeepSpeedEngine -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - - -def run_unbalanced_gradients(model, data_loader): - def drop_some_gradients(model, iter): - odd_iteration = iter % 2 - for i, p in enumerate(model.parameters()): - p.requires_grad = (i % 2) == odd_iteration - - def enable_grads(model): - for p in model.parameters(): - p.requires_grad = True - - for i, batch in enumerate(data_loader): - drop_some_gradients(model, i + 1) - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - enable_grads(model) - - -def dump_state_dict(model): - if dist.get_rank() == 0: - print("state_dict:") - for name, param in model.named_parameters(): - print(f"{name} {param.data}") - - -@pytest.mark.parametrize('zero_stage', [1, 2, 3]) -def test_zero_unbalanced_gradients(tmpdir, zero_stage): - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_optimization": { - "stage": zero_stage - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 4 - - model = SimpleModel(hidden_dim=hidden_dim) - - @distributed_test(world_size=[1]) - def _test_zero_unbalanced_gradients(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=16, - hidden_dim=hidden_dim, - device=model.device) - - run_unbalanced_gradients(model, data_loader) - - _test_zero_unbalanced_gradients(args=args, model=model, hidden_dim=hidden_dim) - - -# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 -@pytest.mark.parametrize('zero_stage', [3]) -def test_zero3_repeat_forward_loop(tmpdir, zero_stage): - - # force all params to be partitioned by forcing threshold=0 - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_optimization": { - "stage": zero_stage, - "stage3_param_persistence_threshold": 0 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 4 - - class AlbertLikeModel(torch.nn.Module): - def __init__(self, hidden_dim): - super().__init__() - self.linear = torch.nn.Linear(hidden_dim, hidden_dim) - self.cross_entropy_loss = torch.nn.CrossEntropyLoss() - - def forward(self, x, y): - # run the same layer multiple times in a loop - to test a stack of forwards, followed by a stack of backwards - hidden = x - for i in range(3): - hidden = hidden + self.linear(hidden) - return self.cross_entropy_loss(hidden, y) - - model = AlbertLikeModel(hidden_dim=hidden_dim) - - @distributed_test(world_size=[1]) - def _test_zero3_repeat_forward_loop(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=16, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_zero3_repeat_forward_loop(args=args, model=model, hidden_dim=hidden_dim) - - -# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 -# also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372 -@pytest.mark.parametrize('zero_stage', [2, 3]) -def test_zero_to_fp32_1_param_group(tmpdir, zero_stage): - - # XXX: ideally refactor with the 2_param_group test as 75% is the same - - # force all params to be partitioned by forcing threshold=0 - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_optimization": { - "stage": zero_stage, - "stage3_param_persistence_threshold": 0 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - @distributed_test(world_size=[2]) - def _test_zero_to_fp32(): - class MyModel(torch.nn.Module): - def __init__(self, hidden_dim, n_layers): - super().__init__() - # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that - # the number of total elements is uneven: - # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total - self.ll = torch.nn.ModuleList( - torch.nn.Linear(hidden_dim, - hidden_dim) for i in range(n_layers)) - # (2) the following adds 4+1=5 elements - self.classifier = torch.nn.Linear(4, 1) - # total 48+5=53 (uneven as desired) elements - self.cross_entropy_loss = torch.nn.CrossEntropyLoss() - - def forward(self, x, y): - hidden = x - for l in self.ll: - hidden = l(hidden) - return self.cross_entropy_loss(hidden, y) - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 3 # do not change - - world_size = dist.get_world_size() - # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2 - n_layers = world_size * 2 - model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) - - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=16, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - model.save_checkpoint(tmpdir) - - # make sure all sides saved it - dist.barrier() - - if zero_stage == 3: - with deepspeed.zero.GatheredParameters(list( - model.module.parameters(recurse=True)), - modifier_rank=None): - pass # this forces gathering the model - - #dump_state_dict(model) - - orig_state_dict = {} - for name, param in model.module.named_parameters(): - orig_state_dict[name] = param.detach().cpu() - - if dist.get_rank() == 0: - fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) - #dump_state_dict(fp32_model) - - fp32_state_dict = fp32_model.state_dict() - for name in orig_state_dict.keys(): - # float() workaround for torch<1.6 - assert torch.allclose(orig_state_dict[name].float(), - fp32_state_dict[name].float()) - - _test_zero_to_fp32() - - -@pytest.mark.parametrize('zero_stage', [2, 3]) -def test_zero_to_fp32_2_param_groups(tmpdir, zero_stage): - - # TODO: - # - need to test with multiple param groups - - # force all params to be partitioned by forcing threshold=0 - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_allow_untested_optimizer": 1, - "zero_optimization": { - "stage": zero_stage, - "stage3_param_persistence_threshold": 0 - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - @distributed_test(world_size=[2]) - def _test_zero_to_fp32(): - class MyModel(torch.nn.Module): - def __init__(self, hidden_dim, n_layers): - super().__init__() - self.ll = torch.nn.ModuleList( - torch.nn.Linear(hidden_dim, - hidden_dim) for i in range(n_layers)) - self.cross_entropy_loss = torch.nn.CrossEntropyLoss() - - def forward(self, x, y): - hidden = x - for l in self.ll: - hidden = l(hidden) - return self.cross_entropy_loss(hidden, y) - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 3 - - world_size = dist.get_world_size() - n_layers = world_size * 2 - model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) - - optim_groups = [ - { - "params": [l.weight for l in model.ll], - "weight_decay": 0.01, - }, - { - "params": [l.bias for l in model.ll], - "weight_decay": 0.0 - }, - ] - optim = torch.optim.SGD(optim_groups, lr=0.1) - - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters(), - optimizer = optim, - ) - data_loader = random_dataloader(model=model, - total_samples=16, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - model.save_checkpoint(tmpdir) - - # make sure all sides saved it - dist.barrier() - - if zero_stage == 3: - with deepspeed.zero.GatheredParameters(list( - model.module.parameters(recurse=True)), - modifier_rank=None): - pass # this forces gathering the model - - #dump_state_dict(model) - - orig_state_dict = {} - for name, param in model.module.named_parameters(): - orig_state_dict[name] = param.detach().cpu() - - if dist.get_rank() == 0: - fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) - #dump_state_dict(fp32_model) - - fp32_state_dict = fp32_model.state_dict() - for name in orig_state_dict.keys(): - # float() workaround for torch<1.6 - assert torch.allclose(orig_state_dict[name].float(), - fp32_state_dict[name].float()) - - _test_zero_to_fp32() - - -@pytest.mark.parametrize('zero_stage, allgather_bucket_size', [(2, 1000), (2, 1001)]) -def test_incorrect_allgather_bucket_size(tmpdir, zero_stage, allgather_bucket_size): - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_optimization": { - "stage": zero_stage, - "allgather_bucket_size": allgather_bucket_size - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 4 - - model = SimpleModel(hidden_dim=hidden_dim) - - @distributed_test(world_size=[1]) - def _test_incorrect_allgather_bucket_size(args, model, hidden_dim): - if allgather_bucket_size % 2 == 0: - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - else: - with pytest.raises(AssertionError) as assertinfo: - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str( - assertinfo) - - _test_incorrect_allgather_bucket_size(args=args, model=model, hidden_dim=hidden_dim) - - -@pytest.mark.parametrize('zero_stage, world_size', [(2, 2), (2, 3), (2, 4)]) -def test_partition_nccl_alignment(tmpdir, zero_stage, world_size): - config_dict = { - "train_micro_batch_size_per_gpu": 2, - "gradient_accumulation_steps": 2, - "steps_per_print": 1, - "zero_optimization": { - "stage": zero_stage - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3 - } - }, - "fp16": { - "enabled": True, - "initial_scale_power": 8 - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 4 - - model = SimpleModel(hidden_dim=hidden_dim) - - @distributed_test(world_size=world_size) - def _test_partition_nccl_alignment(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - - # get nccl all-gather send buffers alignment factor - nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor - - parallel_partitioned_bit16_groups = model.optimizer.parallel_partitioned_bit16_groups if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups - for data_parallel_partitions in parallel_partitioned_bit16_groups: - for partition_id, partitioned_data in enumerate(data_parallel_partitions): - # verify that data partition start locations are 4-byte aligned - assert (partitioned_data.data_ptr() % - (2 * nccl_start_alignment_factor) == 0) - - _test_partition_nccl_alignment(args=args, model=model, hidden_dim=hidden_dim) - - -def _ds_initialize_for_param_partitioning_testing(model: Module, - cfg: dict) -> DeepSpeedEngine: - ds_engine, _, _, _ = deepspeed.initialize( - config=cfg, - model=model, - model_parameters=model.parameters() - ) - - return ds_engine - - -def _assert_partition_status(model: Module, - valid_statuses: Set[ZeroParamStatus]) -> None: - for _, param in model.named_parameters(): - assert param.ds_status in valid_statuses, param.ds_summary() - - -def _assert_fully_available(model: Module) -> None: - for _, param in model.named_parameters(): - assert param.ds_status == ZeroParamStatus.AVAILABLE - - -class EltwiseMultiplicationModule(Module): - def __init__(self, weight: Parameter) -> None: - super().__init__() - self.weight = weight - - def forward(self, x: Tensor) -> Tensor: - _assert_fully_available(self) - result = self.weight * x - - return result - - -class EltwiseMultiplicationTestNetwork(Module): - """used for testing purposes""" - def __init__( - self, - weight1: Parameter, - weight2: Parameter, - weight3: Parameter, - ) -> None: - super().__init__() - self.__layer1 = EltwiseMultiplicationModule(weight1) - self.__layer2 = EltwiseMultiplicationModule(weight2) - self.__layer3 = EltwiseMultiplicationModule(weight3) - - self.loss = L1Loss(reduction="none") - - def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: - _assert_partition_status( - self, - { - ZeroParamStatus.NOT_AVAILABLE, - ZeroParamStatus.INFLIGHT, - ZeroParamStatus.AVAILABLE - } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) - - layerwise_expected_states = { - ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE, - ZeroParamStatus.AVAILABLE, - } - - _assert_partition_status(self.__layer1, layerwise_expected_states) - hidden1 = self.__layer1(x) - _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE}) - - _assert_partition_status(self.__layer2, layerwise_expected_states) - hidden2 = self.__layer2(hidden1) - _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE}) - - _assert_partition_status(self.__layer3, layerwise_expected_states) - y_hat = self.__layer3(hidden2) - _assert_partition_status(self.__layer3, - { - ZeroParamStatus.AVAILABLE - if prefetching else ZeroParamStatus.NOT_AVAILABLE - }) - - loss = self.loss(y_hat, y) - - _assert_partition_status( - self, - { - ZeroParamStatus.NOT_AVAILABLE, - ZeroParamStatus.INFLIGHT, - ZeroParamStatus.AVAILABLE - } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) - - return { - "hidden1": hidden1, - "hidden2": hidden2, - "y_hat": y_hat, - "loss": loss, - } - - -@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) -@pytest.mark.parametrize("fp16_enabled", [True, False]) -@pytest.mark.parametrize("contiguous_gradients", [True, False]) -@pytest.mark.parametrize("offload_optimizer", [True, False]) -@pytest.mark.parametrize("zero_grad", [True, False]) -@pytest.mark.parametrize("iteration", list(range(1))) -def test_zero3_param_partitioning_base( - param_persistence_threshold: int, - fp16_enabled: bool, - contiguous_gradients: bool, - offload_optimizer: bool, - zero_grad: bool, - iteration: int, -) -> None: - @distributed_test(world_size=[2]) - def _test_zero3_param_partitioning(): - if offload_optimizer and not contiguous_gradients: - return - - m = 3 - n = 5 - weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] - model = EltwiseMultiplicationTestNetwork(*weights) - - cfg = { - "train_micro_batch_size_per_gpu": 1, - "zero_optimization": { - "stage": 3, - "stage3_max_reuse_distance": 0, - "stage3_param_persistence_threshold": param_persistence_threshold, - "contiguous_gradients": contiguous_gradients, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1. - } - }, - "fp16": { - "enabled": fp16_enabled, - "loss_scale": 1., - } - } - - if offload_optimizer: - cfg["zero_optimization"]["offload_optimizer"] = { - "device": "cpu", - "pin_memory": True, - } - - ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) - for i, weight in enumerate(weights): - weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, - (i + 1) * (1 + dist.get_rank())) - - def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: - return torch.as_tensor(vals, - dtype=dtype - or (torch.float16 if fp16_enabled else torch.float32), - device=ds_engine.device) - - expected_hidden1 = create_tensor([ - [1, - 1, - 1, - 1, - 1], - [1, - 1, - 1, - 2, - 2], - [2, - 2, - 2, - 2, - 2], - ]) - expected_hidden2 = create_tensor([ - [2, - 2, - 2, - 2, - 2], - [2, - 2, - 2, - 8, - 8], - [8, - 8, - 8, - 8, - 8], - ]) - expected_yhat = create_tensor([[6, - 6, - 6, - 6, - 6], - [6, - 6, - 6, - 48, - 48], - [48, - 48, - 48, - 48, - 48]]) - expected_loss = create_tensor([ - [5, - 5, - 5, - 5, - 5], - [5, - 5, - 5, - 47, - 47], - [47, - 47, - 47, - 47, - 47], - ]) - - for train_iter in range(3): - activations = ds_engine( - x=torch.ones((m, - n), - dtype=torch.float16 if fp16_enabled else torch.float32, - device=ds_engine.device), - y=torch.ones((m, - n), - dtype=torch.float16 if fp16_enabled else torch.float32, - device=ds_engine.device), - prefetching=train_iter > 0, - ) - assert torch.allclose(activations["hidden1"], expected_hidden1) - assert torch.allclose(activations["hidden2"], expected_hidden2) - assert torch.allclose(activations["y_hat"], expected_yhat) - assert torch.allclose(activations["loss"], expected_loss) - - ds_engine.backward(activations["loss"].sum()) - - # check the gradients - grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() - assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" - assert set(grad_partitions[0].keys()) == {0, 1, 2} - dloss_wrt_layer1 = grad_partitions[0][0] - dloss_wrt_layer2 = grad_partitions[0][1] - dloss_wrt_layer3 = grad_partitions[0][2] - - assert dloss_wrt_layer1.dtype == torch.float - assert dloss_wrt_layer2.dtype == torch.float - assert dloss_wrt_layer3.dtype == torch.float - - # layer1 = [..., 1, 2, ...] - # layer2 = [..., 2, 4, ...] - # layer3 = [..., 3, 6, ...] - # dloss_wrt_layer3 = hidden2 - # dloss_wrt_layer2 = layer3 * hidden1 - # dloss_wrt_layer1 = layer3 * layer2 * x - - grad_multiplier = 1 if zero_grad else (train_iter + 1) - if dist.get_rank() == 0: - assert torch.allclose( - dloss_wrt_layer3.cuda(), - grad_multiplier * create_tensor([2] * 8, - torch.float)) - assert torch.allclose( - dloss_wrt_layer2.cuda(), - grad_multiplier * create_tensor([3 * 1] * 8, - torch.float)) - assert torch.allclose( - dloss_wrt_layer1.cuda(), - grad_multiplier * create_tensor([3 * 2 * 1] * 8, - torch.float)) - elif dist.get_rank() == 1: - # parameters dont split evenly across ranks so rank 1 has a zero-padded - # partition - assert torch.allclose( - dloss_wrt_layer3.cuda(), - grad_multiplier * create_tensor(([8] * 7) + [0], - torch.float)) - assert torch.allclose( - dloss_wrt_layer2.cuda(), - grad_multiplier * create_tensor(([6 * 2] * 7) + [0], - torch.float)) - assert torch.allclose( - dloss_wrt_layer1.cuda(), - grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], - torch.float)) - else: - raise RuntimeError("test has world size of two") - - if zero_grad: - ds_engine.optimizer.zero_grad() - - # TODO. add testing for this - for now we just call it to make sure it - # doesnt throw - ds_engine.optimizer.step() - # taking an optimizer step invalidates all parameters, make sure everything - # has been partitioned afterwards - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) - assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) - - _test_zero3_param_partitioning() - - -@pytest.mark.parametrize("world_sz", [1, 2, 4]) -@pytest.mark.parametrize("param_sz", [8100]) -@pytest.mark.parametrize("init_context_manager", [True, False]) -def test_zero3_param_partitioning_large_param(world_sz: int, - param_sz: int, - init_context_manager: bool) -> None: - class LargeParamModel(Module): - def __init__(self): - super().__init__() - self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32)) - - # only do weight initialization on root rank to - # make sure we are broadcasting correctly from rank 0 - if dist.get_rank() == 0: - partition_sz = math.ceil(self.param.numel() / dist.get_world_size()) - offset = 0 - for rank in range(dist.get_world_size()): - with torch.no_grad(): - self.param[offset:offset + partition_sz].fill_(rank) - offset += partition_sz - - def forward(self, x: Tensor) -> Tensor: - return x * self.param - - @distributed_test(world_size=[world_sz]) - def _distributed_test(): - ds_config = { - "train_micro_batch_size_per_gpu": 1, - "zero_optimization": { - "stage": 3, - "stage3_max_reuse_distance": 0, - "contiguous_gradients": True, - "overlap_comm": True, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1. - } - }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } - } - with deepspeed.zero.Init(mem_efficient_linear=False, - enabled=init_context_manager): - model = LargeParamModel() - ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) - - for train_iter in range(3): # test multiple iterations to cover prefetching - activation: Tensor = ds_engine( - torch.ones(param_sz, - dtype=torch.float16, - device=ds_engine.device)) - - partition_sz = math.ceil(param_sz / world_sz) - for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)): - activation_from_partition = activation[start_idx:start_idx + - partition_sz] - assert torch.allclose( - activation_from_partition, - torch.full_like(activation_from_partition, - rank_idx)) - - ds_engine.backward(activation.sum()) - ds_engine.allreduce_gradients() - - avgd_gradients = ds_engine.optimizer.averaged_gradients - assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" - weight_gradient, = avgd_gradients[0] - expected_weight_gradient = (train_iter + 1) * torch.full_like( - weight_gradient, - 1) - - assert torch.allclose(weight_gradient, expected_weight_gradient) - - _distributed_test() - - -@pytest.mark.parametrize("world_sz", [1, 2, 4]) -@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000]) -@pytest.mark.parametrize("n_layers", [100, 1_000]) -@pytest.mark.parametrize("init_context_manager", [True, False]) -def test_zero3_param_partitioning_many_params(world_sz: int, - param_sz: int, - n_layers: int, - init_context_manager: bool) -> None: - class ManyParamModel(Module): - def __init__(self) -> None: - super().__init__() - - self.modulelist = ModuleList( - EltwiseMultiplicationModule( - weight=Parameter(torch.empty((param_sz, - ), - dtype=torch.float32))) - for _ in range(n_layers)) - - for layer_num, module in enumerate(self.modulelist): - if dist.get_rank() == 0: - param: Parameter = module.weight - partition_sz = math.ceil(param.numel() / dist.get_world_size()) - offset = 0 - for rank in range(dist.get_world_size()): - with torch.no_grad(): - param[offset:offset + partition_sz].fill_(2 * layer_num * - rank) - offset += partition_sz - - def forward(self, x: Tensor) -> Tensor: - activations = [] - - for module in self.modulelist: - x = module(x) - activations.append(x) - - return activations - - @distributed_test(world_size=[world_sz]) - def _distributed_test(): - ds_cfg = { - "train_micro_batch_size_per_gpu": 1, - "zero_optimization": { - "stage": 3, - "stage3_max_reuse_distance": 0, - "contiguous_gradients": True, - "overlap_comm": True, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1. - } - }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } - } - - with deepspeed.zero.Init(config=ds_cfg, - mem_efficient_linear=False, - enabled=init_context_manager): - model = ManyParamModel() - - ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) - - for _ in range(3): # test multiple iterations to cover prefetching - activations: List[Tensor] = ds_engine( - torch.ones((param_sz, - ), - dtype=torch.float16, - device=ds_engine.device)) - assert len(activations) == n_layers - - partition_sz = math.ceil(param_sz / world_sz) - expected_activations = torch.empty(param_sz, - dtype=torch.float16, - device=ds_engine.device) - for start_idx in range(0, param_sz, partition_sz): - expected_activations[start_idx:start_idx + - partition_sz] = dist.get_rank() - - for layer_num, activation in enumerate(activations): - expected_activations *= 2 * layer_num - assert torch.allclose(activation, expected_activations) - - # TODO. finish writing this test - ds_engine.backward(activations[-1].sum()) - - avgd_gradients = ds_engine.optimizer.averaged_gradients - assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" - weight_gradients: List[Tensor] = avgd_gradients[0] - - for layer_num, activation in enumerate(weight_gradients): - pass - - _distributed_test() - - -@pytest.mark.parametrize("world_sz", [1, 2, 4]) -def test_zero3_init_for_parent_weight_initialization(world_sz): - class ModelWhereParentInitializesChildWeights(Module): - def __init__(self) -> None: - super().__init__() - - self.linear = Linear(12, 1) - - self.apply(self.__init_weights) - - def __init_weights(self, module): - if isinstance(module, Linear): - with torch.no_grad(): - module.weight.fill_(1 + dist.get_rank()) - - @distributed_test(world_size=[world_sz]) - def _distributed_test(): - ds_cfg = { - "train_micro_batch_size_per_gpu": 1, - "zero_optimization": { - "stage": 3, - "stage3_max_reuse_distance": 0, - "contiguous_gradients": True, - "overlap_comm": True, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1. - } - }, - "fp16": { - "enabled": True, - "loss_scale": 1., - } - } - - with deepspeed.zero.Init(config=ds_cfg, - mem_efficient_linear=False, - enabled=True): - model = ModelWhereParentInitializesChildWeights() - - assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / world_sz) - assert torch.allclose(model.linear.weight.ds_tensor, - torch.full_like(model.linear.weight.ds_tensor, - 1)) - - _distributed_test() - - -@pytest.mark.skip( - reason="depends on upgraded pytorch and nccl that isnt always available") -@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) -@pytest.mark.parametrize("contiguous_gradients", [True, False]) -@pytest.mark.parametrize("offload_optimizer", [True, False]) -@pytest.mark.parametrize("zero_grad", [True]) -@pytest.mark.parametrize("iteration", list(range(1))) -def test_zero3_param_partitioning_base_bf16( - param_persistence_threshold: int, - contiguous_gradients: bool, - offload_optimizer: bool, - zero_grad: bool, - iteration: int, -) -> None: - @distributed_test(world_size=[2]) - def _test_zero3_param_partitioning(): - if offload_optimizer and not contiguous_gradients: - return - - m = 3 - n = 5 - weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] - model = EltwiseMultiplicationTestNetwork(*weights) - - cfg = { - "train_micro_batch_size_per_gpu": 1, - "zero_optimization": { - "stage": 3, - "stage3_max_reuse_distance": 0, - "stage3_param_persistence_threshold": param_persistence_threshold, - "contiguous_gradients": contiguous_gradients, - }, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1. - } - }, - "bfloat16": { - "enabled": True, - "loss_scale": 1., - } - } - - if offload_optimizer: - cfg["zero_optimization"]["offload_optimizer"] = { - "device": "cpu", - "pin_memory": True, - } - - ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) - for i, weight in enumerate(weights): - weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, - (i + 1) * (1 + dist.get_rank())) - - def create_tensor(vals): - return torch.as_tensor(vals, dtype=torch.bfloat16, device=ds_engine.device) - - expected_hidden1 = create_tensor([ - [1, - 1, - 1, - 1, - 1], - [1, - 1, - 1, - 2, - 2], - [2, - 2, - 2, - 2, - 2], - ]) - expected_hidden2 = create_tensor([ - [2, - 2, - 2, - 2, - 2], - [2, - 2, - 2, - 8, - 8], - [8, - 8, - 8, - 8, - 8], - ]) - expected_yhat = create_tensor([[6, - 6, - 6, - 6, - 6], - [6, - 6, - 6, - 48, - 48], - [48, - 48, - 48, - 48, - 48]]) - expected_loss = create_tensor([ - [5, - 5, - 5, - 5, - 5], - [5, - 5, - 5, - 47, - 47], - [47, - 47, - 47, - 47, - 47], - ]) - - for train_iter in range(3): - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) - activations = ds_engine( - x=torch.ones((m, - n), - dtype=torch.bfloat16, - device=ds_engine.device), - y=torch.ones((m, - n), - dtype=torch.bfloat16, - device=ds_engine.device), - prefetching=train_iter > 0, - ) - assert torch.allclose(activations["hidden1"], expected_hidden1) - assert torch.allclose(activations["hidden2"], expected_hidden2) - assert torch.allclose(activations["y_hat"], expected_yhat) - assert torch.allclose(activations["loss"], expected_loss) - - ds_engine.backward(activations["loss"].sum()) - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) - - # check the gradients - grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() - assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" - assert set(grad_partitions[0].keys()) == {0, 1, 2} - dloss_wrt_layer1 = grad_partitions[0][0] - dloss_wrt_layer2 = grad_partitions[0][1] - dloss_wrt_layer3 = grad_partitions[0][2] - - # layer1 = [..., 1, 2, ...] - # layer2 = [..., 2, 4, ...] - # layer3 = [..., 3, 6, ...] - # dloss_wrt_layer3 = hidden2 - # dloss_wrt_layer2 = layer3 * hidden1 - # dloss_wrt_layer1 = layer3 * layer2 * x - - expected_grad_dtype = torch.float32 if offload_optimizer else torch.bfloat16 - - grad_multiplier = 1 if zero_grad else (train_iter + 1) - if dist.get_rank() == 0: - assert torch.allclose( - dloss_wrt_layer3.cuda(), - grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) - assert torch.allclose( - dloss_wrt_layer2.cuda(), - grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) - assert torch.allclose( - dloss_wrt_layer1.cuda(), - grad_multiplier * - create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) - elif dist.get_rank() == 1: - # parameters dont split evenly across ranks so rank 1 has a zero-padded - # partition - assert torch.allclose( - dloss_wrt_layer3.cuda(), - grad_multiplier * - create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) - assert torch.allclose( - dloss_wrt_layer2.cuda(), - grad_multiplier * - create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) - assert torch.allclose( - dloss_wrt_layer1.cuda(), - grad_multiplier * - create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) - else: - raise RuntimeError("test has world size of two") - - if zero_grad: - ds_engine.optimizer.zero_grad() - - # TODO. add testing for this - for now we just call it to make sure it - # doesnt throw - ds_engine.optimizer.step() - _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) - - _test_zero3_param_partitioning() +import math +from typing import Dict, List, Set +import pytest +import torch.distributed as dist +import torch +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn.modules.container import ModuleList +from torch.nn.modules.loss import L1Loss +from torch.nn.parameter import Parameter + +from .common import distributed_test +from .simple_model import SimpleModel, random_dataloader, args_from_dict + +import deepspeed +from deepspeed.runtime.engine import DeepSpeedEngine +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + + +def run_unbalanced_gradients(model, data_loader): + def drop_some_gradients(model, iter): + odd_iteration = iter % 2 + for i, p in enumerate(model.parameters()): + p.requires_grad = (i % 2) == odd_iteration + + def enable_grads(model): + for p in model.parameters(): + p.requires_grad = True + + for i, batch in enumerate(data_loader): + drop_some_gradients(model, i + 1) + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + enable_grads(model) + + +def dump_state_dict(model): + if dist.get_rank() == 0: + print("state_dict:") + for name, param in model.named_parameters(): + print(f"{name} {param.data}") + + +@pytest.mark.parametrize('zero_stage', [1, 2, 3]) +def test_zero_unbalanced_gradients(tmpdir, zero_stage): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 4 + + model = SimpleModel(hidden_dim=hidden_dim) + + @distributed_test(world_size=[1]) + def _test_zero_unbalanced_gradients(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device) + + run_unbalanced_gradients(model, data_loader) + + _test_zero_unbalanced_gradients(args=args, model=model, hidden_dim=hidden_dim) + + +# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 +@pytest.mark.parametrize('zero_stage', [3]) +def test_zero3_repeat_forward_loop(tmpdir, zero_stage): + + # force all params to be partitioned by forcing threshold=0 + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + "stage3_param_persistence_threshold": 0 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 4 + + class AlbertLikeModel(torch.nn.Module): + def __init__(self, hidden_dim): + super().__init__() + self.linear = torch.nn.Linear(hidden_dim, hidden_dim) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + # run the same layer multiple times in a loop - to test a stack of forwards, followed by a stack of backwards + hidden = x + for i in range(3): + hidden = hidden + self.linear(hidden) + return self.cross_entropy_loss(hidden, y) + + model = AlbertLikeModel(hidden_dim=hidden_dim) + + @distributed_test(world_size=[1]) + def _test_zero3_repeat_forward_loop(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_zero3_repeat_forward_loop(args=args, model=model, hidden_dim=hidden_dim) + + +# testing the fix https://github.com/microsoft/DeepSpeed/pull/1227 +# also reproduces the https://github.com/microsoft/DeepSpeed/pull/1372 +@pytest.mark.parametrize('zero_stage', [2, 3]) +def test_zero_to_fp32_1_param_group(tmpdir, zero_stage): + + # XXX: ideally refactor with the 2_param_group test as 75% is the same + + # force all params to be partitioned by forcing threshold=0 + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + "stage3_param_persistence_threshold": 0 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + @distributed_test(world_size=[2]) + def _test_zero_to_fp32(): + class MyModel(torch.nn.Module): + def __init__(self, hidden_dim, n_layers): + super().__init__() + # to reproduce https://github.com/microsoft/DeepSpeed/pull/1372 it is important that + # the number of total elements is uneven: + # (1) 4 layers of 3*(3+1)=12 elements each, 48 in total + self.ll = torch.nn.ModuleList( + torch.nn.Linear(hidden_dim, + hidden_dim) for i in range(n_layers)) + # (2) the following adds 4+1=5 elements + self.classifier = torch.nn.Linear(4, 1) + # total 48+5=53 (uneven as desired) elements + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + for l in self.ll: + hidden = l(hidden) + return self.cross_entropy_loss(hidden, y) + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 3 # do not change + + world_size = dist.get_world_size() + # we want at least 2x layers as there are gpus to trigger round_robin_fp16_groups reshuffle in zero2 + n_layers = world_size * 2 + model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) + + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.save_checkpoint(tmpdir) + + # make sure all sides saved it + dist.barrier() + + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(list( + model.module.parameters(recurse=True)), + modifier_rank=None): + pass # this forces gathering the model + + #dump_state_dict(model) + + orig_state_dict = {} + for name, param in model.module.named_parameters(): + orig_state_dict[name] = param.detach().cpu() + + if dist.get_rank() == 0: + fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) + #dump_state_dict(fp32_model) + + fp32_state_dict = fp32_model.state_dict() + for name in orig_state_dict.keys(): + # float() workaround for torch<1.6 + assert torch.allclose(orig_state_dict[name].float(), + fp32_state_dict[name].float()) + + _test_zero_to_fp32() + + +@pytest.mark.parametrize('zero_stage', [2, 3]) +def test_zero_to_fp32_2_param_groups(tmpdir, zero_stage): + + # TODO: + # - need to test with multiple param groups + + # force all params to be partitioned by forcing threshold=0 + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_allow_untested_optimizer": 1, + "zero_optimization": { + "stage": zero_stage, + "stage3_param_persistence_threshold": 0 + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + @distributed_test(world_size=[2]) + def _test_zero_to_fp32(): + class MyModel(torch.nn.Module): + def __init__(self, hidden_dim, n_layers): + super().__init__() + self.ll = torch.nn.ModuleList( + torch.nn.Linear(hidden_dim, + hidden_dim) for i in range(n_layers)) + self.cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + hidden = x + for l in self.ll: + hidden = l(hidden) + return self.cross_entropy_loss(hidden, y) + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 3 + + world_size = dist.get_world_size() + n_layers = world_size * 2 + model = MyModel(hidden_dim=hidden_dim, n_layers=n_layers) + + optim_groups = [ + { + "params": [l.weight for l in model.ll], + "weight_decay": 0.01, + }, + { + "params": [l.bias for l in model.ll], + "weight_decay": 0.0 + }, + ] + optim = torch.optim.SGD(optim_groups, lr=0.1) + + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters(), + optimizer = optim, + ) + data_loader = random_dataloader(model=model, + total_samples=16, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + model.save_checkpoint(tmpdir) + + # make sure all sides saved it + dist.barrier() + + if zero_stage == 3: + with deepspeed.zero.GatheredParameters(list( + model.module.parameters(recurse=True)), + modifier_rank=None): + pass # this forces gathering the model + + #dump_state_dict(model) + + orig_state_dict = {} + for name, param in model.module.named_parameters(): + orig_state_dict[name] = param.detach().cpu() + + if dist.get_rank() == 0: + fp32_model = load_state_dict_from_zero_checkpoint(model.module, tmpdir) + #dump_state_dict(fp32_model) + + fp32_state_dict = fp32_model.state_dict() + for name in orig_state_dict.keys(): + # float() workaround for torch<1.6 + assert torch.allclose(orig_state_dict[name].float(), + fp32_state_dict[name].float()) + + _test_zero_to_fp32() + + +@pytest.mark.parametrize('zero_stage, allgather_bucket_size', [(2, 1000), (2, 1001)]) +def test_incorrect_allgather_bucket_size(tmpdir, zero_stage, allgather_bucket_size): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage, + "allgather_bucket_size": allgather_bucket_size + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 4 + + model = SimpleModel(hidden_dim=hidden_dim) + + @distributed_test(world_size=[1]) + def _test_incorrect_allgather_bucket_size(args, model, hidden_dim): + if allgather_bucket_size % 2 == 0: + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + else: + with pytest.raises(AssertionError) as assertinfo: + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + assert "allgather_bucket_size must be a multiple of nccl_start_alignment_factor" in str( + assertinfo) + + _test_incorrect_allgather_bucket_size(args=args, model=model, hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('zero_stage, world_size', [(2, 2), (2, 3), (2, 4)]) +def test_partition_nccl_alignment(tmpdir, zero_stage, world_size): + config_dict = { + "train_micro_batch_size_per_gpu": 2, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "zero_optimization": { + "stage": zero_stage + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + } + }, + "fp16": { + "enabled": True, + "initial_scale_power": 8 + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 4 + + model = SimpleModel(hidden_dim=hidden_dim) + + @distributed_test(world_size=world_size) + def _test_partition_nccl_alignment(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + # get nccl all-gather send buffers alignment factor + nccl_start_alignment_factor = model.optimizer.nccl_start_alignment_factor + + parallel_partitioned_bit16_groups = model.optimizer.parallel_partitioned_bit16_groups if zero_stage == 2 else model.optimizer.parallel_partitioned_fp16_groups + for data_parallel_partitions in parallel_partitioned_bit16_groups: + for partition_id, partitioned_data in enumerate(data_parallel_partitions): + # verify that data partition start locations are 4-byte aligned + assert (partitioned_data.data_ptr() % + (2 * nccl_start_alignment_factor) == 0) + + _test_partition_nccl_alignment(args=args, model=model, hidden_dim=hidden_dim) + + +def _ds_initialize_for_param_partitioning_testing(model: Module, + cfg: dict) -> DeepSpeedEngine: + ds_engine, _, _, _ = deepspeed.initialize( + config=cfg, + model=model, + model_parameters=model.parameters() + ) + + return ds_engine + + +def _assert_partition_status(model: Module, + valid_statuses: Set[ZeroParamStatus]) -> None: + for _, param in model.named_parameters(): + assert param.ds_status in valid_statuses, param.ds_summary() + + +def _assert_fully_available(model: Module) -> None: + for _, param in model.named_parameters(): + assert param.ds_status == ZeroParamStatus.AVAILABLE + + +class EltwiseMultiplicationModule(Module): + def __init__(self, weight: Parameter) -> None: + super().__init__() + self.weight = weight + + def forward(self, x: Tensor) -> Tensor: + _assert_fully_available(self) + result = self.weight * x + + return result + + +class EltwiseMultiplicationTestNetwork(Module): + """used for testing purposes""" + def __init__( + self, + weight1: Parameter, + weight2: Parameter, + weight3: Parameter, + ) -> None: + super().__init__() + self.__layer1 = EltwiseMultiplicationModule(weight1) + self.__layer2 = EltwiseMultiplicationModule(weight2) + self.__layer3 = EltwiseMultiplicationModule(weight3) + + self.loss = L1Loss(reduction="none") + + def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + layerwise_expected_states = { + ZeroParamStatus.INFLIGHT if prefetching else ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.AVAILABLE, + } + + _assert_partition_status(self.__layer1, layerwise_expected_states) + hidden1 = self.__layer1(x) + _assert_partition_status(self.__layer1, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer2, layerwise_expected_states) + hidden2 = self.__layer2(hidden1) + _assert_partition_status(self.__layer2, {ZeroParamStatus.NOT_AVAILABLE}) + + _assert_partition_status(self.__layer3, layerwise_expected_states) + y_hat = self.__layer3(hidden2) + _assert_partition_status(self.__layer3, + { + ZeroParamStatus.AVAILABLE + if prefetching else ZeroParamStatus.NOT_AVAILABLE + }) + + loss = self.loss(y_hat, y) + + _assert_partition_status( + self, + { + ZeroParamStatus.NOT_AVAILABLE, + ZeroParamStatus.INFLIGHT, + ZeroParamStatus.AVAILABLE + } if prefetching else {ZeroParamStatus.NOT_AVAILABLE}) + + return { + "hidden1": hidden1, + "hidden2": hidden2, + "y_hat": y_hat, + "loss": loss, + } + + +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("fp16_enabled", [True, False]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True, False]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base( + param_persistence_threshold: int, + fp16_enabled: bool, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": fp16_enabled, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals, dtype: torch.dtype = None) -> Tensor: + return torch.as_tensor(vals, + dtype=dtype + or (torch.float16 if fp16_enabled else torch.float32), + device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.float16 if fp16_enabled else torch.float32, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + assert dloss_wrt_layer1.dtype == torch.float + assert dloss_wrt_layer2.dtype == torch.float + assert dloss_wrt_layer3.dtype == torch.float + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8, + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor([3 * 2 * 1] * 8, + torch.float)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor(([8] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor(([6 * 2] * 7) + [0], + torch.float)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * create_tensor(([6 * 4 * 1] * 7) + [0], + torch.float)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + # taking an optimizer step invalidates all parameters, make sure everything + # has been partitioned afterwards + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + assert not math.isclose(ds_engine.optimizer._global_grad_norm, 0.0) + + _test_zero3_param_partitioning() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [8100]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_large_param(world_sz: int, + param_sz: int, + init_context_manager: bool) -> None: + class LargeParamModel(Module): + def __init__(self): + super().__init__() + self.param = Parameter(torch.zeros((param_sz, ), dtype=torch.float32)) + + # only do weight initialization on root rank to + # make sure we are broadcasting correctly from rank 0 + if dist.get_rank() == 0: + partition_sz = math.ceil(self.param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + self.param[offset:offset + partition_sz].fill_(rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + return x * self.param + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_config = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + with deepspeed.zero.Init(mem_efficient_linear=False, + enabled=init_context_manager): + model = LargeParamModel() + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_config) + + for train_iter in range(3): # test multiple iterations to cover prefetching + activation: Tensor = ds_engine( + torch.ones(param_sz, + dtype=torch.float16, + device=ds_engine.device)) + + partition_sz = math.ceil(param_sz / world_sz) + for rank_idx, start_idx in enumerate(range(0, param_sz, partition_sz)): + activation_from_partition = activation[start_idx:start_idx + + partition_sz] + assert torch.allclose( + activation_from_partition, + torch.full_like(activation_from_partition, + rank_idx)) + + ds_engine.backward(activation.sum()) + ds_engine.allreduce_gradients() + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradient, = avgd_gradients[0] + expected_weight_gradient = (train_iter + 1) * torch.full_like( + weight_gradient, + 1) + + assert torch.allclose(weight_gradient, expected_weight_gradient) + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +@pytest.mark.parametrize("param_sz", [100, 1_000, 10_000]) +@pytest.mark.parametrize("n_layers", [100, 1_000]) +@pytest.mark.parametrize("init_context_manager", [True, False]) +def test_zero3_param_partitioning_many_params(world_sz: int, + param_sz: int, + n_layers: int, + init_context_manager: bool) -> None: + class ManyParamModel(Module): + def __init__(self) -> None: + super().__init__() + + self.modulelist = ModuleList( + EltwiseMultiplicationModule( + weight=Parameter(torch.empty((param_sz, + ), + dtype=torch.float32))) + for _ in range(n_layers)) + + for layer_num, module in enumerate(self.modulelist): + if dist.get_rank() == 0: + param: Parameter = module.weight + partition_sz = math.ceil(param.numel() / dist.get_world_size()) + offset = 0 + for rank in range(dist.get_world_size()): + with torch.no_grad(): + param[offset:offset + partition_sz].fill_(2 * layer_num * + rank) + offset += partition_sz + + def forward(self, x: Tensor) -> Tensor: + activations = [] + + for module in self.modulelist: + x = module(x) + activations.append(x) + + return activations + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=init_context_manager): + model = ManyParamModel() + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, ds_cfg) + + for _ in range(3): # test multiple iterations to cover prefetching + activations: List[Tensor] = ds_engine( + torch.ones((param_sz, + ), + dtype=torch.float16, + device=ds_engine.device)) + assert len(activations) == n_layers + + partition_sz = math.ceil(param_sz / world_sz) + expected_activations = torch.empty(param_sz, + dtype=torch.float16, + device=ds_engine.device) + for start_idx in range(0, param_sz, partition_sz): + expected_activations[start_idx:start_idx + + partition_sz] = dist.get_rank() + + for layer_num, activation in enumerate(activations): + expected_activations *= 2 * layer_num + assert torch.allclose(activation, expected_activations) + + # TODO. finish writing this test + ds_engine.backward(activations[-1].sum()) + + avgd_gradients = ds_engine.optimizer.averaged_gradients + assert set(avgd_gradients.keys()) == {0}, "should only have one parameter group" + weight_gradients: List[Tensor] = avgd_gradients[0] + + for layer_num, activation in enumerate(weight_gradients): + pass + + _distributed_test() + + +@pytest.mark.parametrize("world_sz", [1, 2, 4]) +def test_zero3_init_for_parent_weight_initialization(world_sz): + class ModelWhereParentInitializesChildWeights(Module): + def __init__(self) -> None: + super().__init__() + + self.linear = Linear(12, 1) + + self.apply(self.__init_weights) + + def __init_weights(self, module): + if isinstance(module, Linear): + with torch.no_grad(): + module.weight.fill_(1 + dist.get_rank()) + + @distributed_test(world_size=[world_sz]) + def _distributed_test(): + ds_cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "contiguous_gradients": True, + "overlap_comm": True, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "fp16": { + "enabled": True, + "loss_scale": 1., + } + } + + with deepspeed.zero.Init(config=ds_cfg, + mem_efficient_linear=False, + enabled=True): + model = ModelWhereParentInitializesChildWeights() + + assert model.linear.weight.ds_tensor.numel() == math.ceil(12 / world_sz) + assert torch.allclose(model.linear.weight.ds_tensor, + torch.full_like(model.linear.weight.ds_tensor, + 1)) + + _distributed_test() + + +@pytest.mark.skip( + reason="depends on upgraded pytorch and nccl that isnt always available") +@pytest.mark.parametrize("param_persistence_threshold", [0, 10]) +@pytest.mark.parametrize("contiguous_gradients", [True, False]) +@pytest.mark.parametrize("offload_optimizer", [True, False]) +@pytest.mark.parametrize("zero_grad", [True]) +@pytest.mark.parametrize("iteration", list(range(1))) +def test_zero3_param_partitioning_base_bf16( + param_persistence_threshold: int, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, +) -> None: + @distributed_test(world_size=[2]) + def _test_zero3_param_partitioning(): + if offload_optimizer and not contiguous_gradients: + return + + m = 3 + n = 5 + weights = [Parameter(torch.zeros((m, n), dtype=torch.float32)) for _ in range(3)] + model = EltwiseMultiplicationTestNetwork(*weights) + + cfg = { + "train_micro_batch_size_per_gpu": 1, + "zero_optimization": { + "stage": 3, + "stage3_max_reuse_distance": 0, + "stage3_param_persistence_threshold": param_persistence_threshold, + "contiguous_gradients": contiguous_gradients, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1. + } + }, + "bfloat16": { + "enabled": True, + "loss_scale": 1., + } + } + + if offload_optimizer: + cfg["zero_optimization"]["offload_optimizer"] = { + "device": "cpu", + "pin_memory": True, + } + + ds_engine = _ds_initialize_for_param_partitioning_testing(model, cfg) + for i, weight in enumerate(weights): + weight.ds_tensor.data = torch.full_like(weight.ds_tensor.data, + (i + 1) * (1 + dist.get_rank())) + + def create_tensor(vals): + return torch.as_tensor(vals, dtype=torch.bfloat16, device=ds_engine.device) + + expected_hidden1 = create_tensor([ + [1, + 1, + 1, + 1, + 1], + [1, + 1, + 1, + 2, + 2], + [2, + 2, + 2, + 2, + 2], + ]) + expected_hidden2 = create_tensor([ + [2, + 2, + 2, + 2, + 2], + [2, + 2, + 2, + 8, + 8], + [8, + 8, + 8, + 8, + 8], + ]) + expected_yhat = create_tensor([[6, + 6, + 6, + 6, + 6], + [6, + 6, + 6, + 48, + 48], + [48, + 48, + 48, + 48, + 48]]) + expected_loss = create_tensor([ + [5, + 5, + 5, + 5, + 5], + [5, + 5, + 5, + 47, + 47], + [47, + 47, + 47, + 47, + 47], + ]) + + for train_iter in range(3): + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + activations = ds_engine( + x=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + y=torch.ones((m, + n), + dtype=torch.bfloat16, + device=ds_engine.device), + prefetching=train_iter > 0, + ) + assert torch.allclose(activations["hidden1"], expected_hidden1) + assert torch.allclose(activations["hidden2"], expected_hidden2) + assert torch.allclose(activations["y_hat"], expected_yhat) + assert torch.allclose(activations["loss"], expected_loss) + + ds_engine.backward(activations["loss"].sum()) + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + # check the gradients + grad_partitions = ds_engine.optimizer.get_fp32_grad_partitions() + assert set(grad_partitions.keys()) == {0}, f"should have one parameter group but got {len(grad_partitions)}" + assert set(grad_partitions[0].keys()) == {0, 1, 2} + dloss_wrt_layer1 = grad_partitions[0][0] + dloss_wrt_layer2 = grad_partitions[0][1] + dloss_wrt_layer3 = grad_partitions[0][2] + + # layer1 = [..., 1, 2, ...] + # layer2 = [..., 2, 4, ...] + # layer3 = [..., 3, 6, ...] + # dloss_wrt_layer3 = hidden2 + # dloss_wrt_layer2 = layer3 * hidden1 + # dloss_wrt_layer1 = layer3 * layer2 * x + + expected_grad_dtype = torch.float32 if offload_optimizer else torch.bfloat16 + + grad_multiplier = 1 if zero_grad else (train_iter + 1) + if dist.get_rank() == 0: + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * create_tensor([2] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * create_tensor([3 * 1] * 8).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor([3 * 2 * 1] * 8).to(expected_grad_dtype)) + elif dist.get_rank() == 1: + # parameters dont split evenly across ranks so rank 1 has a zero-padded + # partition + assert torch.allclose( + dloss_wrt_layer3.cuda(), + grad_multiplier * + create_tensor(([8] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer2.cuda(), + grad_multiplier * + create_tensor(([6 * 2] * 7) + [0]).to(expected_grad_dtype)) + assert torch.allclose( + dloss_wrt_layer1.cuda(), + grad_multiplier * + create_tensor(([6 * 4 * 1] * 7) + [0]).to(expected_grad_dtype)) + else: + raise RuntimeError("test has world size of two") + + if zero_grad: + ds_engine.optimizer.zero_grad() + + # TODO. add testing for this - for now we just call it to make sure it + # doesnt throw + ds_engine.optimizer.step() + _assert_partition_status(ds_engine, {ZeroParamStatus.NOT_AVAILABLE}) + + _test_zero3_param_partitioning() From 2b5f6ea2bba91ea1a5e5b9d2febedf041926d810 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Tue, 23 Nov 2021 18:32:05 +0000 Subject: [PATCH 39/59] Fix merge issues --- tests/unit/test_autotuning.py | 2 +- tests/unit/test_ds_initialize.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_autotuning.py b/tests/unit/test_autotuning.py index 96617f8ff53b..2a7898b8af0a 100644 --- a/tests/unit/test_autotuning.py +++ b/tests/unit/test_autotuning.py @@ -1,7 +1,7 @@ import os import pytest import torch -from simple_model import create_config_from_dict +from .simple_model import create_config_from_dict from deepspeed.launcher import runner as dsrun from deepspeed.autotuning.autotuner import Autotuner from deepspeed.autotuning.scheduler import ResourceManager diff --git a/tests/unit/test_ds_initialize.py b/tests/unit/test_ds_initialize.py index 04e6545b887e..a9756af62200 100644 --- a/tests/unit/test_ds_initialize.py +++ b/tests/unit/test_ds_initialize.py @@ -4,9 +4,9 @@ from torch.optim import Optimizer, Adam, AdamW from torch.optim.lr_scheduler import _LRScheduler, LambdaLR -from simple_model import args_from_dict, SimpleModel, random_dataloader -from common import distributed_test -from util import required_torch_version +from .simple_model import args_from_dict, SimpleModel, random_dataloader +from .common import distributed_test +from .util import required_torch_version import deepspeed from deepspeed.ops.adam import FusedAdam From 912e6f043d2b7a98e6319f21dc29c21088d6f8fe Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 29 Nov 2021 10:16:49 -0800 Subject: [PATCH 40/59] switch to CRLF --- deepspeed/runtime/zero/stage3.py | 6702 +++++++++++++++--------------- 1 file changed, 3351 insertions(+), 3351 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d8f79f69a577..2134020d52c0 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1,3351 +1,3351 @@ -""" -"Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -""" - -import gc -from dataclasses import dataclass -import functools -import os -import collections -from collections import OrderedDict, UserDict -import itertools -from typing import Deque, Dict, Iterable, Set, Tuple -import torch -from torch.cuda import Event, Stream -from torch.nn import Module, Parameter -import torch.distributed as dist -import math -from torch._six import inf -from torch.nn import Module -from torch.nn.parameter import Parameter - -from deepspeed.utils.logging import logger -from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim -from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partition_parameters import _init_external_params -from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS -from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.zero.offload_constants import * -from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus -from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper -from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper - -# Toggle this to true to enable correctness test -# with gradient partitioning and without -pg_correctness_test = False - -FWD_MODULE_STACK = list() - - -def print_rank_0(message, debug=False, force=False): - rank = torch.distributed.get_rank() - if rank == 0 and (debug or force): - print(message) - # other variations - # - print for all ranks w/o interleaving - # printflock(f"[{rank}] {message}") - # - print to log file per rank - # log_rank_file(rank, message) - - -def input(msg): - return - - -def isclose(a, b, rtol=1e-09, atol=0.0): - return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) - - -def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 - return x * y // gcd(x, y) - - -def debug_rank0(message: str) -> None: - if dist.get_rank() == 0: - logger.debug(message) - - -def get_cuda_mem_allocated_str() -> str: - # this is really slow. when enabled the python process becomes slow - # to the point where it can't keep the GPU fed with work, so only enable - # for memory debugging. - # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" - return "xGB" - - -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - -@instrument_w_nvtx -def get_all_parameters(sub_module, recurse=False): - return itertools.chain(sub_module.named_parameters(recurse=recurse), - sub_module.ds_external_parameters()) - - -def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: - return map(lambda pair: pair[1], get_all_parameters(module, recurse)) - - -#apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, - functional, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -#for each tensor in outputs run the forward_function and register backward_function as hook -def _apply_forward_and_backward_to_tensors_only(module, - forward_function, - backward_function, - outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_forward_and_backward_to_tensors_only( - module, - forward_function, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - forward_function(outputs) - if outputs.requires_grad: - outputs.register_hook(backward_function) - return outputs - else: - return outputs - - -class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module, *args, **kwargs): - """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. - - Args: - parent_module (``collections.OrderedDict``): the collection to replace - """ - - super().__init__(*args, **kwargs) - self._parent_module = parent_module - self._in_forward = False - - def __getitem__(self, key): - param = super().__getitem__(key) - - # Params can be registered as None (e.g., bias) - if param is None: - return param - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: - print_rank_0(f'Registering external parameter from getter {key}', - force=False) - register_external_parameter(FWD_MODULE_STACK[-1], param) - param.all_gather() - - return param - - -def _inject_parameters(module, cls): - for module in module.modules(): - if cls == ZeROOrderedDict: - new_param = cls(parent_module=module) - else: - new_param = cls() - - for key, param in module._parameters.items(): - new_param[key] = param - module._parameters = new_param - - -class PartitionedParameterCoordinator: - """Handles partitioning and gathering of parameters.""" - class __InflightParamRegistry(UserDict): - """registry for parameters in flight""" - def __setitem__(self, - param: Parameter, - handle: AllGatherCoalescedHandle) -> None: - if param in self.data: - raise RuntimeError(f"{param.ds_summary()} already in registry") - if param.ds_status != ZeroParamStatus.INFLIGHT: - raise RuntimeError( - f"attempted to add non-inflight parameter to registry {param.ds_summary()}" - ) - self.data[param] = handle - - @dataclass - class __ParamInTrace: - param: Parameter - step_id_last_used_at: int - - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: Stream, - prefetch_nvme: bool = False, - ) -> None: - # mapping of param -> handle for each param that is currently in flight - self.__inflight_param_registry = __class__.__InflightParamRegistry() - # keeps track of the number of submodules invoked so far. - self.__step_id: int = 0 - # whether or not we have completed a trace of the entire network. This should - # always be true after the first forward pass + backward pass. - self.trace_complete: bool = False - # sequence of submodules/parameters in forward pass + backward pass - self.__submodule_order: Iterable[Module] = [] - self.__param_order: Iterable[__class__.__ParamInTrace] = [] - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - # number of available params, and max number of available params - self.__n_available_params: int = 0 - self.__max_n_available_params: int = max_available_parameters_in_numel - # max distance between two use of the module beyond which module is released - self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel - # queue for parameters to fetch. parameters will be popped off the left - # side of the dequeue as they are fetched - self.__param_queue: Deque[__class__.__ParamInTrace] = None - self.__prefetch_bucket_sz: int = prefetch_bucket_sz - self.__prefetch_nvme: bool = prefetch_nvme - self.hierarchy: int = 0 - - # stream that will be used for allgather operations - self.__allgather_stream: Stream = allgather_stream - - # limit the number of fetch events that can be queued at once - # otherwise, what happens is memory is allocated by the host thread at the - # time of the call, but not used until later by the asynchronous cuda stream. - # allowing an infinite number of these to queue up causes a lot of memory - # pressure that then becomes detrimental to performance. - # this is a much less elegant way of fixing this vs something like using - # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now - # because ideally in the future its replaced by an async allocation - # mechanism which doesnt require any configuration by the user. - self.__ongoing_fetch_events: Deque[Event] = collections.deque() - self.__max_ongoing_fetch_events: int = 2 - - """Tracing and Tracking - TODO. consider performing trace before initializing PartitionedParameterCoordinator - and passing trace results into constructor. This way all the code in here can - just assume that the trace is complete and the results can be entirely - immutable. - - Bookkeeping operations used to track where we are in the forward/backward pass - """ - - def record_trace(self, sub_module: Module) -> None: - """adds sub module to trace""" - if self.trace_complete: - raise RuntimeError( - "attemted to record trace when trace was already complete") - - self.__submodule_order.append(sub_module) - for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): - self.__param_order.append( - __class__.__ParamInTrace(param=param, - step_id_last_used_at=self.__step_id)) - - def reset_step(self) -> None: - """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError( - f"still have inflight params " - f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") - - if not self.trace_complete: - # make sure that recorded parameter and submodule orders are - # identical across ranks - assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) - assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) - assert_ints_same_as_other_ranks( - [p.step_id_last_used_at for p in self.__param_order]) - - self.__submodule_order = tuple(self.__submodule_order) # freeze - self.__param_order = tuple(self.__param_order) # freeze - self.trace_complete = True - print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", - force=True) - - self.__param_queue = collections.deque(self.__param_order) # reset fetch queue - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - self.__step_id = 0 - self.__n_available_params = 0 - - """Fetch and Release - Fetching, prefetching, and releasing parameters - """ - - @instrument_w_nvtx - @torch.no_grad() - def fetch_sub_module(self, current_submodule: Module) -> None: - """This method does the following (in order): - 1. kick off fetch for parameters in immediately required sub module - 2. kick off fetch for next few parameters we will need later (prefetch) - 3. block on parameters in immediately required sub module - """ - debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " - + str({ - "avail": f"{self.__n_available_params:.1e}", - "queue_sz": f"{len(self.__param_queue or [])}", - "inflight": [p.ds_id for p in self.__inflight_param_registry], - "allocated": get_cuda_mem_allocated_str() - })) - - params_to_fetch = frozenset(iter_params(current_submodule)) - - # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") - self.__all_gather_params(params_to_fetch) - - # wait for parameters in the immediately needed submodule to become available - for param in iter_params(current_submodule): - param.ds_active_sub_modules.add(current_submodule.id) - debug_rank0(f"-wait: {param.ds_summary()}") - if param in self.__inflight_param_registry: - with torch.cuda.stream(self.__allgather_stream): - while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ - 0].query(): - self.__ongoing_fetch_events.popleft() - if len(self.__ongoing_fetch_events - ) > self.__max_ongoing_fetch_events: - self.__ongoing_fetch_events.popleft().synchronize() - - self.__inflight_param_registry.pop(param).wait() - - event = Event() - event.record() - self.__ongoing_fetch_events.append(event) - - assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - torch.cuda.current_stream().wait_stream(self.__allgather_stream) - - # kick off parameter prefetches for upcoming modules - # don't prefetch if we dont have a completed model trace, or if we aren't - # training (throws off the tracing and don't want to prefetch modules for bwd) - if self.trace_complete and current_submodule.training: - # go through the parameters we need for the current module and pop them - # off the fetch queue so that they aren't prefetched later. - # if params have already been popped off the fetch queue by earlier - # prefetches we won't look for them here - discarded_from_prefetch_queue = set() - params_not_already_fetched = set( - filter( - lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. - __step_id, - params_to_fetch)) - while self.__param_queue and len(discarded_from_prefetch_queue) < len( - params_not_already_fetched): - param_in_trace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - discarded_from_prefetch_queue.add(param_in_trace.param) - if discarded_from_prefetch_queue != params_not_already_fetched: - raise RuntimeError( - f"tracing error at step {self.__step_id}: " - f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " - f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." - ) - - # kick off all gather for params in the next few submodules (prefetch) - max_params_to_prefetch = min( - self.__max_n_available_params - self.__n_available_params, - self.__prefetch_bucket_sz) - params_to_prefetch = set() - numel_prefetching = 0 - while self.__param_queue and numel_prefetching < max_params_to_prefetch: - param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: - params_to_prefetch.add(param_in_trace.param) - numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) - - if self.__prefetch_nvme: - self.__prefetch_nvme_param_partitions() - - self.__step_id += 1 - - @instrument_w_nvtx - @torch.no_grad() - def release_sub_module(self, submodule: Module) -> None: - """release the parameters of a sub module, assuming they meet conditions to - be released.""" - params_to_release = (self.__params_to_release(submodule, - self.__step_id) - if self.trace_complete else set( - p.ds_id for p in iter_params(submodule))) - - for param in iter_params(submodule): - param.ds_active_sub_modules.discard(submodule.id) - if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) - - @instrument_w_nvtx - @torch.no_grad() - def release_and_reset_all(self) -> None: - """release all module parameters""" - for param in map(lambda p: p.param, self.__param_order): - if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") - - # TODO. make this throw if if there are still active submodules. currently - # there's a hook execution issue - param.ds_active_sub_modules.clear() - self.__release_param(param) - - for param_in_trace in self.__param_order: - if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError( - f"{param_in_trace.param.ds_summary()} expected to be released") - - @instrument_w_nvtx - def __all_gather_params(self, params: Set[Parameter]) -> None: - """for each partitioned parameter, kick off an async allgather and store - the work handle for the in flight parameters.""" - partitioned_params = [] - for param in params: - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - partitioned_params.append(param) - self.__n_available_params += param.ds_numel - - if partitioned_params: - with torch.cuda.stream(self.__allgather_stream): - handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - - for param in partitioned_params: - assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() - self.__inflight_param_registry[param] = handle - - @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: - if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: - debug_rank0(f"-release: {param.ds_summary()}") - param.partition() - self.__n_available_params -= param.ds_numel - - @instrument_w_nvtx - @functools.lru_cache(maxsize=None) - def __params_to_release(self, - submodule_to_release: Module, - step_id: int) -> Set[int]: - if not self.trace_complete: - raise RuntimeError("expected trace to be complete") - - params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) - if not p.ds_persist) - - # examine all modules within `max_reuse_dist_in_numel` of the current step, - # if we see any of the candidate parameters to be released reoccur while - # doing this, remove them from the set of parameters to release. - params_traversed = 0 - for module in self.__submodule_order[step_id:]: - if params_traversed > self.__max_reuse_dist_in_numel: - break - for param in iter_params(module): - params_to_release.discard(param.ds_id) - params_traversed += param.ds_numel - - return params_to_release - - @instrument_w_nvtx - def __prefetch_nvme_param_partitions(self) -> None: - """swap in parameter partitions from nvme for those parameters that will be used - after the ones that are already being prefetched into full parameters - """ - if not self.trace_complete: - return - - numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) - - numel_considered = 0 - swap_in_params = [] - for param_in_trace in self.__param_queue: - param = param_in_trace.param - if param.nvme_swapper is None: - continue - if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= - param.nvme_swapper.available_swap_in_buffers()): - break - if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_in_params.append(param) - numel_considered += param.ds_numel - - if swap_in_params: - swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - - -class PreBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - -class FP16_DeepSpeedZeroOptimizer_Stage3(object): - """ - DeepSpeedZeroOptimizer designed to reduce the memory footprint - required for training large deep learning models. - - For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models - https://arxiv.org/abs/1910.02054 - - For usage examples, refer to TODO: DeepSpeed Tutorial - - """ - def __init__(self, - module, - init_optimizer, - timers, - ds_config, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True, - contiguous_gradients=True, - reduce_bucket_size=500000000, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - dp_process_group=None, - reduce_scatter=True, - overlap_comm=False, - offload_optimizer_config=None, - offload_param_config=None, - sub_group_size=1000000000000, - mpu=None, - clip_grad=0.0, - allreduce_always_fp32=False, - postscale_gradients=True, - gradient_predivide_factor=1.0, - gradient_accumulation_steps=1, - elastic_checkpoint=False, - aio_config=None): - - see_memory_usage("Stage 3 initialize beginning", force=False) - - if dist.get_rank() == 0: - logger.info(f"initialized {__class__.__name__} with args: {locals()}") - logger.info(f"Reduce bucket size {reduce_bucket_size}") - logger.info(f"Allgather bucket size {prefetch_bucket_size}") - # The fused optimizer does all the work. We need this layer for two reason: - # 1. maintain same user API from apex.fp16_utils - # 2. keep common stuff here in case we need to add ne552w fused optimizer later - - # differences from apex.fp16_utils: - # - assume all model params in fp16 - # - assume all params requires grad - # - flat by groups, not keeping state. TODO: remove state explicitly? - # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - self.optimizer = init_optimizer - - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten - self.dtype = self.optimizer.param_groups[0]['params'][0].dtype - self._global_grad_norm = 0. - - self.optimizer_swapper = None - self.swap_optimizer = False - - self.offload_optimizer = False - self.offload_optimizer_pin_memory = False - self.offload_optimizer_fast_init = False - self.offload_param = False - self.offload_param_pin_memory = False - self.params_in_nvme_and_cpu = False - self.max_params_in_cpu = 0 - - self._configure_offloading(offload_optimizer_config, offload_param_config) - - self._convert_to_zero_parameters(ds_config, module, mpu) - - for m in module.modules(): - _init_external_params(m) - - self.module = module - self.elastic_checkpoint = elastic_checkpoint - - # Replace ._parameters with a new class to enable auto-registration of - # external parameters - _inject_parameters(module, ZeROOrderedDict) - - self.__inf_or_nan_tracker: Tensor = torch.zeros( - 1, - dtype=torch.bool, - device=torch.cuda.current_device(), - requires_grad=False) - - self.deepspeed_adam_offload = (self.offload_optimizer - and type(init_optimizer) == DeepSpeedCPUAdam) - - self.device = torch.cuda.current_device( - ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE - ### streams used for overlapping computation with communication - self.__allgather_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - self.__reduce_and_partition_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - - ############################################################################ - - see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - self.param_coordinator = PartitionedParameterCoordinator( - prefetch_bucket_sz=int(prefetch_bucket_size), - max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters), - allgather_stream=self.__allgather_stream, - prefetch_nvme=self.params_in_nvme_and_cpu, - ) - see_memory_usage("After Partitioned Parameter Coordinator", force=False) - - self.__n_caching_allocator_flushes = 0 - - #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) - #-------------Stage 3 Setup-------------------# - # parameters smaller than the threshold will be collectively gathered at the - # end of the optimizer step and will be kept till the end of the backward pass - # TODO maybe worth just replicating these parameters and doing all reduce for them - self.persistence_threshold = int(param_persistence_threshold) - - self.persistent_parameters = self.persistent_parameters() - - self.setup_zero_stage3_hooks() - - #resetting ds_tensor just in case parameters have been changed after initialization - #example .half() or .to() - #self.reset_ds_tensor() - #---------------------------------------------# - - self.timers = timers - - self.dp_process_group = dp_process_group - - self.partition_count = dist.get_world_size(group=self.dp_process_group) - - if mpu is None: - self.model_parallel_group = None - self.model_parallel_rank = 0 - else: - self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() - - self.overflow = False - self.clip_grad = clip_grad - self.allreduce_always_fp32 = allreduce_always_fp32 - self.gradient_predivide_factor = gradient_predivide_factor - self.postscale_gradients = postscale_gradients - self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 - - # Holds the mode parameter - # The param.data may not hold any meaningful data - # when param's status is NOT_AVAILABLE or IN_FLGHT - self.fp16_groups = [] - - # Hold partitioned parameters - self.fp16_partitioned_groups = [] - - # Holds a fused and flattened copy of the parameters - self.fp16_partitioned_groups_flat = [] - self.fp16_partitioned_groups_flat_numel = [] - - #defragmented pinned memory - self.param_groups_fp16_flat_cpu_memory = [] - - #a single 32-bit partition of the parallel partitioned parameters - #that this process will update - self.fp32_partitioned_groups_flat = [] - self.next_swappable_fp32_partitioned_groups = [] - - # number of elements per partition in each group - self.partition_size = [] - - self.all_reduce_print = False - - self.prefetch_elements = int(prefetch_bucket_size) - - # padding on each partition for alignment purposes - self.groups_padding = [] - - self.sub_group_size = sub_group_size - - self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=False) - self._create_fp16_partitions_with_defragmentation() - num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) - see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=False) - - # Optimizer tensor swapping - if self.swap_optimizer: - self._configure_tensor_swapping(offload_optimizer_config, aio_config) - - see_memory_usage("Before creating fp32 partitions", force=False) - if not isinstance(self.optimizer, DummyOptim): - self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=False) - dist.barrier() - - # To support pipelined optimizer swapping - if not isinstance(init_optimizer, DummyOptim): - self._create_next_swappable_fp32_groups() - - see_memory_usage("Before initializing optimizer states", force=False) - if not isinstance(init_optimizer, DummyOptim): - self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=False) - dist.barrier() - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.reduce_bucket_size = int(reduce_bucket_size) - - # IPG - if contiguous_gradients: - self.__ipg_bucket_flat_buffer: Tensor = torch.empty( - int(reduce_bucket_size), - dtype=self.dtype, - device=torch.cuda.current_device()) - - self.__param_id_to_grad_partition: Dict[int, Tensor] = {} - - all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - - grad_partitions_flat_buffer: Tensor = torch.zeros( - sum(p.ds_tensor.ds_numel for p in all_params), - dtype=self.dtype, - device=self.device, - pin_memory=self.offload_optimizer_pin_memory) - - offset = 0 - for param in all_params: - self.__param_id_to_grad_partition[ - param.ds_id] = grad_partitions_flat_buffer.narrow( - 0, - offset, - param.ds_tensor.numel()) - offset += param.ds_tensor.numel() - - self.__params_in_ipg_bucket: List[Parameter] = [] - self.is_gradient_accumulation_boundary: bool = True - - self.__param_reduce_events: Deque[Event] = collections.deque() - self.__max_param_reduce_events: int = 2 - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.param_dict = {} - - # map between param_id and bool to specify if a param is in this partition - self.is_param_in_current_partition = {} - - self.contiguous_gradients = contiguous_gradients - self.extra_large_param_to_reduce = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.params_already_reduced = [] - self.is_gradient_accumulation_boundary = True - self._release_ipg_buffers() - self.previous_reduced_grads = None - - # simplified param id - self.param_id = {} - - count = 0 - for i, params_group in enumerate(self.fp16_groups): - for param in params_group: - unique_id = id(param) - self.param_id[unique_id] = count - self.param_dict[count] = param - self.params_already_reduced.append(False) - count = count + 1 - - #Largest partitioned param - largest_partitioned_param_numel = max([ - max([tensor.numel() for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups - ]) - print_rank_0( - f'Largest partitioned param numel = {largest_partitioned_param_numel}', - force=False) - - see_memory_usage(f"Before Set Grad positions", force=False) - - self.grad_position = {} - self.set_grad_positions() - see_memory_usage(f"Before CPU Offload initialization", force=False) - - self.grads_in_partition = None - - if self.offload_optimizer: - self.norm_for_param_grads = {} - self.local_overflow = False - - see_memory_usage(f"After CPU Offload initialization", force=False) - - # stores if a partition has been reduced in this step - self.is_partition_reduced = {} - - # stores if a grad in a partition has been computed or not - self.is_grad_computed = {} - - # will store the averaged gradients required by this paritition - self.averaged_gradients = {} - - #creates backward hooks for gradient partitioning - self.create_reduce_and_remove_grad_hooks() - - #exit(0) - - # we may have a way of fusing dynamic scale. Do not support for now - if self.dtype == torch.float or not dynamic_loss_scale: - loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale - - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=loss_scale_value) - cur_iter = 0 - else: - if dynamic_loss_args is None: - self.loss_scaler = DynamicLossScaler() - else: - self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) - - self.dynamic_loss_scale = True - - self.debug_fp16_grads = [{} for _ in self.fp16_groups] - - if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=False) - - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - - gc.collect() - torch.cuda.empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - - def _configure_offloading(self, offload_optimizer_config, offload_param_config): - ###################### offload optimizer setup ################################## - if offload_optimizer_config is not None: - self.offload_optimizer = True - self.offload_optimizer_pin_memory = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIN_MEMORY] - self.swap_optimizer = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE - self.offload_optimizer_fast_init = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_FAST_INIT] - - ###################### offload param setup ################################## - if offload_param_config is not None: - if not isinstance(self.optimizer, DummyOptim): - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" - self.offload_param = True - self.offload_param_pin_memory = offload_param_config[ - OFFLOAD_PARAM_PIN_MEMORY] - self.params_in_nvme_and_cpu = offload_param_config[ - OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE - self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] - print_rank_0( - f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", - force=False) - - def _convert_to_zero_parameters(self, ds_config, module, mpu): - non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] - if non_zero_params: - zero_params = [p for p in module.parameters() if is_zero_param(p)] - if zero_params: - zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) - else: - group = None - if mpu: - group = mpu.get_data_parallel_group() - - if self.params_in_nvme_and_cpu: - remote_device = OFFLOAD_NVME_DEVICE - elif self.offload_param: - remote_device = OFFLOAD_CPU_DEVICE - else: - remote_device = None - - Init(module=module, - data_parallel_group=group, - dtype=self.dtype, - config_dict_or_path=ds_config, - remote_device=remote_device, - pin_memory=self.offload_param_pin_memory, - mpu=mpu) - - def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): - nvme_swap_folder = os.path.join( - offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], - 'zero_stage_3') - os.makedirs(nvme_swap_folder, exist_ok=True) - if torch.distributed.get_rank() == 0: - logger.info(f'Tensor Swapping: Adding optimizer tensors') - - swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper - - self.optimizer_swapper = swapper_type( - swap_config=offload_optimizer_config, - aio_config=aio_config, - base_folder=nvme_swap_folder, - optimizer=self.optimizer, - largest_numel=max(self.fp16_partitioned_groups_flat_numel), - device=self.device, - dtype=torch.float32, - timers=self.timers) - - @property - def elements_in_ipg_bucket(self): - return sum(p.ds_numel for p in self.__params_in_ipg_bucket) - - def _create_fp16_partitions(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param group {i}", force=False) - - if not self.offload_param: - see_memory_usage(f"Before moving param group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param group {i} to CPU", force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param group {i} to GPU", - force=False) - else: - #Without the detach, seems like the flattening becomes part of the - #model graph causing errors downstream - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size( - group=self.dp_process_group)).detach().pin_memory()) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #set model fp16 weight to slices of flattened buffer - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): - partitioned_param.data = q.data - - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): - '''If flat buffer is None then the parameters in the param_list are - not copied to the flat buffer. This is because they excede the number of max_params_in_cpu - Some of these parameters may aready be in CPU in unflattened buffers - or they maybe in GPU, or they maybe in NVME. If they are in NVME, then - they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are - needed during training.''' - if flat_buffer is None: - # this dst buffer is on NVMe, so skip this - return - - start = 0 - for param in param_list: - src = param.ds_tensor - dest = flat_buffer.narrow(0, start, src.ds_numel) - start = start + src.ds_numel - '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' - if src.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" - ) - param.nvme_swapper.swap_into_buffer(param, dest) - src.data = dest.data - src.status = PartitionedParamStatus.AVAILABLE - else: - assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" - if not avoid_copy: - dest.data.copy_(src.data) - src.data = dest.data - - # Final location must be gpu/cpu in this case - param.ds_tensor.final_location = 'not-nvme' - - def _create_param_groups_fp16_flat_cpu_memory(self): - - aggregate_params_count = 0 - - for j, param_group in enumerate(self.optimizer.param_groups): - params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) - - flat_buffer_size = params_in_group - - if self.params_in_nvme_and_cpu and \ - aggregate_params_count + params_in_group > self.max_params_in_cpu: - - flat_buffer_size = max(0, - self.max_params_in_cpu - aggregate_params_count) - - aggregate_params_count += params_in_group - - if flat_buffer_size > 0: - print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", - force=False) - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(int(flat_buffer_size), - dtype=self.dtype, - pin_memory=True)) - else: - print_rank_0( - f"No flat buffer size. Param group size was {params_in_group}", - force=False) - - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(1, - dtype=self.dtype)) - - def _create_fp16_partitions_with_defragmentation(self): - dist.barrier() - param_groups: List[List[Parameter]] = tuple( - self._create_fp16_sub_groups(param_group["params"]) - for param_group in self.optimizer.param_groups) - - # bookkeeping related to param groups - for param_group_idx, param_group in enumerate(param_groups): - for sub_group in param_group: - sub_group_idx = len(self.fp16_groups) - - # record sub group and partitions - self.fp16_groups.append(sub_group) - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in sub_group]) - - # record sub group -> group mapping - self.sub_group_to_group_id[sub_group_idx] = param_group_idx - - # record total elements of parameter partitions in sub group - self.fp16_partitioned_groups_flat_numel.append( - sum(p.ds_tensor.ds_numel for p in sub_group)) - - # record padding required to align group to world size (only applies to last rank) - rank_requires_padding = dist.get_rank( - self.dp_process_group) == dist.get_world_size( - self.dp_process_group) - 1 - self.groups_padding.append([ - p.padding_size() if rank_requires_padding else 0 for p in sub_group - ]) - - # move parameters to flattened buffer - if not self.offload_param: # partitioned params remain in GPU during training - # move parameter partitions into a single contiguous flat buffer - parameter_partitions: List[Tensor] = [] - for sub_group in self.fp16_groups: - for param in sub_group: - parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) - - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) - self.fp16_partitioned_groups_flat.append( - device_buffer.narrow(0, - offset, - sub_group_numel)) - offset += sub_group_numel - else: # partitioned params offloaded to CPU when not in use - # create a flat CPU memory allocation for each param group - self._create_param_groups_fp16_flat_cpu_memory() - for param_group_idx, param_group in enumerate(param_groups): - flat_offset = 0 - for i, sub_group in enumerate(param_group): - total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) - print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") - #Flat buffer may not be available for parameters that reside in NVME - if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].numel(): - fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].narrow(0, - flat_offset, - total_elements) - print_rank_0( - f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", - force=False) - elif self.params_in_nvme_and_cpu: - fp16_partitioned_group_flat = None - print_rank_0( - f"No flat buffer for sub group {i} of {total_elements} elements", - force=False) - else: - assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" - - self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - flat_offset += total_elements - - self._move_to_flat_buffer(sub_group, - fp16_partitioned_group_flat, - avoid_copy=not self.offload_param) - - # if necessary, create a pinned memory buffer to be used for swapping out - # params to NVME after optimizer step - should_create_fp16_flat_reuse_buffer = any( - flattened_partition_group is None - for flattened_partition_group in self.fp16_partitioned_groups_flat) - if should_create_fp16_flat_reuse_buffer: - max_partition_numel, largest_partition_numel = 0, None - for sub_group in self.fp16_groups: - total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) - if total_elements > max_partition_numel: - largest_partition_numel = [t.ds_numel for t in sub_group] - max_partition_numel = total_elements - - assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' - self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( - largest_partition_numel) - - def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): - offset = 0 - elements_in_sub_group = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) - assert (flat_buffer.numel() == elements_in_sub_group) - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" - ) - param.nvme_swapper.swap_in([param], async_op=False) - dest.data.copy_(partitioned_param.data) - param.nvme_swapper.remove_partition_and_release_buffers([param]) - print_rank_0(f"Swapping in {param.ds_id} done") - else: - dest.data.copy_(partitioned_param.data) - offset += partitioned_param.ds_numel - - def _create_next_swappable_fp32_groups(self): - reverse_order_indices = [ - i for i in range(len(self.fp32_partitioned_groups_flat)) - ] - reverse_order_indices.reverse() - - next_group = None - for i in reverse_order_indices: - self.next_swappable_fp32_partitioned_groups.append(next_group) - if self._swappable_optimizer_subgroup(i): - next_group = self.fp32_partitioned_groups_flat[i] - - self.next_swappable_fp32_partitioned_groups.reverse() - - def _get_sub_group_partitions(self, sub_group_id): - sub_group_partitions = [] - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_path = param.nvme_swapper.get_path(param, True) - sub_group_partitions.append((partitioned_param, - param.ds_tensor.ds_numel, - swap_path)) - else: - sub_group_partitions.append((partitioned_param, - partitioned_param.ds_numel, - None)) - - return sub_group_partitions - - def _create_fp32_partitions(self): - cpu_memory_usage = 0 - cpu_memory_sub_groups = 0 - nvme_memory_usage = 0 - num_swappable_partitions = 0 - num_swap_from_nvme_partitions = 0 - num_swap_from_cpu_partitions = 0 - swap_from_nvme_memory_usage = 0 - swap_from_cpu_memory_usage = 0 - GIGA_BYTES = (1024**3) - - swappable_fp32_tensors = [] - swappable_fp16_src_tensors = [] - nvme_fp16_partitions_info = [] - nvme_fp16_num_elems = [] - nvme_fp32_dest_tensors = [] - fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() - - for i, tensor in enumerate(self.fp16_partitioned_groups_flat): - num_elements = self.fp16_partitioned_groups_flat_numel[i] - - # a partition of the fp32 master weights that will be updated by this process - if self._swappable_optimizer_subgroup(i): - self.fp32_partitioned_groups_flat.append(torch.Tensor()) - nvme_memory_usage += (fp32_element_size * num_elements) - num_swappable_partitions += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - num_swap_from_nvme_partitions += 1 - swap_from_nvme_memory_usage += (fp32_element_size * num_elements) - if self.offload_optimizer_fast_init: - sub_group_partitions = self._get_sub_group_partitions(i) - nvme_fp16_partitions_info.append(sub_group_partitions) - nvme_fp16_num_elems.append(num_elements) - nvme_fp32_dest_tensors.append( - self.fp32_partitioned_groups_flat[i]) - else: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.optimizer_swapper.initialize_parameters( - parameters=[self.fp32_partitioned_groups_flat[i]], - src_tensors=[unpinned_fp32_buffer]) - else: - num_swap_from_cpu_partitions += 1 - swap_from_cpu_memory_usage += (fp32_element_size * num_elements) - swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) - swappable_fp16_src_tensors.append( - self.fp16_partitioned_groups_flat[i]) - else: - cpu_memory_usage += (fp32_element_size * num_elements) - cpu_memory_sub_groups += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) - else: - self.fp32_partitioned_groups_flat.append( - self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) - - self.fp32_partitioned_groups_flat[ - i].requires_grad = True # keep this in case internal optimizer uses it - - if len(swappable_fp32_tensors) > 0: - self.optimizer_swapper.initialize_parameters( - parameters=swappable_fp32_tensors, - src_tensors=swappable_fp16_src_tensors) - - if len(nvme_fp32_dest_tensors) > 0: - fp16_pinned_buffers = self.fp16_groups[0][ - 0].nvme_swapper.reserve_available_buffers() - assert len(fp16_pinned_buffers) > 0 - self.optimizer_swapper.initialize_from_swapped_fp16_params( - fp16_partitions_info=nvme_fp16_partitions_info, - fp16_num_elems=nvme_fp16_num_elems, - fp16_pinned_buffers=fp16_pinned_buffers, - fp32_parameters=nvme_fp32_dest_tensors) - self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() - - nvme_gigabytes = nvme_memory_usage / GIGA_BYTES - print_rank_0( - f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', - force=False) - if self.params_in_nvme_and_cpu: - print_rank_0( - f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - print_rank_0( - f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - - cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES - print_rank_0( - f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', - force=False) - - # Clear for on-the-fly population before the optimizer step - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _create_fp16_sub_groups(self, params_group): - - params_group_numel = sum([param.partitioned_size() for param in params_group]) - sub_group_size = self.sub_group_size - - if sub_group_size is None or sub_group_size >= params_group_numel: - return [params_group] - - sub_groups = [] - sub_group = [] - local_sub_group_size = 0 - for param in params_group: - - sub_group.append(param) - local_sub_group_size += param.partitioned_size() - - if local_sub_group_size >= sub_group_size or id(param) == id( - params_group[-1]): - - sub_groups.append(sub_group) - - sub_group = [] - local_sub_group_size = 0 - - return sub_groups - - # def reset_ds_tensor(self): - # for name, param in self.module.named_parameters(recurse=True): - # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" - # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" - # param.ds_tensor.data = param.data - - def setup_zero_stage3_hooks(self): - self.hierarchy = 0 - - #reset step if in inference mode - @instrument_w_nvtx - def _end_of_forward_hook(module, *args): - - if not torch._C.is_grad_enabled(): - self.param_coordinator.reset_step() - - #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) - - # Add top module to stack trace - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(self.module) - - def persistent_parameters(self): - persistent_params = [] - total_persistent_parameters = 0 - params_count = 0 - for _, param in self.module.named_parameters(recurse=True): - if param.ds_numel < self.persistence_threshold: - params_count += 1 - param.ds_persist = True - persistent_params.append(param) - total_persistent_parameters += param.ds_numel - - print_rank_0( - f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=False) - return persistent_params - - def _register_hooks_recursively(self, module, count=[0]): - my_count = count[0] - module.id = my_count - - #print(f"{module.__class__} : {module.id}") - - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) - - @instrument_w_nvtx - def _pre_forward_module_hook(module, *args): - self.pre_sub_module_forward_function(module) - - @instrument_w_nvtx - def _post_forward_module_hook(module, input, output): - global FWD_MODULE_STACK - FWD_MODULE_STACK.pop() - if output is None: - output = [] - elif not isinstance(output, (list, tuple)): - if torch.is_tensor(output): - output = [output] - else: - #print(f'got UNKNOWN type {type(output)}') - outputs = [] - output = output if isinstance(output, dict) else vars(output) - for name, val in output.items(): - if not name.startswith('__') and torch.is_tensor(val): - outputs.append(val) - output = outputs - #print(f'convert output to {output}') - - for item in filter(lambda item: is_zero_param(item), output): - if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.is_external_param = True - module_to_register = FWD_MODULE_STACK[-1] - print_rank_0( - f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', - force=False) - register_external_parameter(module_to_register, item) - - # It's possible that the parameter was already external to the completed module. If so, remove it the - # registration as it will be covered by the outer module instead. - if id(item) in module._external_params: - print_rank_0( - f' Unregistering nested dangling parameter from module {module.__class__.__name__}', - force=False) - unregister_external_parameter(module, item) - - item.all_gather() - - self.post_sub_module_forward_function(module) - - def _pre_backward_module_hook(module, inputs, output): - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - return _apply_to_tensors_only(module, - PreBackwardFunction, - _run_before_backward_function, - output) - - #This is an alternate to doing _post_backward_module_hook - #it uses tensor.register_hook instead of using torch.autograd.Function - def _alternate_post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - #print(f"Before Forward {module.__class__.__name__}") - - def _run_after_backward_hook(*unused): - module.ds_grads_remaining = module.ds_grads_remaining - 1 - if module.ds_grads_remaining == 0: - #print(f"After backward {module.__class__.__name__}") - self.post_sub_module_backward_function(module) - - def _run_before_forward_function(input): - if input.requires_grad: - module.ds_grads_remaining += 1 - - return _apply_forward_and_backward_to_tensors_only( - module, - _run_before_forward_function, - _run_after_backward_hook, - inputs) - - def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, - PostBackwardFunction, - _run_after_backward_function, - inputs) - - # Pre forward hook - module.register_forward_pre_hook(_pre_forward_module_hook) - # Post forward hook - module.register_forward_hook(_post_forward_module_hook) - - # Pre backward hook - module.register_forward_hook(_pre_backward_module_hook) - - # post backward hook - module.register_forward_pre_hook(_post_backward_module_hook) - - @torch.no_grad() - def pre_sub_module_forward_function(self, sub_module): - see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", - force=False) - - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(sub_module) - - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - - self.param_coordinator.fetch_sub_module(sub_module) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after fetch", - force=False) - - @torch.no_grad() - def post_sub_module_forward_function(self, sub_module): - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - self.param_coordinator.release_sub_module(sub_module) - - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - @torch.no_grad() - def pre_sub_module_backward_function(self, sub_module): - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - self.param_coordinator.fetch_sub_module(sub_module) - - @torch.no_grad() - def post_sub_module_backward_function(self, sub_module): - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - self.param_coordinator.release_sub_module(sub_module) - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - def _release_ipg_buffers(self): - if self.contiguous_gradients: - self.ipg_buffer = None - if not self.offload_optimizer and self.is_gradient_accumulation_boundary: - self.grads_in_partition = None - - self.grads_in_partition_offset = 0 - - def _optimizer_step(self, sub_group_id): - param_group_id = self.sub_group_to_group_id[sub_group_id] - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] - - def _swappable_optimizer_subgroup(self, sub_group_id): - if not self.swap_optimizer: - return False - - return self.optimizer_swapper.swappable_tensor( - None, - numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) - - def _partitioned_params_swap_out(self, i): - offset = 0 - fp32_param = self.fp32_partitioned_groups_flat[i] - assert fp32_param is not None, \ - f'fp32 parameters of sub_group {i} is None' - - swap_fp16_params = [] - swap_fp32_params = [] - for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): - src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.AVAILABLE: - partitioned_param.data.copy_(src.data) - else: - swap_fp32_params.append(src) - swap_fp16_params.append(param) - offset += partitioned_param.ds_numel - - if len(swap_fp16_params): - swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( - dst_fp16_params=swap_fp16_params, - src_fp32_params=swap_fp32_params) - - def initialize_optimizer_states(self): - num_subgroups = len(self.fp16_groups) - - largest_numel = max( - [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) - gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype - gradient_buffer = torch.zeros(int(largest_numel), - dtype=gradient_dtype, - device=self.device) - - timers = self.timers - timer_names = set() - - if self.swap_optimizer: - self.optimizer_swapper.init_timers() - - INIT_OPTIMIZER_TIMER = 'init_optimizer_state' - timer_names.add(INIT_OPTIMIZER_TIMER) - self.start_timers([INIT_OPTIMIZER_TIMER]) - - for i, group in enumerate(self.fp16_groups): - swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) - swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None - - num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) - - see_memory_usage( - f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_in(i, timer_names) - - if self.offload_optimizer and not swappable_optimizer_subgroup: - subgroup_gradient_buffer = torch.zeros(num_elements, - dtype=gradient_dtype, - device=self.device) - if self.offload_optimizer_pin_memory: - subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() - - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer - else: - self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( - 0, - 0, - num_elements) - - self._optimizer_step(i) - - if swappable_param_subgroup: - self._partitioned_params_swap_out(i) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_out(i, timer_names) - - see_memory_usage( - f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - self.stop_timers([INIT_OPTIMIZER_TIMER]) - self.log_timers(timer_names) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - if not self.offload_optimizer: - for group in self.fp32_partitioned_groups_flat: - group.grad = None - - # Reset steps - return - - ######################################################################### - #########################ZeRO Partition Gradients######################## - ######################################################################### - - def get_first_param_index(self, group_id, param_group, partition_id): - for index, param in enumerate(param_group): - param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index - return None - - def initialize_gradient_partitioning_data_structures(self): - - total_partitions = dist.get_world_size(group=self.dp_process_group) - - for i, param_group in enumerate(self.fp16_groups): - - self.param_to_partition_ids[i] = {} - self.is_partition_reduced[i] = {} - self.total_grads_in_partition[i] = {} - self.remaining_grads_in_partition[i] = {} - self.is_grad_computed[i] = {} - self.grad_partition_insertion_offset[i] = {} - self.grad_start_offset[i] = {} - self.first_param_index_in_partition[i] = {} - - for partition_id in range(total_partitions): - self.is_grad_computed[i][partition_id] = {} - self.grad_partition_insertion_offset[i][partition_id] = {} - self.grad_start_offset[i][partition_id] = {} - self.initialize_gradient_partition(i, param_group, partition_id) - self.is_partition_reduced[i][partition_id] = False - self.first_param_index_in_partition[i][ - partition_id] = self.get_first_param_index( - i, - param_group, - partition_id) - - @instrument_w_nvtx - def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.__reduce_and_partition_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - - self.__reduce_and_partition_stream.synchronize() - - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False - - #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad - #TODO: use a similar code path for both cpu_offload and non-cpu offload - if not self.offload_optimizer: - for i, sub_group in enumerate(self.fp16_groups): - self.averaged_gradients[i] = [ - self.__param_id_to_grad_partition[param.ds_id] - if param.requires_grad else torch.zeros_like(param.ds_tensor) - for param in sub_group - ] - # self.averaged_gradients[i] = self.get_flat_partition( - # self.fp16_groups[i], - # 0, - # self.fp32_partitioned_groups_flat[i].numel(), - # return_tensor_list=True) - - # this method gets called after every backward. need to increment - # here because if it gets incremented in backward() the micro step - # id will be off by one when we do the reduce and partition at the. - # start of this method. - # TODO. make this less error prone - self.micro_step_id += 1 - - def overlapping_partition_gradients_reduce_epilogue(self): - self.independent_gradient_partition_epilogue() - - def create_reduce_and_remove_grad_hooks(self): - print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] - for i, param_group in enumerate(self.fp16_groups): - for param in param_group: - if param.requires_grad: - #print_rank_0(f" Before all gather {param.device}, {param.shape}") - - # The hook must be created in un-partitioned parameter - param.all_gather() - - #print(f"After all gather {param.device}, {param.shape}") - def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - @instrument_w_nvtx - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) - - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) - - #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param, i) - - # Partition the parameter after creating the hook - param.partition() - print_rank_0(f'[End] Create gradient reduction hooks') - - def get_param_id(self, param): - unique_id = id(param) - return self.param_id[unique_id] - - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", - force=False) - - ###############Idependent Partition Gradient ######################## - def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) - - # Because the ipg bucket is initialized with a random place holder tensor, we must - # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > - # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a - # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be - # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: - self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", - param.ds_numel) - - self.__reduce_and_partition_ipg_grads() - - param_id = self.get_param_id(param) - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.__add_grad_to_ipg_bucket(param) - - @instrument_w_nvtx - @torch.no_grad() - def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) - - if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( - ) < self.reduce_bucket_size: - # move the gradient to a contiguous buffer - with torch.cuda.stream(self.__reduce_and_partition_stream): - # move the parameter's gradient to the contiguous flat buffer - new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( - 0, - self.elements_in_ipg_bucket, - param.grad.numel()).view_as(param.grad) - new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(torch.cuda.current_stream()) - param.grad.data = new_grad_tensor - - self.__params_in_ipg_bucket.append(param) - - @instrument_w_nvtx - @torch.no_grad() - def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: - if not self.__params_in_ipg_bucket: - return - - for param in self.__params_in_ipg_bucket: - if param.grad.numel() != param.ds_numel: - raise RuntimeError( - f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " - f"gradients whose size is not same as the params") - - self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - - assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( - self.__params_in_ipg_bucket) - - while self.__param_reduce_events and self.__param_reduce_events[0].query(): - self.__param_reduce_events.popleft() - if len(self.__param_reduce_events) > self.__max_param_reduce_events: - self.__param_reduce_events.popleft().synchronize() - - with torch.cuda.stream(self.__reduce_and_partition_stream): - if safe_mode: - assert_ints_same_as_other_ranks( - [p.ds_id for p in self.__params_in_ipg_bucket]) - - grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) - self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) - - self.__params_in_ipg_bucket.clear() - - event = Event() - event.record() - self.__param_reduce_events.append(event) - - @instrument_w_nvtx - def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: - """average gradients and scatter partitions across ranks""" - dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) - - full_grads_for_rank = [p.grad for p in params_to_reduce] - if self.allreduce_always_fp32: - full_grads_for_rank = [g.float() for g in full_grads_for_rank] - - if self.postscale_gradients and self.gradient_predivide_factor != 1.0: - full_grads_for_rank = [ - g.div(self.gradient_predivide_factor) for g in full_grads_for_rank - ] - - grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, - self.dp_process_group) - - if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( - self.dp_process_group): - grad_partitions_for_rank = [ - g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank - ] - - if self.allreduce_always_fp32: - grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] - - return grad_partitions_for_rank - - def set_grad_positions(self): - for i, group in enumerate(self.fp16_groups): - current_offset = 0 - for param in group: - param_id = self.get_param_id(param) - num_elements = param.ds_tensor.ds_numel - - self.grad_position[param_id] = [ - int(i), - int(current_offset), - int(num_elements) - ] - #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") - current_offset += num_elements - - def _constant_buffered_norm2(self, input, buffer_size=250000000): - norm = None - for part in input.view(-1).split(buffer_size): - if norm is None: - norm = part.data.double().norm(2)**2.0 - else: - norm += part.data.double().norm(2)**2.0 - return norm**0.5 - - def set_norm_for_param_grad_in_gpu(self, param): - param_id = self.get_param_id(param) - #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) - #Using a more memory efficient version - self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): - with torch.cuda.stream(self.copy_grad_stream): - param_id = self.get_param_id(param) - src_tensor = param.grad.view(-1).float() - #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") - fp32_grad_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None - - def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 - for p in params: - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_id = self.get_param_id(p) - if param_id in self.norm_for_param_grads.keys(): - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - @instrument_w_nvtx - def __partition_grads(self, - params_to_release: List[Parameter], - grad_partitions: List[Tensor]) -> None: - for param, grad_partition in zip(params_to_release, grad_partitions): - if param.ds_tensor.ds_numel * dist.get_rank( - self.dp_process_group) > param.ds_numel: - # this grad partition is empty - don't need to do anything - continue - - # move or accumulate gradient partition to target buffer - grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( - 0, - 0, - grad_partition.numel()) - if self.micro_step_id == 0: # don't accumulate - grad_buffer.copy_(grad_partition, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) - elif grad_buffer.is_cuda: - grad_buffer.add_(grad_partition) - else: - # if dst is CPU, copy first to src device, do the addition - # there, then move back to dst. adding directly to cpu is very slow - cuda_grad_buffer = grad_buffer.to(grad_partition.device, - non_blocking=True) - cuda_grad_buffer.add_(grad_partition) - grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = cuda_grad_buffer - - if hasattr(self.__inf_or_nan_tracker, "logical_or_"): - self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 - - # offload the gradient partition if applicable - if self.offload_optimizer: - i, dest_offset, _ = self.grad_position[self.get_param_id(param)] - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - if self.is_gradient_accumulation_boundary: - self.norm_for_param_grads[self.get_param_id( - param)] = self._constant_buffered_norm2(grad_buffer) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(grad_buffer.float()) - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) - - # free the gradient - param.grad.record_stream(torch.cuda.current_stream()) - param.grad = None - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - def reduce_ready_partitions_and_remove_grads(self, param, i): - #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) - self.reduce_independent_p_g_buckets_and_remove_grads(param, i) - - def zero_reduced_gradients(self, partition_id, i): - def are_all_related_partitions_reduced(params_id): - for partition_id in self.param_to_partition_ids[i][params_id]: - if not self.is_partition_reduced[i][partition_id]: - return False - return True - - for params_id in self.is_grad_computed[i][partition_id]: - if are_all_related_partitions_reduced(params_id): - self.param_dict[params_id].grad = None - - def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = self.flatten(tensors) - - def print_func(): - logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) - - self.sequential_execution(print_func, message) - - def get_grads_to_reduce(self, i, partition_id): - def get_reducible_portion(key): - grad = self.param_dict[key].grad - total_elements = grad.numel() - start = self.grad_start_offset[i][partition_id][key] - num_elements = min( - total_elements - start, - self.partition_size[i] - - self.grad_partition_insertion_offset[i][partition_id][key]) - if not pg_correctness_test: - if num_elements == total_elements: - return grad - else: - return grad.contiguous().view(-1).narrow(0, - int(start), - int(num_elements)) - else: - if num_elements == total_elements: - return grad.clone() - else: - return grad.clone().contiguous().view(-1).narrow( - 0, - int(start), - int(num_elements)) - - grads_to_reduce = [] - for key in self.is_grad_computed[i][partition_id]: - grad = get_reducible_portion(key) - grads_to_reduce.append(grad) - return grads_to_reduce - - def sequential_execution(self, function, message, group=None): - if group is None: - group = self.dp_process_group - if dist.get_rank(group=group) == 0: - logger.info(message) - for id in range(dist.get_world_size(group=group)): - if id == dist.get_rank(group=group): - function() - dist.barrier(group=group) - - def set_none_gradients_to_zero(self, i, partition_id): - for param_id in self.is_grad_computed[i][partition_id]: - param = self.param_dict[param_id] - if param.grad is None: - param.grad = torch.zero_like(param) - - ######################Reduction Related Methods############################## - - def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): - rank = None - tensor = self.flatten(bucket) - - tensor_to_allreduce = tensor - - if pg_correctness_test: - allreduce_always_fp32 = True - - if allreduce_always_fp32: - tensor_to_allreduce = tensor.float() - - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) - - if rank is None: - # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - else: - global_rank = _get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) - - if allreduce_always_fp32 and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - tensor.copy_(tensor_to_allreduce) - - return tensor - - # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): - with torch.cuda.stream(self.reduction_stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): - buf.copy_(synced) - - def allreduce_no_retain(self, - bucket, - numel_per_bucket=500000000, - rank=None, - log=None): - small_bucket = [] - numel = 0 - for tensor in bucket: - small_bucket.append(tensor) - numel = numel + tensor.numel() - if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) - small_bucket = [] - if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) - - ############################################################################# - ############################################################################# - ############################################################################# - - # views the tensor as multiple partitions and returns - # those partitions - def get_data_parallel_partitions(self, tensor): - partitions = [] - - dp = dist.get_world_size(group=self.dp_process_group) - dp_id = dist.get_rank(group=self.dp_process_group) - - total_num_elements = tensor.numel() - - base_size = total_num_elements // dp - remaining = total_num_elements % dp - - start = 0 - for id in range(dp): - partition_size = base_size - if id < remaining: - partition_size = partition_size + 1 - partitions.append(tensor.narrow(0, start, partition_size)) - start = start + partition_size - return partitions - - def get_partition_info(self, tensor_list, partition_size, partition_id): - params_in_partition = [] - params_not_in_partition = [] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for tensor in tensor_list: - - tensor_size = tensor.numel() - - if (current_index >= start_index and current_index < end_index): - params_in_partition.append(tensor) - - elif start_index > current_index and start_index < (current_index + - tensor_size): - params_in_partition.append(tensor) - - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - else: - params_not_in_partition.append(tensor) - - current_index = current_index + tensor_size - - return params_in_partition, params_not_in_partition, first_offset - - @instrument_w_nvtx - def zero_grad(self, set_grads_to_None=True): - """ - Zero FP16 parameter grads. - """ - self.micro_step_id = 0 - - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_grads_to_None: - if p.grad is not None and p.grad.is_cuda: - p.grad.record_stream(torch.cuda.current_stream()) - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def _model_parallel_all_reduce(self, tensor, op): - """ Perform all reduce within model parallel group, if any. - """ - if self.model_parallel_group is None: - pass - else: - torch.distributed.all_reduce(tensor=tensor, - op=op, - group=self.model_parallel_group) - - @instrument_w_nvtx - def get_grad_norm_direct(self, gradients, params, norm_type=2): - """Clips gradient norm of an iterable of parameters. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - grad_norms = [] - for g, p in zip(gradients, params): - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda.item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - # creates a flat fused tensor from the tensor list starting at the first_offset - # in the first tensor of the list. If there are not enough elements in the tensor - # list then the flat tensor will be padded with zeros - def get_flat_partition(self, - tensor_list, - first_offset, - partition_size, - return_tensor_list=False): - flat_tensor_list = [] - current_size = 0 - for i, tensor in enumerate(tensor_list): - if tensor.grad is None: - tensor.grad = torch.zeros_like(tensor) - - tensor = tensor.grad - num_elements = tensor.numel() - tensor_offset = 0 - - # we need to offset to get to the right element - if i == 0 and first_offset > 0: - tensor_offset = first_offset - num_elements = num_elements - tensor_offset - - # we dont need all elements of the tensor - if num_elements > (partition_size - current_size): - num_elements = partition_size - current_size - - # we need a narrow view of the tensor based on the tensor offset and number of elements that - # we need from this tensor - if tensor_offset > 0 or num_elements < tensor.numel(): - flat_tensor_list.append(tensor.contiguous().view(-1).narrow( - 0, - int(tensor_offset), - int(num_elements))) - else: - flat_tensor_list.append(tensor) - - current_size = current_size + num_elements - - # this means its the last partition and does not align with the dp boundary. We need to pad before flattening - if current_size < partition_size: - flat_tensor_list.append( - torch.zeros(int(partition_size - current_size), - dtype=tensor_list[0].dtype, - device=tensor_list[0].device)) - - if return_tensor_list: - return flat_tensor_list - - return self.flatten(flat_tensor_list) - - def free_grad_in_param_list(self, param_list): - for p in param_list: - p.grad = None - - def reset_cpu_buffers(self): - self.norm_for_param_grads = {} - self.local_overflow = False - - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() - - def _pre_step(self): - self.micro_step_id = 0 - - print_rank_0(f"Inside Step function") - see_memory_usage(f"In step before checking overflow", force=False) - - print_rank_0("Finished Tracing at Beginning of Step") - self.param_coordinator.hierarchy = 0 - - print_rank_0("Finished Tracing at Beginning of Step") - - @instrument_w_nvtx - def _get_norm_groups(self): - norm_groups = [] - for i, group in enumerate(self.fp16_groups): - if self.offload_optimizer: - norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload( - self.fp16_groups[i])) - else: - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.fp16_groups[i])) - return norm_groups - - @instrument_w_nvtx - def _prepare_fp32_grad_for_sub_group(self, sub_group_id): - partition_id = dist.get_rank(group=self.dp_process_group) - - single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( - self.fp32_partitioned_groups_flat[sub_group_id].dtype) - - assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format( - single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) - - self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition - - # release all the gradient since we have already created a necessary copy in dp_grad_partition - self.zero_grad() - - for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): - grad.record_stream(torch.cuda.current_stream()) - - self.averaged_gradients[sub_group_id] = None - - @instrument_w_nvtx - def _prepare_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', - force=False) - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) - elif not self.offload_optimizer: - self._prepare_fp32_grad_for_sub_group(sub_group_id) - see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', - force=False) - - def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' - see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_IN_STATE]) - - self.optimizer_swapper.swap_in_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) - - self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) - timer_names.add(OPTIMIZER_SWAP_IN_STATE) - see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', - force=False) - - @instrument_w_nvtx - def _release_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before release optimizer sub group {sub_group_id}', - force=False) - # get rid of the fp32 gradients. Not needed anymore - if not self.offload_optimizer: - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) - see_memory_usage(f'After release optimizer sub group {sub_group_id}', - force=False) - - # create a flat tensor aligned at the alignment boundary - @instrument_w_nvtx - def flatten_dense_tensors_aligned(self, tensor_list, alignment): - num_elements = 0 - for tens in tensor_list: - num_elements = num_elements + tens.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return self.flatten(padded_tensor_list) - - def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' - see_memory_usage( - f'post-step Before swapping out optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) - - self.optimizer_swapper.swap_out_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is - not None) - - self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) - see_memory_usage( - f'post-step After swapping out optimizer tensors {sub_group_id}', - force=False) - timer_names.add(OPTIMIZER_SWAP_OUT_STATE) - - # get rid of the fp32 gradients. Not needed anymore - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - def _unflatten_partitioned_parameters(self, sub_group_id): - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - def _overflow_clean_up(self, prev_scale): - see_memory_usage('After overflow before clearing gradients', force=False) - self.zero_grad() - - if self.offload_optimizer: - self.reset_cpu_buffers() - else: - self.averaged_gradients = {} - - see_memory_usage('After overflow after clearing gradients', force=False) - - if torch.distributed.get_rank() == 0: - logger.info( - "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) - - @instrument_w_nvtx - def _overflow_check_and_loss_scale_update(self): - - # First compute norm for all group so we know if there is overflow - self.check_overflow() - - #loss scaling related computation - prev_scale = self.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self._overflow_clean_up(prev_scale) - - return self.overflow - - @instrument_w_nvtx - def _post_step(self, timer_names=set()): - if self.offload_optimizer: - self.reset_cpu_buffers() - - #Gathering persisting parameters - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - self.log_timers(timer_names) - - see_memory_usage('After zero_optimizer step', force=False) - print_rank_0(f"------------------Finishing Step-----------------------") - - @instrument_w_nvtx - def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): - if self.fp16_partitioned_groups_flat[sub_group_id] is not None: - self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( - self.fp32_partitioned_groups_flat[sub_group_id].data) - - #unflatten fp16 parameter subgroup - self._unflatten_partitioned_parameters(sub_group_id) - else: - self._partitioned_params_swap_out(sub_group_id) - - @instrument_w_nvtx - def step(self, closure=None): - """ - Not supporting closure. - """ - self._pre_step() - self._partition_all_parameters() - - #checks for overflow, adjust the loss scale accordingly - if self._overflow_check_and_loss_scale_update(): - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - return - - norm_groups = self._get_norm_groups() - self._global_grad_norm = get_global_norm(norm_list=norm_groups) - - timer_names = set() - - timer_names.add('optimizer_step') - self.start_timers(['optimizer_step']) - - #update parameters one sub group at a time - for sub_group_id, group in enumerate(self.fp16_groups): - - #prepare optimizer states, gradients and fp32 parameters for update - self._prepare_sub_group(sub_group_id, timer_names) - - #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) - - #apply the optimizer step on the sub group and copy fp32 parameters to fp16 - self._optimizer_step(sub_group_id) - - #put fp16 parameters in appropriate location - self._reassign_or_swap_out_partitioned_parameters(sub_group_id) - - #release memory or swap out optimizer states of fp32 parameters - self._release_sub_group(sub_group_id, timer_names) - - self.stop_timers(['optimizer_step']) - - self._post_step(timer_names) - - # warn user about caching allocator flushes - alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( - torch.cuda, - "memory_stats") else 0 - if alloc_retries > self.__n_caching_allocator_flushes: - if dist.get_rank() == 0: - logger.warning( - "%d pytorch allocator cache flushes since last step. this happens " - "when there is high memory pressure and is detrimental to " - "performance. if this is happening frequently consider adjusting " - "settings to reduce memory consumption. If you are unable to " - "make the cache flushes go away consider adding " - "torch.cuda.empty_cache() calls in your training loop to ensure " - "that all ranks flush their caches at the same time", - alloc_retries - self.__n_caching_allocator_flushes) - self.__n_caching_allocator_flushes = alloc_retries - - def dump_pre_step_gradients(self, debug_fp32_grads): - # Dump gradient norms for debugging - for i, _ in enumerate(self.fp16_groups): - print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') - for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): - param_id = self.get_param_id(fp16_param) - fp16_grad_norm = self.debug_fp16_grads[i][param_id] - - fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] - norm_list = [fp16_grad_norm, fp32_grad_norm] - print(f'Pre-Step Norms {i} {param_id} = {norm_list}') - - def dump_post_step_gradients(self): - # Dump gradient norms for debugging - for i, group in enumerate(self.fp16_groups): - print( - f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') - unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) - unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], - self.fp16_groups[i]) - for j, p in enumerate(self.fp16_groups[i]): - param_id = self.get_param_id(p) - param_norm = float(p.data.float().norm(2)) - ds_norm = float(p.ds_tensor.data.float().norm(2)) - - unflat_norm = [ - float(t.data.float().norm(2)) - for t in [unflat_fp16[j], - unflat_fp32[j]] - ] - norm_list = [param_norm, ds_norm] + unflat_norm - print(f'Post-Step Norms {i} {param_id} = {norm_list}') - - @instrument_w_nvtx - def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - - # compute combined scale factor for this group - combined_scale = self.loss_scale - if self.clip_grad > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale - # to maintain behavior of averaging over accumulation steps - combined_scale *= self.micro_step_id + 1 - - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) - - def _check_overflow(self, partition_gradients=True): - self.overflow = self.has_overflow(partition_gradients) - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params, is_grad_list=False): - for p in params: - if p.grad is not None and self._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow_partitioned_grads_serial(self): - for i in range(len(self.fp16_groups)): - for j, grad in enumerate(self.averaged_gradients[i]): - if grad is not None and self._has_inf_or_nan(grad.data, j): - return True - return False - - @instrument_w_nvtx - def has_overflow(self, partition_gradients=True): - if partition_gradients: - with torch.cuda.stream(self.__reduce_and_partition_stream): - self.local_overflow = bool(self.__inf_or_nan_tracker.item()) - self.__inf_or_nan_tracker.zero_() - - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = torch.cuda.ByteTensor([overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - else: - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - - overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) - - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - self._model_parallel_all_reduce(tensor=overflow_gpu, - op=torch.distributed.ReduceOp.MAX) - - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - @staticmethod - def _has_inf_or_nan(x, j=None): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - @instrument_w_nvtx - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - if self.swap_optimizer: - self.optimizer_swapper.pre_backward() - - see_memory_usage(f"Before backward", force=False) - - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - self.param_coordinator.reset_step() - - if self.swap_optimizer: - self.optimizer_swapper.post_backward() - - def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: - """get fp32 gradient partition dictionary - accessed as grad_dict[parameter_group_index][parameter_index] - """ - self.__reduce_and_partition_stream.synchronize() - grad_dict = collections.defaultdict(dict) - if self.offload_optimizer: - for group in self.fp16_groups: - for param_idx, param in enumerate(group): - group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( - 0, - dest_offset, - num_elements) - grad_dict[group_idx][param_idx] = fp32_grad - else: - for group_idx, group in self.averaged_gradients.items(): - for param_idx, gradient in enumerate(group): - grad_dict[group_idx][param_idx] = gradient.float() - - return grad_dict - - @instrument_w_nvtx - def _partition_all_parameters(self): - """Partitioning Parameters that were not partitioned usually if parameters - of modules whose input parameters do not require grad computation do not - trigger post call and will therefore will remain unpartitioned""" - self.param_coordinator.release_and_reset_all() - for param in iter_params(self.module, recurse=True): - if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(f"{param.ds_summary()} expected to be released") - - def check_overflow(self, partition_gradients=True): - self._check_overflow(partition_gradients) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - cur_scale = property(_get_loss_scale, _set_loss_scale) - - def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): - # Remove paddings from flattened tensor - individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) - lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] - lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] - #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') - return lean_tensors - - #TODO REVISIT this for stage 3 - def get_lean_optimizer_state(self): - # Return optimizer states after removing paddings. - # This method assumes that each param group contains a single flattened tensor. - optimizer_groups_state = [] - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - lean_state = {} - for key, value in self.optimizer.state[p].items(): - if torch.is_tensor(value): - padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] - lean_state[key] = self._get_lean_tensors( - value, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - lean_flat_len = sum([t.numel() for t in lean_state[key]]) - else: - lean_state[key] = value - - optimizer_groups_state.append(lean_state) - - return optimizer_groups_state - - def get_groups_without_padding(self, groups_with_padding): - # Return group tensor after removing paddings added for alignment to DP world size. - groups_without_padding = [] - for i, group in enumerate(groups_with_padding): - lean_group = self._get_lean_tensors(group, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - groups_without_padding.append(lean_group) - - return groups_without_padding - - def _set_fp32_optimizer_param_groups(self): - for sub_group_id, _ in enumerate(self.fp16_groups): - param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'].append( - self.fp32_partitioned_groups_flat[sub_group_id]) - - def _clear_fp32_optimizer_param_groups(self): - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _rigid_state_dict(self): - state_dict = {} - state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS - state_dict['loss_scaler'] = self.loss_scaler - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['overflow'] = self.overflow - state_dict['partition_count'] = self.partition_count - - self._set_fp32_optimizer_param_groups() - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat - self._clear_fp32_optimizer_param_groups() - - return state_dict - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - return self._rigid_state_dict() - - -# Restore base optimizer fp32 weights from checkpoint by: -# 1) Merging fp32 weights from checkpoints of all partitions -# 2) Extracting fp32 weights for current partition from merged weights -# 3) Using extracted weights to update base optimizer weights directly. - - def _restore_from_fp32_weights(self, all_state_dict): - - flat_local_partition = [] - for i in range(len(self.fp32_partitioned_groups_flat)): - merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] - flat_local_partition.append(self._get_flattened_partition(merged_partitions)) - - for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): - current.data.copy_(saved.data) - - # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): - for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): - fp32_partition.data.copy_(fp16_partitions.data) - - # Refresh the fp32 master params from the fp16 copies. - def refresh_fp32_params(self): - self._restore_from_fp16_weights() - - # Extract flattened partition for current rank from all partitions - def _get_flattened_partition(self, all_partition_states): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) - - param_partitions = [[] for _ in range(len(all_partition_states[0]))] - for i, partition in enumerate(all_partition_states): - for j, param in enumerate(partition): - param_partitions[j].append(param) - - local_state_partitions = [] - for param_index, param_slices in enumerate(param_partitions): - flattened_merged_tensor = self.flatten_dense_tensors_aligned( - param_slices, - alignment) - new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) - local_state_partitions.append(new_partitions[partition_id]) - - if torch.is_tensor(local_state_partitions[0]): - return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) - - # Assume non-tensor states are not partitioned and equal across ranks, so return first one - return local_state_partitions[0] - - # Restore base optimizer state from checkpoint by - # 1) Merging optimizer state from checkpoints of all partitions - # 2) Extracting optimizer state for current partition from the merged state - # 3) Using the extracted value to directly update the base optimizer. - def _restore_base_optimizer_state(self, all_state_dict): - base_optimizer_group_states = [] - for i in range(len(self.optimizer.param_groups)): - partition_states = {} - all_partition_group_states = [ - sd['base_optimizer_state'][i] for sd in all_state_dict - ] - for key in all_partition_group_states[0].keys(): - all_partition_states = [ - all_states[key] for all_states in all_partition_group_states - ] - partition_states[key] = self._get_flattened_partition( - all_partition_states) - base_optimizer_group_states.append(partition_states) - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - self.optimizer.state[p][key].data.copy_(saved.data) - else: - self.optimizer.state[p][key] = saved - - def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.overflow = state_dict['overflow'] - - if load_optimizer_states: - self._set_fp32_optimizer_param_groups() - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - self._clear_fp32_optimizer_param_groups() - - # restore fp32 partitions - for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): - curr_param.data.copy_(saved_param.data) - - # restore fp16 partitions from fp32 - for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - fp16_param.data.copy_(fp32_param.data) - - # update fp16 unflattened params - for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): - updated_params = self.unflatten( - self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - # TODO: Support different/changing load/save DP degree. - def load_state_dict(self, - state_dict_list, - load_optimizer_states=True, - load_from_fp32_weights=False): - r"""Loading a ZeRO checkpoint - Arguments: - state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. - Note that the number of saved partitions may differ from number of loading partitions to support - changing GPU count, specifically DP world size, between saving and loading checkpoints. - load_optimizer_states: Boolean indicating whether or not to load base optimizer states - load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 - copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). - """ - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - self._rigid_load_state_dict( - state_dict_list[dist.get_rank(group=self.dp_process_group)], - load_optimizer_states=load_optimizer_states) - - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].partition(self.persistent_parameters) - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - def save_checkpoint_prologue(self): - self._partition_all_parameters() - - def save_checkpoint_epilogue(self): - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - -def _handle_overflow(cpu_sum, x, i): - import math - rank = torch.distributed.get_rank() - if rank == 0: - t_i = -1 - for v_i, v in enumerate(x.data.contiguous().view(-1)): - if not math.isfinite(float(v)): - t_i = v_i - break - logger.info( - f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" - ) - - -def estimate_zero3_model_states_mem_needs(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - cpu_offload=True, - cpu_offload_params=True, - zero_init=True, - additional_buffer_factor=1.5): - - total_gpus = num_nodes * num_gpus_per_node - gpus_factor = 1 / num_nodes - largest_layer_memory = (4 * largest_layer_params) - - if cpu_offload: - if cpu_offload_params: - gpu_mem = largest_layer_memory - - if zero_init: - cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 18 * gpus_factor) * additional_buffer_factor - - else: - gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) - - if zero_init: - cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 16 * gpus_factor) * additional_buffer_factor - else: - gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) - if zero_init: - cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor - else: - cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor - - return int(cpu_mem), int(gpu_mem), largest_layer_memory - - -def model_to_params(model): - # shared params calculated only once - total_params = sum( - dict((p.data_ptr(), - p.numel()) for p in model.parameters()).values()) - - largest_layer_params = 0 - for m in model.modules(): - # assuming no shared params within a single layer - layer_params = sum(p.numel() for p in m.parameters(recurse=False)) - largest_layer_params = max(largest_layer_params, layer_params) - - return total_params, largest_layer_params - - -import math - - -def estimate_zero3_model_states_mem_needs_all_live(model, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If you have an actual model object, use this function and everything will be derived - automatically. - - If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - Args: - - ``model``: ``nn.Module`` object - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - - total_params, largest_layer_params = model_to_params(model) - - estimate_zero3_model_states_mem_needs_all_cold( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - additional_buffer_factor=additional_buffer_factor) - - -def estimate_zero3_model_states_mem_needs_all_cold(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If it's a hypothetical model, use this function where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything - will be derived automatically. - - Args: - - ``total_params``: total model params - - ``largest_layer_params``: largest layer's params - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - def format_options(cpu_offload, cpu_offload_params, zero_init): - enabled = [] - padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' - param_device = padded_cpu_str if cpu_offload_params else "none" - enabled.append(f"{OFFLOAD_PARAM}={param_device}") - optimizer_device = padded_cpu_str if cpu_offload else "none" - enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") - enabled.append(f"zero_init={1 if zero_init else 0}") - return ", ".join(enabled) - - nodes_str = "nodes" if num_nodes > 1 else "node" - gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" - print( - "Estimated memory needed for params, optim states and gradients for a:\n" - f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" - f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." - ) - print(" per CPU | per GPU | Options") - for cpu_offload in [True, False]: - for cpu_offload_params in [True, False]: - if not cpu_offload and cpu_offload_params: - continue - for zero_init in [True, False]: - cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init, - additional_buffer_factor=additional_buffer_factor - ) - - options_str = format_options(cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init) - print( - f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") +""" +"Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +""" + +import gc +from dataclasses import dataclass +import functools +import os +import collections +from collections import OrderedDict, UserDict +import itertools +from typing import Deque, Dict, Iterable, Set, Tuple +import torch +from torch.cuda import Event, Stream +from torch.nn import Module, Parameter +import torch.distributed as dist +import math +from torch._six import inf +from torch.nn import Module +from torch.nn.parameter import Parameter + +from deepspeed.utils.logging import logger +from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.runtime.zero.offload_constants import * +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper +from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +FWD_MODULE_STACK = list() + + +def print_rank_0(message, debug=False, force=False): + rank = torch.distributed.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def input(msg): + return + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + +def get_cuda_mem_allocated_str() -> str: + # this is really slow. when enabled the python process becomes slow + # to the point where it can't keep the GPU fed with work, so only enable + # for memory debugging. + # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" + return "xGB" + + +def move_to_cpu(tensor_list): + for tensor in tensor_list: + tensor.data = tensor.data.cpu() + + +@instrument_w_nvtx +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), + sub_module.ds_external_parameters()) + + +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + +#apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, + functional, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, + forward_function, + backward_function, + outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only( + module, + forward_function, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + def __init__(self, parent_module, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + print_rank_0(f'Registering external parameter from getter {key}', + force=False) + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" + class __InflightParamRegistry(UserDict): + """registry for parameters in flight""" + def __setitem__(self, + param: Parameter, + handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"attempted to add non-inflight parameter to registry {param.ds_summary()}" + ) + self.data[param] = handle + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = __class__.__InflightParamRegistry() + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # whether or not we have completed a trace of the entire network. This should + # always be true after the first forward pass + backward pass. + self.trace_complete: bool = False + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel + # max distance between two use of the module beyond which module is released + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: Deque[__class__.__ParamInTrace] = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + + # stream that will be used for allgather operations + self.__allgather_stream: Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesnt require any configuration by the user. + self.__ongoing_fetch_events: Deque[Event] = collections.deque() + self.__max_ongoing_fetch_events: int = 2 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ + + def record_trace(self, sub_module: Module) -> None: + """adds sub module to trace""" + if self.trace_complete: + raise RuntimeError( + "attemted to record trace when trace was already complete") + + self.__submodule_order.append(sub_module) + for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + self.__param_order.append( + __class__.__ParamInTrace(param=param, + step_id_last_used_at=self.__step_id)) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + if self.__inflight_param_registry: + raise RuntimeError( + f"still have inflight params " + f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + + if not self.trace_complete: + # make sure that recorded parameter and submodule orders are + # identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.trace_complete = True + print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", + force=True) + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + self.__step_id = 0 + self.__n_available_params = 0 + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ + + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + "allocated": get_cuda_mem_allocated_str() + })) + + params_to_fetch = frozenset(iter_params(current_submodule)) + + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch) + + # wait for parameters in the immediately needed submodule to become available + for param in iter_params(current_submodule): + param.ds_active_sub_modules.add(current_submodule.id) + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + with torch.cuda.stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events + ) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + + self.__inflight_param_registry.pop(param).wait() + + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + torch.cuda.current_stream().wait_stream(self.__allgather_stream) + + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace, or if we aren't + # training (throws off the tracing and don't want to prefetch modules for bwd) + if self.trace_complete and current_submodule.training: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + params_to_release = (self.__params_to_release(submodule, + self.__step_id) + if self.trace_complete else set( + p.ds_id for p in iter_params(submodule))) + + for param in iter_params(submodule): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self) -> None: + """release all module parameters""" + for param in map(lambda p: p.param, self.__param_order): + if param in self.__inflight_param_registry: + raise RuntimeError(f"param {param.ds_summary()} still in flight") + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param_in_trace in self.__param_order: + if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError( + f"{param_in_trace.param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter]) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + self.__n_available_params += param.ds_numel + + if partitioned_params: + with torch.cuda.stream(self.__allgather_stream): + handle = partitioned_params[0].all_gather_coalesced(partitioned_params) + + for param in partitioned_params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle + + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, + submodule_to_release: Module, + step_id: int) -> Set[int]: + if not self.trace_complete: + raise RuntimeError("expected trace to be complete") + + params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) + if not p.ds_persist) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed > self.__max_reuse_dist_in_numel: + break + for param in iter_params(module): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel + + return params_to_release + + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + if not self.trace_complete: + return + + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) + + numel_considered = 0 + swap_in_params = [] + for param_in_trace in self.__param_queue: + param = param_in_trace.param + if param.nvme_swapper is None: + continue + if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= + param.nvme_swapper.available_swap_in_buffers()): + break + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_in_params.append(param) + numel_considered += param.ds_numel + + if swap_in_params: + swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + if not hasattr(module, "applied_pre_backward_ref_cnt"): + module.applied_pre_backward_ref_cnt = 0 + module.applied_pre_backward_ref_cnt += 1 + #print(f"After Forward: {ctx.module.__class__.__name__}") + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + #print(f"Before Backward: {ctx.module.__class__.__name__}") + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.pre_backward_function = pre_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.pre_backward_function(ctx.module) + #print(f"After Backward: {ctx.module.__class__.__name__}") + return (None, None) + args + + +class FP16_DeepSpeedZeroOptimizer_Stage3(object): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + def __init__(self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + mpu=None, + clip_grad=0.0, + allreduce_always_fp32=False, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None): + + see_memory_usage("Stage 3 initialize beginning", force=False) + + if dist.get_rank() == 0: + logger.info(f"initialized {__class__.__name__} with args: {locals()}") + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Allgather bucket size {prefetch_bucket_size}") + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + self.optimizer = init_optimizer + + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self._global_grad_norm = 0. + + self.optimizer_swapper = None + self.swap_optimizer = False + + self.offload_optimizer = False + self.offload_optimizer_pin_memory = False + self.offload_optimizer_fast_init = False + self.offload_param = False + self.offload_param_pin_memory = False + self.params_in_nvme_and_cpu = False + self.max_params_in_cpu = 0 + + self._configure_offloading(offload_optimizer_config, offload_param_config) + + self._convert_to_zero_parameters(ds_config, module, mpu) + + for m in module.modules(): + _init_external_params(m) + + self.module = module + self.elastic_checkpoint = elastic_checkpoint + + # Replace ._parameters with a new class to enable auto-registration of + # external parameters + _inject_parameters(module, ZeROOrderedDict) + + self.__inf_or_nan_tracker: Tensor = torch.zeros( + 1, + dtype=torch.bool, + device=torch.cuda.current_device(), + requires_grad=False) + + self.deepspeed_adam_offload = (self.offload_optimizer + and type(init_optimizer) == DeepSpeedCPUAdam) + + self.device = torch.cuda.current_device( + ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ### streams used for overlapping computation with communication + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + + ############################################################################ + + see_memory_usage("Before Partitioned Parameter Coordinator", force=False) + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=int(prefetch_bucket_size), + max_reuse_distance_in_numel=int(max_reuse_distance), + max_available_parameters_in_numel=int(max_live_parameters), + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.params_in_nvme_and_cpu, + ) + see_memory_usage("After Partitioned Parameter Coordinator", force=False) + + self.__n_caching_allocator_flushes = 0 + + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) + #-------------Stage 3 Setup-------------------# + # parameters smaller than the threshold will be collectively gathered at the + # end of the optimizer step and will be kept till the end of the backward pass + # TODO maybe worth just replicating these parameters and doing all reduce for them + self.persistence_threshold = int(param_persistence_threshold) + + self.persistent_parameters = self.persistent_parameters() + + self.setup_zero_stage3_hooks() + + #resetting ds_tensor just in case parameters have been changed after initialization + #example .half() or .to() + #self.reset_ds_tensor() + #---------------------------------------------# + + self.timers = timers + + self.dp_process_group = dp_process_group + + self.partition_count = dist.get_world_size(group=self.dp_process_group) + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_rank = mpu.get_model_parallel_rank() + + self.overflow = False + self.clip_grad = clip_grad + self.allreduce_always_fp32 = allreduce_always_fp32 + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = 0 + + # Holds the mode parameter + # The param.data may not hold any meaningful data + # when param's status is NOT_AVAILABLE or IN_FLGHT + self.fp16_groups = [] + + # Hold partitioned parameters + self.fp16_partitioned_groups = [] + + # Holds a fused and flattened copy of the parameters + self.fp16_partitioned_groups_flat = [] + self.fp16_partitioned_groups_flat_numel = [] + + #defragmented pinned memory + self.param_groups_fp16_flat_cpu_memory = [] + + #a single 32-bit partition of the parallel partitioned parameters + #that this process will update + self.fp32_partitioned_groups_flat = [] + self.next_swappable_fp32_partitioned_groups = [] + + # number of elements per partition in each group + self.partition_size = [] + + self.all_reduce_print = False + + self.prefetch_elements = int(prefetch_bucket_size) + + # padding on each partition for alignment purposes + self.groups_padding = [] + + self.sub_group_size = sub_group_size + + self.sub_group_to_group_id = {} + see_memory_usage("Before creating fp16 partitions", force=False) + self._create_fp16_partitions_with_defragmentation() + num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", + force=False) + + # Optimizer tensor swapping + if self.swap_optimizer: + self._configure_tensor_swapping(offload_optimizer_config, aio_config) + + see_memory_usage("Before creating fp32 partitions", force=False) + if not isinstance(self.optimizer, DummyOptim): + self._create_fp32_partitions() + see_memory_usage("After creating fp32 partitions", force=False) + dist.barrier() + + # To support pipelined optimizer swapping + if not isinstance(init_optimizer, DummyOptim): + self._create_next_swappable_fp32_groups() + + see_memory_usage("Before initializing optimizer states", force=False) + if not isinstance(init_optimizer, DummyOptim): + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=False) + dist.barrier() + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.reduce_bucket_size = int(reduce_bucket_size) + + # IPG + if contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty( + int(reduce_bucket_size), + dtype=self.dtype, + device=torch.cuda.current_device()) + + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + + grad_partitions_flat_buffer: Tensor = torch.zeros( + sum(p.ds_tensor.ds_numel for p in all_params), + dtype=self.dtype, + device=self.device, + pin_memory=self.offload_optimizer_pin_memory) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[ + param.ds_id] = grad_partitions_flat_buffer.narrow( + 0, + offset, + param.ds_tensor.numel()) + offset += param.ds_tensor.numel() + + self.__params_in_ipg_bucket: List[Parameter] = [] + self.is_gradient_accumulation_boundary: bool = True + + self.__param_reduce_events: Deque[Event] = collections.deque() + self.__max_param_reduce_events: int = 2 + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.contiguous_gradients = contiguous_gradients + self.extra_large_param_to_reduce = None + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.params_already_reduced = [] + self.is_gradient_accumulation_boundary = True + self._release_ipg_buffers() + self.previous_reduced_grads = None + + # simplified param id + self.param_id = {} + + count = 0 + for i, params_group in enumerate(self.fp16_groups): + for param in params_group: + unique_id = id(param) + self.param_id[unique_id] = count + self.param_dict[count] = param + self.params_already_reduced.append(False) + count = count + 1 + + #Largest partitioned param + largest_partitioned_param_numel = max([ + max([tensor.numel() for tensor in fp16_partitioned_group]) + for fp16_partitioned_group in self.fp16_partitioned_groups + ]) + print_rank_0( + f'Largest partitioned param numel = {largest_partitioned_param_numel}', + force=False) + + see_memory_usage(f"Before Set Grad positions", force=False) + + self.grad_position = {} + self.set_grad_positions() + see_memory_usage(f"Before CPU Offload initialization", force=False) + + self.grads_in_partition = None + + if self.offload_optimizer: + self.norm_for_param_grads = {} + self.local_overflow = False + + see_memory_usage(f"After CPU Offload initialization", force=False) + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # will store the averaged gradients required by this paritition + self.averaged_gradients = {} + + #creates backward hooks for gradient partitioning + self.create_reduce_and_remove_grad_hooks() + + #exit(0) + + # we may have a way of fusing dynamic scale. Do not support for now + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: + if dynamic_loss_args is None: + self.loss_scaler = DynamicLossScaler() + else: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + + self.dynamic_loss_scale = True + + self.debug_fp16_grads = [{} for _ in self.fp16_groups] + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=False) + + @staticmethod + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = [] + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + + offset += tensor_numel + + gc.collect() + torch.cuda.empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer + + def _configure_offloading(self, offload_optimizer_config, offload_param_config): + ###################### offload optimizer setup ################################## + if offload_optimizer_config is not None: + self.offload_optimizer = True + self.offload_optimizer_pin_memory = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIN_MEMORY] + self.swap_optimizer = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE + self.offload_optimizer_fast_init = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_FAST_INIT] + + ###################### offload param setup ################################## + if offload_param_config is not None: + if not isinstance(self.optimizer, DummyOptim): + assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" + self.offload_param = True + self.offload_param_pin_memory = offload_param_config[ + OFFLOAD_PARAM_PIN_MEMORY] + self.params_in_nvme_and_cpu = offload_param_config[ + OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE + self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] + print_rank_0( + f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", + force=False) + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + if self.params_in_nvme_and_cpu: + remote_device = OFFLOAD_NVME_DEVICE + elif self.offload_param: + remote_device = OFFLOAD_CPU_DEVICE + else: + remote_device = None + + Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=remote_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu) + + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): + nvme_swap_folder = os.path.join( + offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], + 'zero_stage_3') + os.makedirs(nvme_swap_folder, exist_ok=True) + if torch.distributed.get_rank() == 0: + logger.info(f'Tensor Swapping: Adding optimizer tensors') + + swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper + + self.optimizer_swapper = swapper_type( + swap_config=offload_optimizer_config, + aio_config=aio_config, + base_folder=nvme_swap_folder, + optimizer=self.optimizer, + largest_numel=max(self.fp16_partitioned_groups_flat_numel), + device=self.device, + dtype=torch.float32, + timers=self.timers) + + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.__params_in_ipg_bucket) + + def _create_fp16_partitions(self): + dist.barrier() + partition_id = dist.get_rank(group=self.dp_process_group) + + # loop to deal with groups + for j, param_group in enumerate(self.optimizer.param_groups): + + sub_groups = self._create_fp16_sub_groups(param_group['params']) + for sub_group in sub_groups: + i = len(self.fp16_groups) + + # push this group to list before modify + self.fp16_groups.append(sub_group) + self.sub_group_to_group_id[i] = j + + #These are the list of the partitioned parameters + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in self.fp16_groups[i]]) + + print_rank_0( + f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" + ) + + # Record padding required to align group to world size (only applies to last rank) + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + padding = [p.padding_size() for p in self.fp16_groups[i]] + else: + padding = [0] * len(self.fp16_groups[i]) + self.groups_padding.append(padding) + + #not sure why apex was cloning the weights before flattening + #removing cloning here + see_memory_usage(f"Before Flattening param group {i}", force=False) + + if not self.offload_param: + see_memory_usage(f"Before moving param group {i} to CPU", + force=False) + #move all the parameters to cpu to free up GPU space for creating flat buffer + move_to_cpu(self.fp16_partitioned_groups[i]) + see_memory_usage(f"After moving param group {i} to CPU", force=False) + + #create flat buffer in CPU and move to GPU + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size(group=self.dp_process_group)).cuda( + torch.cuda.current_device())) + see_memory_usage( + f"After flattening and moving param group {i} to GPU", + force=False) + else: + #Without the detach, seems like the flattening becomes part of the + #model graph causing errors downstream + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size( + group=self.dp_process_group)).detach().pin_memory()) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + #set model fp16 weight to slices of flattened buffer + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): + partitioned_param.data = q.data + + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): + '''If flat buffer is None then the parameters in the param_list are + not copied to the flat buffer. This is because they excede the number of max_params_in_cpu + Some of these parameters may aready be in CPU in unflattened buffers + or they maybe in GPU, or they maybe in NVME. If they are in NVME, then + they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are + needed during training.''' + if flat_buffer is None: + # this dst buffer is on NVMe, so skip this + return + + start = 0 + for param in param_list: + src = param.ds_tensor + dest = flat_buffer.narrow(0, start, src.ds_numel) + start = start + src.ds_numel + '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' + if src.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" + ) + param.nvme_swapper.swap_into_buffer(param, dest) + src.data = dest.data + src.status = PartitionedParamStatus.AVAILABLE + else: + assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" + if not avoid_copy: + dest.data.copy_(src.data) + src.data = dest.data + + # Final location must be gpu/cpu in this case + param.ds_tensor.final_location = 'not-nvme' + + def _create_param_groups_fp16_flat_cpu_memory(self): + + aggregate_params_count = 0 + + for j, param_group in enumerate(self.optimizer.param_groups): + params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) + + flat_buffer_size = params_in_group + + if self.params_in_nvme_and_cpu and \ + aggregate_params_count + params_in_group > self.max_params_in_cpu: + + flat_buffer_size = max(0, + self.max_params_in_cpu - aggregate_params_count) + + aggregate_params_count += params_in_group + + if flat_buffer_size > 0: + print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", + force=False) + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(int(flat_buffer_size), + dtype=self.dtype, + pin_memory=True)) + else: + print_rank_0( + f"No flat buffer size. Param group size was {params_in_group}", + force=False) + + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(1, + dtype=self.dtype)) + + def _create_fp16_partitions_with_defragmentation(self): + dist.barrier() + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) + for param_group in self.optimizer.param_groups) + + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) + + # record sub group and partitions + self.fp16_groups.append(sub_group) + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append( + sum(p.ds_tensor.ds_numel for p in sub_group)) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size( + self.dp_process_group) - 1 + self.groups_padding.append([ + p.padding_size() if rank_requires_padding else 0 for p in sub_group + ]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + device_buffer = __class__.defragment(parameter_partitions) + + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + self.fp16_partitioned_groups_flat.append( + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) + print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") + #Flat buffer may not be available for parameters that reside in NVME + if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].numel(): + fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].narrow(0, + flat_offset, + total_elements) + print_rank_0( + f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", + force=False) + elif self.params_in_nvme_and_cpu: + fp16_partitioned_group_flat = None + print_rank_0( + f"No flat buffer for sub group {i} of {total_elements} elements", + force=False) + else: + assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" + + self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) + flat_offset += total_elements + + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any( + flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements + + assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( + largest_partition_numel) + + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): + offset = 0 + elements_in_sub_group = sum( + [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) + assert (flat_buffer.numel() == elements_in_sub_group) + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" + ) + param.nvme_swapper.swap_in([param], async_op=False) + dest.data.copy_(partitioned_param.data) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0(f"Swapping in {param.ds_id} done") + else: + dest.data.copy_(partitioned_param.data) + offset += partitioned_param.ds_numel + + def _create_next_swappable_fp32_groups(self): + reverse_order_indices = [ + i for i in range(len(self.fp32_partitioned_groups_flat)) + ] + reverse_order_indices.reverse() + + next_group = None + for i in reverse_order_indices: + self.next_swappable_fp32_partitioned_groups.append(next_group) + if self._swappable_optimizer_subgroup(i): + next_group = self.fp32_partitioned_groups_flat[i] + + self.next_swappable_fp32_partitioned_groups.reverse() + + def _get_sub_group_partitions(self, sub_group_id): + sub_group_partitions = [] + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_path = param.nvme_swapper.get_path(param, True) + sub_group_partitions.append((partitioned_param, + param.ds_tensor.ds_numel, + swap_path)) + else: + sub_group_partitions.append((partitioned_param, + partitioned_param.ds_numel, + None)) + + return sub_group_partitions + + def _create_fp32_partitions(self): + cpu_memory_usage = 0 + cpu_memory_sub_groups = 0 + nvme_memory_usage = 0 + num_swappable_partitions = 0 + num_swap_from_nvme_partitions = 0 + num_swap_from_cpu_partitions = 0 + swap_from_nvme_memory_usage = 0 + swap_from_cpu_memory_usage = 0 + GIGA_BYTES = (1024**3) + + swappable_fp32_tensors = [] + swappable_fp16_src_tensors = [] + nvme_fp16_partitions_info = [] + nvme_fp16_num_elems = [] + nvme_fp32_dest_tensors = [] + fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): + num_elements = self.fp16_partitioned_groups_flat_numel[i] + + # a partition of the fp32 master weights that will be updated by this process + if self._swappable_optimizer_subgroup(i): + self.fp32_partitioned_groups_flat.append(torch.Tensor()) + nvme_memory_usage += (fp32_element_size * num_elements) + num_swappable_partitions += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + num_swap_from_nvme_partitions += 1 + swap_from_nvme_memory_usage += (fp32_element_size * num_elements) + if self.offload_optimizer_fast_init: + sub_group_partitions = self._get_sub_group_partitions(i) + nvme_fp16_partitions_info.append(sub_group_partitions) + nvme_fp16_num_elems.append(num_elements) + nvme_fp32_dest_tensors.append( + self.fp32_partitioned_groups_flat[i]) + else: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.optimizer_swapper.initialize_parameters( + parameters=[self.fp32_partitioned_groups_flat[i]], + src_tensors=[unpinned_fp32_buffer]) + else: + num_swap_from_cpu_partitions += 1 + swap_from_cpu_memory_usage += (fp32_element_size * num_elements) + swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) + swappable_fp16_src_tensors.append( + self.fp16_partitioned_groups_flat[i]) + else: + cpu_memory_usage += (fp32_element_size * num_elements) + cpu_memory_sub_groups += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + else: + self.fp32_partitioned_groups_flat.append( + self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) + + self.fp32_partitioned_groups_flat[ + i].requires_grad = True # keep this in case internal optimizer uses it + + if len(swappable_fp32_tensors) > 0: + self.optimizer_swapper.initialize_parameters( + parameters=swappable_fp32_tensors, + src_tensors=swappable_fp16_src_tensors) + + if len(nvme_fp32_dest_tensors) > 0: + fp16_pinned_buffers = self.fp16_groups[0][ + 0].nvme_swapper.reserve_available_buffers() + assert len(fp16_pinned_buffers) > 0 + self.optimizer_swapper.initialize_from_swapped_fp16_params( + fp16_partitions_info=nvme_fp16_partitions_info, + fp16_num_elems=nvme_fp16_num_elems, + fp16_pinned_buffers=fp16_pinned_buffers, + fp32_parameters=nvme_fp32_dest_tensors) + self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() + + nvme_gigabytes = nvme_memory_usage / GIGA_BYTES + print_rank_0( + f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', + force=False) + if self.params_in_nvme_and_cpu: + print_rank_0( + f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + print_rank_0( + f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + + cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES + print_rank_0( + f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', + force=False) + + # Clear for on-the-fly population before the optimizer step + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partitioned_size() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + for param in params_group: + + sub_group.append(param) + local_sub_group_size += param.partitioned_size() + + if local_sub_group_size >= sub_group_size or id(param) == id( + params_group[-1]): + + sub_groups.append(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + # def reset_ds_tensor(self): + # for name, param in self.module.named_parameters(recurse=True): + # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" + # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" + # param.ds_tensor.data = param.data + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + + #reset step if in inference mode + @instrument_w_nvtx + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def persistent_parameters(self): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for _, param in self.module.named_parameters(recurse=True): + if param.ds_numel < self.persistence_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=False) + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + @instrument_w_nvtx + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module) + + @instrument_w_nvtx + def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + #print(f'convert output to {output}') + + for item in filter(lambda item: is_zero_param(item), output): + if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): + item.is_external_param = True + module_to_register = FWD_MODULE_STACK[-1] + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', + force=False) + register_external_parameter(module_to_register, item) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if id(item) in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}', + force=False) + unregister_external_parameter(module, item) + + item.all_gather() + + self.post_sub_module_forward_function(module) + + def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + return _apply_to_tensors_only(module, + PreBackwardFunction, + _run_before_backward_function, + output) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only( + module, + _run_before_forward_function, + _run_after_backward_hook, + inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + return _apply_to_tensors_only(module, + PostBackwardFunction, + _run_after_backward_function, + inputs) + + # Pre forward hook + module.register_forward_pre_hook(_pre_forward_module_hook) + # Post forward hook + module.register_forward_hook(_post_forward_module_hook) + + # Pre backward hook + module.register_forward_hook(_pre_backward_module_hook) + + # post backward hook + module.register_forward_pre_hook(_post_backward_module_hook) + + @torch.no_grad() + def pre_sub_module_forward_function(self, sub_module): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", + force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + + self.param_coordinator.fetch_sub_module(sub_module) + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after fetch", + force=False) + + @torch.no_grad() + def post_sub_module_forward_function(self, sub_module): + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.param_coordinator.release_sub_module(sub_module) + + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + @torch.no_grad() + def pre_sub_module_backward_function(self, sub_module): + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + self.param_coordinator.fetch_sub_module(sub_module) + + @torch.no_grad() + def post_sub_module_backward_function(self, sub_module): + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + self.param_coordinator.release_sub_module(sub_module) + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + if not self.offload_optimizer and self.is_gradient_accumulation_boundary: + self.grads_in_partition = None + + self.grads_in_partition_offset = 0 + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + + def _swappable_optimizer_subgroup(self, sub_group_id): + if not self.swap_optimizer: + return False + + return self.optimizer_swapper.swappable_tensor( + None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + + def _partitioned_params_swap_out(self, i): + offset = 0 + fp32_param = self.fp32_partitioned_groups_flat[i] + assert fp32_param is not None, \ + f'fp32 parameters of sub_group {i} is None' + + swap_fp16_params = [] + swap_fp32_params = [] + for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): + src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.AVAILABLE: + partitioned_param.data.copy_(src.data) + else: + swap_fp32_params.append(src) + swap_fp16_params.append(param) + offset += partitioned_param.ds_numel + + if len(swap_fp16_params): + swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( + dst_fp16_params=swap_fp16_params, + src_fp32_params=swap_fp32_params) + + def initialize_optimizer_states(self): + num_subgroups = len(self.fp16_groups) + + largest_numel = max( + [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) + gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), + dtype=gradient_dtype, + device=self.device) + + timers = self.timers + timer_names = set() + + if self.swap_optimizer: + self.optimizer_swapper.init_timers() + + INIT_OPTIMIZER_TIMER = 'init_optimizer_state' + timer_names.add(INIT_OPTIMIZER_TIMER) + self.start_timers([INIT_OPTIMIZER_TIMER]) + + for i, group in enumerate(self.fp16_groups): + swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) + swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None + + num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_in(i, timer_names) + + if self.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, + dtype=gradient_dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() + + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + else: + self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( + 0, + 0, + num_elements) + + self._optimizer_step(i) + + if swappable_param_subgroup: + self._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + self.stop_timers([INIT_OPTIMIZER_TIMER]) + self.log_timers(timer_names) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + if not self.offload_optimizer: + for group in self.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + total_partitions = dist.get_world_size(group=self.dp_process_group) + + for i, param_group in enumerate(self.fp16_groups): + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][ + partition_id] = self.get_first_param_index( + i, + param_group, + partition_id) + + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.__reduce_and_partition_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + self.__reduce_and_partition_stream.synchronize() + + # if dist.get_rank() == 0: + # logger.info("Params already reduced %s", self.params_already_reduced) + for i in range(len(self.params_already_reduced)): + self.params_already_reduced[i] = False + + #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad + #TODO: use a similar code path for both cpu_offload and non-cpu offload + if not self.offload_optimizer: + for i, sub_group in enumerate(self.fp16_groups): + self.averaged_gradients[i] = [ + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) + for param in sub_group + ] + # self.averaged_gradients[i] = self.get_flat_partition( + # self.fp16_groups[i], + # 0, + # self.fp32_partitioned_groups_flat[i].numel(), + # return_tensor_list=True) + + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + + def create_reduce_and_remove_grad_hooks(self): + print_rank_0(f'[Begin] Create gradient reduction hooks') + self.grad_accs = [] + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: + #print_rank_0(f" Before all gather {param.device}, {param.shape}") + + # The hook must be created in un-partitioned parameter + param.all_gather() + + #print(f"After all gather {param.device}, {param.shape}") + def wrapper(param, i): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + @instrument_w_nvtx + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param, i) + + grad_acc.register_hook(reduce_partition_and_remove_grads) + self.grad_accs.append(grad_acc) + + #print(f"param grad fn {param.expand_as(param).grad_fn}") + wrapper(param, i) + + # Partition the parameter after creating the hook + param.partition() + print_rank_0(f'[End] Create gradient reduction hooks') + + def get_param_id(self, param): + unique_id = id(param) + return self.param_id[unique_id] + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", + force=False) + + ###############Idependent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) + + # Because the ipg bucket is initialized with a random place holder tensor, we must + # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > + # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a + # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be + # empty, while reduction_list will have that garbage data. + if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", + param.ds_numel) + + self.__reduce_and_partition_ipg_grads() + + param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.__add_grad_to_ipg_bucket(param) + + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) + + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( + ) < self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with torch.cuda.stream(self.__reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( + 0, + self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + param.grad.record_stream(torch.cuda.current_stream()) + param.grad.data = new_grad_tensor + + self.__params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.__params_in_ipg_bucket: + return + + for param in self.__params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError( + f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") + + self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + + assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( + self.__params_in_ipg_bucket) + + while self.__param_reduce_events and self.__param_reduce_events[0].query(): + self.__param_reduce_events.popleft() + if len(self.__param_reduce_events) > self.__max_param_reduce_events: + self.__param_reduce_events.popleft().synchronize() + + with torch.cuda.stream(self.__reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks( + [p.ds_id for p in self.__params_in_ipg_bucket]) + + grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) + self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) + + self.__params_in_ipg_bucket.clear() + + event = Event() + event.record() + self.__param_reduce_events.append(event) + + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.allreduce_always_fp32: + full_grads_for_rank = [g.float() for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [ + g.div(self.gradient_predivide_factor) for g in full_grads_for_rank + ] + + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, + self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [ + g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank + ] + + if self.allreduce_always_fp32: + grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank + + def set_grad_positions(self): + for i, group in enumerate(self.fp16_groups): + current_offset = 0 + for param in group: + param_id = self.get_param_id(param) + num_elements = param.ds_tensor.ds_numel + + self.grad_position[param_id] = [ + int(i), + int(current_offset), + int(num_elements) + ] + #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") + current_offset += num_elements + + def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.double().norm(2)**2.0 + else: + norm += part.data.double().norm(2)**2.0 + return norm**0.5 + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) + #Using a more memory efficient version + self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): + with torch.cuda.stream(self.copy_grad_stream): + param_id = self.get_param_id(param) + src_tensor = param.grad.view(-1).float() + #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") + fp32_grad_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + if param_id in self.norm_for_param_grads.keys(): + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + @instrument_w_nvtx + def __partition_grads(self, + params_to_release: List[Parameter], + grad_partitions: List[Tensor]) -> None: + for param, grad_partition in zip(params_to_release, grad_partitions): + if param.ds_tensor.ds_numel * dist.get_rank( + self.dp_process_group) > param.ds_numel: + # this grad partition is empty - don't need to do anything + continue + + # move or accumulate gradient partition to target buffer + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( + 0, + 0, + grad_partition.numel()) + if self.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif grad_buffer.is_cuda: + grad_buffer.add_(grad_partition) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, + non_blocking=True) + cuda_grad_buffer.add_(grad_partition) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + if hasattr(self.__inf_or_nan_tracker, "logical_or_"): + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() + self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() + self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 + + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + offload_fp32_gradients = {} + offload_fp32_offsets = {} + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id( + param)] = self._constant_buffered_norm2(grad_buffer) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer) + + # free the gradient + param.grad.record_stream(torch.cuda.current_stream()) + param.grad = None + + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + + def reduce_ready_partitions_and_remove_grads(self, param, i): + #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) + self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + + def zero_reduced_gradients(self, partition_id, i): + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min( + total_elements - start, + self.partition_size[i] - + self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, + int(start), + int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow( + 0, + int(start), + int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zero_like(param) + + ######################Reduction Related Methods############################## + + def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): + rank = None + tensor = self.flatten(bucket) + + tensor_to_allreduce = tensor + + if pg_correctness_test: + allreduce_always_fp32 = True + + if allreduce_always_fp32: + tensor_to_allreduce = tensor.float() + + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + else: + global_rank = _get_global_rank(self.dp_process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + + if allreduce_always_fp32 and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None): + with torch.cuda.stream(self.reduction_stream): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain(self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None) + small_bucket = [] + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor): + partitions = [] + + dp = dist.get_world_size(group=self.dp_process_group) + dp_id = dist.get_rank(group=self.dp_process_group) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if (current_index >= start_index and current_index < end_index): + params_in_partition.append(tensor) + + elif start_index > current_index and start_index < (current_index + + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + @instrument_w_nvtx + def zero_grad(self, set_grads_to_None=True): + """ + Zero FP16 parameter grads. + """ + self.micro_step_id = 0 + + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if set_grads_to_None: + if p.grad is not None and p.grad.is_cuda: + p.grad.record_stream(torch.cuda.current_stream()) + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None: + pass + else: + torch.distributed.all_reduce(tensor=tensor, + op=op, + group=self.model_parallel_group) + + @instrument_w_nvtx + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] + for g, p in zip(gradients, params): + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda.item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, + tensor_list, + first_offset, + partition_size, + return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + for i, tensor in enumerate(tensor_list): + if tensor.grad is None: + tensor.grad = torch.zeros_like(tensor) + + tensor = tensor.grad + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow( + 0, + int(tensor_offset), + int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append( + torch.zeros(int(partition_size - current_size), + dtype=tensor_list[0].dtype, + device=tensor_list[0].device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + self.local_overflow = False + + def log_timers(self, timer_names): + if self.timers is None: + return + + self.timers.log(names=list(timer_names)) + + def start_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).start() + + def stop_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).stop() + + def _pre_step(self): + self.micro_step_id = 0 + + print_rank_0(f"Inside Step function") + see_memory_usage(f"In step before checking overflow", force=False) + + print_rank_0("Finished Tracing at Beginning of Step") + self.param_coordinator.hierarchy = 0 + + print_rank_0("Finished Tracing at Beginning of Step") + + @instrument_w_nvtx + def _get_norm_groups(self): + norm_groups = [] + for i, group in enumerate(self.fp16_groups): + if self.offload_optimizer: + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload( + self.fp16_groups[i])) + else: + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.fp16_groups[i])) + return norm_groups + + @instrument_w_nvtx + def _prepare_fp32_grad_for_sub_group(self, sub_group_id): + partition_id = dist.get_rank(group=self.dp_process_group) + + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) + + assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) + + self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition + + # release all the gradient since we have already created a necessary copy in dp_grad_partition + self.zero_grad() + + for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): + grad.record_stream(torch.cuda.current_stream()) + + self.averaged_gradients[sub_group_id] = None + + @instrument_w_nvtx + def _prepare_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', + force=False) + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) + elif not self.offload_optimizer: + self._prepare_fp32_grad_for_sub_group(sub_group_id) + see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', + force=False) + + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' + see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_IN_STATE]) + + self.optimizer_swapper.swap_in_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) + + self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) + timer_names.add(OPTIMIZER_SWAP_IN_STATE) + see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', + force=False) + + @instrument_w_nvtx + def _release_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before release optimizer sub group {sub_group_id}', + force=False) + # get rid of the fp32 gradients. Not needed anymore + if not self.offload_optimizer: + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) + see_memory_usage(f'After release optimizer sub group {sub_group_id}', + force=False) + + # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' + see_memory_usage( + f'post-step Before swapping out optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) + + self.optimizer_swapper.swap_out_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is + not None) + + self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) + see_memory_usage( + f'post-step After swapping out optimizer tensors {sub_group_id}', + force=False) + timer_names.add(OPTIMIZER_SWAP_OUT_STATE) + + # get rid of the fp32 gradients. Not needed anymore + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _unflatten_partitioned_parameters(self, sub_group_id): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _overflow_clean_up(self, prev_scale): + see_memory_usage('After overflow before clearing gradients', force=False) + self.zero_grad() + + if self.offload_optimizer: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients', force=False) + + if torch.distributed.get_rank() == 0: + logger.info( + "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " + "reducing to {}".format(dist.get_rank(), + prev_scale, + self.loss_scale)) + + @instrument_w_nvtx + def _overflow_check_and_loss_scale_update(self): + + # First compute norm for all group so we know if there is overflow + self.check_overflow() + + #loss scaling related computation + prev_scale = self.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self._overflow_clean_up(prev_scale) + + return self.overflow + + @instrument_w_nvtx + def _post_step(self, timer_names=set()): + if self.offload_optimizer: + self.reset_cpu_buffers() + + #Gathering persisting parameters + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + self.log_timers(timer_names) + + see_memory_usage('After zero_optimizer step', force=False) + print_rank_0(f"------------------Finishing Step-----------------------") + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + + #unflatten fp16 parameter subgroup + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + self._pre_step() + self._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if self._overflow_check_and_loss_scale_update(): + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + return + + norm_groups = self._get_norm_groups() + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + + timer_names = set() + + timer_names.add('optimizer_step') + self.start_timers(['optimizer_step']) + + #update parameters one sub group at a time + for sub_group_id, group in enumerate(self.fp16_groups): + + #prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + #scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) + + #apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + #put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + #release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.stop_timers(['optimizer_step']) + + self._post_step(timer_names) + + # warn user about caching allocator flushes + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( + torch.cuda, + "memory_stats") else 0 + if alloc_retries > self.__n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - self.__n_caching_allocator_flushes) + self.__n_caching_allocator_flushes = alloc_retries + + def dump_pre_step_gradients(self, debug_fp32_grads): + # Dump gradient norms for debugging + for i, _ in enumerate(self.fp16_groups): + print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') + for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): + param_id = self.get_param_id(fp16_param) + fp16_grad_norm = self.debug_fp16_grads[i][param_id] + + fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] + norm_list = [fp16_grad_norm, fp32_grad_norm] + print(f'Pre-Step Norms {i} {param_id} = {norm_list}') + + def dump_post_step_gradients(self): + # Dump gradient norms for debugging + for i, group in enumerate(self.fp16_groups): + print( + f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], + self.fp16_groups[i]) + for j, p in enumerate(self.fp16_groups[i]): + param_id = self.get_param_id(p) + param_norm = float(p.data.float().norm(2)) + ds_norm = float(p.ds_tensor.data.float().norm(2)) + + unflat_norm = [ + float(t.data.float().norm(2)) + for t in [unflat_fp16[j], + unflat_fp32[j]] + ] + norm_list = [param_norm, ds_norm] + unflat_norm + print(f'Post-Step Norms {i} {param_id} = {norm_list}') + + @instrument_w_nvtx + def unscale_and_clip_grads(self, sub_group_id, total_norm): + grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] + + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + if clip > 1: + combined_scale = clip * self.loss_scale + # to maintain behavior of averaging over accumulation steps + combined_scale *= self.micro_step_id + 1 + + for grad in grad_groups_flat: + if isinstance(grad, list): + sub_partitions = grad + for g in sub_partitions: + g.data.mul_(1. / combined_scale) + else: + grad.data.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.fp16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False + + @instrument_w_nvtx + def has_overflow(self, partition_gradients=True): + if partition_gradients: + with torch.cuda.stream(self.__reduce_and_partition_stream): + self.local_overflow = bool(self.__inf_or_nan_tracker.item()) + self.__inf_or_nan_tracker.zero_() + + overflow = self.local_overflow + #overflow = self.has_overflow_partitioned_grads_serial() + overflow_gpu = torch.cuda.ByteTensor([overflow]) + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + else: + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + + overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, + op=torch.distributed.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + @instrument_w_nvtx + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + if self.swap_optimizer: + self.optimizer_swapper.pre_backward() + + see_memory_usage(f"Before backward", force=False) + + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.param_coordinator.reset_step() + + if self.swap_optimizer: + self.optimizer_swapper.post_backward() + + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + self.__reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, + dest_offset, + num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + @instrument_w_nvtx + def _partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.param_coordinator.release_and_reset_all() + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): + # Remove paddings from flattened tensor + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) + lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] + lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] + #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') + return lean_tensors + + #TODO REVISIT this for stage 3 + def get_lean_optimizer_state(self): + # Return optimizer states after removing paddings. + # This method assumes that each param group contains a single flattened tensor. + optimizer_groups_state = [] + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_state = {} + for key, value in self.optimizer.state[p].items(): + if torch.is_tensor(value): + padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] + lean_state[key] = self._get_lean_tensors( + value, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + lean_flat_len = sum([t.numel() for t in lean_state[key]]) + else: + lean_state[key] = value + + optimizer_groups_state.append(lean_state) + + return optimizer_groups_state + + def get_groups_without_padding(self, groups_with_padding): + # Return group tensor after removing paddings added for alignment to DP world size. + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_group = self._get_lean_tensors(group, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + groups_without_padding.append(lean_group) + + return groups_without_padding + + def _set_fp32_optimizer_param_groups(self): + for sub_group_id, _ in enumerate(self.fp16_groups): + param_group_id = self.sub_group_to_group_id[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) + + def _clear_fp32_optimizer_param_groups(self): + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _rigid_state_dict(self): + state_dict = {} + state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['partition_count'] = self.partition_count + + self._set_fp32_optimizer_param_groups() + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat + self._clear_fp32_optimizer_param_groups() + + return state_dict + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + return self._rigid_state_dict() + + +# Restore base optimizer fp32 weights from checkpoint by: +# 1) Merging fp32 weights from checkpoints of all partitions +# 2) Extracting fp32 weights for current partition from merged weights +# 3) Using extracted weights to update base optimizer weights directly. + + def _restore_from_fp32_weights(self, all_state_dict): + + flat_local_partition = [] + for i in range(len(self.fp32_partitioned_groups_flat)): + merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] + flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + + for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 weights + def _restore_from_fp16_weights(self): + for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + # Refresh the fp32 master params from the fp16 copies. + def refresh_fp32_params(self): + self._restore_from_fp16_weights() + + # Extract flattened partition for current rank from all partitions + def _get_flattened_partition(self, all_partition_states): + partition_id = dist.get_rank(group=self.dp_process_group) + alignment = dist.get_world_size(group=self.dp_process_group) + + param_partitions = [[] for _ in range(len(all_partition_states[0]))] + for i, partition in enumerate(all_partition_states): + for j, param in enumerate(partition): + param_partitions[j].append(param) + + local_state_partitions = [] + for param_index, param_slices in enumerate(param_partitions): + flattened_merged_tensor = self.flatten_dense_tensors_aligned( + param_slices, + alignment) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + local_state_partitions.append(new_partitions[partition_id]) + + if torch.is_tensor(local_state_partitions[0]): + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) + + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return local_state_partitions[0] + + # Restore base optimizer state from checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [ + sd['base_optimizer_state'][i] for sd in all_state_dict + ] + for key in all_partition_group_states[0].keys(): + all_partition_states = [ + all_states[key] for all_states in all_partition_group_states + ] + partition_states[key] = self._get_flattened_partition( + all_partition_states) + base_optimizer_group_states.append(partition_states) + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + for key, saved in base_optimizer_group_states[i].items(): + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved + + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + + if load_optimizer_states: + self._set_fp32_optimizer_param_groups() + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + self._clear_fp32_optimizer_param_groups() + + # restore fp32 partitions + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): + curr_param.data.copy_(saved_param.data) + + # restore fp16 partitions from fp32 + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + # update fp16 unflattened params + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten( + self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + # TODO: Support different/changing load/save DP degree. + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + r"""Loading a ZeRO checkpoint + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + self._rigid_load_state_dict( + state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + def save_checkpoint_prologue(self): + self._partition_all_parameters() + + def save_checkpoint_epilogue(self): + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = torch.distributed.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info( + f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" + ) + + +def estimate_zero3_model_states_mem_needs(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + cpu_offload_params=True, + zero_init=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + gpus_factor = 1 / num_nodes + largest_layer_memory = (4 * largest_layer_params) + + if cpu_offload: + if cpu_offload_params: + gpu_mem = largest_layer_memory + + if zero_init: + cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 18 * gpus_factor) * additional_buffer_factor + + else: + gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) + + if zero_init: + cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 16 * gpus_factor) * additional_buffer_factor + else: + gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) + if zero_init: + cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor + else: + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem), largest_layer_memory + + +def model_to_params(model): + # shared params calculated only once + total_params = sum( + dict((p.data_ptr(), + p.numel()) for p in model.parameters()).values()) + + largest_layer_params = 0 + for m in model.modules(): + # assuming no shared params within a single layer + layer_params = sum(p.numel() for p in m.parameters(recurse=False)) + largest_layer_params = max(largest_layer_params, layer_params) + + return total_params, largest_layer_params + + +import math + + +def estimate_zero3_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params, largest_layer_params = model_to_params(model) + + estimate_zero3_model_states_mem_needs_all_cold( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero3_model_states_mem_needs_all_cold(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``largest_layer_params``: largest layer's params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + def format_options(cpu_offload, cpu_offload_params, zero_init): + enabled = [] + padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' + param_device = padded_cpu_str if cpu_offload_params else "none" + enabled.append(f"{OFFLOAD_PARAM}={param_device}") + optimizer_device = padded_cpu_str if cpu_offload else "none" + enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") + enabled.append(f"zero_init={1 if zero_init else 0}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print( + "Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." + ) + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + for cpu_offload_params in [True, False]: + if not cpu_offload and cpu_offload_params: + continue + for zero_init in [True, False]: + cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init, + additional_buffer_factor=additional_buffer_factor + ) + + options_str = format_options(cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init) + print( + f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") From 4b0133b6afe13570baa9aadc46f8d2e522636653 Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 29 Nov 2021 16:27:24 -0800 Subject: [PATCH 41/59] fix to LF line endings --- .clang-format | 310 +- .github/workflows/formatting.yml | 70 +- .pre-commit-config.yaml | 3 + CODE_OF_CONDUCT.md | 18 +- LICENSE | 42 +- SECURITY.md | 82 +- csrc/adagrad/cpu_adagrad.cpp | 454 +- csrc/adam/cpu_adam.cpp | 584 +- csrc/aio/common/deepspeed_aio_common.cpp | 666 +- csrc/aio/common/deepspeed_aio_common.h | 72 +- csrc/aio/common/deepspeed_aio_types.cpp | 148 +- csrc/aio/common/deepspeed_aio_types.h | 114 +- csrc/aio/common/deepspeed_aio_utils.cpp | 246 +- csrc/aio/common/deepspeed_aio_utils.h | 154 +- csrc/aio/py_lib/deepspeed_aio_thread.cpp | 168 +- csrc/aio/py_lib/deepspeed_aio_thread.h | 114 +- csrc/aio/py_lib/deepspeed_py_aio.cpp | 242 +- csrc/aio/py_lib/deepspeed_py_aio.h | 54 +- csrc/aio/py_lib/deepspeed_py_aio_handle.cpp | 564 +- csrc/aio/py_lib/deepspeed_py_aio_handle.h | 136 +- csrc/aio/py_lib/deepspeed_py_copy.cpp | 266 +- csrc/aio/py_lib/deepspeed_py_copy.h | 84 +- csrc/aio/py_lib/py_ds_aio.cpp | 82 +- csrc/aio/py_test/ds_aio_basic.py | 288 +- csrc/aio/py_test/ds_aio_handle.py | 352 +- csrc/aio/py_test/parse_aio_stats.py | 308 +- csrc/aio/py_test/test_ds_aio.py | 202 +- csrc/aio/py_test/test_ds_aio_utils.py | 118 +- csrc/common/custom_cuda_kernel.cu | 78 +- csrc/includes/Timer.h | 94 +- csrc/includes/cpu_adagrad.h | 294 +- csrc/includes/cpu_adam.h | 444 +- csrc/includes/dropout.h | 152 +- csrc/includes/feed_forward.h | 186 +- csrc/includes/gelu.h | 72 +- csrc/includes/gemm_test.h | 586 +- csrc/includes/general_kernels.h | 94 +- csrc/includes/normalize_layer.h | 404 +- csrc/includes/simd.h | 274 +- csrc/includes/softmax.h | 120 +- csrc/includes/strided_batch_gemm.h | 358 +- csrc/transformer/dropout_kernels.cu | 1736 ++--- csrc/transformer/gelu_kernels.cu | 660 +- csrc/transformer/general_kernels.cu | 822 +- csrc/transformer/inference/csrc/dequantize.cu | 220 +- csrc/transformer/inference/csrc/gelu.cu | 532 +- csrc/transformer/inference/csrc/normalize.cu | 852 +-- .../transformer/inference/csrc/pt_binding.cpp | 1260 ++-- csrc/transformer/inference/csrc/softmax.cu | 864 +-- csrc/transformer/inference/includes/context.h | 224 +- .../inference/includes/cublas_wrappers.h | 416 +- .../inference/includes/custom_cuda_layers.h | 158 +- csrc/transformer/normalize_kernels.cu | 4206 +++++------ csrc/transformer/softmax_kernels.cu | 1190 +-- csrc/transformer/transform_kernels.cu | 1150 +-- deepspeed/launcher/constants.py | 18 +- deepspeed/launcher/multinode_runner.py | 454 +- deepspeed/module_inject/module_quantize.py | 160 +- deepspeed/module_inject/replace_policy.py | 478 +- deepspeed/ops/adagrad/cpu_adagrad.py | 270 +- deepspeed/ops/adam/__init__.py | 4 +- deepspeed/ops/adam/cpu_adam.py | 372 +- deepspeed/ops/aio/__init__.py | 12 +- .../bert_sparse_self_attention.py | 156 +- .../activation_checkpointing/checkpointing.py | 1824 ++--- .../activation_checkpointing/config.py | 206 +- deepspeed/runtime/config_utils.py | 160 +- deepspeed/runtime/eigenvalue.py | 304 +- deepspeed/runtime/progressive_layer_drop.py | 66 +- deepspeed/runtime/quantize.py | 448 +- deepspeed/runtime/zero/stage3.py | 6702 ++++++++--------- docs/README.md | 98 +- docs/_posts/2021-03-08-zero3-offload.md | 200 +- ...021-05-05-inference-kernel-optimization.md | 146 +- docs/_tutorials/mixture-of-experts.md | 394 +- docs/_tutorials/progressive_layer_dropping.md | 310 +- docs/_tutorials/zero-offload.md | 150 +- docs/_tutorials/zero.md | 602 +- docs/code-docs/source/schedulers.rst | 50 +- tests/perf/adam_test.py | 48 +- tests/perf/adam_test1.py | 44 +- tests/unit/ds_batch_config.json | 30 +- tests/unit/modelingpreln.py | 3384 ++++----- tests/unit/test_aio.py | 670 +- tests/unit/test_cpu_adagrad.py | 250 +- tests/unit/test_cpu_adam.py | 124 +- tests/unit/test_onebit.py | 1840 ++--- tests/unit/test_pld.py | 234 +- 88 files changed, 21799 insertions(+), 21796 deletions(-) diff --git a/.clang-format b/.clang-format index 9f90836e1f2a..38790238fcd2 100755 --- a/.clang-format +++ b/.clang-format @@ -1,155 +1,155 @@ ---- -# Refer to the following link for the explanation of each params: -# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html -Language: Cpp -# BasedOnStyle: Google -AccessModifierOffset: -4 -AlignAfterOpenBracket: Align -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Left -AlignOperands: true -AlignTrailingComments: true -AllowAllParametersOfDeclarationOnNextLine: false -AllowShortBlocksOnASingleLine: true -AllowShortCaseLabelsOnASingleLine: true -AllowShortFunctionsOnASingleLine: All -AllowShortIfStatementsOnASingleLine: true -AllowShortLoopsOnASingleLine: true -# This is deprecated -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: true -BinPackArguments: false -BinPackParameters: false -BraceWrapping: - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false - # disabling the below splits, else, they'll just add to the vertical length of source files! - SplitEmptyFunction: false - SplitEmptyRecord: false - SplitEmptyNamespace: false -BreakBeforeBinaryOperators: None -BreakBeforeBraces: WebKit -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 100 -CommentPragmas: '^ IWYU pragma:' -CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: false -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Preserve -IncludeCategories: - - Regex: '^' - Priority: 2 - - Regex: '^<.*\.h>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IndentCaseLabels: true -IndentPPDirectives: None -IndentWidth: 4 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Never -ObjCBlockIndentWidth: 4 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 4 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -RawStringFormats: - - Language: Cpp - Delimiters: - - cc - - CC - - cpp - - Cpp - - CPP - - 'c++' - - 'C++' - CanonicalDelimiter: '' - - Language: TextProto - Delimiters: - - pb - - PB - - proto - - PROTO - EnclosingFunctions: - - EqualsProto - - EquivToProto - - PARSE_PARTIAL_TEXT_PROTO - - PARSE_TEST_PROTO - - PARSE_TEXT_PROTO - - ParseTextOrDie - - ParseTextProtoOrDie - CanonicalDelimiter: '' - BasedOnStyle: google -# Enabling comment reflow causes doxygen comments to be messed up in their formats! -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 2 -SpacesInAngles: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Cpp11 -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -# Be consistent with indent-width, even for people who use tab for indentation! -TabWidth: 4 -UseTab: Never +--- +# Refer to the following link for the explanation of each params: +# http://releases.llvm.org/8.0.0/tools/clang/docs/ClangFormatStyleOptions.html +Language: Cpp +# BasedOnStyle: Google +AccessModifierOffset: -4 +AlignAfterOpenBracket: Align +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Left +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: true +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: All +AllowShortIfStatementsOnASingleLine: true +AllowShortLoopsOnASingleLine: true +# This is deprecated +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false + # disabling the below splits, else, they'll just add to the vertical length of source files! + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: WebKit +BreakBeforeInheritanceComma: false +BreakInheritanceList: BeforeColon +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 100 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^' + Priority: 2 + - Regex: '^<.*\.h>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Never +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 4 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' + - Language: TextProto + Delimiters: + - pb + - PB + - proto + - PROTO + EnclosingFunctions: + - EqualsProto + - EquivToProto + - PARSE_PARTIAL_TEXT_PROTO + - PARSE_TEST_PROTO + - PARSE_TEXT_PROTO + - ParseTextOrDie + - ParseTextProtoOrDie + CanonicalDelimiter: '' + BasedOnStyle: google +# Enabling comment reflow causes doxygen comments to be messed up in their formats! +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION +# Be consistent with indent-width, even for people who use tab for indentation! +TabWidth: 4 +UseTab: Never diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index b23e0910ab1f..4d5628768d36 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -1,35 +1,35 @@ -name: Formatting - -on: - push: - branches: - - 'master' - - 'staging**' - pull_request: - branches: - '**' - -jobs: - - # formatting and basic install on cpu-only machine - formatting: - runs-on: ubuntu-20.04 - - steps: - - uses: actions/checkout@v2 - - - name: environment - run: | - which python - python --version - pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - python -c "import torch; print('torch:', torch.__version__, torch)" - - - name: Install deepspeed - run: | - pip install .[dev,autotuning] - ds_report - - - name: Formatting checks - run: | - pre-commit run --all-files +name: Formatting + +on: + push: + branches: + - 'master' + - 'staging**' + pull_request: + branches: + '**' + +jobs: + + # formatting and basic install on cpu-only machine + formatting: + runs-on: ubuntu-20.04 + + steps: + - uses: actions/checkout@v2 + + - name: environment + run: | + which python + python --version + pip install torch==1.9.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + python -c "import torch; print('torch:', torch.__version__, torch)" + + - name: Install deepspeed + run: | + pip install .[dev,autotuning] + ds_report + + - name: Formatting checks + run: | + pre-commit run --all-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 571105d41230..76f237aacdb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,9 @@ repos: - id: end-of-file-fixer exclude: "examples/" exclude: "docs/CNAME" + - id: mixed-line-ending + exclude: "DeepSpeedExamples/" + args: [--fix=lf] - repo: https://github.com/pre-commit/mirrors-yapf diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c72a5749c52a..f9ba8cf65f3e 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -1,9 +1,9 @@ -# Microsoft Open Source Code of Conduct - -This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). - -Resources: - -- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) -- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) -- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns +# Microsoft Open Source Code of Conduct + +This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). + +Resources: + +- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) +- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) +- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns diff --git a/LICENSE b/LICENSE index 3d8b93bc7987..9e841e7a26e4 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ - MIT License - - Copyright (c) Microsoft Corporation. - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in all - copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - SOFTWARE + MIT License + + Copyright (c) Microsoft Corporation. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE diff --git a/SECURITY.md b/SECURITY.md index 7ab49eb82964..e0dfff56a956 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,41 +1,41 @@ - - -## Security - -Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). - -If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. - -## Reporting Security Issues - -**Please do not report security vulnerabilities through public GitHub issues.** - -Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). - -If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). - -You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). - -Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: - - * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) - * Full paths of source file(s) related to the manifestation of the issue - * The location of the affected source code (tag/branch/commit or direct URL) - * Any special configuration required to reproduce the issue - * Step-by-step instructions to reproduce the issue - * Proof-of-concept or exploit code (if possible) - * Impact of the issue, including how an attacker might exploit the issue - -This information will help us triage your report more quickly. - -If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. - -## Preferred Languages - -We prefer all communications to be in English. - -## Policy - -Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). - - + + +## Security + +Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). + +If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below. + +## Reporting Security Issues + +**Please do not report security vulnerabilities through public GitHub issues.** + +Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). + +If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). + +You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). + +Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: + + * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) + * Full paths of source file(s) related to the manifestation of the issue + * The location of the affected source code (tag/branch/commit or direct URL) + * Any special configuration required to reproduce the issue + * Step-by-step instructions to reproduce the issue + * Proof-of-concept or exploit code (if possible) + * Impact of the issue, including how an attacker might exploit the issue + +This information will help us triage your report more quickly. + +If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. + +## Preferred Languages + +We prefer all communications to be in English. + +## Policy + +Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). + + diff --git a/csrc/adagrad/cpu_adagrad.cpp b/csrc/adagrad/cpu_adagrad.cpp index 607072dec1b9..4f2a9b69ef96 100644 --- a/csrc/adagrad/cpu_adagrad.cpp +++ b/csrc/adagrad/cpu_adagrad.cpp @@ -1,227 +1,227 @@ -#include "cpu_adagrad.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" - -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adagrad_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) { - float step_size = -1 * _alpha; - __half* grads_cast_h; - __half* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast<__half*>(grads); - params_cast_h = reinterpret_cast<__half*>(_params); - } - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = grads[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0) { grad = param * _weight_decay + grad; } - - variance += grad * grad; - - grad = sqrt(variance); - grad += _eps; - grad = momentum / grad; - param = grad * step_size + param; - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; - - if (half_precision) - params_cast_h[k] = (__half)param; - else - _params[k] = param; - // STORE UPDATE TERM TO GRAD'S MEMORY - grads[k] = grad * step_size; - _exp_avg_sq[k] = variance; - } - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - _buf_index = !_buf_index; - } - } - } -} - -void Adagrad_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adagrad_optimizer(int optimizer_id, - float alpha = 1e-2, - float eps = 1e-8, - float weight_decay = 0, - bool should_log = false) -{ - auto opt = std::make_shared(alpha, eps, weight_decay); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay); - } - - return 0; -} - -void Adagrad_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>( - &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adagrad_step(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0)); - - opt->SynchronizeStreams(); - return 0; -} - -int ds_adagrad_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float epsilon, - float weight_decay, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step); - opt->update_state(lr, epsilon, weight_decay); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_sq_ptr, - params_c.size(0), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int destroy_adagrad_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); - m.def("adagrad_update_copy", - &ds_adagrad_step_plus_copy, - "DeepSpeed CPU Adagrad update and param copy (C++)"); - m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); - m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); -} +#include "cpu_adagrad.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adagrad_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) { + float step_size = -1 * _alpha; + __half* grads_cast_h; + __half* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + params_cast_h = reinterpret_cast<__half*>(_params); + } + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = grads[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0) { grad = param * _weight_decay + grad; } + + variance += grad * grad; + + grad = sqrt(variance); + grad += _eps; + grad = momentum / grad; + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + if (half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + // STORE UPDATE TERM TO GRAD'S MEMORY + grads[k] = grad * step_size; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + _buf_index = !_buf_index; + } + } + } +} + +void Adagrad_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adagrad_optimizer(int optimizer_id, + float alpha = 1e-2, + float eps = 1e-8, + float weight_decay = 0, + bool should_log = false) +{ + auto opt = std::make_shared(alpha, eps, weight_decay); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adagrad Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, weight_decay=%f\n", alpha, weight_decay); + } + + return 0; +} + +void Adagrad_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>( + &rounded_size, _params, grads, _exp_avg_sq, _param_size, dev_params, half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adagrad_step(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step); + opt->update_state(lr, epsilon, weight_decay); + opt->Step_8(params_ptr, grads_ptr, exp_avg_sq_ptr, params_c.size(0)); + + opt->SynchronizeStreams(); + return 0; +} + +int ds_adagrad_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float epsilon, + float weight_decay, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step); + opt->update_state(lr, epsilon, weight_decay); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_sq_ptr, + params_c.size(0), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int destroy_adagrad_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adagrad_update", &ds_adagrad_step, "DeepSpeed CPU Adagrad update (C++)"); + m.def("adagrad_update_copy", + &ds_adagrad_step_plus_copy, + "DeepSpeed CPU Adagrad update and param copy (C++)"); + m.def("create_adagrad", &create_adagrad_optimizer, "DeepSpeed CPU Adagrad (C++)"); + m.def("destroy_adagrad", &destroy_adagrad_optimizer, "DeepSpeed CPU Adagrad destroy (C++)"); +} diff --git a/csrc/adam/cpu_adam.cpp b/csrc/adam/cpu_adam.cpp index b9d993148128..727eec8182c1 100644 --- a/csrc/adam/cpu_adam.cpp +++ b/csrc/adam/cpu_adam.cpp @@ -1,292 +1,292 @@ -#include "cpu_adam.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" -#include "custom_cuda_layers.h" - -static std::unordered_map> s_optimizers; - -// C++ interface - -void Adam_Optimizer::Step_1(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<1>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) { - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - - float step_size = -1 * _alpha / _bias_correction1; - float w_decay = -1 * _alpha * _weight_decay; - __half* grads_cast_h; - __half* params_cast_h; - if (half_precision) { - grads_cast_h = reinterpret_cast<__half*>(grads); - params_cast_h = reinterpret_cast<__half*>(_params); - } - - for (size_t t = rounded_size; t < _param_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > _param_size) copy_size = _param_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } - -#pragma omp parallel for - for (size_t k = t; k < offset; k++) { - float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; - float param = half_precision ? (float)params_cast_h[k] : _params[k]; - float momentum = _exp_avg[k]; - float variance = _exp_avg_sq[k]; - if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } - momentum = momentum * _betta1; - momentum = grad * betta1_minus1 + momentum; - - variance = variance * _betta2; - grad = grad * grad; - variance = grad * betta2_minus1 + variance; - - grad = sqrt(variance); - grad = grad * _bias_correction2 + _eps; - grad = momentum / grad; - if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } - param = grad * step_size + param; - if (dev_params) _doubled_buffer[_buf_index][k - t] = param; - - if (half_precision) - params_cast_h[k] = (__half)param; - else - _params[k] = param; - _exp_avg[k] = momentum; - _exp_avg_sq[k] = variance; - } - if (dev_params) { - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - } -} - -void Adam_Optimizer::Step_4(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<4>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_1((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int create_adam_optimizer(int optimizer_id, - float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true, - bool should_log = false) -{ - auto opt = - std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); - - s_optimizers[optimizer_id] = opt; - - if (should_log) { - std::string avx_type = ""; -#if defined(__AVX512__) - avx_type = "AVX512"; -#else -#if defined(__AVX256__) - avx_type = "AVX2"; -#else - avx_type = "scalar"; -#endif -#endif - - printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", - optimizer_id, - avx_type.c_str()); - printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", - alpha, - betta1, - betta2, - weight_decay, - (int)adamw_mode); - } - - return 0; -} - -void Adam_Optimizer::Step_8(float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t rounded_size = 0; -#if defined(__AVX512__) or defined(__AVX256__) - Step_AVX<8>(&rounded_size, - _params, - grads, - _exp_avg, - _exp_avg_sq, - _param_size, - dev_params, - half_precision); -#endif - if (_param_size > rounded_size) - Step_4((_params + rounded_size), - (grads + rounded_size), - (_exp_avg + rounded_size), - (_exp_avg_sq + rounded_size), - (_param_size - rounded_size), - (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), - half_precision); -} - -int ds_adam_step(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq) -{ - auto params_c = params.contiguous(); - auto grads_c = grads.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - - // assert(params.options().dtype() == grads.options().dtype()); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.size(0), - nullptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int ds_adam_step_plus_copy(int optimizer_id, - size_t step, - float lr, - float beta1, - float beta2, - float epsilon, - float weight_decay, - bool bias_correction, - torch::Tensor& params, - torch::Tensor& grads, - torch::Tensor& exp_avg, - torch::Tensor& exp_avg_sq, - torch::Tensor& gpu_params) -{ - auto params_c = params.contiguous(); - auto gpu_params_c = gpu_params.contiguous(); - auto exp_avg_c = exp_avg.contiguous(); - auto exp_avg_sq_c = exp_avg_sq.contiguous(); - auto grads_c = grads.contiguous(); - - float* params_ptr = (float*)params_c.data_ptr(); - float* grads_ptr = (float*)grads_c.data_ptr(); - __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); - float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); - float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); - - std::shared_ptr opt = - std::static_pointer_cast(s_optimizers[optimizer_id]); - opt->IncrementStep(step, beta1, beta2); - opt->update_state(lr, epsilon, weight_decay, bias_correction); - opt->Step_8(params_ptr, - grads_ptr, - exp_avg_ptr, - exp_avg_sq_ptr, - params_c.size(0), - gpu_params_ptr, - (params.options().dtype() == at::kHalf)); - - opt->SynchronizeStreams(); - return 0; -} - -int destroy_adam_optimizer(int optimizer_id) -{ - s_optimizers.erase(optimizer_id); - - return 0; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); - m.def("adam_update_copy", - &ds_adam_step_plus_copy, - "DeepSpeed CPU Adam update and param copy (C++)"); - m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); - m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); -} +#include "cpu_adam.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" +#include "custom_cuda_layers.h" + +static std::unordered_map> s_optimizers; + +// C++ interface + +void Adam_Optimizer::Step_1(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<1>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) { + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + + float step_size = -1 * _alpha / _bias_correction1; + float w_decay = -1 * _alpha * _weight_decay; + __half* grads_cast_h; + __half* params_cast_h; + if (half_precision) { + grads_cast_h = reinterpret_cast<__half*>(grads); + params_cast_h = reinterpret_cast<__half*>(_params); + } + + for (size_t t = rounded_size; t < _param_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > _param_size) copy_size = _param_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } + +#pragma omp parallel for + for (size_t k = t; k < offset; k++) { + float grad = half_precision ? (float)grads_cast_h[k] : grads[k]; + float param = half_precision ? (float)params_cast_h[k] : _params[k]; + float momentum = _exp_avg[k]; + float variance = _exp_avg_sq[k]; + if (_weight_decay > 0 && !_adamw_mode) { grad = param * _weight_decay + grad; } + momentum = momentum * _betta1; + momentum = grad * betta1_minus1 + momentum; + + variance = variance * _betta2; + grad = grad * grad; + variance = grad * betta2_minus1 + variance; + + grad = sqrt(variance); + grad = grad * _bias_correction2 + _eps; + grad = momentum / grad; + if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; } + param = grad * step_size + param; + if (dev_params) _doubled_buffer[_buf_index][k - t] = param; + + if (half_precision) + params_cast_h[k] = (__half)param; + else + _params[k] = param; + _exp_avg[k] = momentum; + _exp_avg_sq[k] = variance; + } + if (dev_params) { + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, (copy_size), _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + } +} + +void Adam_Optimizer::Step_4(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<4>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_1((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int create_adam_optimizer(int optimizer_id, + float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true, + bool should_log = false) +{ + auto opt = + std::make_shared(alpha, betta1, betta2, eps, weight_decay, adamw_mode); + + s_optimizers[optimizer_id] = opt; + + if (should_log) { + std::string avx_type = ""; +#if defined(__AVX512__) + avx_type = "AVX512"; +#else +#if defined(__AVX256__) + avx_type = "AVX2"; +#else + avx_type = "scalar"; +#endif +#endif + + printf("Adam Optimizer #%d is created with %s arithmetic capability.\n", + optimizer_id, + avx_type.c_str()); + printf("Config: alpha=%f, betas=(%f, %f), weight_decay=%f, adam_w=%d\n", + alpha, + betta1, + betta2, + weight_decay, + (int)adamw_mode); + } + + return 0; +} + +void Adam_Optimizer::Step_8(float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t rounded_size = 0; +#if defined(__AVX512__) or defined(__AVX256__) + Step_AVX<8>(&rounded_size, + _params, + grads, + _exp_avg, + _exp_avg_sq, + _param_size, + dev_params, + half_precision); +#endif + if (_param_size > rounded_size) + Step_4((_params + rounded_size), + (grads + rounded_size), + (_exp_avg + rounded_size), + (_exp_avg_sq + rounded_size), + (_param_size - rounded_size), + (dev_params != nullptr ? (dev_params + rounded_size) : dev_params), + half_precision); +} + +int ds_adam_step(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq) +{ + auto params_c = params.contiguous(); + auto grads_c = grads.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + + // assert(params.options().dtype() == grads.options().dtype()); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + nullptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int ds_adam_step_plus_copy(int optimizer_id, + size_t step, + float lr, + float beta1, + float beta2, + float epsilon, + float weight_decay, + bool bias_correction, + torch::Tensor& params, + torch::Tensor& grads, + torch::Tensor& exp_avg, + torch::Tensor& exp_avg_sq, + torch::Tensor& gpu_params) +{ + auto params_c = params.contiguous(); + auto gpu_params_c = gpu_params.contiguous(); + auto exp_avg_c = exp_avg.contiguous(); + auto exp_avg_sq_c = exp_avg_sq.contiguous(); + auto grads_c = grads.contiguous(); + + float* params_ptr = (float*)params_c.data_ptr(); + float* grads_ptr = (float*)grads_c.data_ptr(); + __half* gpu_params_ptr = (__half*)gpu_params_c.data_ptr(); + float* exp_avg_ptr = (float*)exp_avg_c.data_ptr(); + float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr(); + + std::shared_ptr opt = + std::static_pointer_cast(s_optimizers[optimizer_id]); + opt->IncrementStep(step, beta1, beta2); + opt->update_state(lr, epsilon, weight_decay, bias_correction); + opt->Step_8(params_ptr, + grads_ptr, + exp_avg_ptr, + exp_avg_sq_ptr, + params_c.size(0), + gpu_params_ptr, + (params.options().dtype() == at::kHalf)); + + opt->SynchronizeStreams(); + return 0; +} + +int destroy_adam_optimizer(int optimizer_id) +{ + s_optimizers.erase(optimizer_id); + + return 0; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("adam_update", &ds_adam_step, "DeepSpeed CPU Adam update (C++)"); + m.def("adam_update_copy", + &ds_adam_step_plus_copy, + "DeepSpeed CPU Adam update and param copy (C++)"); + m.def("create_adam", &create_adam_optimizer, "DeepSpeed CPU Adam (C++)"); + m.def("destroy_adam", &destroy_adam_optimizer, "DeepSpeed CPU Adam destroy (C++)"); +} diff --git a/csrc/aio/common/deepspeed_aio_common.cpp b/csrc/aio/common/deepspeed_aio_common.cpp index 11927969c50f..9e405d8e704c 100644 --- a/csrc/aio/common/deepspeed_aio_common.cpp +++ b/csrc/aio/common/deepspeed_aio_common.cpp @@ -1,333 +1,333 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "deepspeed_aio_common.h" - -using namespace std; -using namespace std::chrono; - -#define DEBUG_DS_AIO_PERF 0 -#define DEBUG_DS_AIO_SUBMIT_PERF 0 - -static const std::string c_library_name = "deepspeed_aio"; - -static void _report_aio_statistics(const char* tag, - const std::vector>& latencies) - __attribute__((unused)); - -static void _report_aio_statistics(const char* tag, - const std::vector>& latencies) -{ - std::vector lat_usec; - for (auto& lat : latencies) { lat_usec.push_back(lat.count() * 1e6); } - const auto min_lat = *(std::min_element(lat_usec.begin(), lat_usec.end())); - const auto max_lat = *(std::max_element(lat_usec.begin(), lat_usec.end())); - const auto avg_lat = std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); - - std::cout << c_library_name << ": latency statistics(usec) " << tag - << " min/max/avg = " << min_lat << " " << max_lat << " " << avg_lat << std::endl; -} - -static void _get_aio_latencies(std::vector>& raw_latencies, - struct deepspeed_aio_latency_t& summary_latencies) -{ - std::vector lat_usec; - for (auto& lat : raw_latencies) { lat_usec.push_back(lat.count() * 1e6); } - summary_latencies._min_usec = *(std::min_element(lat_usec.begin(), lat_usec.end())); - summary_latencies._max_usec = *(std::max_element(lat_usec.begin(), lat_usec.end())); - summary_latencies._avg_usec = - std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); -} - -static void _do_io_submit_singles(const long long int n_iocbs, - const long long int iocb_index, - std::unique_ptr& aio_ctxt, - std::vector>& submit_times) -{ - for (auto i = 0; i < n_iocbs; ++i) { - const auto st = std::chrono::high_resolution_clock::now(); - const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, 1, aio_ctxt->_iocbs.data() + i); - submit_times.push_back(std::chrono::high_resolution_clock::now() - st); -#if DEBUG_DS_AIO_SUBMIT_PERF - printf("submit(usec) %f io_index=%lld buf=%p len=%lu off=%llu \n", - submit_times.back().count() * 1e6, - iocb_index, - aio_ctxt->_iocbs[i]->u.c.buf, - aio_ctxt->_iocbs[i]->u.c.nbytes, - aio_ctxt->_iocbs[i]->u.c.offset); -#endif - assert(submit_ret > 0); - } -} - -static void _do_io_submit_block(const long long int n_iocbs, - const long long int iocb_index, - std::unique_ptr& aio_ctxt, - std::vector>& submit_times) -{ - const auto st = std::chrono::high_resolution_clock::now(); - const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, n_iocbs, aio_ctxt->_iocbs.data()); - submit_times.push_back(std::chrono::high_resolution_clock::now() - st); -#if DEBUG_DS_AIO_SUBMIT_PERF - printf("submit(usec) %f io_index=%lld nr=%lld buf=%p len=%lu off=%llu \n", - submit_times.back().count() * 1e6, - iocb_index, - n_iocbs, - aio_ctxt->_iocbs[0]->u.c.buf, - aio_ctxt->_iocbs[0]->u.c.nbytes, - aio_ctxt->_iocbs[0]->u.c.offset); -#endif - assert(submit_ret > 0); -} - -static int _do_io_complete(const long long int min_completes, - const long long int max_completes, - std::unique_ptr& aio_ctxt, - std::vector>& reap_times) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - const auto n_completes = io_getevents( - aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); - reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); - - assert(n_completes >= min_completes); - return n_completes; -} - -void do_aio_operation_sequential(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf) -{ - struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); - - const auto num_io_blocks = static_cast( - ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); -#if DEBUG_DS_AIO_PERF - const auto io_op_name = std::string(read_op ? "read" : "write"); - std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes with " << num_io_blocks << " io blocks" << std::endl; -#endif - - std::vector> submit_times; - std::vector> reap_times; - const auto max_queue_bytes = - static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); - - auto start = std::chrono::high_resolution_clock::now(); - for (long long iocb_index = 0; iocb_index < num_io_blocks; - iocb_index += aio_ctxt->_queue_depth) { - const auto start_offset = iocb_index * aio_ctxt->_block_size; - const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; - const auto n_iocbs = - min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); - const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); - prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); - - if (config->_single_submit) { - _do_io_submit_singles(n_iocbs, iocb_index, aio_ctxt, submit_times); - } else { - _do_io_submit_block(n_iocbs, iocb_index, aio_ctxt, submit_times); - } - - _do_io_complete(n_iocbs, n_iocbs, aio_ctxt, reap_times); - } - const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; - - if (perf) { - _get_aio_latencies(submit_times, perf->_submit); - _get_aio_latencies(reap_times, perf->_complete); - perf->_e2e_usec = elapsed.count() * 1e6; - perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); - } - -#if DEBUG_DS_AIO_PERF - _report_aio_statistics("submit", submit_times); - _report_aio_statistics("complete", reap_times); -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 - << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes " << std::endl; -#endif -} - -void do_aio_operation_overlap(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf) -{ - struct io_prep_generator io_gen(read_op, xfer_ctxt, aio_ctxt->_block_size); - -#if DEBUG_DS_AIO_PERF - const auto io_op_name = std::string(read_op ? "read" : "write"); - std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes with " << io_gen._num_io_blocks << " io blocks" << std::endl; -#endif - - std::vector> submit_times; - std::vector> reap_times; - - auto request_iocbs = aio_ctxt->_queue_depth; - auto n_pending_iocbs = 0; - const auto min_completes = 1; - auto start = std::chrono::high_resolution_clock::now(); - while (true) { - const auto n_iocbs = io_gen.prep_iocbs(request_iocbs - n_pending_iocbs, &aio_ctxt->_iocbs); - if (n_iocbs > 0) { - if (config->_single_submit) { - _do_io_submit_singles( - n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); - } else { - _do_io_submit_block( - n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); - } - } - - n_pending_iocbs += n_iocbs; - assert(n_pending_iocbs <= aio_ctxt->_queue_depth); - - if (n_pending_iocbs == 0) { break; } - - const auto n_complete = - _do_io_complete(min_completes, n_pending_iocbs, aio_ctxt, reap_times); - n_pending_iocbs -= n_complete; - } - - const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; - - if (perf) { - _get_aio_latencies(submit_times, perf->_submit); - _get_aio_latencies(reap_times, perf->_complete); - perf->_e2e_usec = elapsed.count() * 1e6; - perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); - } - -#if DEBUG_DS_AIO_PERF - _report_aio_statistics("submit", submit_times); - _report_aio_statistics("complete", reap_times); -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 - << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; -#endif - -#if DEBUG_DS_AIO_PERF - std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes - << " bytes " << std::endl; -#endif -} - -void report_file_error(const char* filename, const std::string file_op, const int error_code) -{ - std::string err_msg = file_op + std::string(" failed on ") + std::string(filename) + - " error = " + std::to_string(error_code); - std::cerr << c_library_name << ": " << err_msg << std::endl; -} - -int open_file(const char* filename, const bool read_op) -{ - const int flags = read_op ? (O_RDONLY | __O_DIRECT) : (O_WRONLY | O_CREAT | __O_DIRECT); - const int mode = 0600; - const auto fd = open(filename, flags, mode); - if (fd == -1) { - const auto error_code = errno; - const auto error_msg = read_op ? " open for read " : " open for write "; - report_file_error(filename, error_msg, error_code); - return -1; - } - return fd; -} - -int regular_read(const char* filename, std::vector& buffer) -{ - long long int num_bytes; - const auto f_size = get_file_size(filename, num_bytes); - assert(f_size != -1); - buffer.resize(num_bytes); - const auto fd = open(filename, O_RDONLY, 0600); - assert(fd != -1); - long long int read_bytes = 0; - auto r = 0; - do { - const auto buffer_ptr = buffer.data() + read_bytes; - const auto bytes_to_read = num_bytes - read_bytes; - r = read(fd, buffer_ptr, bytes_to_read); - read_bytes += r; - } while (r > 0); - - if (read_bytes != num_bytes) { - std::cerr << "read error " - << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes - << std::endl; - } - assert(read_bytes == num_bytes); - close(fd); - return 0; -} - -static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) -{ - std::vector regular_buffer; - const auto reg_ret = regular_read(filename, regular_buffer); - assert(0 == reg_ret); - std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" - << std::endl; - - if (static_cast(regular_buffer.size()) != num_bytes) { return false; } - - return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); -} - -bool validate_aio_operation(const bool read_op, - const char* filename, - void* aio_buffer, - const long long int num_bytes) -{ - const auto msg_suffix = std::string("deepspeed_aio_") + - std::string(read_op ? "read()" : "write()") + - std::string("using read()"); - - if (false == _validate_buffer(filename, aio_buffer, num_bytes)) { - std::cout << "Fail: correctness of " << msg_suffix << std::endl; - return false; - } - - std::cout << "Pass: correctness of " << msg_suffix << std::endl; - return true; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_aio_common.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_PERF 0 +#define DEBUG_DS_AIO_SUBMIT_PERF 0 + +static const std::string c_library_name = "deepspeed_aio"; + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) + __attribute__((unused)); + +static void _report_aio_statistics(const char* tag, + const std::vector>& latencies) +{ + std::vector lat_usec; + for (auto& lat : latencies) { lat_usec.push_back(lat.count() * 1e6); } + const auto min_lat = *(std::min_element(lat_usec.begin(), lat_usec.end())); + const auto max_lat = *(std::max_element(lat_usec.begin(), lat_usec.end())); + const auto avg_lat = std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); + + std::cout << c_library_name << ": latency statistics(usec) " << tag + << " min/max/avg = " << min_lat << " " << max_lat << " " << avg_lat << std::endl; +} + +static void _get_aio_latencies(std::vector>& raw_latencies, + struct deepspeed_aio_latency_t& summary_latencies) +{ + std::vector lat_usec; + for (auto& lat : raw_latencies) { lat_usec.push_back(lat.count() * 1e6); } + summary_latencies._min_usec = *(std::min_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._max_usec = *(std::max_element(lat_usec.begin(), lat_usec.end())); + summary_latencies._avg_usec = + std::accumulate(lat_usec.begin(), lat_usec.end(), 0) / lat_usec.size(); +} + +static void _do_io_submit_singles(const long long int n_iocbs, + const long long int iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + for (auto i = 0; i < n_iocbs; ++i) { + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, 1, aio_ctxt->_iocbs.data() + i); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + aio_ctxt->_iocbs[i]->u.c.buf, + aio_ctxt->_iocbs[i]->u.c.nbytes, + aio_ctxt->_iocbs[i]->u.c.offset); +#endif + assert(submit_ret > 0); + } +} + +static void _do_io_submit_block(const long long int n_iocbs, + const long long int iocb_index, + std::unique_ptr& aio_ctxt, + std::vector>& submit_times) +{ + const auto st = std::chrono::high_resolution_clock::now(); + const auto submit_ret = io_submit(aio_ctxt->_io_ctxt, n_iocbs, aio_ctxt->_iocbs.data()); + submit_times.push_back(std::chrono::high_resolution_clock::now() - st); +#if DEBUG_DS_AIO_SUBMIT_PERF + printf("submit(usec) %f io_index=%lld nr=%lld buf=%p len=%lu off=%llu \n", + submit_times.back().count() * 1e6, + iocb_index, + n_iocbs, + aio_ctxt->_iocbs[0]->u.c.buf, + aio_ctxt->_iocbs[0]->u.c.nbytes, + aio_ctxt->_iocbs[0]->u.c.offset); +#endif + assert(submit_ret > 0); +} + +static int _do_io_complete(const long long int min_completes, + const long long int max_completes, + std::unique_ptr& aio_ctxt, + std::vector>& reap_times) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + const auto n_completes = io_getevents( + aio_ctxt->_io_ctxt, min_completes, max_completes, aio_ctxt->_io_events.data(), nullptr); + reap_times.push_back(std::chrono::high_resolution_clock::now() - start_time); + + assert(n_completes >= min_completes); + return n_completes; +} + +void do_aio_operation_sequential(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_context prep_ctxt(read_op, xfer_ctxt, aio_ctxt->_block_size, &aio_ctxt->_iocbs); + + const auto num_io_blocks = static_cast( + ceil(static_cast(xfer_ctxt->_num_bytes) / aio_ctxt->_block_size)); +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + const auto max_queue_bytes = + static_cast(aio_ctxt->_queue_depth * aio_ctxt->_block_size); + + auto start = std::chrono::high_resolution_clock::now(); + for (long long iocb_index = 0; iocb_index < num_io_blocks; + iocb_index += aio_ctxt->_queue_depth) { + const auto start_offset = iocb_index * aio_ctxt->_block_size; + const auto start_buffer = (char*)xfer_ctxt->_mem_buffer + start_offset; + const auto n_iocbs = + min(static_cast(aio_ctxt->_queue_depth), (num_io_blocks - iocb_index)); + const auto num_bytes = min(max_queue_bytes, (xfer_ctxt->_num_bytes - start_offset)); + prep_ctxt.prep_iocbs(n_iocbs, num_bytes, start_buffer, start_offset); + + if (config->_single_submit) { + _do_io_submit_singles(n_iocbs, iocb_index, aio_ctxt, submit_times); + } else { + _do_io_submit_block(n_iocbs, iocb_index, aio_ctxt, submit_times); + } + + _do_io_complete(n_iocbs, n_iocbs, aio_ctxt, reap_times); + } + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void do_aio_operation_overlap(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf) +{ + struct io_prep_generator io_gen(read_op, xfer_ctxt, aio_ctxt->_block_size); + +#if DEBUG_DS_AIO_PERF + const auto io_op_name = std::string(read_op ? "read" : "write"); + std::cout << c_library_name << ": start " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes with " << io_gen._num_io_blocks << " io blocks" << std::endl; +#endif + + std::vector> submit_times; + std::vector> reap_times; + + auto request_iocbs = aio_ctxt->_queue_depth; + auto n_pending_iocbs = 0; + const auto min_completes = 1; + auto start = std::chrono::high_resolution_clock::now(); + while (true) { + const auto n_iocbs = io_gen.prep_iocbs(request_iocbs - n_pending_iocbs, &aio_ctxt->_iocbs); + if (n_iocbs > 0) { + if (config->_single_submit) { + _do_io_submit_singles( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } else { + _do_io_submit_block( + n_iocbs, (io_gen._next_iocb_index - n_iocbs), aio_ctxt, submit_times); + } + } + + n_pending_iocbs += n_iocbs; + assert(n_pending_iocbs <= aio_ctxt->_queue_depth); + + if (n_pending_iocbs == 0) { break; } + + const auto n_complete = + _do_io_complete(min_completes, n_pending_iocbs, aio_ctxt, reap_times); + n_pending_iocbs -= n_complete; + } + + const std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; + + if (perf) { + _get_aio_latencies(submit_times, perf->_submit); + _get_aio_latencies(reap_times, perf->_complete); + perf->_e2e_usec = elapsed.count() * 1e6; + perf->_e2e_rate_GB = (xfer_ctxt->_num_bytes / elapsed.count() / 1e9); + } + +#if DEBUG_DS_AIO_PERF + _report_aio_statistics("submit", submit_times); + _report_aio_statistics("complete", reap_times); +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": runtime(usec) " << elapsed.count() * 1e6 + << " rate(GB/sec) = " << (xfer_ctxt->_num_bytes / elapsed.count() / 1e9) << std::endl; +#endif + +#if DEBUG_DS_AIO_PERF + std::cout << c_library_name << ": finish " << io_op_name << " " << xfer_ctxt->_num_bytes + << " bytes " << std::endl; +#endif +} + +void report_file_error(const char* filename, const std::string file_op, const int error_code) +{ + std::string err_msg = file_op + std::string(" failed on ") + std::string(filename) + + " error = " + std::to_string(error_code); + std::cerr << c_library_name << ": " << err_msg << std::endl; +} + +int open_file(const char* filename, const bool read_op) +{ + const int flags = read_op ? (O_RDONLY | __O_DIRECT) : (O_WRONLY | O_CREAT | __O_DIRECT); + const int mode = 0600; + const auto fd = open(filename, flags, mode); + if (fd == -1) { + const auto error_code = errno; + const auto error_msg = read_op ? " open for read " : " open for write "; + report_file_error(filename, error_msg, error_code); + return -1; + } + return fd; +} + +int regular_read(const char* filename, std::vector& buffer) +{ + long long int num_bytes; + const auto f_size = get_file_size(filename, num_bytes); + assert(f_size != -1); + buffer.resize(num_bytes); + const auto fd = open(filename, O_RDONLY, 0600); + assert(fd != -1); + long long int read_bytes = 0; + auto r = 0; + do { + const auto buffer_ptr = buffer.data() + read_bytes; + const auto bytes_to_read = num_bytes - read_bytes; + r = read(fd, buffer_ptr, bytes_to_read); + read_bytes += r; + } while (r > 0); + + if (read_bytes != num_bytes) { + std::cerr << "read error " + << " read_bytes (read) = " << read_bytes << " num_bytes (fstat) = " << num_bytes + << std::endl; + } + assert(read_bytes == num_bytes); + close(fd); + return 0; +} + +static bool _validate_buffer(const char* filename, void* aio_buffer, const long long int num_bytes) +{ + std::vector regular_buffer; + const auto reg_ret = regular_read(filename, regular_buffer); + assert(0 == reg_ret); + std::cout << "regular read of " << filename << " returned " << regular_buffer.size() << " bytes" + << std::endl; + + if (static_cast(regular_buffer.size()) != num_bytes) { return false; } + + return (0 == memcmp(aio_buffer, regular_buffer.data(), regular_buffer.size())); +} + +bool validate_aio_operation(const bool read_op, + const char* filename, + void* aio_buffer, + const long long int num_bytes) +{ + const auto msg_suffix = std::string("deepspeed_aio_") + + std::string(read_op ? "read()" : "write()") + + std::string("using read()"); + + if (false == _validate_buffer(filename, aio_buffer, num_bytes)) { + std::cout << "Fail: correctness of " << msg_suffix << std::endl; + return false; + } + + std::cout << "Pass: correctness of " << msg_suffix << std::endl; + return true; +} diff --git a/csrc/aio/common/deepspeed_aio_common.h b/csrc/aio/common/deepspeed_aio_common.h index 1f32fc8f794f..cc62d33765c8 100644 --- a/csrc/aio/common/deepspeed_aio_common.h +++ b/csrc/aio/common/deepspeed_aio_common.h @@ -1,36 +1,36 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include -#include - -using namespace std; - -void do_aio_operation_sequential(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf); - -void do_aio_operation_overlap(const bool read_op, - std::unique_ptr& aio_ctxt, - std::unique_ptr& xfer_ctxt, - deepspeed_aio_config_t* config, - deepspeed_aio_perf_t* perf); - -int open_file(const char* filename, const bool read_op); - -void report_file_error(const char* filename, const std::string file_op, const int error_code); - -int regular_read(const char* filename, std::vector& buffer); - -bool validate_aio_operation(const bool read_op, - const char* filename, - void* aio_buffer, - const long long int num_bytes); +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include +#include + +using namespace std; + +void do_aio_operation_sequential(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf); + +void do_aio_operation_overlap(const bool read_op, + std::unique_ptr& aio_ctxt, + std::unique_ptr& xfer_ctxt, + deepspeed_aio_config_t* config, + deepspeed_aio_perf_t* perf); + +int open_file(const char* filename, const bool read_op); + +void report_file_error(const char* filename, const std::string file_op, const int error_code); + +int regular_read(const char* filename, std::vector& buffer); + +bool validate_aio_operation(const bool read_op, + const char* filename, + void* aio_buffer, + const long long int num_bytes); diff --git a/csrc/aio/common/deepspeed_aio_types.cpp b/csrc/aio/common/deepspeed_aio_types.cpp index 5f717c3b5658..e5811bb91149 100644 --- a/csrc/aio/common/deepspeed_aio_types.cpp +++ b/csrc/aio/common/deepspeed_aio_types.cpp @@ -1,74 +1,74 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include - -#include "deepspeed_aio_utils.h" - -using namespace std; - -const int c_block_size = 128 * 1024; -const int c_io_queue_depth = 8; - -deepspeed_aio_config_t::deepspeed_aio_config_t() - : _block_size(c_block_size), - _queue_depth(c_io_queue_depth), - _single_submit(false), - _overlap_events(false), - _lock_memory(false) -{ -} - -deepspeed_aio_config_t::deepspeed_aio_config_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool lock_memory) - : _block_size(block_size), - _queue_depth(queue_depth), - _single_submit(single_submit), - _overlap_events(overlap_events), - _lock_memory(lock_memory) -{ -} - -void deepspeed_aio_latency_t::dump(const std::string tag) -{ - std::cout << tag << _min_usec << " " << _max_usec << " " << _avg_usec << " " << std::endl; -} - -void deepspeed_aio_latency_t::accumulate(const struct deepspeed_aio_latency_t& other) -{ - _min_usec += other._min_usec; - _max_usec += other._max_usec; - _avg_usec += other._avg_usec; -} - -void deepspeed_aio_latency_t::scale(const float scaler) -{ - _min_usec *= scaler; - _max_usec *= scaler; - _avg_usec *= scaler; -} - -aio_context::aio_context(const int block_size, const int queue_depth) -{ - _block_size = block_size; - _queue_depth = queue_depth; - for (auto i = 0; i < queue_depth; ++i) { - _iocbs.push_back((struct iocb*)calloc(1, sizeof(struct iocb))); - } - _io_events.resize(queue_depth); - io_queue_init(queue_depth, &_io_ctxt); -} - -aio_context::~aio_context() -{ - for (auto& iocb : _iocbs) { free(iocb); } - _io_events.resize(0); - io_queue_release(_io_ctxt); -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include + +#include "deepspeed_aio_utils.h" + +using namespace std; + +const int c_block_size = 128 * 1024; +const int c_io_queue_depth = 8; + +deepspeed_aio_config_t::deepspeed_aio_config_t() + : _block_size(c_block_size), + _queue_depth(c_io_queue_depth), + _single_submit(false), + _overlap_events(false), + _lock_memory(false) +{ +} + +deepspeed_aio_config_t::deepspeed_aio_config_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool lock_memory) + : _block_size(block_size), + _queue_depth(queue_depth), + _single_submit(single_submit), + _overlap_events(overlap_events), + _lock_memory(lock_memory) +{ +} + +void deepspeed_aio_latency_t::dump(const std::string tag) +{ + std::cout << tag << _min_usec << " " << _max_usec << " " << _avg_usec << " " << std::endl; +} + +void deepspeed_aio_latency_t::accumulate(const struct deepspeed_aio_latency_t& other) +{ + _min_usec += other._min_usec; + _max_usec += other._max_usec; + _avg_usec += other._avg_usec; +} + +void deepspeed_aio_latency_t::scale(const float scaler) +{ + _min_usec *= scaler; + _max_usec *= scaler; + _avg_usec *= scaler; +} + +aio_context::aio_context(const int block_size, const int queue_depth) +{ + _block_size = block_size; + _queue_depth = queue_depth; + for (auto i = 0; i < queue_depth; ++i) { + _iocbs.push_back((struct iocb*)calloc(1, sizeof(struct iocb))); + } + _io_events.resize(queue_depth); + io_queue_init(queue_depth, &_io_ctxt); +} + +aio_context::~aio_context() +{ + for (auto& iocb : _iocbs) { free(iocb); } + _io_events.resize(0); + io_queue_release(_io_ctxt); +} diff --git a/csrc/aio/common/deepspeed_aio_types.h b/csrc/aio/common/deepspeed_aio_types.h index 5c5dcdf0b559..be3b352d6be2 100644 --- a/csrc/aio/common/deepspeed_aio_types.h +++ b/csrc/aio/common/deepspeed_aio_types.h @@ -1,57 +1,57 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include - -#include -#include - -using namespace std; - -struct deepspeed_aio_latency_t { - double _min_usec; - double _max_usec; - double _avg_usec; - - void dump(const std::string tag); - void accumulate(const deepspeed_aio_latency_t&); - void scale(const float value); -}; - -struct deepspeed_aio_perf_t { - deepspeed_aio_latency_t _submit; - deepspeed_aio_latency_t _complete; - double _e2e_usec; - double _e2e_rate_GB; -}; - -struct deepspeed_aio_config_t { - const int _block_size; - const int _queue_depth; - const bool _single_submit; - const bool _overlap_events; - const bool _lock_memory; - - deepspeed_aio_config_t(); - deepspeed_aio_config_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool lock_memory); -}; - -struct aio_context { - io_context_t _io_ctxt; - std::vector _io_events; - std::vector _iocbs; - int _block_size; - int _queue_depth; - - aio_context(const int block_size, const int queue_depth); - ~aio_context(); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include + +#include +#include + +using namespace std; + +struct deepspeed_aio_latency_t { + double _min_usec; + double _max_usec; + double _avg_usec; + + void dump(const std::string tag); + void accumulate(const deepspeed_aio_latency_t&); + void scale(const float value); +}; + +struct deepspeed_aio_perf_t { + deepspeed_aio_latency_t _submit; + deepspeed_aio_latency_t _complete; + double _e2e_usec; + double _e2e_rate_GB; +}; + +struct deepspeed_aio_config_t { + const int _block_size; + const int _queue_depth; + const bool _single_submit; + const bool _overlap_events; + const bool _lock_memory; + + deepspeed_aio_config_t(); + deepspeed_aio_config_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool lock_memory); +}; + +struct aio_context { + io_context_t _io_ctxt; + std::vector _io_events; + std::vector _iocbs; + int _block_size; + int _queue_depth; + + aio_context(const int block_size, const int queue_depth); + ~aio_context(); +}; diff --git a/csrc/aio/common/deepspeed_aio_utils.cpp b/csrc/aio/common/deepspeed_aio_utils.cpp index a3d89be5ad3e..200c7030f120 100644 --- a/csrc/aio/common/deepspeed_aio_utils.cpp +++ b/csrc/aio/common/deepspeed_aio_utils.cpp @@ -1,123 +1,123 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include - -#include "deepspeed_aio_utils.h" - -using namespace std; - -const int c_block_size = 128 * 1024; -const int c_io_queue_depth = 8; - -io_xfer_ctxt::io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, - const void* buffer) - : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) -{ -} - -io_prep_context::io_prep_context(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size, - const std::vector* iocbs) - : _read_op(read_op), _xfer_ctxt(xfer_ctxt), _block_size(block_size), _iocbs(iocbs) -{ -} - -void io_prep_context::prep_iocbs(const int n_iocbs, - const size_t num_bytes, - const void* start_buffer, - const long long int start_offset) -{ - assert(static_cast(n_iocbs) <= _iocbs->size()); - for (auto i = 0; i < n_iocbs; ++i) { - const auto shift = i * _block_size; - const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; - const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; - auto byte_count = _block_size; - if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } - - if (_read_op) { - io_prep_pread(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); - } else { - io_prep_pwrite(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); - } - } -} - -io_prep_generator::io_prep_generator(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size) - : _read_op(read_op), - _xfer_ctxt(xfer_ctxt), - _block_size(block_size), - _remaining_bytes(xfer_ctxt->_num_bytes), - _next_iocb_index(0) -{ - _num_io_blocks = - static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); - _remaining_io_blocks = _num_io_blocks; -} - -int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) -{ - if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { - assert(static_cast(_remaining_bytes) == _remaining_io_blocks); - return 0; - } - - assert(static_cast(n_iocbs) <= iocbs->size()); - - auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); - for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { - const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); - const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; - const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); - - if (_read_op) { - io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); - } else { - io_prep_pwrite(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); - } - _remaining_bytes -= num_bytes; - } - _remaining_io_blocks -= actual_n_iocbs; - - return actual_n_iocbs; -} - -int get_file_size(const char* filename, long long int& size) -{ - struct stat st; - if (stat(filename, &st) == -1) { return -1; } - size = st.st_size; - return 0; -} - -void* ds_page_aligned_alloc(const size_t size, const bool lock) -{ - void* ptr; - int retval; - - retval = posix_memalign(&ptr, (size_t)sysconf(_SC_PAGESIZE), size); - if (retval) { return nullptr; } - - if (lock == false) { return ptr; } - - auto mlock_ret = mlock(ptr, size); - if (mlock_ret != 0) { - auto mlock_error = errno; - printf("mlock failed with %d %s\n", mlock_error, strerror(mlock_error)); - - free(ptr); - return nullptr; - } - - return ptr; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include + +#include "deepspeed_aio_utils.h" + +using namespace std; + +const int c_block_size = 128 * 1024; +const int c_io_queue_depth = 8; + +io_xfer_ctxt::io_xfer_ctxt(const int fd, + const long long int file_offset, + const long long int num_bytes, + const void* buffer) + : _fd(fd), _base_offset(file_offset), _mem_buffer(buffer), _num_bytes(num_bytes) +{ +} + +io_prep_context::io_prep_context(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size, + const std::vector* iocbs) + : _read_op(read_op), _xfer_ctxt(xfer_ctxt), _block_size(block_size), _iocbs(iocbs) +{ +} + +void io_prep_context::prep_iocbs(const int n_iocbs, + const size_t num_bytes, + const void* start_buffer, + const long long int start_offset) +{ + assert(static_cast(n_iocbs) <= _iocbs->size()); + for (auto i = 0; i < n_iocbs; ++i) { + const auto shift = i * _block_size; + const auto xfer_buffer = (char*)start_buffer + _xfer_ctxt->_base_offset + shift; + const auto xfer_offset = _xfer_ctxt->_base_offset + start_offset + shift; + auto byte_count = _block_size; + if ((shift + _block_size) > num_bytes) { byte_count = num_bytes - shift; } + + if (_read_op) { + io_prep_pread(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } else { + io_prep_pwrite(_iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, byte_count, xfer_offset); + } + } +} + +io_prep_generator::io_prep_generator(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size) + : _read_op(read_op), + _xfer_ctxt(xfer_ctxt), + _block_size(block_size), + _remaining_bytes(xfer_ctxt->_num_bytes), + _next_iocb_index(0) +{ + _num_io_blocks = + static_cast(ceil(static_cast(xfer_ctxt->_num_bytes) / block_size)); + _remaining_io_blocks = _num_io_blocks; +} + +int io_prep_generator::prep_iocbs(const int n_iocbs, std::vector* iocbs) +{ + if ((_remaining_bytes) == 0 || (_remaining_io_blocks == 0)) { + assert(static_cast(_remaining_bytes) == _remaining_io_blocks); + return 0; + } + + assert(static_cast(n_iocbs) <= iocbs->size()); + + auto actual_n_iocbs = min(static_cast(n_iocbs), _remaining_io_blocks); + for (auto i = 0; i < actual_n_iocbs; ++i, ++_next_iocb_index) { + const auto xfer_offset = _xfer_ctxt->_base_offset + (_next_iocb_index * _block_size); + const auto xfer_buffer = (char*)_xfer_ctxt->_mem_buffer + xfer_offset; + const auto num_bytes = min(static_cast(_block_size), _remaining_bytes); + + if (_read_op) { + io_prep_pread(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } else { + io_prep_pwrite(iocbs->at(i), _xfer_ctxt->_fd, xfer_buffer, num_bytes, xfer_offset); + } + _remaining_bytes -= num_bytes; + } + _remaining_io_blocks -= actual_n_iocbs; + + return actual_n_iocbs; +} + +int get_file_size(const char* filename, long long int& size) +{ + struct stat st; + if (stat(filename, &st) == -1) { return -1; } + size = st.st_size; + return 0; +} + +void* ds_page_aligned_alloc(const size_t size, const bool lock) +{ + void* ptr; + int retval; + + retval = posix_memalign(&ptr, (size_t)sysconf(_SC_PAGESIZE), size); + if (retval) { return nullptr; } + + if (lock == false) { return ptr; } + + auto mlock_ret = mlock(ptr, size); + if (mlock_ret != 0) { + auto mlock_error = errno; + printf("mlock failed with %d %s\n", mlock_error, strerror(mlock_error)); + + free(ptr); + return nullptr; + } + + return ptr; +} diff --git a/csrc/aio/common/deepspeed_aio_utils.h b/csrc/aio/common/deepspeed_aio_utils.h index f37a95c5149a..6c5952749dd3 100644 --- a/csrc/aio/common/deepspeed_aio_utils.h +++ b/csrc/aio/common/deepspeed_aio_utils.h @@ -1,77 +1,77 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -struct io_xfer_ctxt { - const int _fd; - const long long int _base_offset; - const void* _mem_buffer; - const long long int _num_bytes; - - io_xfer_ctxt(const int fd, - const long long int file_offset, - const long long int num_bytes, - const void* buffer); -}; - -struct io_prep_context { - const bool _read_op; - const std::unique_ptr& _xfer_ctxt; - const size_t _block_size; - const std::vector* _iocbs; - - io_prep_context(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size, - const std::vector* iocbs); - - void prep_iocbs(const int n_iocbs, - const size_t num_bytes, - const void* start_buffer, - const long long int start_offset); -}; - -struct io_prep_generator { - const bool _read_op; - const std::unique_ptr& _xfer_ctxt; - const size_t _block_size; - - long long int _remaining_bytes; - long long int _num_io_blocks; - long long int _remaining_io_blocks; - long long int _next_iocb_index; - - io_prep_generator(const bool read_op, - const std::unique_ptr& xfer_ctxt, - const size_t block_size); - - int prep_iocbs(const int n_iocbs, std::vector* iocbs); -}; - -void* ds_page_aligned_alloc(const size_t size, const bool lock = false); - -int get_file_size(const char* filename, long long int& size); +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +struct io_xfer_ctxt { + const int _fd; + const long long int _base_offset; + const void* _mem_buffer; + const long long int _num_bytes; + + io_xfer_ctxt(const int fd, + const long long int file_offset, + const long long int num_bytes, + const void* buffer); +}; + +struct io_prep_context { + const bool _read_op; + const std::unique_ptr& _xfer_ctxt; + const size_t _block_size; + const std::vector* _iocbs; + + io_prep_context(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size, + const std::vector* iocbs); + + void prep_iocbs(const int n_iocbs, + const size_t num_bytes, + const void* start_buffer, + const long long int start_offset); +}; + +struct io_prep_generator { + const bool _read_op; + const std::unique_ptr& _xfer_ctxt; + const size_t _block_size; + + long long int _remaining_bytes; + long long int _num_io_blocks; + long long int _remaining_io_blocks; + long long int _next_iocb_index; + + io_prep_generator(const bool read_op, + const std::unique_ptr& xfer_ctxt, + const size_t block_size); + + int prep_iocbs(const int n_iocbs, std::vector* iocbs); +}; + +void* ds_page_aligned_alloc(const size_t size, const bool lock = false); + +int get_file_size(const char* filename, long long int& size); diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.cpp b/csrc/aio/py_lib/deepspeed_aio_thread.cpp index 2c7509cb3ba0..a2670fb7b4cb 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.cpp +++ b/csrc/aio/py_lib/deepspeed_aio_thread.cpp @@ -1,84 +1,84 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_aio_thread.h" - -using namespace std; - -io_op_desc_t::io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate) - : _read_op(read_op), - _buffer(buffer), - _fd(fd), - _filename(filename), - _num_bytes(num_bytes), - _validate(validate) -{ - _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer; - _contiguous_buffer = _cpu_buffer.contiguous(); -} - -char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } - -void io_op_desc_t::fini() -{ - if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } -} - -deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) - : _tid(tid), - _aio_config(aio_config), - _aio_ctxt(new aio_context(aio_config._block_size, aio_config._queue_depth)), - _time_to_exit(false) -{ -} - -deepspeed_aio_thread_t::~deepspeed_aio_thread_t() {} - -void deepspeed_aio_thread_t::run() -{ - while (true) { - std::shared_ptr next_io_op = nullptr; - - { - std::unique_lock lock(_work_sync._mutex); - _work_sync._cond_var.wait(lock, - [this] { return (!_work_queue.empty() || _time_to_exit); }); - if (!_work_queue.empty()) { - next_io_op = _work_queue.front(); - _work_queue.pop(); - } - } - - if (next_io_op) { - const auto base_offset = next_io_op->_num_bytes * _tid; - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt( - next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential( - next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - { - std::lock_guard lock(_complete_sync._mutex); - _complete_queue.push(next_io_op); - } - _complete_sync._cond_var.notify_one(); - } - - if (_time_to_exit) { break; } - } -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_aio_thread.h" + +using namespace std; + +io_op_desc_t::io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int num_bytes, + const bool validate) + : _read_op(read_op), + _buffer(buffer), + _fd(fd), + _filename(filename), + _num_bytes(num_bytes), + _validate(validate) +{ + _cpu_buffer = _buffer.is_cuda() ? _buffer.to(torch::kCPU).pin_memory() : _buffer; + _contiguous_buffer = _cpu_buffer.contiguous(); +} + +char* io_op_desc_t::data_ptr() const { return (char*)_contiguous_buffer.data_ptr(); } + +void io_op_desc_t::fini() +{ + if (_read_op && _buffer.is_cuda()) { _buffer.copy_(_cpu_buffer.to(torch::kCUDA)); } +} + +deepspeed_aio_thread_t::deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config) + : _tid(tid), + _aio_config(aio_config), + _aio_ctxt(new aio_context(aio_config._block_size, aio_config._queue_depth)), + _time_to_exit(false) +{ +} + +deepspeed_aio_thread_t::~deepspeed_aio_thread_t() {} + +void deepspeed_aio_thread_t::run() +{ + while (true) { + std::shared_ptr next_io_op = nullptr; + + { + std::unique_lock lock(_work_sync._mutex); + _work_sync._cond_var.wait(lock, + [this] { return (!_work_queue.empty() || _time_to_exit); }); + if (!_work_queue.empty()) { + next_io_op = _work_queue.front(); + _work_queue.pop(); + } + } + + if (next_io_op) { + const auto base_offset = next_io_op->_num_bytes * _tid; + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt( + next_io_op->_fd, base_offset, next_io_op->_num_bytes, next_io_op->data_ptr())); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap( + next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential( + next_io_op->_read_op, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + { + std::lock_guard lock(_complete_sync._mutex); + _complete_queue.push(next_io_op); + } + _complete_sync._cond_var.notify_one(); + } + + if (_time_to_exit) { break; } + } +} diff --git a/csrc/aio/py_lib/deepspeed_aio_thread.h b/csrc/aio/py_lib/deepspeed_aio_thread.h index ee099dd2d16c..d1cfcab8bfc2 100644 --- a/csrc/aio/py_lib/deepspeed_aio_thread.h +++ b/csrc/aio/py_lib/deepspeed_aio_thread.h @@ -1,57 +1,57 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include -#include "deepspeed_py_aio.h" - -struct io_op_desc_t { - const bool _read_op; - torch::Tensor _buffer; - int _fd; - const std::string _filename; - const long long int _num_bytes; - torch::Tensor _cpu_buffer; - torch::Tensor _contiguous_buffer; - const bool _validate; - - io_op_desc_t(const bool read_op, - const torch::Tensor& buffer, - const int fd, - const char* filename, - const long long int num_bytes, - const bool validate); - - char* data_ptr() const; - void fini(); -}; - -struct thread_sync_t { - std::mutex _mutex; - std::condition_variable _cond_var; -}; - -struct deepspeed_aio_thread_t { - const int _tid; - deepspeed_aio_config_t& _aio_config; - - std::unique_ptr _aio_ctxt; - std::queue> _work_queue; - std::queue> _complete_queue; - - bool _time_to_exit; - - struct thread_sync_t _work_sync; - struct thread_sync_t _complete_sync; - - deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config); - - ~deepspeed_aio_thread_t(); - - void run(); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include +#include "deepspeed_py_aio.h" + +struct io_op_desc_t { + const bool _read_op; + torch::Tensor _buffer; + int _fd; + const std::string _filename; + const long long int _num_bytes; + torch::Tensor _cpu_buffer; + torch::Tensor _contiguous_buffer; + const bool _validate; + + io_op_desc_t(const bool read_op, + const torch::Tensor& buffer, + const int fd, + const char* filename, + const long long int num_bytes, + const bool validate); + + char* data_ptr() const; + void fini(); +}; + +struct thread_sync_t { + std::mutex _mutex; + std::condition_variable _cond_var; +}; + +struct deepspeed_aio_thread_t { + const int _tid; + deepspeed_aio_config_t& _aio_config; + + std::unique_ptr _aio_ctxt; + std::queue> _work_queue; + std::queue> _complete_queue; + + bool _time_to_exit; + + struct thread_sync_t _work_sync; + struct thread_sync_t _complete_sync; + + deepspeed_aio_thread_t(const int tid, deepspeed_aio_config_t& aio_config); + + ~deepspeed_aio_thread_t(); + + void run(); +}; diff --git a/csrc/aio/py_lib/deepspeed_py_aio.cpp b/csrc/aio/py_lib/deepspeed_py_aio.cpp index cc2895cc74b3..49ff1f240c43 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio.cpp @@ -1,121 +1,121 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "deepspeed_py_aio.h" - -using namespace std; -using namespace std::chrono; - -#define DEBUG_DS_AIO_READ 0 -#define DEBUG_DS_AIO_WRITE 0 - -static const std::string c_library_name = "deepspeed_aio"; - -int deepspeed_py_aio_write(const torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); - - if (config._overlap_events) { - do_aio_operation_overlap(false, aio_ctxt, xfer_ctxt, &config, nullptr); - } else { - do_aio_operation_sequential(false, aio_ctxt, xfer_ctxt, &config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -int deepspeed_py_aio_read(torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - - deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); - - if (config._overlap_events) { - do_aio_operation_overlap(true, aio_ctxt, xfer_ctxt, &config, nullptr); - } else { - do_aio_operation_sequential(true, aio_ctxt, xfer_ctxt, &config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "deepspeed_py_aio.h" + +using namespace std; +using namespace std::chrono; + +#define DEBUG_DS_AIO_READ 0 +#define DEBUG_DS_AIO_WRITE 0 + +static const std::string c_library_name = "deepspeed_aio"; + +int deepspeed_py_aio_write(const torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(false, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +int deepspeed_py_aio_read(torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + + deepspeed_aio_config_t config(block_size, queue_depth, single_submit, overlap_events, false); + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + std::unique_ptr aio_ctxt(new aio_context(config._block_size, config._queue_depth)); + + if (config._overlap_events) { + do_aio_operation_overlap(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } else { + do_aio_operation_sequential(true, aio_ctxt, xfer_ctxt, &config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} diff --git a/csrc/aio/py_lib/deepspeed_py_aio.h b/csrc/aio/py_lib/deepspeed_py_aio.h index a78d5734009d..230d88da9763 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio.h +++ b/csrc/aio/py_lib/deepspeed_py_aio.h @@ -1,27 +1,27 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include - -int deepspeed_py_aio_write(const torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate); - -int deepspeed_py_aio_read(torch::Tensor& buffer, - const char* filename, - const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const bool validate); + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include + +int deepspeed_py_aio_write(const torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate); + +int deepspeed_py_aio_read(torch::Tensor& buffer, + const char* filename, + const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const bool validate); diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp index 4635e751d6d8..417319f8ae5c 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.cpp @@ -1,282 +1,282 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_py_aio_handle.h" - -using namespace std; - -static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } - -deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const int num_threads) - : _aio_ctxt(new aio_context(block_size, queue_depth)), - _single_submit(single_submit), - _overlap_events(overlap_events), - _num_threads(num_threads), - _aio_config(block_size, queue_depth, single_submit, overlap_events, false), - _num_pending_ops(0) -{ - for (auto i = 0; i < num_threads; ++i) { - _thread_contexts.push_back(std::make_shared(i, _aio_config)); - } - - for (auto& ctxt : _thread_contexts) { - _threads.push_back(std::thread(_start_aio_thread, ctxt)); - } -} - -deepspeed_aio_handle_t::~deepspeed_aio_handle_t() -{ - _stop_threads(); - for (auto& thr : _threads) { thr.join(); } -} - -const int deepspeed_aio_handle_t::get_block_size() const -{ - return _aio_ctxt ? _aio_ctxt->_block_size : -1; -} - -const int deepspeed_aio_handle_t::get_queue_depth() const -{ - return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; -} - -const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } - -const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } - -const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } - -int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) -{ - const auto start_time = std::chrono::high_resolution_clock::now(); - - assert(_aio_ctxt); - - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto read_buffer = (char*)buffer.data_ptr(); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - - close(fd); - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, - const char* filename, - const bool validate) -{ - assert(_aio_ctxt); - - const auto start_time = std::chrono::high_resolution_clock::now(); - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto write_buffer = (char*)buffer.data_ptr(); - const auto num_write_bytes = static_cast(buffer.nbytes()); - std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); - - if (_aio_config._overlap_events) { - do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } else { - do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); - } - const std::chrono::duration aio_time = - std::chrono::high_resolution_clock::now() - start_time; - - close(fd); - - if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } - - const std::chrono::duration fn_time = - std::chrono::high_resolution_clock::now() - start_time; - std::cout << "Elapsed time(usec): " - << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 - << std::endl; - return 0; -} - -void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) -{ - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_work_queue.push(scheduled_op); - } - ctxt->_work_sync._cond_var.notify_one(); - } - _num_pending_ops++; -} - -std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() -{ - std::shared_ptr completed_op = nullptr; - for (auto& ctxt : _thread_contexts) { - std::unique_lock lock(ctxt->_complete_sync._mutex); - ctxt->_complete_sync._cond_var.wait(lock, - [ctxt] { return !ctxt->_complete_queue.empty(); }); - completed_op = ctxt->_complete_queue.front(); - ctxt->_complete_queue.pop(); - } - return completed_op; -} - -void deepspeed_aio_handle_t::_stop_threads() -{ - assert(0 == _num_pending_ops); - for (auto& ctxt : _thread_contexts) { - { - std::lock_guard lock(ctxt->_work_sync._mutex); - ctxt->_time_to_exit = true; - } - ctxt->_work_sync._cond_var.notify_one(); - } -} - -int deepspeed_aio_handle_t::wait() -{ - assert(_num_pending_ops > 0); - auto num_completed_ops = 0; - - while (_num_pending_ops > 0) { - auto completed_op = _wait_for_aio_work(); - - completed_op->fini(); - - close(completed_op->_fd); - - if (completed_op->_validate) { - validate_aio_operation(completed_op->_read_op, - completed_op->_filename.c_str(), - completed_op->data_ptr(), - _num_threads * completed_op->_num_bytes); - } - --_num_pending_ops; - ++num_completed_ops; - } - - return num_completed_ops; -} - -bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, - const long long int num_bytes) -{ - const auto op_string = read_op ? "Read" : "Write"; - if (num_bytes % get_thread_count()) { - std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes - << " not divisible by thread count = " << get_thread_count() << std::endl; - return false; - } - - return true; -} - -int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - long long num_file_bytes; - if (-1 == get_file_size(filename, num_file_bytes)) { - const auto error_code = errno; - report_file_error(filename, " fstat for read", error_code); - return -1; - } - const auto buffer_bytes = static_cast(buffer.nbytes()); - if (buffer_bytes != num_file_bytes) { - std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes - << " != " << num_file_bytes << std::endl; - } - assert(static_cast(buffer.nbytes()) == num_file_bytes); - assert((num_file_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } - - const auto fd = open_file(filename, true); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async) -{ - const auto num_write_bytes = static_cast(buffer.nbytes()); - assert((num_write_bytes % _num_threads) == 0); - - if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } - - const auto fd = open_file(filename, false); - if (fd == -1) { return -1; } - - auto scheduled_op = std::make_shared( - false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); - - _schedule_aio_work(scheduled_op); - - if (async) { return 0; } - - return wait(); -} - -int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, false); -} - -int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) -{ - return pread(buffer, filename, false, true); -} - -int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) -{ - return pwrite(buffer, filename, false, true); -} + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_aio_handle.h" + +using namespace std; + +static void _start_aio_thread(std::shared_ptr ctxt) { ctxt->run(); } + +deepspeed_aio_handle_t::deepspeed_aio_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads) + : _aio_ctxt(new aio_context(block_size, queue_depth)), + _single_submit(single_submit), + _overlap_events(overlap_events), + _num_threads(num_threads), + _aio_config(block_size, queue_depth, single_submit, overlap_events, false), + _num_pending_ops(0) +{ + for (auto i = 0; i < num_threads; ++i) { + _thread_contexts.push_back(std::make_shared(i, _aio_config)); + } + + for (auto& ctxt : _thread_contexts) { + _threads.push_back(std::thread(_start_aio_thread, ctxt)); + } +} + +deepspeed_aio_handle_t::~deepspeed_aio_handle_t() +{ + _stop_threads(); + for (auto& thr : _threads) { thr.join(); } +} + +const int deepspeed_aio_handle_t::get_block_size() const +{ + return _aio_ctxt ? _aio_ctxt->_block_size : -1; +} + +const int deepspeed_aio_handle_t::get_queue_depth() const +{ + return _aio_ctxt ? _aio_ctxt->_queue_depth : -1; +} + +const bool deepspeed_aio_handle_t::get_single_submit() const { return _single_submit; } + +const bool deepspeed_aio_handle_t::get_overlap_events() const { return _overlap_events; } + +const int deepspeed_aio_handle_t::get_thread_count() const { return _num_threads; } + +int deepspeed_aio_handle_t::read(torch::Tensor& buffer, const char* filename, const bool validate) +{ + const auto start_time = std::chrono::high_resolution_clock::now(); + + assert(_aio_ctxt); + + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto read_buffer = (char*)buffer.data_ptr(); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_file_bytes, read_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(true, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + + close(fd); + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + if (validate) { validate_aio_operation(true, filename, read_buffer, num_file_bytes); } + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +int deepspeed_aio_handle_t::write(const torch::Tensor& buffer, + const char* filename, + const bool validate) +{ + assert(_aio_ctxt); + + const auto start_time = std::chrono::high_resolution_clock::now(); + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto write_buffer = (char*)buffer.data_ptr(); + const auto num_write_bytes = static_cast(buffer.nbytes()); + std::unique_ptr xfer_ctxt(new io_xfer_ctxt(fd, 0, num_write_bytes, write_buffer)); + + if (_aio_config._overlap_events) { + do_aio_operation_overlap(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } else { + do_aio_operation_sequential(false, _aio_ctxt, xfer_ctxt, &_aio_config, nullptr); + } + const std::chrono::duration aio_time = + std::chrono::high_resolution_clock::now() - start_time; + + close(fd); + + if (validate) { validate_aio_operation(false, filename, write_buffer, num_write_bytes); } + + const std::chrono::duration fn_time = + std::chrono::high_resolution_clock::now() - start_time; + std::cout << "Elapsed time(usec): " + << "aio = " << aio_time.count() * 1e6 << " call = " << fn_time.count() * 1e6 + << std::endl; + return 0; +} + +void deepspeed_aio_handle_t::_schedule_aio_work(std::shared_ptr scheduled_op) +{ + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_work_queue.push(scheduled_op); + } + ctxt->_work_sync._cond_var.notify_one(); + } + _num_pending_ops++; +} + +std::shared_ptr deepspeed_aio_handle_t::_wait_for_aio_work() +{ + std::shared_ptr completed_op = nullptr; + for (auto& ctxt : _thread_contexts) { + std::unique_lock lock(ctxt->_complete_sync._mutex); + ctxt->_complete_sync._cond_var.wait(lock, + [ctxt] { return !ctxt->_complete_queue.empty(); }); + completed_op = ctxt->_complete_queue.front(); + ctxt->_complete_queue.pop(); + } + return completed_op; +} + +void deepspeed_aio_handle_t::_stop_threads() +{ + assert(0 == _num_pending_ops); + for (auto& ctxt : _thread_contexts) { + { + std::lock_guard lock(ctxt->_work_sync._mutex); + ctxt->_time_to_exit = true; + } + ctxt->_work_sync._cond_var.notify_one(); + } +} + +int deepspeed_aio_handle_t::wait() +{ + assert(_num_pending_ops > 0); + auto num_completed_ops = 0; + + while (_num_pending_ops > 0) { + auto completed_op = _wait_for_aio_work(); + + completed_op->fini(); + + close(completed_op->_fd); + + if (completed_op->_validate) { + validate_aio_operation(completed_op->_read_op, + completed_op->_filename.c_str(), + completed_op->data_ptr(), + _num_threads * completed_op->_num_bytes); + } + --_num_pending_ops; + ++num_completed_ops; + } + + return num_completed_ops; +} + +bool deepspeed_aio_handle_t::_is_valid_parallel_aio_op(const bool read_op, + const long long int num_bytes) +{ + const auto op_string = read_op ? "Read" : "Write"; + if (num_bytes % get_thread_count()) { + std::cout << "deepspeed_aio failure: parallel " << op_string << " num_bytes = " << num_bytes + << " not divisible by thread count = " << get_thread_count() << std::endl; + return false; + } + + return true; +} + +int deepspeed_aio_handle_t::pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + long long num_file_bytes; + if (-1 == get_file_size(filename, num_file_bytes)) { + const auto error_code = errno; + report_file_error(filename, " fstat for read", error_code); + return -1; + } + const auto buffer_bytes = static_cast(buffer.nbytes()); + if (buffer_bytes != num_file_bytes) { + std::cout << filename << ": buffer nbytes != file bytes " << buffer_bytes + << " != " << num_file_bytes << std::endl; + } + assert(static_cast(buffer.nbytes()) == num_file_bytes); + assert((num_file_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(true, num_file_bytes)) { return -1; } + + const auto fd = open_file(filename, true); + if (fd == -1) { return -1; } + + auto scheduled_op = std::make_shared( + true, buffer, fd, filename, (num_file_bytes / _num_threads), validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_aio_handle_t::pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async) +{ + const auto num_write_bytes = static_cast(buffer.nbytes()); + assert((num_write_bytes % _num_threads) == 0); + + if (!_is_valid_parallel_aio_op(false, num_write_bytes)) { return -1; } + + const auto fd = open_file(filename, false); + if (fd == -1) { return -1; } + + auto scheduled_op = std::make_shared( + false, buffer, fd, filename, (num_write_bytes / _num_threads), validate); + + _schedule_aio_work(scheduled_op); + + if (async) { return 0; } + + return wait(); +} + +int deepspeed_aio_handle_t::sync_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, false); +} + +int deepspeed_aio_handle_t::sync_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, false); +} + +int deepspeed_aio_handle_t::async_pread(torch::Tensor& buffer, const char* filename) +{ + return pread(buffer, filename, false, true); +} + +int deepspeed_aio_handle_t::async_pwrite(const torch::Tensor& buffer, const char* filename) +{ + return pwrite(buffer, filename, false, true); +} diff --git a/csrc/aio/py_lib/deepspeed_py_aio_handle.h b/csrc/aio/py_lib/deepspeed_py_aio_handle.h index 09358f4d927b..22de4c3961d2 100644 --- a/csrc/aio/py_lib/deepspeed_py_aio_handle.h +++ b/csrc/aio/py_lib/deepspeed_py_aio_handle.h @@ -1,68 +1,68 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include -#include "deepspeed_aio_thread.h" - -struct deepspeed_aio_handle_t { - std::unique_ptr _aio_ctxt; - const bool _single_submit; - const bool _overlap_events; - const int _num_threads; - deepspeed_aio_config_t _aio_config; - - std::vector> _thread_contexts; - std::vector _threads; - int _num_pending_ops; - - deepspeed_aio_handle_t(const int block_size, - const int queue_depth, - const bool single_submit, - const bool overlap_events, - const int num_threads); - - ~deepspeed_aio_handle_t(); - - const int get_block_size() const; - const int get_queue_depth() const; - const bool get_single_submit() const; - const bool get_overlap_events() const; - const int get_thread_count() const; - - int read(torch::Tensor& buffer, const char* filename, const bool validate); - - int write(const torch::Tensor& buffer, const char* filename, const bool validate); - - int pread(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int pwrite(const torch::Tensor& buffer, - const char* filename, - const bool validate, - const bool async); - - int sync_pread(torch::Tensor& buffer, const char* filename); - - int sync_pwrite(const torch::Tensor& buffer, const char* filename); - - int async_pread(torch::Tensor& buffer, const char* filename); - - int async_pwrite(const torch::Tensor& buffer, const char* filename); - - int wait(); - - void _stop_threads(); - - void _schedule_aio_work(std::shared_ptr scheduled_op); - - std::shared_ptr _wait_for_aio_work(); - - bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); -}; +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include +#include "deepspeed_aio_thread.h" + +struct deepspeed_aio_handle_t { + std::unique_ptr _aio_ctxt; + const bool _single_submit; + const bool _overlap_events; + const int _num_threads; + deepspeed_aio_config_t _aio_config; + + std::vector> _thread_contexts; + std::vector _threads; + int _num_pending_ops; + + deepspeed_aio_handle_t(const int block_size, + const int queue_depth, + const bool single_submit, + const bool overlap_events, + const int num_threads); + + ~deepspeed_aio_handle_t(); + + const int get_block_size() const; + const int get_queue_depth() const; + const bool get_single_submit() const; + const bool get_overlap_events() const; + const int get_thread_count() const; + + int read(torch::Tensor& buffer, const char* filename, const bool validate); + + int write(const torch::Tensor& buffer, const char* filename, const bool validate); + + int pread(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int pwrite(const torch::Tensor& buffer, + const char* filename, + const bool validate, + const bool async); + + int sync_pread(torch::Tensor& buffer, const char* filename); + + int sync_pwrite(const torch::Tensor& buffer, const char* filename); + + int async_pread(torch::Tensor& buffer, const char* filename); + + int async_pwrite(const torch::Tensor& buffer, const char* filename); + + int wait(); + + void _stop_threads(); + + void _schedule_aio_work(std::shared_ptr scheduled_op); + + std::shared_ptr _wait_for_aio_work(); + + bool _is_valid_parallel_aio_op(const bool read_op, const long long int num_bytes); +}; diff --git a/csrc/aio/py_lib/deepspeed_py_copy.cpp b/csrc/aio/py_lib/deepspeed_py_copy.cpp index 3cdb5ed344bf..ee51147f9c41 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.cpp +++ b/csrc/aio/py_lib/deepspeed_py_copy.cpp @@ -1,133 +1,133 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include "deepspeed_py_copy.h" -#include - -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - -#if defined(__AVX512__) or defined(__AVX256__) -union AVX_Data { -#if defined(__AVX512__) - __m512 data; -#else - __m256 data; -#endif -}; -#endif - -static void helper_memcpy_1(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH) { - AVX_Data src_4; - src_4.data = SIMD_LOAD(src + i); - - SIMD_STORE(dest + i, src_4.data); - } - } - -#endif - - if (param_size > rounded_size) { -#pragma omp parallel for - for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; } - } -} - -static void helper_memcpy_4(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { - AVX_Data src_4[4]; - src_4[0].data = SIMD_LOAD(src + i); - src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); - src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); - src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); - - SIMD_STORE(dest + i, src_4[0].data); - SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); - } - } -#endif - if (param_size > rounded_size) - helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); -} - -static void helper_mempcy_8(float* dest, float* src, size_t param_size) -{ - size_t rounded_size = 0; - -#if defined(__AVX512__) or defined(__AVX256__) - - rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); - - for (size_t t = 0; t < rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > rounded_size) copy_size = rounded_size - t; - size_t offset = copy_size + t; -#pragma omp parallel for - for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { - AVX_Data src_4[8]; - src_4[0].data = SIMD_LOAD(src + i); - src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); - src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); - src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); - src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2)); - src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5); - src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6); - src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7); - - SIMD_STORE(dest + i, src_4[0].data); - SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); - SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data); - SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data); - } - } -#endif - if (param_size > rounded_size) - helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); -} - -int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src) -{ - auto dest_c = dest.contiguous(); - auto src_c = src.contiguous(); - - float* dest_ptr = (float*)dest_c.data_ptr(); - float* src_ptr = (float*)src_c.data_ptr(); - - helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0)); - - return 0; -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include "deepspeed_py_copy.h" +#include + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +#if defined(__AVX512__) or defined(__AVX256__) +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#else + __m256 data; +#endif +}; +#endif + +static void helper_memcpy_1(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, SIMD_WIDTH); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH) { + AVX_Data src_4; + src_4.data = SIMD_LOAD(src + i); + + SIMD_STORE(dest + i, src_4.data); + } + } + +#endif + + if (param_size > rounded_size) { +#pragma omp parallel for + for (size_t k = rounded_size; k < param_size; k++) { dest[k] = src[k]; } + } +} + +static void helper_memcpy_4(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 2)) { + AVX_Data src_4[4]; + src_4[0].data = SIMD_LOAD(src + i); + src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); + src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); + src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); + + SIMD_STORE(dest + i, src_4[0].data); + SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); + } + } +#endif + if (param_size > rounded_size) + helper_memcpy_1((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); +} + +static void helper_mempcy_8(float* dest, float* src, size_t param_size) +{ + size_t rounded_size = 0; + +#if defined(__AVX512__) or defined(__AVX256__) + + rounded_size = ROUND_DOWN(param_size, (SIMD_WIDTH << 2)); + + for (size_t t = 0; t < rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > rounded_size) copy_size = rounded_size - t; + size_t offset = copy_size + t; +#pragma omp parallel for + for (size_t i = t; i < offset; i += (SIMD_WIDTH << 3)) { + AVX_Data src_4[8]; + src_4[0].data = SIMD_LOAD(src + i); + src_4[1].data = SIMD_LOAD(src + i + SIMD_WIDTH); + src_4[2].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 1)); + src_4[3].data = SIMD_LOAD(src + i + SIMD_WIDTH * 3); + src_4[4].data = SIMD_LOAD(src + i + (SIMD_WIDTH << 2)); + src_4[5].data = SIMD_LOAD(src + i + SIMD_WIDTH * 5); + src_4[6].data = SIMD_LOAD(src + i + SIMD_WIDTH * 6); + src_4[7].data = SIMD_LOAD(src + i + SIMD_WIDTH * 7); + + SIMD_STORE(dest + i, src_4[0].data); + SIMD_STORE(dest + i + SIMD_WIDTH, src_4[1].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 1), src_4[2].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 3, src_4[3].data); + SIMD_STORE(dest + i + (SIMD_WIDTH << 2), src_4[4].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 5, src_4[5].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 6, src_4[6].data); + SIMD_STORE(dest + i + SIMD_WIDTH * 7, src_4[7].data); + } + } +#endif + if (param_size > rounded_size) + helper_memcpy_4((dest + rounded_size), (src + rounded_size), (param_size - rounded_size)); +} + +int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src) +{ + auto dest_c = dest.contiguous(); + auto src_c = src.contiguous(); + + float* dest_ptr = (float*)dest_c.data_ptr(); + float* src_ptr = (float*)src_c.data_ptr(); + + helper_mempcy_8(dest_ptr, src_ptr, dest_c.size(0)); + + return 0; +} diff --git a/csrc/aio/py_lib/deepspeed_py_copy.h b/csrc/aio/py_lib/deepspeed_py_copy.h index 819d568bb92b..69b044851eca 100644 --- a/csrc/aio/py_lib/deepspeed_py_copy.h +++ b/csrc/aio/py_lib/deepspeed_py_copy.h @@ -1,42 +1,42 @@ - -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#if (__x86_64__ || __i386__) -#include -#include -#endif - -#include -#include -#include - -#define TILE (1024 * 1024 * 1024) - -#if defined(__AVX512__) -#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm512_loadu_ps(x) -#define SIMD_SET(x) _mm512_set1_ps(x) -#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm512_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_WIDTH 16 -#else -#if defined(__AVX256__) -#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm256_loadu_ps(x) -#define SIMD_SET(x) _mm256_set1_ps(x) -#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm256_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_WIDTH 8 -#endif -#endif - -int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src); + +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#include +#include +#include + +#define TILE (1024 * 1024 * 1024) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 +#else +#if defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#endif +#endif + +int deepspeed_py_memcpy(torch::Tensor& dest, const torch::Tensor& src); diff --git a/csrc/aio/py_lib/py_ds_aio.cpp b/csrc/aio/py_lib/py_ds_aio.cpp index eee2cba0a962..61f95cd99c0e 100755 --- a/csrc/aio/py_lib/py_ds_aio.cpp +++ b/csrc/aio/py_lib/py_ds_aio.cpp @@ -1,41 +1,41 @@ -/* -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality for swapping optimizer tensors to/from (NVMe) storage devices. -*/ - -#include -#include "deepspeed_py_aio_handle.h" -#include "deepspeed_py_copy.h" - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("aio_read", &deepspeed_py_aio_read, "DeepSpeed Asynchornous I/O Read"); - - m.def("aio_write", &deepspeed_py_aio_write, "DeepSpeed Asynchornous I/O Write"); - - m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); - - py::class_(m, "aio_handle") - .def(py::init()) - - .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) - .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) - .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) - .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) - .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) - - .def("read", &deepspeed_aio_handle_t::read) - .def("write", &deepspeed_aio_handle_t::write) - - .def("pread", &deepspeed_aio_handle_t::pread) - .def("pwrite", &deepspeed_aio_handle_t::pwrite) - - .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) - .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) - .def("async_pread", &deepspeed_aio_handle_t::async_pread) - .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) - - .def("wait", &deepspeed_aio_handle_t::wait); -} +/* +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality for swapping optimizer tensors to/from (NVMe) storage devices. +*/ + +#include +#include "deepspeed_py_aio_handle.h" +#include "deepspeed_py_copy.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("aio_read", &deepspeed_py_aio_read, "DeepSpeed Asynchornous I/O Read"); + + m.def("aio_write", &deepspeed_py_aio_write, "DeepSpeed Asynchornous I/O Write"); + + m.def("deepspeed_memcpy", &deepspeed_py_memcpy, "DeepSpeed Memory Copy"); + + py::class_(m, "aio_handle") + .def(py::init()) + + .def("get_block_size", &deepspeed_aio_handle_t::get_block_size) + .def("get_queue_depth", &deepspeed_aio_handle_t::get_queue_depth) + .def("get_single_submit", &deepspeed_aio_handle_t::get_single_submit) + .def("get_overlap_events", &deepspeed_aio_handle_t::get_overlap_events) + .def("get_thread_count", &deepspeed_aio_handle_t::get_thread_count) + + .def("read", &deepspeed_aio_handle_t::read) + .def("write", &deepspeed_aio_handle_t::write) + + .def("pread", &deepspeed_aio_handle_t::pread) + .def("pwrite", &deepspeed_aio_handle_t::pwrite) + + .def("sync_pread", &deepspeed_aio_handle_t::sync_pread) + .def("sync_pwrite", &deepspeed_aio_handle_t::sync_pwrite) + .def("async_pread", &deepspeed_aio_handle_t::async_pread) + .def("async_pwrite", &deepspeed_aio_handle_t::async_pwrite) + + .def("wait", &deepspeed_aio_handle_t::wait); +} diff --git a/csrc/aio/py_test/ds_aio_basic.py b/csrc/aio/py_test/ds_aio_basic.py index e9f399d502a8..cf70b6655e9c 100755 --- a/csrc/aio/py_test/ds_aio_basic.py +++ b/csrc/aio/py_test/ds_aio_basic.py @@ -1,144 +1,144 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import torch -import os -import time -from deepspeed.ops.aio import AsyncIOBuilder -from multiprocessing import Pool, Barrier -from test_ds_aio_utils import report_results, task_log, task_barrier - - -def pre_basic(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() - task_log( - tid, - f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' - ) - - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 - - return ctxt - - -def pre_basic_read(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, True) - return ctxt - - -def pre_basic_write(pool_params): - args, tid = pool_params - ctxt = pre_basic(args, tid, False) - return ctxt - - -def post_basic(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_basic_read(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_read(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_basic_write(pool_params): - args, tid, ctxt = pool_params - start_time = time.time() - AsyncIOBuilder().load().aio_write(ctxt['buffer'], - ctxt['file'], - args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - args.validate) - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_basic_read - schedule['post'] = post_basic - schedule['main'] = main_basic_read - else: - schedule['pre'] = pre_basic_write - schedule['post'] = post_basic - schedule['main'] = main_basic_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) - - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) - - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') - start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops - - -def _init_tasklet(b): - global aio_barrier - aio_barrier = b - - -def aio_basic_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) - - report_results(args, read_op, pool_results) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import torch +import os +import time +from deepspeed.ops.aio import AsyncIOBuilder +from multiprocessing import Pool, Barrier +from test_ds_aio_utils import report_results, task_log, task_barrier + + +def pre_basic(args, tid, read_op): + io_string = "Read" if read_op else "Write" + num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size + file = args.read_file if read_op else f'{args.write_file}.{tid}' + + task_log(tid, f'Allocate tensor of size {num_bytes} bytes') + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + task_log( + tid, + f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' + ) + + ctxt = {} + ctxt['file'] = file + ctxt['num_bytes'] = num_bytes + ctxt['buffer'] = buffer + ctxt['elapsed_sec'] = 0 + + return ctxt + + +def pre_basic_read(pool_params): + args, tid = pool_params + ctxt = pre_basic(args, tid, True) + return ctxt + + +def pre_basic_write(pool_params): + args, tid = pool_params + ctxt = pre_basic(args, tid, False) + return ctxt + + +def post_basic(pool_params): + _, _, ctxt = pool_params + ctxt["buffer"].detach() + ctxt["buffer"] = None + return ctxt + + +def main_basic_read(pool_params): + args, tid, ctxt = pool_params + start_time = time.time() + AsyncIOBuilder().load().aio_read(ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_basic_write(pool_params): + args, tid, ctxt = pool_params + start_time = time.time() + AsyncIOBuilder().load().aio_write(ctxt['buffer'], + ctxt['file'], + args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + args.validate) + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = pre_basic_read + schedule['post'] = post_basic + schedule['main'] = main_basic_read + else: + schedule['pre'] = pre_basic_write + schedule['post'] = post_basic + schedule['main'] = main_basic_write + + return schedule + + +def _aio_handle_tasklet(pool_params): + args, tid, read_op = pool_params + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, args.threads) + + # Run pre task + task_log(tid, f'running pre-task') + ctxt = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, args.threads) + + # Run main tasks in a loop + ctxt["main_task_sec"] = 0 + for i in range(args.loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + ctxt = schedule["main"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + stop_time = time.time() + ctxt["main_task_sec"] += stop_time - start_time + + # Run post task + task_log(tid, f'running post-task') + ctxt = schedule["post"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + + return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + + +def _init_tasklet(b): + global aio_barrier + aio_barrier = b + + +def aio_basic_multiprocessing(args, read_op): + b = Barrier(args.threads) + pool_params = [(args, p, read_op) for p in range(args.threads)] + with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + pool_results = p.map(_aio_handle_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/ds_aio_handle.py b/csrc/aio/py_test/ds_aio_handle.py index 68abbe80261b..947ee2e6cb63 100755 --- a/csrc/aio/py_test/ds_aio_handle.py +++ b/csrc/aio/py_test/ds_aio_handle.py @@ -1,176 +1,176 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import torch -import os -import time -from multiprocessing import Pool, Barrier -from deepspeed.ops.aio import AsyncIOBuilder -from test_ds_aio_utils import report_results, task_log, task_barrier - - -def pre_handle(args, tid, read_op): - io_string = "Read" if read_op else "Write" - num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size - file = args.read_file if read_op else f'{args.write_file}.{tid}' - - task_log(tid, f'Allocate tensor of size {num_bytes} bytes') - if args.gpu: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda') - else: - buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() - task_log( - tid, - f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' - ) - - io_parallel = args.io_parallel if args.io_parallel else 1 - handle = AsyncIOBuilder().load().aio_handle(args.block_size, - args.queue_depth, - args.single_submit, - args.overlap_events, - io_parallel) - task_log(tid, f'created deepspeed aio handle') - - ctxt = {} - ctxt['file'] = file - ctxt['num_bytes'] = num_bytes - ctxt['handle'] = handle - ctxt['buffer'] = buffer - ctxt['elapsed_sec'] = 0 - - return ctxt - - -def pre_handle_read(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, True) - return ctxt - - -def pre_handle_write(pool_params): - args, tid = pool_params - ctxt = pre_handle(args, tid, False) - return ctxt - - -def post_handle(pool_params): - _, _, ctxt = pool_params - ctxt["buffer"].detach() - ctxt["buffer"] = None - return ctxt - - -def main_parallel_read(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - - start_time = time.time() - ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_parallel_write(pool_params): - args, tid, ctxt = pool_params - handle = ctxt['handle'] - start_time = time.time() - ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) - assert ret != -1 - handle.wait() - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_handle_read(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - - start_time = time.time() - ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def main_handle_write(pool_parms): - args, tid, ctxt = pool_parms - handle = ctxt['handle'] - start_time = time.time() - ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) - assert ret != -1 - end_time = time.time() - ctxt['elapsed_sec'] += end_time - start_time - - return ctxt - - -def get_schedule(args, read_op): - schedule = {} - if read_op: - schedule['pre'] = pre_handle_read - schedule['post'] = post_handle - schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read - else: - schedule['pre'] = pre_handle_write - schedule['post'] = post_handle - schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write - - return schedule - - -def _aio_handle_tasklet(pool_params): - args, tid, read_op = pool_params - - # Create schedule - schedule = get_schedule(args, read_op) - task_log(tid, f'schedule = {schedule}') - task_barrier(aio_barrier, args.threads) - - # Run pre task - task_log(tid, f'running pre-task') - ctxt = schedule["pre"]((args, tid)) - task_barrier(aio_barrier, args.threads) - - # Run main tasks in a loop - ctxt["main_task_sec"] = 0 - for i in range(args.loops): - task_log(tid, f'running main task {i}') - start_time = time.time() - ctxt = schedule["main"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - stop_time = time.time() - ctxt["main_task_sec"] += stop_time - start_time - - # Run post task - task_log(tid, f'running post-task') - ctxt = schedule["post"]((args, tid, ctxt)) - task_barrier(aio_barrier, args.threads) - - return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops - - -def _init_tasklet(b): - global aio_barrier - aio_barrier = b - - -def aio_handle_multiprocessing(args, read_op): - b = Barrier(args.threads) - pool_params = [(args, p, read_op) for p in range(args.threads)] - with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: - pool_results = p.map(_aio_handle_tasklet, pool_params) - - report_results(args, read_op, pool_results) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import torch +import os +import time +from multiprocessing import Pool, Barrier +from deepspeed.ops.aio import AsyncIOBuilder +from test_ds_aio_utils import report_results, task_log, task_barrier + + +def pre_handle(args, tid, read_op): + io_string = "Read" if read_op else "Write" + num_bytes = os.path.getsize(args.read_file) if read_op else args.write_size + file = args.read_file if read_op else f'{args.write_file}.{tid}' + + task_log(tid, f'Allocate tensor of size {num_bytes} bytes') + if args.gpu: + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cuda') + else: + buffer = torch.empty(num_bytes, dtype=torch.uint8, device='cpu').pin_memory() + task_log( + tid, + f'{io_string} file {file} of size {num_bytes} bytes from buffer on device {buffer.device}' + ) + + io_parallel = args.io_parallel if args.io_parallel else 1 + handle = AsyncIOBuilder().load().aio_handle(args.block_size, + args.queue_depth, + args.single_submit, + args.overlap_events, + io_parallel) + task_log(tid, f'created deepspeed aio handle') + + ctxt = {} + ctxt['file'] = file + ctxt['num_bytes'] = num_bytes + ctxt['handle'] = handle + ctxt['buffer'] = buffer + ctxt['elapsed_sec'] = 0 + + return ctxt + + +def pre_handle_read(pool_params): + args, tid = pool_params + ctxt = pre_handle(args, tid, True) + return ctxt + + +def pre_handle_write(pool_params): + args, tid = pool_params + ctxt = pre_handle(args, tid, False) + return ctxt + + +def post_handle(pool_params): + _, _, ctxt = pool_params + ctxt["buffer"].detach() + ctxt["buffer"] = None + return ctxt + + +def main_parallel_read(pool_params): + args, tid, ctxt = pool_params + handle = ctxt['handle'] + + start_time = time.time() + ret = handle.pread(ctxt['buffer'], ctxt['file'], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_parallel_write(pool_params): + args, tid, ctxt = pool_params + handle = ctxt['handle'] + start_time = time.time() + ret = handle.pwrite(ctxt['buffer'], ctxt['file'], args.validate, True) + assert ret != -1 + handle.wait() + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_handle_read(pool_parms): + args, tid, ctxt = pool_parms + handle = ctxt['handle'] + + start_time = time.time() + ret = handle.read(ctxt['buffer'], ctxt['file'], args.validate) + assert ret != -1 + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def main_handle_write(pool_parms): + args, tid, ctxt = pool_parms + handle = ctxt['handle'] + start_time = time.time() + ret = handle.write(ctxt['buffer'], ctxt['file'], args.validate) + assert ret != -1 + end_time = time.time() + ctxt['elapsed_sec'] += end_time - start_time + + return ctxt + + +def get_schedule(args, read_op): + schedule = {} + if read_op: + schedule['pre'] = pre_handle_read + schedule['post'] = post_handle + schedule['main'] = main_parallel_read if args.io_parallel else main_handle_read + else: + schedule['pre'] = pre_handle_write + schedule['post'] = post_handle + schedule['main'] = main_parallel_write if args.io_parallel else main_handle_write + + return schedule + + +def _aio_handle_tasklet(pool_params): + args, tid, read_op = pool_params + + # Create schedule + schedule = get_schedule(args, read_op) + task_log(tid, f'schedule = {schedule}') + task_barrier(aio_barrier, args.threads) + + # Run pre task + task_log(tid, f'running pre-task') + ctxt = schedule["pre"]((args, tid)) + task_barrier(aio_barrier, args.threads) + + # Run main tasks in a loop + ctxt["main_task_sec"] = 0 + for i in range(args.loops): + task_log(tid, f'running main task {i}') + start_time = time.time() + ctxt = schedule["main"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + stop_time = time.time() + ctxt["main_task_sec"] += stop_time - start_time + + # Run post task + task_log(tid, f'running post-task') + ctxt = schedule["post"]((args, tid, ctxt)) + task_barrier(aio_barrier, args.threads) + + return ctxt["main_task_sec"], ctxt["elapsed_sec"], ctxt["num_bytes"] * args.loops + + +def _init_tasklet(b): + global aio_barrier + aio_barrier = b + + +def aio_handle_multiprocessing(args, read_op): + b = Barrier(args.threads) + pool_params = [(args, p, read_op) for p in range(args.threads)] + with Pool(processes=args.threads, initializer=_init_tasklet, initargs=(b, )) as p: + pool_results = p.map(_aio_handle_tasklet, pool_params) + + report_results(args, read_op, pool_results) diff --git a/csrc/aio/py_test/parse_aio_stats.py b/csrc/aio/py_test/parse_aio_stats.py index 3e4600a4666d..1921973e4f73 100755 --- a/csrc/aio/py_test/parse_aio_stats.py +++ b/csrc/aio/py_test/parse_aio_stats.py @@ -1,154 +1,154 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os -import argparse -import re - -READ_SPEED = 'read_speed' -WRITE_SPEED = 'write_speed' - -PERF_METRICS = [READ_SPEED, WRITE_SPEED] - -METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'} - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--log_dir', - type=str, - required=True, - help='Folder of statistics logs') - - parser.add_argument('--metric', - type=str, - required=True, - help='Performance metric to report: [read_speed|write_speed]') - - args = parser.parse_args() - print(f'args = {args}') - - return args - - -def extract_value(key, file): - INVALID_PREFIXES = ["ds"] - for p in INVALID_PREFIXES: - if key.startswith(p): - return key - try: - if key[0] in ['t', 'd', 'p']: - return int(key[1:]) - if key.startswith("bs"): - if key.endswith('K'): - v = key[2:].split('K') - return int(v[0]) * 1024 - elif key.endswith('M'): - v = key[2:].split('M') - return int(v[0]) * 1024 * 1024 - else: - return int(key[2:]) - except: - print(f"{file}: extract_value fails on {key}") - return None - - return key - - -def get_file_key(file): - f, _ = os.path.splitext(os.path.basename(file)) - fields = f.split('_') - values = [extract_value(k, file) for k in fields] - return tuple(values) - - -def get_thread_count(file): - f, _ = os.path.splitext(os.path.basename(file)) - fields = f.split('_') - for key in fields: - if key[0] == 't': - return int(key[1:]) - return 1 - - -""" -Extract performance metric from log file. -Sample file lines are: -Task Read Latency = 0.031647682189941406 sec -Task Read Speed = 12.342926020792527 GB/sec -E2E Read Latency = 0.031697988510131836 sec -E2E Read Speed = 12.323337169333062 GB/sec - -For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned -""" - - -def get_metric(file, metric): - thread_count = get_thread_count(file) - with open(file) as f: - for line in f.readlines(): - if line.startswith(METRIC_SEARCH[metric]): - if metric in [READ_SPEED, WRITE_SPEED]: - fields = line.split() - return float(fields[-2]) - else: - fields = line.split('=') - return float(fields[-1]) - - return None - - -def validate_args(args): - if not args.metric in PERF_METRICS: - print(f'{args.metric} is not a valid performance metrics') - return False - - if not os.path.isdir(args.log_dir): - print(f'{args.log_dir} folder is not existent') - return False - - return True - - -def get_results(log_files, metric): - results = {} - for f in log_files: - file_key = get_file_key(f) - value = get_metric(f, metric) - results[file_key] = value - - return results - - -def get_sorted_results(log_dir, metric): - log_files = [ - f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, - f)) - ] - - log_files_path = [os.path.join(log_dir, f) for f in log_files] - results = get_results(log_files_path, metric) - result_keys = list(results.keys()) - sorted_keys = sorted(result_keys) - return sorted_keys, results - - -def main(): - print("Parsing aio statistics") - args = parse_arguments() - - if not validate_args(args): - quit() - - sorted_keys, results = get_sorted_results(args.log_dir, args.metric) - for k in sorted_keys: - print(f'{k} = {results[k]}') - - -if __name__ == "__main__": - main() +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import argparse +import re + +READ_SPEED = 'read_speed' +WRITE_SPEED = 'write_speed' + +PERF_METRICS = [READ_SPEED, WRITE_SPEED] + +METRIC_SEARCH = {READ_SPEED: 'E2E Read Speed', WRITE_SPEED: 'E2E Write Speed'} + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--log_dir', + type=str, + required=True, + help='Folder of statistics logs') + + parser.add_argument('--metric', + type=str, + required=True, + help='Performance metric to report: [read_speed|write_speed]') + + args = parser.parse_args() + print(f'args = {args}') + + return args + + +def extract_value(key, file): + INVALID_PREFIXES = ["ds"] + for p in INVALID_PREFIXES: + if key.startswith(p): + return key + try: + if key[0] in ['t', 'd', 'p']: + return int(key[1:]) + if key.startswith("bs"): + if key.endswith('K'): + v = key[2:].split('K') + return int(v[0]) * 1024 + elif key.endswith('M'): + v = key[2:].split('M') + return int(v[0]) * 1024 * 1024 + else: + return int(key[2:]) + except: + print(f"{file}: extract_value fails on {key}") + return None + + return key + + +def get_file_key(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + values = [extract_value(k, file) for k in fields] + return tuple(values) + + +def get_thread_count(file): + f, _ = os.path.splitext(os.path.basename(file)) + fields = f.split('_') + for key in fields: + if key[0] == 't': + return int(key[1:]) + return 1 + + +""" +Extract performance metric from log file. +Sample file lines are: +Task Read Latency = 0.031647682189941406 sec +Task Read Speed = 12.342926020792527 GB/sec +E2E Read Latency = 0.031697988510131836 sec +E2E Read Speed = 12.323337169333062 GB/sec + +For the above sample, -metric = "read_speed" corresponds to "E2E Read Speed", and 12.32 will be returned +""" + + +def get_metric(file, metric): + thread_count = get_thread_count(file) + with open(file) as f: + for line in f.readlines(): + if line.startswith(METRIC_SEARCH[metric]): + if metric in [READ_SPEED, WRITE_SPEED]: + fields = line.split() + return float(fields[-2]) + else: + fields = line.split('=') + return float(fields[-1]) + + return None + + +def validate_args(args): + if not args.metric in PERF_METRICS: + print(f'{args.metric} is not a valid performance metrics') + return False + + if not os.path.isdir(args.log_dir): + print(f'{args.log_dir} folder is not existent') + return False + + return True + + +def get_results(log_files, metric): + results = {} + for f in log_files: + file_key = get_file_key(f) + value = get_metric(f, metric) + results[file_key] = value + + return results + + +def get_sorted_results(log_dir, metric): + log_files = [ + f for f in os.listdir(log_dir) if os.path.isfile(os.path.join(log_dir, + f)) + ] + + log_files_path = [os.path.join(log_dir, f) for f in log_files] + results = get_results(log_files_path, metric) + result_keys = list(results.keys()) + sorted_keys = sorted(result_keys) + return sorted_keys, results + + +def main(): + print("Parsing aio statistics") + args = parse_arguments() + + if not validate_args(args): + quit() + + sorted_keys, results = get_sorted_results(args.log_dir, args.metric) + for k in sorted_keys: + print(f'{k} = {results[k]}') + + +if __name__ == "__main__": + main() diff --git a/csrc/aio/py_test/test_ds_aio.py b/csrc/aio/py_test/test_ds_aio.py index db7c12d7f782..f97d3e676c03 100755 --- a/csrc/aio/py_test/test_ds_aio.py +++ b/csrc/aio/py_test/test_ds_aio.py @@ -1,101 +1,101 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os -import torch -import argparse -import time -import sys -from multiprocessing import Pool -import multiprocessing as mp -from ds_aio_basic import aio_basic_multiprocessing -from ds_aio_handle import aio_handle_multiprocessing -from test_ds_aio_utils import refine_args - - -def parse_arguments(): - parser = argparse.ArgumentParser() - - parser.add_argument('--read_file', type=str, default=None, help='Read file.') - - parser.add_argument('--write_file', type=str, default=None, help='Write file.') - - parser.add_argument('--write_size', - type=str, - default=None, - help='Number of bytes to write.') - - parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') - - parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') - - parser.add_argument('--threads', - type=int, - default=1, - help='Thread parallelism count.') - - parser.add_argument( - '--single_submit', - action='store_true', - help= - 'Submit I/O requests in singles (default is submit queue_depth amount at once.).' - ) - - parser.add_argument('--overlap_events', - action='store_true', - help='Overlap I/O submission and completion requests.') - - parser.add_argument('--validate', - action='store_true', - help='Perform validation in library.') - - parser.add_argument('--handle', action='store_true', help='Use AIO handle.') - - parser.add_argument('--loops', - type=int, - default=1, - help='Count of operation repetitions') - - parser.add_argument('--io_parallel', - type=int, - default=None, - help='Per iop parallelism') - - parser.add_argument('--gpu', action='store_true', help='Use GPU memory') - - args = parser.parse_args() - print(f'args = {args}') - return args - - -def validate_args(args): - if args.read_file and not os.path.isfile(args.read_file): - print(f'args validation error: {args.read_file} not found') - return False - - return True - - -def main(): - print(f'Testing deepspeed_aio python frontend') - - args = parse_arguments() - refine_args(args) - if not validate_args(args): - quit() - - mp.set_start_method('spawn') - multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing - if args.read_file: - multiprocess_function(args, True) - - if args.write_file: - multiprocess_function(args, False) - - -if __name__ == "__main__": - main() +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os +import torch +import argparse +import time +import sys +from multiprocessing import Pool +import multiprocessing as mp +from ds_aio_basic import aio_basic_multiprocessing +from ds_aio_handle import aio_handle_multiprocessing +from test_ds_aio_utils import refine_args + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument('--read_file', type=str, default=None, help='Read file.') + + parser.add_argument('--write_file', type=str, default=None, help='Write file.') + + parser.add_argument('--write_size', + type=str, + default=None, + help='Number of bytes to write.') + + parser.add_argument('--block_size', type=str, default='1M', help='I/O block size.') + + parser.add_argument('--queue_depth', type=int, default=32, help='I/O queue depth.') + + parser.add_argument('--threads', + type=int, + default=1, + help='Thread parallelism count.') + + parser.add_argument( + '--single_submit', + action='store_true', + help= + 'Submit I/O requests in singles (default is submit queue_depth amount at once.).' + ) + + parser.add_argument('--overlap_events', + action='store_true', + help='Overlap I/O submission and completion requests.') + + parser.add_argument('--validate', + action='store_true', + help='Perform validation in library.') + + parser.add_argument('--handle', action='store_true', help='Use AIO handle.') + + parser.add_argument('--loops', + type=int, + default=1, + help='Count of operation repetitions') + + parser.add_argument('--io_parallel', + type=int, + default=None, + help='Per iop parallelism') + + parser.add_argument('--gpu', action='store_true', help='Use GPU memory') + + args = parser.parse_args() + print(f'args = {args}') + return args + + +def validate_args(args): + if args.read_file and not os.path.isfile(args.read_file): + print(f'args validation error: {args.read_file} not found') + return False + + return True + + +def main(): + print(f'Testing deepspeed_aio python frontend') + + args = parse_arguments() + refine_args(args) + if not validate_args(args): + quit() + + mp.set_start_method('spawn') + multiprocess_function = aio_handle_multiprocessing if args.handle else aio_basic_multiprocessing + if args.read_file: + multiprocess_function(args, True) + + if args.write_file: + multiprocess_function(args, False) + + +if __name__ == "__main__": + main() diff --git a/csrc/aio/py_test/test_ds_aio_utils.py b/csrc/aio/py_test/test_ds_aio_utils.py index fa0f0f6be79a..c68dfdddc233 100755 --- a/csrc/aio/py_test/test_ds_aio_utils.py +++ b/csrc/aio/py_test/test_ds_aio_utils.py @@ -1,59 +1,59 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -Licensed under the MIT license. - -Functionality of swapping optimizer tensors to/from (NVMe) storage devices. -""" - -import os - -BYTES_PER_GB = 1024**3 -LOG_TIDS = [0] - - -def task_log(tid, msg): - if tid in LOG_TIDS: - print(f'tid {tid}: {msg}') - - -def task_barrier(barrier, num_parties): - assert barrier.parties == num_parties - barrier.wait() - assert barrier.broken == False - - -def report_results(args, read_op, pool_results): - #print(f'pool_results = {pool_results}') - io_string = 'Read' if read_op else 'Write' - if None in pool_results: - print(f'Failure in one of {args.threads} {io_string} processes') - return - - total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) - - task_latency_sec = max([sec for _, sec, _ in pool_results]) - task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB - print(f'Task {io_string} Latency = {task_latency_sec} sec') - print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') - - e2e_latency_sec = max([sec for sec, _, _ in pool_results]) - e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB - print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') - print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') - - -def refine_integer_value(value): - unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} - - if value[-1] in list(unit_dict.keys()): - int_value = int(value[:-1]) * unit_dict[value[-1]] - return int_value - return int(value) - - -def refine_args(args): - if args.write_size and type(args.write_size) == str: - args.write_size = refine_integer_value(args.write_size) - - if args.block_size and type(args.block_size) == str: - args.block_size = refine_integer_value(args.block_size) +""" +Copyright 2020 The Microsoft DeepSpeed Team +Licensed under the MIT license. + +Functionality of swapping optimizer tensors to/from (NVMe) storage devices. +""" + +import os + +BYTES_PER_GB = 1024**3 +LOG_TIDS = [0] + + +def task_log(tid, msg): + if tid in LOG_TIDS: + print(f'tid {tid}: {msg}') + + +def task_barrier(barrier, num_parties): + assert barrier.parties == num_parties + barrier.wait() + assert barrier.broken == False + + +def report_results(args, read_op, pool_results): + #print(f'pool_results = {pool_results}') + io_string = 'Read' if read_op else 'Write' + if None in pool_results: + print(f'Failure in one of {args.threads} {io_string} processes') + return + + total_bytes = sum([num_bytes for _, _, num_bytes in pool_results]) + + task_latency_sec = max([sec for _, sec, _ in pool_results]) + task_speed_GB = total_bytes / task_latency_sec / BYTES_PER_GB + print(f'Task {io_string} Latency = {task_latency_sec} sec') + print(f'Task {io_string} Speed = {task_speed_GB} GB/sec') + + e2e_latency_sec = max([sec for sec, _, _ in pool_results]) + e2e_speed_GB = total_bytes / e2e_latency_sec / BYTES_PER_GB + print(f'E2E {io_string} Latency = {e2e_latency_sec} sec') + print(f'E2E {io_string} Speed = {e2e_speed_GB} GB/sec') + + +def refine_integer_value(value): + unit_dict = {'K': 1024, 'M': 1024**2, 'G': 1024**3} + + if value[-1] in list(unit_dict.keys()): + int_value = int(value[:-1]) * unit_dict[value[-1]] + return int_value + return int(value) + + +def refine_args(args): + if args.write_size and type(args.write_size) == str: + args.write_size = refine_integer_value(args.write_size) + + if args.block_size and type(args.block_size) == str: + args.block_size = refine_integer_value(args.block_size) diff --git a/csrc/common/custom_cuda_kernel.cu b/csrc/common/custom_cuda_kernel.cu index dee09aac55b2..f7a2b5d480df 100644 --- a/csrc/common/custom_cuda_kernel.cu +++ b/csrc/common/custom_cuda_kernel.cu @@ -1,39 +1,39 @@ -#include "custom_cuda_layers.h" - -__global__ void param_update_kernel(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - - if (id < size) { output[id] = (__half)input[id]; } -} - -void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel<<>>(input, output, size); -} - -__global__ void param_update_kernel_half(const float* input, __half* output, int size) -{ - int id = blockIdx.x * blockDim.x + threadIdx.x; - __half2* output_cast = reinterpret_cast<__half2*>(output); - if (id < size) { - float input_f = input[id]; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - output_cast[id] = *input_h; - } -} - -void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) -{ - int threads = 1024; - size /= 2; - dim3 grid_dim((size - 1) / threads + 1); - dim3 block_dim(threads); - - param_update_kernel_half<<>>(input, output, size); -} +#include "custom_cuda_layers.h" + +__global__ void param_update_kernel(const float* input, __half* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + + if (id < size) { output[id] = (__half)input[id]; } +} + +void launch_param_update(const float* input, __half* output, int size, cudaStream_t stream) +{ + int threads = 1024; + + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel<<>>(input, output, size); +} + +__global__ void param_update_kernel_half(const float* input, __half* output, int size) +{ + int id = blockIdx.x * blockDim.x + threadIdx.x; + __half2* output_cast = reinterpret_cast<__half2*>(output); + if (id < size) { + float input_f = input[id]; + __half2* input_h = reinterpret_cast<__half2*>(&input_f); + output_cast[id] = *input_h; + } +} + +void launch_param_update_half(const float* input, __half* output, int size, cudaStream_t stream) +{ + int threads = 1024; + size /= 2; + dim3 grid_dim((size - 1) / threads + 1); + dim3 block_dim(threads); + + param_update_kernel_half<<>>(input, output, size); +} diff --git a/csrc/includes/Timer.h b/csrc/includes/Timer.h index 7c20854a056a..efc7fff84abb 100644 --- a/csrc/includes/Timer.h +++ b/csrc/includes/Timer.h @@ -1,47 +1,47 @@ - -#ifndef __TIMER_H__ -#define __TIMER_H__ - -#include -#include -#include "cuda.h" - -class GPUTimer { - cudaEvent_t start, stop; - -public: - GPUTimer() - { - cudaEventCreate(&start); - cudaEventCreate(&stop); - } - ~GPUTimer() - { - cudaEventDestroy(start); - cudaEventDestroy(stop); - } - inline void Record() { cudaEventRecord(start); } - inline void Elapsed(float& time_elapsed) - { - cudaEventRecord(stop); - cudaEventSynchronize(stop); - cudaEventElapsedTime(&time_elapsed, start, stop); - } -}; - -class CPUTimer { - std::chrono::high_resolution_clock::time_point start; - -public: - CPUTimer() : start(std::chrono::high_resolution_clock::now()) {} - inline void Reset() { start = std::chrono::high_resolution_clock::now(); } - inline float Elapsed() - { - auto temp = start; - start = std::chrono::high_resolution_clock::now(); - return (float)(std::chrono::duration_cast(start - temp).count() / - 1e3); - } -}; - -#endif + +#ifndef __TIMER_H__ +#define __TIMER_H__ + +#include +#include +#include "cuda.h" + +class GPUTimer { + cudaEvent_t start, stop; + +public: + GPUTimer() + { + cudaEventCreate(&start); + cudaEventCreate(&stop); + } + ~GPUTimer() + { + cudaEventDestroy(start); + cudaEventDestroy(stop); + } + inline void Record() { cudaEventRecord(start); } + inline void Elapsed(float& time_elapsed) + { + cudaEventRecord(stop); + cudaEventSynchronize(stop); + cudaEventElapsedTime(&time_elapsed, start, stop); + } +}; + +class CPUTimer { + std::chrono::high_resolution_clock::time_point start; + +public: + CPUTimer() : start(std::chrono::high_resolution_clock::now()) {} + inline void Reset() { start = std::chrono::high_resolution_clock::now(); } + inline float Elapsed() + { + auto temp = start; + start = std::chrono::high_resolution_clock::now(); + return (float)(std::chrono::duration_cast(start - temp).count() / + 1e3); + } +}; + +#endif diff --git a/csrc/includes/cpu_adagrad.h b/csrc/includes/cpu_adagrad.h index c398246f1d43..6bfd09bd2a36 100644 --- a/csrc/includes/cpu_adagrad.h +++ b/csrc/includes/cpu_adagrad.h @@ -1,147 +1,147 @@ -#pragma once - -#include -#include -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -#include "simd.h" - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - __half* dev_param = nullptr, \ - bool half_precision = false); - -class Adagrad_Optimizer { -public: - Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) - : _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false) - { - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); - } - ~Adagrad_Optimizer() - { - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); - } -#if defined(__AVX512__) or defined(__AVX256__) - template - void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t param_size, - __half* dev_param = nullptr, - bool half_precision = false); -#endif - STEP(1) - STEP(4) - STEP(8) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } - inline void IncrementStep(size_t step) - { - _step++; - if (_step != step) { _step = step; } - } - inline void update_state(float lr, float epsilon, float weight_decay) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; - } - -private: - float _alpha; - float _eps; - float _weight_decay; - - float _betta1_t; - float _betta2_t; - size_t _step; - - float* _doubled_buffer[2]; - bool _buf_index; - - cudaStream_t _streams[2]; -}; - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t new_rounded_size = 0; - AVX_Data eps_4; - eps_4.data = SIMD_SET(_eps); - - float step_size = -1 * _alpha; - AVX_Data step_size_4; - step_size_4.data = SIMD_SET(step_size); - - AVX_Data weight_decay4; - if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); - new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); - for (size_t t = 0; t < new_rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { - AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); - - AVX_Data momentum_4[span]; - simd_load(momentum_4, grads + i, false); - - AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); - - AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); - - if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } - - simd_fma(variance_4, grad_4, grad_4, variance_4); - simd_sqrt(grad_4, variance_4); - simd_add(grad_4, grad_4, eps_4); - simd_div(grad_4, momentum_4, grad_4); - simd_fma(param_4, grad_4, step_size_4, param_4); - - simd_store(_params + i, param_4, half_precision); - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } - simd_store(_exp_avg_sq + i, variance_4, false); - } - - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - *rounded_size = new_rounded_size; -} -#endif +#pragma once + +#include +#include +#include +#include +#include "cuda.h" +#include "custom_cuda_layers.h" +#include "simd.h" + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + __half* dev_param = nullptr, \ + bool half_precision = false); + +class Adagrad_Optimizer { +public: + Adagrad_Optimizer(float alpha = 1e-2, float eps = 1e-8, float weight_decay = 0) + : _alpha(alpha), _eps(eps), _weight_decay(weight_decay), _buf_index(false) + { + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = Context::Instance().GetCurrentStream(); + _streams[1] = Context::Instance().GetNewStream(); + } + ~Adagrad_Optimizer() + { + cudaFreeHost(_doubled_buffer[0]); + cudaFreeHost(_doubled_buffer[1]); + } +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t param_size, + __half* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } + inline void IncrementStep(size_t step) + { + _step++; + if (_step != step) { _step = step; } + } + inline void update_state(float lr, float epsilon, float weight_decay) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + } + +private: + float _alpha; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float* _doubled_buffer[2]; + bool _buf_index; + + cudaStream_t _streams[2]; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adagrad_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + AVX_Data weight_decay4; + if (_weight_decay > 0) weight_decay4.data = SIMD_SET(_weight_decay); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i, half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, grads + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i, half_precision); + + if (_weight_decay > 0) { simd_fma(grad_4, param_4, weight_decay4, grad_4); } + + simd_fma(variance_4, grad_4, grad_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_add(grad_4, grad_4, eps_4); + simd_div(grad_4, momentum_4, grad_4); + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + i, param_4, half_precision); + if (dev_params) { + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + } + simd_store(_exp_avg_sq + i, variance_4, false); + } + + if (dev_params) { + if (half_precision) + launch_param_update_half( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + else + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/includes/cpu_adam.h b/csrc/includes/cpu_adam.h index 88779ef5fa8a..9a4e80593f21 100644 --- a/csrc/includes/cpu_adam.h +++ b/csrc/includes/cpu_adam.h @@ -1,222 +1,222 @@ -#pragma once - -#include -#include -#include -#include -#include "cuda.h" -#include "custom_cuda_layers.h" -#include "simd.h" - -#define STEP(SPAN) \ - void Step_##SPAN(float* _params, \ - float* grads, \ - float* _exp_avg, \ - float* _exp_avg_sq, \ - size_t _param_size, \ - __half* dev_param = nullptr, \ - bool half_precision = false); - -class Adam_Optimizer { -public: - Adam_Optimizer(float alpha = 1e-3, - float betta1 = 0.9, - float betta2 = 0.999, - float eps = 1e-8, - float weight_decay = 0, - bool adamw_mode = true) - : _alpha(alpha), - _betta1(betta1), - _betta2(betta2), - _eps(eps), - _weight_decay(weight_decay), - _betta1_t(1.0), - _betta2_t(1.0), - _step(0), - _buf_index(false), - _adamw_mode(adamw_mode) - { - cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); - cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); - - _streams[0] = Context::Instance().GetCurrentStream(); - _streams[1] = Context::Instance().GetNewStream(); - } - ~Adam_Optimizer() - { - cudaFreeHost(_doubled_buffer[0]); - cudaFreeHost(_doubled_buffer[1]); - } -#if defined(__AVX512__) or defined(__AVX256__) - template - void Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t param_size, - __half* dev_param = nullptr, - bool half_precision = false); -#endif - STEP(1) - STEP(4) - STEP(8) - inline void SynchronizeStreams() - { - for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); - } - inline void IncrementStep(size_t step, float beta1, float beta2) - { - if (beta1 != _betta1 || beta2 != _betta2) { - _step = step; - _betta1 = beta1; - _betta2 = beta2; - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - } else { - _step++; - if (_step != step) { - _betta1_t = std::pow(_betta1, step); - _betta2_t = std::pow(_betta2, step); - _step = step; - } else { - _betta1_t *= _betta1; - _betta2_t *= _betta2; - } - } - } - inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) - { - _alpha = lr; - _eps = epsilon; - _weight_decay = weight_decay; - - _bias_correction1 = 1.0f; - _bias_correction2 = 1.0f; - if (bias_correction == 1) { - _bias_correction1 = 1 - _betta1_t; - _bias_correction2 = 1 / sqrt(1 - _betta2_t); - } - } - -private: - float _alpha; - float _betta1; - float _betta2; - float _eps; - float _weight_decay; - - float _betta1_t; - float _betta2_t; - size_t _step; - - float _bias_correction1; - float _bias_correction2; - - float* _doubled_buffer[2]; - bool _buf_index; - bool _adamw_mode; - - cudaStream_t _streams[2]; -}; - -#if defined(__AVX512__) or defined(__AVX256__) -template -void Adam_Optimizer::Step_AVX(size_t* rounded_size, - float* _params, - float* grads, - float* _exp_avg, - float* _exp_avg_sq, - size_t _param_size, - __half* dev_params, - bool half_precision) -{ - size_t new_rounded_size = 0; - - AVX_Data betta1_4; - betta1_4.data = SIMD_SET(_betta1); - AVX_Data betta2_4; - betta2_4.data = SIMD_SET(_betta2); - - float betta1_minus1 = 1 - _betta1; - float betta2_minus1 = 1 - _betta2; - AVX_Data betta1_minus1_4; - betta1_minus1_4.data = SIMD_SET(betta1_minus1); - AVX_Data betta2_minus1_4; - betta2_minus1_4.data = SIMD_SET(betta2_minus1); - - AVX_Data bias2_sqrt; - bias2_sqrt.data = SIMD_SET(_bias_correction2); - - AVX_Data eps_4; - eps_4.data = SIMD_SET(_eps); - - float step_size = -1 * _alpha / _bias_correction1; - AVX_Data step_size_4; - step_size_4.data = SIMD_SET(step_size); - - float w_decay = -1 * _alpha * _weight_decay; - AVX_Data weight_decay4; - if (_weight_decay > 0) - weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); - new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); - for (size_t t = 0; t < new_rounded_size; t += TILE) { - size_t copy_size = TILE; - if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; - size_t offset = copy_size + t; - if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } -#pragma omp parallel for - for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { - AVX_Data grad_4[span]; - simd_load(grad_4, grads + i, half_precision); - - AVX_Data momentum_4[span]; - simd_load(momentum_4, _exp_avg + i, false); - - AVX_Data variance_4[span]; - simd_load(variance_4, _exp_avg_sq + i, false); - - AVX_Data param_4[span]; - simd_load(param_4, _params + i, half_precision); - - if (_weight_decay > 0 && !_adamw_mode) { - simd_fma(grad_4, param_4, weight_decay4, grad_4); - } - - simd_mul(momentum_4, momentum_4, betta1_4); - simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); - simd_mul(variance_4, variance_4, betta2_4); - simd_mul(grad_4, grad_4, grad_4); - simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); - simd_sqrt(grad_4, variance_4); - simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); - simd_div(grad_4, momentum_4, grad_4); - - if (_weight_decay > 0 && _adamw_mode) { - simd_fma(param_4, param_4, weight_decay4, param_4); - } - - simd_fma(param_4, grad_4, step_size_4, param_4); - - simd_store(_params + i, param_4, half_precision); - if (dev_params) { - simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); - } - simd_store(_exp_avg + i, momentum_4, false); - simd_store(_exp_avg_sq + i, variance_4, false); - } - - if (dev_params) { - if (half_precision) - launch_param_update_half( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - else - launch_param_update( - _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); - - _buf_index = !_buf_index; - } - } - *rounded_size = new_rounded_size; -} -#endif +#pragma once + +#include +#include +#include +#include +#include "cuda.h" +#include "custom_cuda_layers.h" +#include "simd.h" + +#define STEP(SPAN) \ + void Step_##SPAN(float* _params, \ + float* grads, \ + float* _exp_avg, \ + float* _exp_avg_sq, \ + size_t _param_size, \ + __half* dev_param = nullptr, \ + bool half_precision = false); + +class Adam_Optimizer { +public: + Adam_Optimizer(float alpha = 1e-3, + float betta1 = 0.9, + float betta2 = 0.999, + float eps = 1e-8, + float weight_decay = 0, + bool adamw_mode = true) + : _alpha(alpha), + _betta1(betta1), + _betta2(betta2), + _eps(eps), + _weight_decay(weight_decay), + _betta1_t(1.0), + _betta2_t(1.0), + _step(0), + _buf_index(false), + _adamw_mode(adamw_mode) + { + cudaMallocHost((void**)_doubled_buffer, TILE * sizeof(float)); + cudaMallocHost((void**)(_doubled_buffer + 1), TILE * sizeof(float)); + + _streams[0] = Context::Instance().GetCurrentStream(); + _streams[1] = Context::Instance().GetNewStream(); + } + ~Adam_Optimizer() + { + cudaFreeHost(_doubled_buffer[0]); + cudaFreeHost(_doubled_buffer[1]); + } +#if defined(__AVX512__) or defined(__AVX256__) + template + void Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t param_size, + __half* dev_param = nullptr, + bool half_precision = false); +#endif + STEP(1) + STEP(4) + STEP(8) + inline void SynchronizeStreams() + { + for (int i = 0; i < 2; i++) cudaStreamSynchronize(_streams[i]); + } + inline void IncrementStep(size_t step, float beta1, float beta2) + { + if (beta1 != _betta1 || beta2 != _betta2) { + _step = step; + _betta1 = beta1; + _betta2 = beta2; + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + } else { + _step++; + if (_step != step) { + _betta1_t = std::pow(_betta1, step); + _betta2_t = std::pow(_betta2, step); + _step = step; + } else { + _betta1_t *= _betta1; + _betta2_t *= _betta2; + } + } + } + inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction) + { + _alpha = lr; + _eps = epsilon; + _weight_decay = weight_decay; + + _bias_correction1 = 1.0f; + _bias_correction2 = 1.0f; + if (bias_correction == 1) { + _bias_correction1 = 1 - _betta1_t; + _bias_correction2 = 1 / sqrt(1 - _betta2_t); + } + } + +private: + float _alpha; + float _betta1; + float _betta2; + float _eps; + float _weight_decay; + + float _betta1_t; + float _betta2_t; + size_t _step; + + float _bias_correction1; + float _bias_correction2; + + float* _doubled_buffer[2]; + bool _buf_index; + bool _adamw_mode; + + cudaStream_t _streams[2]; +}; + +#if defined(__AVX512__) or defined(__AVX256__) +template +void Adam_Optimizer::Step_AVX(size_t* rounded_size, + float* _params, + float* grads, + float* _exp_avg, + float* _exp_avg_sq, + size_t _param_size, + __half* dev_params, + bool half_precision) +{ + size_t new_rounded_size = 0; + + AVX_Data betta1_4; + betta1_4.data = SIMD_SET(_betta1); + AVX_Data betta2_4; + betta2_4.data = SIMD_SET(_betta2); + + float betta1_minus1 = 1 - _betta1; + float betta2_minus1 = 1 - _betta2; + AVX_Data betta1_minus1_4; + betta1_minus1_4.data = SIMD_SET(betta1_minus1); + AVX_Data betta2_minus1_4; + betta2_minus1_4.data = SIMD_SET(betta2_minus1); + + AVX_Data bias2_sqrt; + bias2_sqrt.data = SIMD_SET(_bias_correction2); + + AVX_Data eps_4; + eps_4.data = SIMD_SET(_eps); + + float step_size = -1 * _alpha / _bias_correction1; + AVX_Data step_size_4; + step_size_4.data = SIMD_SET(step_size); + + float w_decay = -1 * _alpha * _weight_decay; + AVX_Data weight_decay4; + if (_weight_decay > 0) + weight_decay4.data = (_adamw_mode ? SIMD_SET(w_decay) : SIMD_SET(_weight_decay)); + new_rounded_size = ROUND_DOWN(_param_size, SIMD_WIDTH * span); + for (size_t t = 0; t < new_rounded_size; t += TILE) { + size_t copy_size = TILE; + if ((t + TILE) > new_rounded_size) copy_size = new_rounded_size - t; + size_t offset = copy_size + t; + if ((t / TILE) >= 2) { cudaStreamSynchronize(_streams[_buf_index]); } +#pragma omp parallel for + for (size_t i = t; i < offset; i += SIMD_WIDTH * span) { + AVX_Data grad_4[span]; + simd_load(grad_4, grads + i, half_precision); + + AVX_Data momentum_4[span]; + simd_load(momentum_4, _exp_avg + i, false); + + AVX_Data variance_4[span]; + simd_load(variance_4, _exp_avg_sq + i, false); + + AVX_Data param_4[span]; + simd_load(param_4, _params + i, half_precision); + + if (_weight_decay > 0 && !_adamw_mode) { + simd_fma(grad_4, param_4, weight_decay4, grad_4); + } + + simd_mul(momentum_4, momentum_4, betta1_4); + simd_fma(momentum_4, grad_4, betta1_minus1_4, momentum_4); + simd_mul(variance_4, variance_4, betta2_4); + simd_mul(grad_4, grad_4, grad_4); + simd_fma(variance_4, grad_4, betta2_minus1_4, variance_4); + simd_sqrt(grad_4, variance_4); + simd_fma(grad_4, grad_4, bias2_sqrt, eps_4); + simd_div(grad_4, momentum_4, grad_4); + + if (_weight_decay > 0 && _adamw_mode) { + simd_fma(param_4, param_4, weight_decay4, param_4); + } + + simd_fma(param_4, grad_4, step_size_4, param_4); + + simd_store(_params + i, param_4, half_precision); + if (dev_params) { + simd_store(_doubled_buffer[_buf_index] + (i - t), param_4, half_precision); + } + simd_store(_exp_avg + i, momentum_4, false); + simd_store(_exp_avg_sq + i, variance_4, false); + } + + if (dev_params) { + if (half_precision) + launch_param_update_half( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + else + launch_param_update( + _doubled_buffer[_buf_index], dev_params + t, copy_size, _streams[_buf_index]); + + _buf_index = !_buf_index; + } + } + *rounded_size = new_rounded_size; +} +#endif diff --git a/csrc/includes/dropout.h b/csrc/includes/dropout.h index f6e32af5608d..a72572d0876e 100644 --- a/csrc/includes/dropout.h +++ b/csrc/includes/dropout.h @@ -1,76 +1,76 @@ -#pragma once - -#include -#include -#include - -template -class Dropout { -public: - struct Config { - float ratio; - uint32_t dim; - bool training; - - Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} - - float RATIO() const { return training ? ratio : 0.0; } - inline void SetDim(uint32_t d) { dim = d; } - }; - - Dropout(const Config& config) : _config(config), _mask(nullptr) {} - - virtual ~Dropout() {} - - void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false) - { - launch_dropout( - out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd); - } - - void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream) - { - launch_dropout(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); - } - - void ForwardWithBias(int bsz, - T* out, - const T* vals, - const T* residual, - const T* bias, - cudaStream_t stream) - { - launch_dropout( - out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); - } - - void Backward(int bsz, T* d_vals, cudaStream_t stream) - { - launch_dropout_grad(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); - } - - void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream) - { - launch_dropout_grad( - d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); - } - - bool HasDropout() const { return _config.RATIO() > 0.0; } - - void SetTrainingMode(bool training) { _config.training = training; } - - void SetMask(uint8_t* mask) - { - if (!mask) { throw std::runtime_error("Dropout mask is null."); } - - _mask = mask; - } - - Config GetConfig() const { return _config; } - - inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } - -private: - uint8_t* _mask; - Config _config; -}; +#pragma once + +#include +#include +#include + +template +class Dropout { +public: + struct Config { + float ratio; + uint32_t dim; + bool training; + + Config(float r, uint32_t d) : ratio(r), dim(d), training(true) {} + + float RATIO() const { return training ? ratio : 0.0; } + inline void SetDim(uint32_t d) { dim = d; } + }; + + Dropout(const Config& config) : _config(config), _mask(nullptr) {} + + virtual ~Dropout() {} + + void Forward(int bsz, T* out, const T* vals, cudaStream_t stream, bool bwd = false) + { + launch_dropout( + out, vals, _mask, bsz * _config.dim, _config.dim, _config.RATIO(), stream, bwd); + } + + void ForwardWithBias(int bsz, T* vals, const T* bias, cudaStream_t stream) + { + launch_dropout(vals, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void ForwardWithBias(int bsz, + T* out, + const T* vals, + const T* residual, + const T* bias, + cudaStream_t stream) + { + launch_dropout( + out, vals, residual, bias, _mask, bsz, _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals, cudaStream_t stream) + { + launch_dropout_grad(d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + void Backward(int bsz, T* d_vals_out, const T* d_vals, cudaStream_t stream) + { + launch_dropout_grad( + d_vals_out, d_vals, _mask, bsz * _config.dim, _config.RATIO(), stream); + } + + bool HasDropout() const { return _config.RATIO() > 0.0; } + + void SetTrainingMode(bool training) { _config.training = training; } + + void SetMask(uint8_t* mask) + { + if (!mask) { throw std::runtime_error("Dropout mask is null."); } + + _mask = mask; + } + + Config GetConfig() const { return _config; } + + inline void SetDimension(uint32_t dim) { _config.SetDim(dim); } + +private: + uint8_t* _mask; + Config _config; +}; diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h index 7b7379d9b998..fc4d5f90a203 100644 --- a/csrc/includes/feed_forward.h +++ b/csrc/includes/feed_forward.h @@ -1,93 +1,93 @@ -#ifndef __FEEDFORWARD_H__ -#define __FEEDFORWARD_H__ - -#include -#include -#include -#include "custom_cuda_layers.h" - -template -class FeedForward { -public: - struct Config { - int batchSize, outputSize; - int inputSize; - std::array gemm_algos; - Config(int batch, int outputs, int inputs, const std::array& algos) - : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos) - { - } - }; - - FeedForward(Config config) : config_(config) {} - - ~FeedForward() {} - - void Forward(int bsz, - const T* input_ptr, - const T* weights, - T* out, - cublasHandle_t& _cublasHandle) - { - float alpha = T(1.); - float beta = T(0.); - - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_T, - CUBLAS_OP_N, - config_.outputSize, - bsz, - config_.inputSize, - &alpha, - &beta, - weights, - input_ptr, - out, - cublasGemmAlgo_t(config_.gemm_algos[0])); - } - void Backward(int bsz, - const T* out_grad, - const T* input_ptr, - const T* weights, - T* weights_grad, - T* bias_grad, - cublasHandle_t& _cublasHandle, - cudaStream_t& stream, - T* inp_grad_out = nullptr, - T* out_grad_trans_out = nullptr) - { - float alpha = (T)1.0, beta = (T)0.0; - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_T, - config_.inputSize, - config_.outputSize, - bsz, - &alpha, - &beta, - input_ptr, - out_grad, - weights_grad, - cublasGemmAlgo_t(config_.gemm_algos[1])); - - cublas_gemm_ex(_cublasHandle, - CUBLAS_OP_N, - CUBLAS_OP_N, - config_.inputSize, - bsz, - config_.outputSize, - &alpha, - &beta, - weights, - out_grad, - inp_grad_out, - cublasGemmAlgo_t(config_.gemm_algos[2])); - - launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, config_.outputSize, stream); - } - -private: - Config config_; -}; - -#endif +#ifndef __FEEDFORWARD_H__ +#define __FEEDFORWARD_H__ + +#include +#include +#include +#include "custom_cuda_layers.h" + +template +class FeedForward { +public: + struct Config { + int batchSize, outputSize; + int inputSize; + std::array gemm_algos; + Config(int batch, int outputs, int inputs, const std::array& algos) + : batchSize(batch), outputSize(outputs), inputSize(inputs), gemm_algos(algos) + { + } + }; + + FeedForward(Config config) : config_(config) {} + + ~FeedForward() {} + + void Forward(int bsz, + const T* input_ptr, + const T* weights, + T* out, + cublasHandle_t& _cublasHandle) + { + float alpha = T(1.); + float beta = T(0.); + + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_T, + CUBLAS_OP_N, + config_.outputSize, + bsz, + config_.inputSize, + &alpha, + &beta, + weights, + input_ptr, + out, + cublasGemmAlgo_t(config_.gemm_algos[0])); + } + void Backward(int bsz, + const T* out_grad, + const T* input_ptr, + const T* weights, + T* weights_grad, + T* bias_grad, + cublasHandle_t& _cublasHandle, + cudaStream_t& stream, + T* inp_grad_out = nullptr, + T* out_grad_trans_out = nullptr) + { + float alpha = (T)1.0, beta = (T)0.0; + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_T, + config_.inputSize, + config_.outputSize, + bsz, + &alpha, + &beta, + input_ptr, + out_grad, + weights_grad, + cublasGemmAlgo_t(config_.gemm_algos[1])); + + cublas_gemm_ex(_cublasHandle, + CUBLAS_OP_N, + CUBLAS_OP_N, + config_.inputSize, + bsz, + config_.outputSize, + &alpha, + &beta, + weights, + out_grad, + inp_grad_out, + cublasGemmAlgo_t(config_.gemm_algos[2])); + + launch_fuse_transpose_bias_kernel(out_grad, bias_grad, bsz, config_.outputSize, stream); + } + +private: + Config config_; +}; + +#endif diff --git a/csrc/includes/gelu.h b/csrc/includes/gelu.h index 41cf6f2a68a7..560f4140ed61 100644 --- a/csrc/includes/gelu.h +++ b/csrc/includes/gelu.h @@ -1,36 +1,36 @@ -#pragma once - -#include -#include -#include -#include "custom_cuda_layers.h" - -template -class Gelu { -public: - struct Config { - uint32_t intermediate_size; - Config(uint32_t inter_size) : intermediate_size(inter_size) {} - }; - - Gelu(const Config& config) : _config(config) {} - - virtual ~Gelu() {} - - void ForwardWithBiasAdd(int bsz, - const T* input_buf, - const T* bias, - T* output, - cudaStream_t stream) - { - launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); - } - - void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) - { - launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); - } - -private: - Config _config; -}; +#pragma once + +#include +#include +#include +#include "custom_cuda_layers.h" + +template +class Gelu { +public: + struct Config { + uint32_t intermediate_size; + Config(uint32_t inter_size) : intermediate_size(inter_size) {} + }; + + Gelu(const Config& config) : _config(config) {} + + virtual ~Gelu() {} + + void ForwardWithBiasAdd(int bsz, + const T* input_buf, + const T* bias, + T* output, + cudaStream_t stream) + { + launch_bias_gelu(input_buf, bias, output, _config.intermediate_size, bsz, stream); + } + + void Backward(int bsz, T* d_output, const T* input_buf, const T* bias, cudaStream_t stream) + { + launch_d_gelu(d_output, input_buf, bias, _config.intermediate_size, bsz, stream); + } + +private: + Config _config; +}; diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h index b920896b419e..3bfeee35dcd8 100644 --- a/csrc/includes/gemm_test.h +++ b/csrc/includes/gemm_test.h @@ -1,293 +1,293 @@ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include "StopWatch.h" -#include "cublas_wrappers.h" - -template -void check(T result, char const* const func, const char* const file, int const line) -{ - if (result) { - std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + - " \n"); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) - -template -class GemmTest { -public: - GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) - : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) - { - check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K)); - check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N)); - check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N)); - } - - ~GemmTest() - { - check_cuda_error(cudaFree(A)); - check_cuda_error(cudaFree(B)); - check_cuda_error(cudaFree(C)); - } - - std::array TestAlgo(int loops) - { - float alpha = (T)1.0f; - float beta = (T)0.0f; - - int algo_fw = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - N, - M, - K, - &alpha, - &beta, - B, - A, - C, - static_cast(algo)); - }); - - int algo_bw1 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - K, - N, - M, - &alpha, - &beta, - A, - C, - B, - static_cast(algo)); - }); - - int algo_bw2 = Run(loops, [=](int algo) { - cublas_gemm_ex(handle, - CUBLAS_OP_N, - CUBLAS_OP_N, - K, - M, - N, - &alpha, - &beta, - B, - C, - A, - static_cast(algo)); - }); - - return std::array({algo_fw, algo_bw1, algo_bw2}); - } - - template - int Run(int loops, Func f) - { - float fast_latency = (std::numeric_limits::max)(); - int fast_algo = 0; - - for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; - algo++) { - int warm_up = 5; - for (int i = 0; i < warm_up; ++i) f(algo); - - cudaDeviceSynchronize(); - Stopwatch timer; - timer.Restart(); - - for (int i = 0; i < loops; ++i) f(algo); - - cudaDeviceSynchronize(); - timer.Stop(); - - float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; - - printf("algo-%d: %.3fms\n", algo, avg_latency); - - if (avg_latency < fast_latency) { - fast_latency = avg_latency; - fast_algo = algo; - } - } - - printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); - - return fast_algo; - } - -private: - int M, N, K; - cublasHandle_t handle; - cublasOperation_t transa, transb; - T *A, *B, *C; -}; - -template -class StridedGemmTest { -public: - StridedGemmTest(int b, - int m, - int n, - int k, - cublasOperation_t ta, - cublasOperation_t tb, - cublasHandle_t h) - : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) - { - check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz)); - check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz)); - check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz)); - } - - ~StridedGemmTest() - { - check_cuda_error(cudaFree(A)); - check_cuda_error(cudaFree(B)); - check_cuda_error(cudaFree(C)); - } - - std::array TestAlgo(int loops) - { - float alpha = (T)1.0f; - float beta = (T)0.0f; - - int algo_fw = Run(loops, [=](int algo) { - int stride_a = M * K; - int stride_b = N * K; - int stride_c = M * N; - - cublas_strided_batched_gemm(handle, - M, - N, - K, - &alpha, - &beta, - A, - B, - C, - transa, - transb, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - int algo_bw1 = Run(loops, [=](int algo) { - int mb = (transa == CUBLAS_OP_T ? K : M); - int kb = (transa == CUBLAS_OP_T ? M : K); - - int stride_a = mb * N; - int stride_b = N * kb; - int stride_c = M * K; - - // B need to transpose. - cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm(handle, - mb, - kb, - N, - &alpha, - &beta, - (transa == CUBLAS_OP_T ? B : C), - (transa == CUBLAS_OP_T ? C : B), - A, - CUBLAS_OP_N, - op_b, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - int algo_bw2 = Run(loops, [=](int algo) { - // A need to transpose. - cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - int stride_a = M * K; - int stride_b = M * N; - int stride_c = N * K; - - // Calculate d_B. - cublas_strided_batched_gemm(handle, - K, - N, - M, - &alpha, - &beta, - A, - C, - B, - op_a, - CUBLAS_OP_N, - stride_a, - stride_b, - stride_c, - bsz, - static_cast(algo)); - }); - - return std::array({algo_fw, algo_bw1, algo_bw2}); - } - - template - int Run(int loops, Func f) - { - float fast_latency = (std::numeric_limits::max)(); - int fast_algo = 0; - - for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; - algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; - algo++) { - int warm_up = 5; - for (int i = 0; i < warm_up; ++i) f(algo); - - cudaDeviceSynchronize(); - Stopwatch timer; - timer.Restart(); - - for (int i = 0; i < loops; ++i) f(algo); - - cudaDeviceSynchronize(); - timer.Stop(); - - float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; - - printf("algo-%d: %.3fms\n", algo, avg_latency); - - if (avg_latency < fast_latency) { - fast_latency = avg_latency; - fast_algo = algo; - } - } - - printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); - - return fast_algo; - } - -private: - int bsz, M, N, K; - cublasHandle_t handle; - cublasOperation_t transa, transb; - T *A, *B, *C; -}; + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "StopWatch.h" +#include "cublas_wrappers.h" + +template +void check(T result, char const* const func, const char* const file, int const line) +{ + if (result) { + std::cout << (std::string("CUDA runtime error: ") + +file + ":" + std::to_string(line) + + " \n"); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) + +template +class GemmTest { +public: + GemmTest(int m, int n, int k, cublasOperation_t ta, cublasOperation_t tb, cublasHandle_t h) + : M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K)); + check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N)); + check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N)); + } + + ~GemmTest() + { + check_cuda_error(cudaFree(A)); + check_cuda_error(cudaFree(B)); + check_cuda_error(cudaFree(C)); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_T, + CUBLAS_OP_N, + N, + M, + K, + &alpha, + &beta, + B, + A, + C, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_N, + CUBLAS_OP_T, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + cublas_gemm_ex(handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + K, + M, + N, + &alpha, + &beta, + B, + C, + A, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int M, N, K; + cublasHandle_t handle; + cublasOperation_t transa, transb; + T *A, *B, *C; +}; + +template +class StridedGemmTest { +public: + StridedGemmTest(int b, + int m, + int n, + int k, + cublasOperation_t ta, + cublasOperation_t tb, + cublasHandle_t h) + : bsz(b), M(m), N(n), K(k), transa(ta), transb(tb), handle(h) + { + check_cuda_error(cudaMalloc((void**)&A, sizeof(T) * M * K * bsz)); + check_cuda_error(cudaMalloc((void**)&B, sizeof(T) * K * N * bsz)); + check_cuda_error(cudaMalloc((void**)&C, sizeof(T) * M * N * bsz)); + } + + ~StridedGemmTest() + { + check_cuda_error(cudaFree(A)); + check_cuda_error(cudaFree(B)); + check_cuda_error(cudaFree(C)); + } + + std::array TestAlgo(int loops) + { + float alpha = (T)1.0f; + float beta = (T)0.0f; + + int algo_fw = Run(loops, [=](int algo) { + int stride_a = M * K; + int stride_b = N * K; + int stride_c = M * N; + + cublas_strided_batched_gemm(handle, + M, + N, + K, + &alpha, + &beta, + A, + B, + C, + transa, + transb, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw1 = Run(loops, [=](int algo) { + int mb = (transa == CUBLAS_OP_T ? K : M); + int kb = (transa == CUBLAS_OP_T ? M : K); + + int stride_a = mb * N; + int stride_b = N * kb; + int stride_c = M * K; + + // B need to transpose. + cublasOperation_t op_b = (transb == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm(handle, + mb, + kb, + N, + &alpha, + &beta, + (transa == CUBLAS_OP_T ? B : C), + (transa == CUBLAS_OP_T ? C : B), + A, + CUBLAS_OP_N, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + int algo_bw2 = Run(loops, [=](int algo) { + // A need to transpose. + cublasOperation_t op_a = (transa == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + int stride_a = M * K; + int stride_b = M * N; + int stride_c = N * K; + + // Calculate d_B. + cublas_strided_batched_gemm(handle, + K, + N, + M, + &alpha, + &beta, + A, + C, + B, + op_a, + CUBLAS_OP_N, + stride_a, + stride_b, + stride_c, + bsz, + static_cast(algo)); + }); + + return std::array({algo_fw, algo_bw1, algo_bw2}); + } + + template + int Run(int loops, Func f) + { + float fast_latency = (std::numeric_limits::max)(); + int fast_algo = 0; + + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; + algo++) { + int warm_up = 5; + for (int i = 0; i < warm_up; ++i) f(algo); + + cudaDeviceSynchronize(); + Stopwatch timer; + timer.Restart(); + + for (int i = 0; i < loops; ++i) f(algo); + + cudaDeviceSynchronize(); + timer.Stop(); + + float avg_latency = (float)timer.GetTimeInSeconds() * 1000 / loops; + + printf("algo-%d: %.3fms\n", algo, avg_latency); + + if (avg_latency < fast_latency) { + fast_latency = avg_latency; + fast_algo = algo; + } + } + + printf("fast_algo %d: %.3f ms\n", fast_algo, fast_latency); + + return fast_algo; + } + +private: + int bsz, M, N, K; + cublasHandle_t handle; + cublasOperation_t transa, transb; + T *A, *B, *C; +}; diff --git a/csrc/includes/general_kernels.h b/csrc/includes/general_kernels.h index 588cf2aaa048..90e15b770637 100644 --- a/csrc/includes/general_kernels.h +++ b/csrc/includes/general_kernels.h @@ -1,47 +1,47 @@ -#include -#include -#include -#include - -#include -#include - -#include "context.h" -#include "cublas_wrappers.h" - -#define THREADS 256 -#define TILE_DIM 32 - -#define minus_infinity -1 * std::numeric_limits::infinity() - -#define FINAL_MASK 0xffffffff - -template -void launch_fused_add2(T* out, - const T* inp1, - const T* inp2, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); - -template -void launch_fused_add4(T* out, - const T* inp1, - const T* inp2, - const T* inp3, - const T* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); - -template -void launch_fused_add3(T* out, - const T* inp1, - const T* inp2, - const T* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream); +#include +#include +#include +#include + +#include +#include + +#include "context.h" +#include "cublas_wrappers.h" + +#define THREADS 256 +#define TILE_DIM 32 + +#define minus_infinity -1 * std::numeric_limits::infinity() + +#define FINAL_MASK 0xffffffff + +template +void launch_fused_add2(T* out, + const T* inp1, + const T* inp2, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); + +template +void launch_fused_add4(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + const T* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); + +template +void launch_fused_add3(T* out, + const T* inp1, + const T* inp2, + const T* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream); diff --git a/csrc/includes/normalize_layer.h b/csrc/includes/normalize_layer.h index e18e01a33893..b4d135ec495e 100644 --- a/csrc/includes/normalize_layer.h +++ b/csrc/includes/normalize_layer.h @@ -1,202 +1,202 @@ -#pragma once - -#include -#include -#include -#include -#include "custom_cuda_layers.h" - -using namespace std; - -template -class Normalize_Layer { -public: - struct Config { - uint32_t batchSize; - uint32_t seqLength; - uint32_t hiddenDim; - float epsilon; - bool training; - bool useMean; - Config(uint32_t batch, - uint32_t seq, - uint32_t h, - float epsilon = 1e-12, - bool training = true, - bool useMean = true) - : batchSize(batch), - seqLength(seq), - hiddenDim(h), - epsilon(epsilon), - training(training), - useMean(useMean) - { - } - }; - - Normalize_Layer(Config config) - : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) - { - } - - ~Normalize_Layer() {} - - void ForwardCheckpoint(int bsz, // batch * seq - T* vals, - const T* residual, - const T* gamma, - const T* betta, - cudaStream_t& stream, - bool preLayerNorm = false) - { - launch_bias_residual_layer_norm(vals, - residual, - gamma, - betta, - config_.epsilon, - bsz, - config_.hiddenDim, - stream, - preLayerNorm, - config_.training, - vars, - means); - } - - void Forward(int bsz, - T* vals, - const T* residual, - const T* gamma, - const T* betta, - cudaStream_t& stream, - bool preLayerNorm = false) - { - launch_bias_residual_layer_norm(vals, - residual, - gamma, - betta, - config_.epsilon, - bsz, - config_.hiddenDim, - stream, - preLayerNorm, - config_.training, - vars); - } - - void Backward(int bsz, - const T* out_grad, - const T* gamma, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_in = nullptr) - { - launch_layerNorm_backward(out_grad, - norm_in, - vars, - means, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream); - } - - void Backward(int bsz, - const T* out_grad, - const T* gamma, - const T* betta, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_out) - { - launch_layerNorm_backward(out_grad, - norm_out, - vars, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream, - !config_.useMean, - betta); - } - - void BackwardFusedAdd(int bsz, - const T* out_grad1, - const T* out_grad2, - const T* gamma, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_in = nullptr) - { - launch_layerNorm_backward_fused_add(out_grad1, - out_grad2, - norm_in, - vars, - means, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream); - } - - void BackwardFusedAdd(int bsz, - const T* out_grad1, - const T* out_grad2, - const T* gamma, - const T* betta, - T* gamma_grad, - T* betta_grad, - cudaStream_t stream[2], - T* inp_grad_out, - const T* norm_out) - { - launch_layerNorm_backward_fused_add(out_grad1, - out_grad2, - norm_out, - vars, - gamma, - gamma_grad, - betta_grad, - inp_grad_out, - bsz, - config_.hiddenDim, - stream, - !config_.useMean, - betta); - } - - inline bool UseMean() const { return config_.useMean; } - - inline void SetVar(T* variance) - { - if (!variance) { throw std::runtime_error("Normalize variance is null."); } - vars = variance; - } - - inline void SetMean(T* mean) - { - if (!mean) { throw std::runtime_error("Normalize mean is null."); } - means = mean; - } - -private: - Config config_; - T* vars; - T* means; - T* vals_hat; -}; +#pragma once + +#include +#include +#include +#include +#include "custom_cuda_layers.h" + +using namespace std; + +template +class Normalize_Layer { +public: + struct Config { + uint32_t batchSize; + uint32_t seqLength; + uint32_t hiddenDim; + float epsilon; + bool training; + bool useMean; + Config(uint32_t batch, + uint32_t seq, + uint32_t h, + float epsilon = 1e-12, + bool training = true, + bool useMean = true) + : batchSize(batch), + seqLength(seq), + hiddenDim(h), + epsilon(epsilon), + training(training), + useMean(useMean) + { + } + }; + + Normalize_Layer(Config config) + : config_(config), vars(nullptr), means(nullptr), vals_hat(nullptr) + { + } + + ~Normalize_Layer() {} + + void ForwardCheckpoint(int bsz, // batch * seq + T* vals, + const T* residual, + const T* gamma, + const T* betta, + cudaStream_t& stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars, + means); + } + + void Forward(int bsz, + T* vals, + const T* residual, + const T* gamma, + const T* betta, + cudaStream_t& stream, + bool preLayerNorm = false) + { + launch_bias_residual_layer_norm(vals, + residual, + gamma, + betta, + config_.epsilon, + bsz, + config_.hiddenDim, + stream, + preLayerNorm, + config_.training, + vars); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward(out_grad, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void Backward(int bsz, + const T* out_grad, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward(out_grad, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_in = nullptr) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_in, + vars, + means, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream); + } + + void BackwardFusedAdd(int bsz, + const T* out_grad1, + const T* out_grad2, + const T* gamma, + const T* betta, + T* gamma_grad, + T* betta_grad, + cudaStream_t stream[2], + T* inp_grad_out, + const T* norm_out) + { + launch_layerNorm_backward_fused_add(out_grad1, + out_grad2, + norm_out, + vars, + gamma, + gamma_grad, + betta_grad, + inp_grad_out, + bsz, + config_.hiddenDim, + stream, + !config_.useMean, + betta); + } + + inline bool UseMean() const { return config_.useMean; } + + inline void SetVar(T* variance) + { + if (!variance) { throw std::runtime_error("Normalize variance is null."); } + vars = variance; + } + + inline void SetMean(T* mean) + { + if (!mean) { throw std::runtime_error("Normalize mean is null."); } + means = mean; + } + +private: + Config config_; + T* vars; + T* means; + T* vals_hat; +}; diff --git a/csrc/includes/simd.h b/csrc/includes/simd.h index 8c6533f3db6f..216d281a2672 100644 --- a/csrc/includes/simd.h +++ b/csrc/includes/simd.h @@ -1,137 +1,137 @@ -#pragma once - -#if (__x86_64__ || __i386__) -#include -#include -#endif - -#define TILE (128 * 1024 * 1024) -#if defined(__AVX512__) or defined(__AVX256__) - -#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) - -#if defined(__AVX512__) -#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm512_loadu_ps(x) -#define SIMD_SET(x) _mm512_set1_ps(x) -#define SIMD_ADD(x, y) _mm512_add_ps(x, y) -#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm512_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm512_div_ps(x, y) -#define SIMD_WIDTH 16 - -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x)) -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm512_storeu_ps(x, d)) - -#define INTV __m256i -#elif defined(__AVX256__) -#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) -#define SIMD_LOAD(x) _mm256_loadu_ps(x) -#define SIMD_SET(x) _mm256_set1_ps(x) -#define SIMD_ADD(x, y) _mm256_add_ps(x, y) -#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) -#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) -#define SIMD_SQRT(x) _mm256_sqrt_ps(x) -#define SIMD_DIV(x, y) _mm256_div_ps(x, y) -#define SIMD_WIDTH 8 -#define SIMD_LOAD2(x, h) \ - ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) - -#define SIMD_STORE2(x, d, h) \ - ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ - : _mm256_storeu_ps(x, d)) - -#define INTV __m128i -#endif - -union AVX_Data { -#if defined(__AVX512__) - __m512 data; -#elif defined(__AVX256__) - __m256 data; -#endif - // float data_f[16]; -}; - -template -inline void simd_store(float* dst, AVX_Data* src, bool half_precision) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision); - } -} -template -inline void simd_load(AVX_Data* dst, float* src, bool half_precision) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); - } -} -template -inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { - dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); - } -} -template -inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } -} -template -inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } -} -template -inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) -{ -#pragma omp parallel for - for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } -} - -#endif +#pragma once + +#if (__x86_64__ || __i386__) +#include +#include +#endif + +#define TILE (128 * 1024 * 1024) +#if defined(__AVX512__) or defined(__AVX256__) + +#define ROUND_DOWN(size, step) ((size) & ~((step)-1)) + +#if defined(__AVX512__) +#define SIMD_STORE(a, d) _mm512_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm512_loadu_ps(x) +#define SIMD_SET(x) _mm512_set1_ps(x) +#define SIMD_ADD(x, y) _mm512_add_ps(x, y) +#define SIMD_MUL(x, y) _mm512_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm512_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm512_div_ps(x, y) +#define SIMD_WIDTH 16 + +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)x)) : _mm512_loadu_ps(x)) +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm256_store_ps(x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm512_storeu_ps(x, d)) + +#define INTV __m256i +#elif defined(__AVX256__) +#define SIMD_STORE(a, d) _mm256_storeu_ps(a, d) +#define SIMD_LOAD(x) _mm256_loadu_ps(x) +#define SIMD_SET(x) _mm256_set1_ps(x) +#define SIMD_ADD(x, y) _mm256_add_ps(x, y) +#define SIMD_MUL(x, y) _mm256_mul_ps(x, y) +#define SIMD_FMA(x, y, c) _mm256_fmadd_ps(x, y, c) +#define SIMD_SQRT(x) _mm256_sqrt_ps(x) +#define SIMD_DIV(x, y) _mm256_div_ps(x, y) +#define SIMD_WIDTH 8 +#define SIMD_LOAD2(x, h) \ + ((h) ? _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)x)) : _mm256_loadu_ps(x)) + +#define SIMD_STORE2(x, d, h) \ + ((h) ? _mm_store_ps(x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) \ + : _mm256_storeu_ps(x, d)) + +#define INTV __m128i +#endif + +union AVX_Data { +#if defined(__AVX512__) + __m512 data; +#elif defined(__AVX256__) + __m256 data; +#endif + // float data_f[16]; +}; + +template +inline void simd_store(float* dst, AVX_Data* src, bool half_precision) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + SIMD_STORE2(dst + SIMD_WIDTH * i, src[i].data, half_precision); + } +} +template +inline void simd_load(AVX_Data* dst, float* src, bool half_precision) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_LOAD2(src + SIMD_WIDTH * i, half_precision); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data* src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a[i].data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data src_m_r, AVX_Data src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r.data, src_a.data); + } +} +template +inline void simd_fma(AVX_Data* dst, AVX_Data* src_m_l, AVX_Data* src_m_r, AVX_Data* src_a) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { + dst[i].data = SIMD_FMA(src_m_l[i].data, src_m_r[i].data, src_a[i].data); + } +} +template +inline void simd_sqrt(AVX_Data* dst, AVX_Data* src) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_SQRT(src[i].data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_add(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_ADD(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r.data); } +} +template +inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_MUL(src_a_l[i].data, src_a_r[i].data); } +} +template +inline void simd_div(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data* src_a_r) +{ +#pragma omp parallel for + for (size_t i = 0; i < span; ++i) { dst[i].data = SIMD_DIV(src_a_l[i].data, src_a_r[i].data); } +} + +#endif diff --git a/csrc/includes/softmax.h b/csrc/includes/softmax.h index ca3cf8a8ee6a..8d541a6fe0e6 100755 --- a/csrc/includes/softmax.h +++ b/csrc/includes/softmax.h @@ -1,60 +1,60 @@ -#pragma once - -#include -#include -#include -#include "custom_cuda_layers.h" - -#include - -using namespace std; - -template -class Softmax { -public: - struct Config { - size_t batchSize; - size_t heads; - size_t seq_length; - size_t prob_depth; - float temperature; - bool mem_alloc; - Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) - : batchSize(batch), - heads(h), - seq_length(seq), - prob_depth(prob_size), - temperature(1.0), - mem_alloc(mem_alloc) - { - } - }; - - Softmax(Config config) : config_(config) {} - - ~Softmax() {} - - void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream) - { - launch_attn_softmax(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); - } - - void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream) - { - launch_attn_softmax_backward_v2( - out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); - } - - inline size_t GetProbDepth() const { return config_.prob_depth; } - - inline size_t GetBatchSize() const { return config_.batchSize; } - - inline size_t GetNumHeads() const { return config_.heads; } - - inline size_t GetSeqLength() const { return config_.seq_length; } - - inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } - -private: - Config config_; -}; +#pragma once + +#include +#include +#include +#include "custom_cuda_layers.h" + +#include + +using namespace std; + +template +class Softmax { +public: + struct Config { + size_t batchSize; + size_t heads; + size_t seq_length; + size_t prob_depth; + float temperature; + bool mem_alloc; + Config(size_t batch, size_t h, size_t seq, int prob_size = 0, bool mem_alloc = false) + : batchSize(batch), + heads(h), + seq_length(seq), + prob_depth(prob_size), + temperature(1.0), + mem_alloc(mem_alloc) + { + } + }; + + Softmax(Config config) : config_(config) {} + + ~Softmax() {} + + void Forward(int bsz, T* vals, const T* attn_mask, cudaStream_t& stream) + { + launch_attn_softmax(vals, attn_mask, bsz, config_.heads, config_.seq_length, stream); + } + + void Backward(int bsz, T* out_grad, const T* soft_out, cudaStream_t stream) + { + launch_attn_softmax_backward_v2( + out_grad, soft_out, bsz, config_.heads, config_.seq_length, stream); + } + + inline size_t GetProbDepth() const { return config_.prob_depth; } + + inline size_t GetBatchSize() const { return config_.batchSize; } + + inline size_t GetNumHeads() const { return config_.heads; } + + inline size_t GetSeqLength() const { return config_.seq_length; } + + inline void SetSeqLength(size_t seq_len) { config_.seq_length = seq_len; } + +private: + Config config_; +}; diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index 44a1b313b986..3a9ad65bc8ee 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -1,179 +1,179 @@ -#pragma once - -#include -#include -#include -#include "context.h" - -template -class StridedBatchGemm { -public: - struct Config { - int batch_size; - int m; - int n; - int k; - float alpha; - float beta; - cublasOperation_t op_A; - cublasOperation_t op_B; - std::array gemm_algos; - - Config(int batch, - int mm, - int nn, - int kk, - float param_alpha, - float param_beta, - cublasOperation_t opA, - cublasOperation_t opB, - const std::array& algos) - : batch_size(batch), - m(mm), - n(nn), - k(kk), - alpha(param_alpha), - beta(param_beta), - op_A(opA), - op_B(opB), - gemm_algos(algos) - { - } - void SetConfig(int mm, int nn, int kk) - { - m = mm; - n = nn; - k = kk; - } - }; - - StridedBatchGemm(const Config& config) : _config(config) {} - - virtual ~StridedBatchGemm() {} - - void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) - { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm(handle, - _config.m, - _config.n, - _config.k, - &_config.alpha, - &_config.beta, - _buffer_a, - _buffer_b, - output, - _config.op_A, - _config.op_B, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[0])); - } - - void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) - { - int stride_a = _config.m * _config.k; - int stride_b = _config.n * _config.k; - int stride_c = _config.m * _config.n; - - cublas_strided_batched_gemm(handle, - _config.m, - _config.n, - _config.k, - &_config.alpha, - &_config.beta, - _buffer_a, - _buffer_b, - output, - _config.op_A, - _config.op_B, - stride_a, - stride_b, - stride_c, - _config.batch_size, - cublasGemmAlgo_t(_config.gemm_algos[0])); - - k_buf = _buffer_a; - q_buf = _buffer_b; - } - - void Backward(int bsz, - const T* d_output, - const T* _buffer_a, - const T* _buffer_b, - cublasHandle_t handle, - T* inpGradA = nullptr, - T* inpGradB = nullptr) - { - int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); - int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); - - int stride_a = mb * _config.n; - int stride_b = _config.n * kb; - int stride_c = _config.m * _config.k; - - // B need to transpose. - cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - // Calculate d_A. - cublas_strided_batched_gemm(handle, - mb, - kb, - _config.n, - &_config.alpha, - &_config.beta, - (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), - (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), - inpGradA, - CUBLAS_OP_N, - op_b, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[1])); - - // A need to transpose. - cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); - - stride_a = _config.m * _config.k; - stride_b = _config.m * _config.n; - stride_c = _config.n * _config.k; - - // Calculate d_B. - cublas_strided_batched_gemm(handle, - _config.k, - _config.n, - _config.m, - &_config.alpha, - &_config.beta, - _buffer_a, - d_output, - inpGradB, - op_a, - CUBLAS_OP_N, - stride_a, - stride_b, - stride_c, - bsz, - cublasGemmAlgo_t(_config.gemm_algos[2])); - } - - inline int GetN() const { return _config.k; } - - inline const T* GetBufferA() const { return k_buf; } - - inline const T* GetBufferB() const { return q_buf; } - - inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } - -private: - Config _config; - const T* q_buf; - const T* k_buf; -}; +#pragma once + +#include +#include +#include +#include "context.h" + +template +class StridedBatchGemm { +public: + struct Config { + int batch_size; + int m; + int n; + int k; + float alpha; + float beta; + cublasOperation_t op_A; + cublasOperation_t op_B; + std::array gemm_algos; + + Config(int batch, + int mm, + int nn, + int kk, + float param_alpha, + float param_beta, + cublasOperation_t opA, + cublasOperation_t opB, + const std::array& algos) + : batch_size(batch), + m(mm), + n(nn), + k(kk), + alpha(param_alpha), + beta(param_beta), + op_A(opA), + op_B(opB), + gemm_algos(algos) + { + } + void SetConfig(int mm, int nn, int kk) + { + m = mm; + n = nn; + k = kk; + } + }; + + StridedBatchGemm(const Config& config) : _config(config) {} + + virtual ~StridedBatchGemm() {} + + void Forward(int bsz, T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + &_config.alpha, + &_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[0])); + } + + void ForwardPlusSave(T* output, const T* _buffer_a, const T* _buffer_b, cublasHandle_t handle) + { + int stride_a = _config.m * _config.k; + int stride_b = _config.n * _config.k; + int stride_c = _config.m * _config.n; + + cublas_strided_batched_gemm(handle, + _config.m, + _config.n, + _config.k, + &_config.alpha, + &_config.beta, + _buffer_a, + _buffer_b, + output, + _config.op_A, + _config.op_B, + stride_a, + stride_b, + stride_c, + _config.batch_size, + cublasGemmAlgo_t(_config.gemm_algos[0])); + + k_buf = _buffer_a; + q_buf = _buffer_b; + } + + void Backward(int bsz, + const T* d_output, + const T* _buffer_a, + const T* _buffer_b, + cublasHandle_t handle, + T* inpGradA = nullptr, + T* inpGradB = nullptr) + { + int mb = (_config.op_A == CUBLAS_OP_T ? _config.k : _config.m); + int kb = (_config.op_A == CUBLAS_OP_T ? _config.m : _config.k); + + int stride_a = mb * _config.n; + int stride_b = _config.n * kb; + int stride_c = _config.m * _config.k; + + // B need to transpose. + cublasOperation_t op_b = (_config.op_B == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + // Calculate d_A. + cublas_strided_batched_gemm(handle, + mb, + kb, + _config.n, + &_config.alpha, + &_config.beta, + (_config.op_A == CUBLAS_OP_T ? _buffer_b : d_output), + (_config.op_A == CUBLAS_OP_T ? d_output : _buffer_b), + inpGradA, + CUBLAS_OP_N, + op_b, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[1])); + + // A need to transpose. + cublasOperation_t op_a = (_config.op_A == CUBLAS_OP_T ? CUBLAS_OP_N : CUBLAS_OP_T); + + stride_a = _config.m * _config.k; + stride_b = _config.m * _config.n; + stride_c = _config.n * _config.k; + + // Calculate d_B. + cublas_strided_batched_gemm(handle, + _config.k, + _config.n, + _config.m, + &_config.alpha, + &_config.beta, + _buffer_a, + d_output, + inpGradB, + op_a, + CUBLAS_OP_N, + stride_a, + stride_b, + stride_c, + bsz, + cublasGemmAlgo_t(_config.gemm_algos[2])); + } + + inline int GetN() const { return _config.k; } + + inline const T* GetBufferA() const { return k_buf; } + + inline const T* GetBufferB() const { return q_buf; } + + inline void SetConfig(int m, int n, int k) { _config.SetConfig(m, n, k); } + +private: + Config _config; + const T* q_buf; + const T* k_buf; +}; diff --git a/csrc/transformer/dropout_kernels.cu b/csrc/transformer/dropout_kernels.cu index 98f2ac22fddd..d1ba135f4900 100755 --- a/csrc/transformer/dropout_kernels.cu +++ b/csrc/transformer/dropout_kernels.cu @@ -1,868 +1,868 @@ -#include "custom_cuda_layers.h" - -const int unroll_factor = 4; - -__global__ void dropout_kernel(const int N, - const float ratio, - float* out, - const float* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float4 rand = curand_uniform4(&state); - uint8_t m[unroll_factor]; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - int i = j * unroll_factor; - - mask[i] = (uint8_t)m[0]; - mask[i + 1] = (uint8_t)m[1]; - mask[i + 2] = (uint8_t)m[2]; - mask[i + 3] = (uint8_t)m[3]; - - out[i] = Xdata[i] * scale * m[0]; - out[i + 1] = Xdata[i + 1] * scale * m[1]; - out[i + 2] = Xdata[i + 2] * scale * m[2]; - out[i + 3] = Xdata[i + 3] * scale * m[3]; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - out[i] = Xdata[i] * scale * m; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const float ratio, - __half* out, - const __half* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - - int idx = blockIdx.x * blockDim.x + threadIdx.x; - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - -#ifdef __STOCHASTIC_MODE__ - - const __half2 h_scale = __float2half2_rn(scale); - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_cast = reinterpret_cast(mask); - - uint32_t m_32; - uint8_t* m = reinterpret_cast(&m_32); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - __half2 mask_h[2]; - float2 mask_f[2]; - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_f = x_cast[j]; - __half2* x_h = reinterpret_cast<__half2*>(&x_f); - - float4 rand = curand_uniform4(&state); - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); - - mask_h[0] = __float22half2_rn(mask_f[0]); - mask_h[1] = __float22half2_rn(mask_f[1]); - - result_h[0] = x_h[0] * h_scale * mask_h[0]; - result_h[1] = x_h[1] * h_scale * mask_h[1]; - - out_cast[j] = result_f; - - mask_cast[j] = m_32; - } - -#else - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - const __half2* vals_half = reinterpret_cast(Xdata + i); - float2 vals_half_f[2]; - vals_half_f[0] = __half22float2(vals_half[0]); - vals_half_f[1] = __half22float2(vals_half[1]); - - uint8_t m[unroll_factor]; - float4 rand = curand_uniform4(&state); - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - out[i] = __float2half(vals_half_f[0].x * scale * m[0]); - out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); - out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); - out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); - - mask[i] = m[0]; - mask[i + 1] = m[1]; - mask[i + 2] = m[2]; - mask[i + 3] = m[3]; - } - -#endif - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - out[i] = __float2half((float)Xdata[i] * scale * m); - mask[i] = m; - } - } -} - -__global__ void dropout_kernel_bwd(const int N, - const float ratio, - const float* Xdata, - float* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - out[i] = mask[i] ? Xdata[i] * scale : 0.0; - out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; - out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; - out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; } - } -} - -__global__ void dropout_kernel_bwd(const int N, - const float ratio, - const __half* Xdata, - __half* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - -#ifdef __STOCHASTIC_MODE__ - - const __half2 h_scale = __float2half2_rn(scale); - - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_cast = reinterpret_cast(mask); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_f = x_cast[j]; - __half2* x_h = reinterpret_cast<__half2*>(&x_f); - - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - __half2 mask_h[2]; - float2 mask_f[2]; - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); - -#pragma unroll - for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = x_h[0] * h_scale * mask_h[0]; - result_h[1] = x_h[1] * h_scale * mask_h[1]; - - out_cast[j] = result_f; - } - -#else - - const __half h_scale = __float2half(scale); - const __half h_zero = __float2half(0.0); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - int i = j * unroll_factor; - - const __half2* vals_half = reinterpret_cast(Xdata + i); - - uint8_t* m = mask + i; - - float2 vals_half_f[2]; - - vals_half_f[0] = __half22float2(vals_half[0]); - vals_half_f[1] = __half22float2(vals_half[1]); - - out[i] = __float2half(vals_half_f[0].x * scale * m[0]); - out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); - out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); - out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); - } - -#endif - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - out[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout(T* out, - const T* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool bwd) -{ - assert(unroll_factor == 4); - - dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - if (dim > 512) { - block_dim.x >>= 1; - grid_dim.x <<= 1; - } - uint64_t inc = total_count / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - if (bwd) - dropout_kernel_bwd<<>>( - total_count, ratio, vals, out, mask, seed); - else - dropout_kernel<<>>( - total_count, ratio, out, vals, mask, seed); -} - -template void launch_dropout(float* out, - const float* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool); -template void launch_dropout(__half* out, - const __half* vals, - uint8_t* mask, - int total_count, - int dim, - float ratio, - cudaStream_t stream, - bool); - -__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask) -{ - CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } -} - -__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) -{ - const __half2 h_scale = __float2half2_rn(scale); - float2* x_cast = reinterpret_cast(Xdata); - uint32_t* mask_cast = reinterpret_cast(mask); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_data = x_cast[j]; - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - -#ifdef __STOCHASTIC_MODE__ - - __half2* x_data_h = reinterpret_cast<__half2*>(&x_data); - __half2 mask_h[2]; - float2 mask_f[2]; - - float* mask_f_data = &mask_f[0].x; -#pragma unroll - for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); - - mask_h[0] = __float22half2_rn(mask_f[0]); - mask_h[1] = __float22half2_rn(mask_f[1]); - - result_h[0] = x_data_h[0] * h_scale * mask_h[0]; - result_h[1] = x_data_h[1] * h_scale * mask_h[1]; - -#else - - __half* x_data_h = reinterpret_cast<__half*>(&x_data); - float2 result[2]; - - result[0].x = (float)x_data_h[0] * scale * m[0]; - result[0].y = (float)x_data_h[1] * scale * m[1]; - result[1].x = (float)x_data_h[2] * scale * m[2]; - result[1].y = (float)x_data_h[3] * scale * m[3]; - - result_h[0] = __float22half2_rn(result[0]); - result_h[1] = __float22half2_rn(result[1]); - -#endif - x_cast[j] = result_f; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) -{ - assert(unroll_factor == 4); - - const float scale = 1. / (1. - ratio); - dropout_grad_kernel<<>>(total_count, scale, vals, mask); -} - -template void launch_dropout_grad(float* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); -template void launch_dropout_grad(__half* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); - -__global__ void dropout_grad_kernel(const int N, - const float scale, - const float* Xdata, - float* out, - uint8_t* mask) -{ - CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } -} - -__global__ void dropout_grad_kernel(const int N, - const float scale, - const __half* Xdata, - __half* out, - uint8_t* mask) -{ - const float2* x_cast = reinterpret_cast(Xdata); - float2* out_cast = reinterpret_cast(out); - const uint32_t* mask_cast = reinterpret_cast(mask); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) - { - float2 x_data = x_cast[j]; - uint32_t m_32 = mask_cast[j]; - uint8_t* m = (uint8_t*)&m_32; - - __half* x_data_h = reinterpret_cast<__half*>(&x_data); - float2 result[2]; - - result[0].x = (float)x_data_h[0] * scale * m[0]; - result[0].y = (float)x_data_h[1] * scale * m[1]; - result[1].x = (float)x_data_h[2] * scale * m[2]; - result[1].y = (float)x_data_h[3] * scale * m[3]; - - result_h[0] = __float22half2_rn(result[0]); - result_h[1] = __float22half2_rn(result[1]); - - out_cast[j] = result_f; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - for (int i = high_index; i < N; i++) { - out[i] = __float2half((float)Xdata[i] * scale * mask[i]); - } - } -} - -template -void launch_dropout_grad(T* vals_out, - const T* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - const float scale = 1. / (1. - ratio); - dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask); -} -template void launch_dropout_grad(float*, - const float* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); -template void launch_dropout_grad(__half*, - const __half* vals, - uint8_t* mask, - int total_count, - float ratio, - cudaStream_t stream); - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const float* bias, - float* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float4* Xdata_cast = reinterpret_cast(Xdata); - uint32_t* mask_32 = reinterpret_cast(mask); - const float4* bias_cast = reinterpret_cast(bias); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float4 x_data = Xdata_cast[j]; - float4 b_data = bias_cast[j % (dim / unroll_factor)]; - - x_data.x += b_data.x; - x_data.y += b_data.y; - x_data.z += b_data.z; - x_data.w += b_data.w; - - x_data.x = x_data.x * scale * m[0]; - x_data.y = x_data.y * scale * m[1]; - x_data.z = x_data.z * scale * m[2]; - x_data.w = x_data.w * scale * m[3]; - - mask_32[j] = m_32; - Xdata_cast[j] = x_data; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = Xdata[i] + bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - Xdata[i] = x_data * scale * m; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const __half* bias, - __half* Xdata, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float2* Xdata_cast = reinterpret_cast(Xdata); - uint32_t* mask_32 = reinterpret_cast(mask); - const float2* bias_cast = reinterpret_cast(bias); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - float2 data_f; - __half2* data_h = reinterpret_cast<__half2*>(&data_f); - - float2 bias_f; - __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); - - data_f = Xdata_cast[j]; - bias_f = bias_cast[j % (dim / unroll_factor)]; - - float2 data_h_0 = __half22float2(data_h[0]); - float2 data_h_1 = __half22float2(data_h[1]); - - float2 bias_h_0 = __half22float2(bias_h[0]); - float2 bias_h_1 = __half22float2(bias_h[1]); - - data_h_0.x += bias_h_0.x; - data_h_0.y += bias_h_0.y; - data_h_1.x += bias_h_1.x; - data_h_1.y += bias_h_1.y; - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - data_h_0.x = __float2half(data_h_0.x * scale * m[0]); - data_h_0.y = __float2half(data_h_0.y * scale * m[1]); - data_h_1.x = __float2half(data_h_1.x * scale * m[2]); - data_h_1.y = __float2half(data_h_1.y * scale * m[3]); - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = __float22half2_rn(data_h_0); - result_h[1] = __float22half2_rn(data_h_1); - - Xdata_cast[j] = result_f; - mask_32[j] = m_32; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = (float)Xdata[i] + (float)bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - Xdata[i] = __float2half(x_data * scale * m); - mask[i] = m; - } - } -} - -template -void launch_dropout(T* out, - const T* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - int total_count = batch * dim / unroll_factor; - - dim3 grid_dim = DS_GET_BLOCKS(total_count); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - - dropout_kernel<<>>( - total_count, dim, ratio, bias, out, mask, seed); -} - -template void launch_dropout(float*, - const float* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); -template void launch_dropout(__half*, - const __half* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const float* input, - const float* residual, - const float* bias, - float* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float4* out_cast = reinterpret_cast(out); - uint32_t* mask_32 = reinterpret_cast(mask); - - const float4* bias_cast = reinterpret_cast(bias); - const float4* residual_cast = reinterpret_cast(residual); - const float4* input_cast = reinterpret_cast(input); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - float4 out_data; - float4 b_data = bias_cast[j % (dim / unroll_factor)]; - float4 res_data = residual_cast[j]; - float4 inp_data = input_cast[j]; - - out_data.x = (b_data.x + inp_data.x); - out_data.y = (b_data.y + inp_data.y); - out_data.z = (b_data.z + inp_data.z); - out_data.w = (b_data.w + inp_data.w); - - out_data.x = out_data.x * scale * m[0]; - out_data.y = out_data.y * scale * m[1]; - out_data.z = out_data.z * scale * m[2]; - out_data.w = out_data.w * scale * m[3]; - - out_data.x += res_data.x; - out_data.y += res_data.y; - out_data.z += res_data.z; - out_data.w += res_data.w; - - mask_32[j] = m_32; - out_cast[j] = out_data; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = input[i] + bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - x_data = x_data * scale * m; - x_data += residual[i]; - - out[i] = x_data; - mask[i] = m; - } - } -} - -__global__ void dropout_kernel(const int N, - const int dim, - const float ratio, - const __half* input, - const __half* residual, - const __half* bias, - __half* out, - uint8_t* mask, - std::pair seed) -{ - const float scale = 1. / (1. - ratio); - int idx = blockIdx.x * blockDim.x + threadIdx.x; - int tid = threadIdx.x % (dim / unroll_factor); - - curandStatePhilox4_32_10_t state; - curand_init(seed.first, idx, seed.second, &state); - - float2* out_cast = reinterpret_cast(out); - uint32_t* mask_32 = reinterpret_cast(mask); - - const float2* bias_cast = reinterpret_cast(bias); - const float2* residual_cast = reinterpret_cast(residual); - const float2* input_cast = reinterpret_cast(input); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 rand = curand_uniform4(&state); - - float2 data_f; - __half2* data_h = reinterpret_cast<__half2*>(&data_f); - - float2 bias_f; - __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); - - float2 residual_f; - __half2* residual_h = reinterpret_cast<__half2*>(&residual_f); - - float2 input_f; - __half2* input_h = reinterpret_cast<__half2*>(&input_f); - - bias_f = bias_cast[j % (dim / unroll_factor)]; - residual_f = residual_cast[j]; - input_f = input_cast[j]; - - float2 data_h_0 = __half22float2(data_h[0]); - float2 data_h_1 = __half22float2(data_h[1]); - - float2 bias_h_0 = __half22float2(bias_h[0]); - float2 bias_h_1 = __half22float2(bias_h[1]); - - float2 residual_h_0 = __half22float2(residual_h[0]); - float2 residual_h_1 = __half22float2(residual_h[1]); - - float2 input_h_0 = __half22float2(input_h[0]); - float2 input_h_1 = __half22float2(input_h[1]); - - data_h_0.x = (bias_h_0.x + input_h_0.x); - data_h_0.y = (bias_h_0.y + input_h_0.y); - data_h_1.x = (bias_h_1.x + input_h_1.x); - data_h_1.y = (bias_h_1.y + input_h_1.y); - - uint32_t m_32; - uint8_t* m = (uint8_t*)&m_32; - - m[0] = (uint8_t)(rand.x > ratio); - m[1] = (uint8_t)(rand.y > ratio); - m[2] = (uint8_t)(rand.z > ratio); - m[3] = (uint8_t)(rand.w > ratio); - - data_h_0.x = __float2half(data_h_0.x * scale * m[0]); - data_h_0.y = __float2half(data_h_0.y * scale * m[1]); - data_h_1.x = __float2half(data_h_1.x * scale * m[2]); - data_h_1.y = __float2half(data_h_1.y * scale * m[3]); - - data_h_0.x += residual_h_0.x; - data_h_0.y += residual_h_0.y; - data_h_1.x += residual_h_1.x; - data_h_1.y += residual_h_1.y; - - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - result_h[0] = __float22half2_rn(data_h_0); - result_h[1] = __float22half2_rn(data_h_1); - - out_cast[j] = result_f; - mask_32[j] = m_32; - } - int high_index = - ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; - if (N > high_index) { - float4 rand = curand_uniform4(&state); - float* rand_data = &(rand.x); - int k = 0; - for (int i = high_index; i < N; i++) { - float x_data = (float)input[i] + (float)bias[i % dim]; - uint8_t m = (uint8_t)(rand_data[k++] > ratio); - x_data = x_data * scale * m; - x_data += (float)residual[i]; - - out[i] = __float2half(x_data); - mask[i] = m; - } - } -} - -template -void launch_dropout(T* out, - const T* input, - const T* residual, - const T* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream) -{ - assert(unroll_factor == 4); - - int total_count = batch * dim / unroll_factor; - dim3 grid_dim = DS_GET_BLOCKS(total_count); - dim3 block_dim = DS_CUDA_NUM_THREADS; - - uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; - std::pair seed = Context::Instance().IncrementOffset(inc); - - dropout_kernel<<>>( - total_count, dim, ratio, input, residual, bias, out, mask, seed); -} - -template void launch_dropout(float*, - const float*, - const float* residual, - const float* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); -template void launch_dropout(__half*, - const __half*, - const __half* residual, - const __half* bias, - uint8_t* mask, - int batch, - int dim, - float ratio, - cudaStream_t stream); +#include "custom_cuda_layers.h" + +const int unroll_factor = 4; + +__global__ void dropout_kernel(const int N, + const float ratio, + float* out, + const float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float4 rand = curand_uniform4(&state); + uint8_t m[unroll_factor]; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + int i = j * unroll_factor; + + mask[i] = (uint8_t)m[0]; + mask[i + 1] = (uint8_t)m[1]; + mask[i + 2] = (uint8_t)m[2]; + mask[i + 3] = (uint8_t)m[3]; + + out[i] = Xdata[i] * scale * m[0]; + out[i + 1] = Xdata[i + 1] * scale * m[1]; + out[i + 2] = Xdata[i + 2] * scale * m[2]; + out[i + 3] = Xdata[i + 3] * scale * m[3]; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = Xdata[i] * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const float ratio, + __half* out, + const __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + uint32_t m_32; + uint8_t* m = reinterpret_cast(&m_32); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + __half2 mask_h[2]; + float2 mask_f[2]; + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + float4 rand = curand_uniform4(&state); + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + + mask_cast[j] = m_32; + } + +#else + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + float2 vals_half_f[2]; + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + uint8_t m[unroll_factor]; + float4 rand = curand_uniform4(&state); + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + + mask[i] = m[0]; + mask[i + 1] = m[1]; + mask[i + 2] = m[2]; + mask[i + 3] = m[3]; + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + out[i] = __float2half((float)Xdata[i] * scale * m); + mask[i] = m; + } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const float* Xdata, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + out[i] = mask[i] ? Xdata[i] * scale : 0.0; + out[i + 1] = mask[i + 1] ? Xdata[i + 1] * scale : 0.0; + out[i + 2] = mask[i + 2] ? Xdata[i + 2] * scale : 0.0; + out[i + 3] = mask[i + 3] ? Xdata[i + 3] * scale : 0.0; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { out[i] = mask[i] ? Xdata[i] * scale : 0.0; } + } +} + +__global__ void dropout_kernel_bwd(const int N, + const float ratio, + const __half* Xdata, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + +#ifdef __STOCHASTIC_MODE__ + + const __half2 h_scale = __float2half2_rn(scale); + + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_f = x_cast[j]; + __half2* x_h = reinterpret_cast<__half2*>(&x_f); + + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) mask_f_data[i] = (float)(m[i]); + +#pragma unroll + for (int i = 0; i < 2; i++) mask_h[i] = __float22half2_rn(mask_f[i]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = x_h[0] * h_scale * mask_h[0]; + result_h[1] = x_h[1] * h_scale * mask_h[1]; + + out_cast[j] = result_f; + } + +#else + + const __half h_scale = __float2half(scale); + const __half h_zero = __float2half(0.0); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + int i = j * unroll_factor; + + const __half2* vals_half = reinterpret_cast(Xdata + i); + + uint8_t* m = mask + i; + + float2 vals_half_f[2]; + + vals_half_f[0] = __half22float2(vals_half[0]); + vals_half_f[1] = __half22float2(vals_half[1]); + + out[i] = __float2half(vals_half_f[0].x * scale * m[0]); + out[i + 1] = __float2half(vals_half_f[0].y * scale * m[1]); + out[i + 2] = __float2half(vals_half_f[1].x * scale * m[2]); + out[i + 3] = __float2half(vals_half_f[1].y * scale * m[3]); + } + +#endif + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout(T* out, + const T* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool bwd) +{ + assert(unroll_factor == 4); + + dim3 grid_dim = DS_GET_BLOCKS(total_count / unroll_factor); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + if (dim > 512) { + block_dim.x >>= 1; + grid_dim.x <<= 1; + } + uint64_t inc = total_count / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + if (bwd) + dropout_kernel_bwd<<>>( + total_count, ratio, vals, out, mask, seed); + else + dropout_kernel<<>>( + total_count, ratio, out, vals, mask, seed); +} + +template void launch_dropout(float* out, + const float* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); +template void launch_dropout(__half* out, + const __half* vals, + uint8_t* mask, + int total_count, + int dim, + float ratio, + cudaStream_t stream, + bool); + +__global__ void dropout_grad_kernel(const int N, const float scale, float* Xdata, uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { Xdata[i] *= scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, const float scale, __half* Xdata, uint8_t* mask) +{ + const __half2 h_scale = __float2half2_rn(scale); + float2* x_cast = reinterpret_cast(Xdata); + uint32_t* mask_cast = reinterpret_cast(mask); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + +#ifdef __STOCHASTIC_MODE__ + + __half2* x_data_h = reinterpret_cast<__half2*>(&x_data); + __half2 mask_h[2]; + float2 mask_f[2]; + + float* mask_f_data = &mask_f[0].x; +#pragma unroll + for (int i = 0; i < unroll_factor; i++) *(mask_f_data++) = (float)(m[i]); + + mask_h[0] = __float22half2_rn(mask_f[0]); + mask_h[1] = __float22half2_rn(mask_f[1]); + + result_h[0] = x_data_h[0] * h_scale * mask_h[0]; + result_h[1] = x_data_h[1] * h_scale * mask_h[1]; + +#else + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + +#endif + x_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + Xdata[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals, uint8_t* mask, int total_count, float ratio, cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, mask); +} + +template void launch_dropout_grad(float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const float* Xdata, + float* out, + uint8_t* mask) +{ + CUDA_1D_KERNEL_LOOP(i, N) { out[i] = Xdata[i] * scale * mask[i]; } +} + +__global__ void dropout_grad_kernel(const int N, + const float scale, + const __half* Xdata, + __half* out, + uint8_t* mask) +{ + const float2* x_cast = reinterpret_cast(Xdata); + float2* out_cast = reinterpret_cast(out); + const uint32_t* mask_cast = reinterpret_cast(mask); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + CUDA_1D_KERNEL_LOOP(j, N / unroll_factor) + { + float2 x_data = x_cast[j]; + uint32_t m_32 = mask_cast[j]; + uint8_t* m = (uint8_t*)&m_32; + + __half* x_data_h = reinterpret_cast<__half*>(&x_data); + float2 result[2]; + + result[0].x = (float)x_data_h[0] * scale * m[0]; + result[0].y = (float)x_data_h[1] * scale * m[1]; + result[1].x = (float)x_data_h[2] * scale * m[2]; + result[1].y = (float)x_data_h[3] * scale * m[3]; + + result_h[0] = __float22half2_rn(result[0]); + result_h[1] = __float22half2_rn(result[1]); + + out_cast[j] = result_f; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + for (int i = high_index; i < N; i++) { + out[i] = __float2half((float)Xdata[i] * scale * mask[i]); + } + } +} + +template +void launch_dropout_grad(T* vals_out, + const T* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + const float scale = 1. / (1. - ratio); + dropout_grad_kernel<<>>(total_count, scale, vals, vals_out, mask); +} +template void launch_dropout_grad(float*, + const float* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); +template void launch_dropout_grad(__half*, + const __half* vals, + uint8_t* mask, + int total_count, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* bias, + float* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float4* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 x_data = Xdata_cast[j]; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + + x_data.x += b_data.x; + x_data.y += b_data.y; + x_data.z += b_data.z; + x_data.w += b_data.w; + + x_data.x = x_data.x * scale * m[0]; + x_data.y = x_data.y * scale * m[1]; + x_data.z = x_data.z * scale * m[2]; + x_data.w = x_data.w * scale * m[3]; + + mask_32[j] = m_32; + Xdata_cast[j] = x_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = Xdata[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = x_data * scale * m; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* bias, + __half* Xdata, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* Xdata_cast = reinterpret_cast(Xdata); + uint32_t* mask_32 = reinterpret_cast(mask); + const float2* bias_cast = reinterpret_cast(bias); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + data_f = Xdata_cast[j]; + bias_f = bias_cast[j % (dim / unroll_factor)]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + data_h_0.x += bias_h_0.x; + data_h_0.y += bias_h_0.y; + data_h_1.x += bias_h_1.x; + data_h_1.y += bias_h_1.y; + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + Xdata_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)Xdata[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + Xdata[i] = __float2half(x_data * scale * m); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const float* input, + const float* residual, + const float* bias, + float* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float4* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float4* bias_cast = reinterpret_cast(bias); + const float4* residual_cast = reinterpret_cast(residual); + const float4* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + float4 out_data; + float4 b_data = bias_cast[j % (dim / unroll_factor)]; + float4 res_data = residual_cast[j]; + float4 inp_data = input_cast[j]; + + out_data.x = (b_data.x + inp_data.x); + out_data.y = (b_data.y + inp_data.y); + out_data.z = (b_data.z + inp_data.z); + out_data.w = (b_data.w + inp_data.w); + + out_data.x = out_data.x * scale * m[0]; + out_data.y = out_data.y * scale * m[1]; + out_data.z = out_data.z * scale * m[2]; + out_data.w = out_data.w * scale * m[3]; + + out_data.x += res_data.x; + out_data.y += res_data.y; + out_data.z += res_data.z; + out_data.w += res_data.w; + + mask_32[j] = m_32; + out_cast[j] = out_data; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = input[i] + bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += residual[i]; + + out[i] = x_data; + mask[i] = m; + } + } +} + +__global__ void dropout_kernel(const int N, + const int dim, + const float ratio, + const __half* input, + const __half* residual, + const __half* bias, + __half* out, + uint8_t* mask, + std::pair seed) +{ + const float scale = 1. / (1. - ratio); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int tid = threadIdx.x % (dim / unroll_factor); + + curandStatePhilox4_32_10_t state; + curand_init(seed.first, idx, seed.second, &state); + + float2* out_cast = reinterpret_cast(out); + uint32_t* mask_32 = reinterpret_cast(mask); + + const float2* bias_cast = reinterpret_cast(bias); + const float2* residual_cast = reinterpret_cast(residual); + const float2* input_cast = reinterpret_cast(input); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 rand = curand_uniform4(&state); + + float2 data_f; + __half2* data_h = reinterpret_cast<__half2*>(&data_f); + + float2 bias_f; + __half2* bias_h = reinterpret_cast<__half2*>(&bias_f); + + float2 residual_f; + __half2* residual_h = reinterpret_cast<__half2*>(&residual_f); + + float2 input_f; + __half2* input_h = reinterpret_cast<__half2*>(&input_f); + + bias_f = bias_cast[j % (dim / unroll_factor)]; + residual_f = residual_cast[j]; + input_f = input_cast[j]; + + float2 data_h_0 = __half22float2(data_h[0]); + float2 data_h_1 = __half22float2(data_h[1]); + + float2 bias_h_0 = __half22float2(bias_h[0]); + float2 bias_h_1 = __half22float2(bias_h[1]); + + float2 residual_h_0 = __half22float2(residual_h[0]); + float2 residual_h_1 = __half22float2(residual_h[1]); + + float2 input_h_0 = __half22float2(input_h[0]); + float2 input_h_1 = __half22float2(input_h[1]); + + data_h_0.x = (bias_h_0.x + input_h_0.x); + data_h_0.y = (bias_h_0.y + input_h_0.y); + data_h_1.x = (bias_h_1.x + input_h_1.x); + data_h_1.y = (bias_h_1.y + input_h_1.y); + + uint32_t m_32; + uint8_t* m = (uint8_t*)&m_32; + + m[0] = (uint8_t)(rand.x > ratio); + m[1] = (uint8_t)(rand.y > ratio); + m[2] = (uint8_t)(rand.z > ratio); + m[3] = (uint8_t)(rand.w > ratio); + + data_h_0.x = __float2half(data_h_0.x * scale * m[0]); + data_h_0.y = __float2half(data_h_0.y * scale * m[1]); + data_h_1.x = __float2half(data_h_1.x * scale * m[2]); + data_h_1.y = __float2half(data_h_1.y * scale * m[3]); + + data_h_0.x += residual_h_0.x; + data_h_0.y += residual_h_0.y; + data_h_1.x += residual_h_1.x; + data_h_1.y += residual_h_1.y; + + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + result_h[0] = __float22half2_rn(data_h_0); + result_h[1] = __float22half2_rn(data_h_1); + + out_cast[j] = result_f; + mask_32[j] = m_32; + } + int high_index = + ((((N / unroll_factor) - 1) / blockDim.x + 1) * (unroll_factor * blockDim.x)) + threadIdx.x; + if (N > high_index) { + float4 rand = curand_uniform4(&state); + float* rand_data = &(rand.x); + int k = 0; + for (int i = high_index; i < N; i++) { + float x_data = (float)input[i] + (float)bias[i % dim]; + uint8_t m = (uint8_t)(rand_data[k++] > ratio); + x_data = x_data * scale * m; + x_data += (float)residual[i]; + + out[i] = __float2half(x_data); + mask[i] = m; + } + } +} + +template +void launch_dropout(T* out, + const T* input, + const T* residual, + const T* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream) +{ + assert(unroll_factor == 4); + + int total_count = batch * dim / unroll_factor; + dim3 grid_dim = DS_GET_BLOCKS(total_count); + dim3 block_dim = DS_CUDA_NUM_THREADS; + + uint64_t inc = (batch * dim) / grid_dim.x / block_dim.x; + std::pair seed = Context::Instance().IncrementOffset(inc); + + dropout_kernel<<>>( + total_count, dim, ratio, input, residual, bias, out, mask, seed); +} + +template void launch_dropout(float*, + const float*, + const float* residual, + const float* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); +template void launch_dropout(__half*, + const __half*, + const __half* residual, + const __half* bias, + uint8_t* mask, + int batch, + int dim, + float ratio, + cudaStream_t stream); diff --git a/csrc/transformer/gelu_kernels.cu b/csrc/transformer/gelu_kernels.cu index 12048006266e..cea337b064ac 100644 --- a/csrc/transformer/gelu_kernels.cu +++ b/csrc/transformer/gelu_kernels.cu @@ -1,330 +1,330 @@ -#include "custom_cuda_layers.h" - -inline __device__ float gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); -} - -inline __device__ float d_gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - - float x2mul = x * x * mul_param; - float tan_h = tanhf(sqrt_param * (x + x * x2mul)); - float dg1 = 0.5f * (1.0f + tan_h); - float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); - float dg3 = dg2 * 3 * x2mul; - return (dg1 + dg2 + dg3); -} - -/* -Fused bias add with GELU - -Loads a vector of 4 elements each iteration, for stride -iterations. It was written with the intention to launch 256 thread -threadblocks, so to launch for bert-large, we would set ITERATIONS -to 4. This is currently done automatically as a heuristic, setting -the number of iterations as blocks of 1024. - -For FP16, the values are loaded from memory as __half, but converted -to FP32 for the arithmetic itself, to prevent numerous overflow on -the intermediate hyperbolic tangent, since there's no intrinsic -that computes it directly. -*/ - -__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float4* input_cast = reinterpret_cast(input); - float4* vals_cast = reinterpret_cast(vals); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 data = input_cast[row * row_stride + i * loop_stride + id]; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - vals_cast[row * row_stride + i * loop_stride + id] = data; - } - } -} - -__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float2* input_cast = reinterpret_cast(input); - float2* vals_cast = reinterpret_cast(vals); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; - } - } -#endif -} - -__global__ void fused_bias_gelu(const float* input, - const float* bias, - float* vals, - int row_stride, - int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float4* input_cast = reinterpret_cast(input); - float4* vals_cast = reinterpret_cast(vals); - const float4* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 data = input_cast[row * row_stride + i * loop_stride + id]; - float4 bias_data = bias_cast[i * loop_stride + id]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - vals_cast[row * row_stride + i * loop_stride + id] = data; - } - } -} - -__global__ void fused_bias_gelu(const __half* input, - const __half* bias, - __half* vals, - int row_stride, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - const float2* input_cast = reinterpret_cast(input); - float2* vals_cast = reinterpret_cast(vals); - const float2* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; - float2 bias_vec = bias_cast[i * loop_stride + id]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; - } - } -#endif -} - -__global__ void d_gelu_func(float* d_output, - const float* gelu_input, - const float* bias, - int row_stride, - int iterations) -{ - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - float4* d_output_cast = reinterpret_cast(d_output); - const float4* gelu_input_cast = reinterpret_cast(gelu_input); - const float4* bias_cast = reinterpret_cast(bias); - - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; - float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; - float4 bias_data = bias_cast[i * loop_stride + id]; - - gelu_input_data.x += bias_data.x; - gelu_input_data.y += bias_data.y; - gelu_input_data.z += bias_data.z; - gelu_input_data.w += bias_data.w; - - output_data.x *= d_gelu(gelu_input_data.x); - output_data.y *= d_gelu(gelu_input_data.y); - output_data.z *= d_gelu(gelu_input_data.z); - output_data.w *= d_gelu(gelu_input_data.w); - - d_output_cast[row * row_stride + i * loop_stride + id] = output_data; - } - } -} - -__global__ void d_gelu_func(__half* d_output, - const __half* gelu_input, - const __half* bias, - int row_stride, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - int row = blockIdx.x; - int id = threadIdx.x; - int loop_stride = blockDim.x; - - float2* d_output_cast = reinterpret_cast(d_output); - const float2* gelu_input_cast = reinterpret_cast(gelu_input); - const float2* bias_cast = reinterpret_cast(bias); - -#pragma unroll - for (int i = 0; i < iterations; i++) { - if (i * loop_stride + id < row_stride) { - float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; - float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; - float2 bias_vec = bias_cast[i * loop_stride + id]; - - __half2* output_data_half = reinterpret_cast<__half2*>(&output_data); - __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 output_half_0 = __half22float2(output_data_half[0]); - float2 output_half_1 = __half22float2(output_data_half[1]); - - float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]); - float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]); - - float2 bias_half_0 = __half22float2(bias_half[0]); - float2 bias_half_1 = __half22float2(bias_half[1]); - - gelu_input_half_0.x += bias_half_0.x; - gelu_input_half_0.y += bias_half_0.y; - gelu_input_half_1.x += bias_half_1.x; - gelu_input_half_1.y += bias_half_1.y; - - output_half_0.x *= d_gelu(gelu_input_half_0.x); - output_half_0.y *= d_gelu(gelu_input_half_0.y); - output_half_1.x *= d_gelu(gelu_input_half_1.x); - output_half_1.y *= d_gelu(gelu_input_half_1.y); - - float2 result; - __half2* result_half2 = reinterpret_cast<__half2*>(&result); - - result_half2[0] = __float22half2_rn(output_half_0); - result_half2[1] = __float22half2_rn(output_half_1); - - d_output_cast[row * row_stride + i * loop_stride + id] = result; - } - } -#endif -} - -template -void launch_bias_gelu(const T* input, - const T* bias, - T* output, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - fused_bias_gelu<<>>( - input, bias, output, intermediate_size / 4, iterations); -} - -template -void launch_gelu(const T* input, - T* output, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - gelu_kernel<<>>( - input, output, intermediate_size / 4, iterations); -} - -template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); -template void launch_bias_gelu<__half>(const __half*, - const __half*, - __half*, - int, - int, - cudaStream_t); - -template void launch_gelu(const float*, float*, int, int, cudaStream_t); -template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); - -template -void launch_d_gelu(T* d_output, - const T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int iterations = (intermediate_size + 1023) / 1024; - int threads = (intermediate_size - 1) / (iterations * 4) + 1; - dim3 block_dims(threads); - dim3 grid_dims(batch_size); - - d_gelu_func<<>>( - d_output, input, bias, intermediate_size / 4, iterations); -} - -template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); -template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); +#include "custom_cuda_layers.h" + +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +inline __device__ float d_gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + + float x2mul = x * x * mul_param; + float tan_h = tanhf(sqrt_param * (x + x * x2mul)); + float dg1 = 0.5f * (1.0f + tan_h); + float dg2 = x * 0.5f * sqrt_param * (1 - tan_h * tan_h); + float dg3 = dg2 * 3 * x2mul; + return (dg1 + dg2 + dg3); +} + +/* +Fused bias add with GELU + +Loads a vector of 4 elements each iteration, for stride +iterations. It was written with the intention to launch 256 thread +threadblocks, so to launch for bert-large, we would set ITERATIONS +to 4. This is currently done automatically as a heuristic, setting +the number of iterations as blocks of 1024. + +For FP16, the values are loaded from memory as __half, but converted +to FP32 for the arithmetic itself, to prevent numerous overflow on +the intermediate hyperbolic tangent, since there's no intrinsic +that computes it directly. +*/ + +__global__ void gelu_kernel(const float* input, float* vals, int row_stride, int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void gelu_kernel(const __half* input, __half* vals, int row_stride, int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void fused_bias_gelu(const float* input, + const float* bias, + float* vals, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float4* input_cast = reinterpret_cast(input); + float4* vals_cast = reinterpret_cast(vals); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 data = input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + vals_cast[row * row_stride + i * loop_stride + id] = data; + } + } +} + +__global__ void fused_bias_gelu(const __half* input, + const __half* bias, + __half* vals, + int row_stride, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + const float2* input_cast = reinterpret_cast(input); + float2* vals_cast = reinterpret_cast(vals); + const float2* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 vals_vec = input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + vals_cast[row * row_stride + i * loop_stride + id] = vals_vec; + } + } +#endif +} + +__global__ void d_gelu_func(float* d_output, + const float* gelu_input, + const float* bias, + int row_stride, + int iterations) +{ + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float4* d_output_cast = reinterpret_cast(d_output); + const float4* gelu_input_cast = reinterpret_cast(gelu_input); + const float4* bias_cast = reinterpret_cast(bias); + + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float4 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float4 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float4 bias_data = bias_cast[i * loop_stride + id]; + + gelu_input_data.x += bias_data.x; + gelu_input_data.y += bias_data.y; + gelu_input_data.z += bias_data.z; + gelu_input_data.w += bias_data.w; + + output_data.x *= d_gelu(gelu_input_data.x); + output_data.y *= d_gelu(gelu_input_data.y); + output_data.z *= d_gelu(gelu_input_data.z); + output_data.w *= d_gelu(gelu_input_data.w); + + d_output_cast[row * row_stride + i * loop_stride + id] = output_data; + } + } +} + +__global__ void d_gelu_func(__half* d_output, + const __half* gelu_input, + const __half* bias, + int row_stride, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + int row = blockIdx.x; + int id = threadIdx.x; + int loop_stride = blockDim.x; + + float2* d_output_cast = reinterpret_cast(d_output); + const float2* gelu_input_cast = reinterpret_cast(gelu_input); + const float2* bias_cast = reinterpret_cast(bias); + +#pragma unroll + for (int i = 0; i < iterations; i++) { + if (i * loop_stride + id < row_stride) { + float2 output_data = d_output_cast[row * row_stride + i * loop_stride + id]; + float2 gelu_input_data = gelu_input_cast[row * row_stride + i * loop_stride + id]; + float2 bias_vec = bias_cast[i * loop_stride + id]; + + __half2* output_data_half = reinterpret_cast<__half2*>(&output_data); + __half2* gelu_input_data_half = reinterpret_cast<__half2*>(&gelu_input_data); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 output_half_0 = __half22float2(output_data_half[0]); + float2 output_half_1 = __half22float2(output_data_half[1]); + + float2 gelu_input_half_0 = __half22float2(gelu_input_data_half[0]); + float2 gelu_input_half_1 = __half22float2(gelu_input_data_half[1]); + + float2 bias_half_0 = __half22float2(bias_half[0]); + float2 bias_half_1 = __half22float2(bias_half[1]); + + gelu_input_half_0.x += bias_half_0.x; + gelu_input_half_0.y += bias_half_0.y; + gelu_input_half_1.x += bias_half_1.x; + gelu_input_half_1.y += bias_half_1.y; + + output_half_0.x *= d_gelu(gelu_input_half_0.x); + output_half_0.y *= d_gelu(gelu_input_half_0.y); + output_half_1.x *= d_gelu(gelu_input_half_1.x); + output_half_1.y *= d_gelu(gelu_input_half_1.y); + + float2 result; + __half2* result_half2 = reinterpret_cast<__half2*>(&result); + + result_half2[0] = __float22half2_rn(output_half_0); + result_half2[1] = __float22half2_rn(output_half_1); + + d_output_cast[row * row_stride + i * loop_stride + id] = result; + } + } +#endif +} + +template +void launch_bias_gelu(const T* input, + const T* bias, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + fused_bias_gelu<<>>( + input, bias, output, intermediate_size / 4, iterations); +} + +template +void launch_gelu(const T* input, + T* output, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + gelu_kernel<<>>( + input, output, intermediate_size / 4, iterations); +} + +template void launch_bias_gelu(const float*, const float*, float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(const __half*, + const __half*, + __half*, + int, + int, + cudaStream_t); + +template void launch_gelu(const float*, float*, int, int, cudaStream_t); +template void launch_gelu<__half>(const __half*, __half*, int, int, cudaStream_t); + +template +void launch_d_gelu(T* d_output, + const T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int iterations = (intermediate_size + 1023) / 1024; + int threads = (intermediate_size - 1) / (iterations * 4) + 1; + dim3 block_dims(threads); + dim3 grid_dims(batch_size); + + d_gelu_func<<>>( + d_output, input, bias, intermediate_size / 4, iterations); +} + +template void launch_d_gelu(float*, const float*, const float*, int, int, cudaStream_t); +template void launch_d_gelu<__half>(__half*, const __half*, const __half*, int, int, cudaStream_t); diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index 7d318773f354..1eaa94e1e71a 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -1,411 +1,411 @@ -#include "general_kernels.h" - -namespace cg = cooperative_groups; - -template -__global__ void column_sum_reduce(const T* __restrict__ inp, - T* __restrict__ out, - int rows, - int width) -{ - __shared__ float tile[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - - int y_stride = width * TILE_DIM; - - float localSum = 0; - - // Loop across matrix height - if (idx < width) { - int offset = threadIdx.y * width + idx; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - localSum += (float)inp[offset]; - offset += y_stride; - } - } - - tile[threadIdx.x][threadIdx.y] = localSum; - - __syncthreads(); - - // Sum the shared buffer. - float sum = tile[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i); - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (pos < width) out[pos] = sum; - } -} - -template -void launch_fuse_transpose_bias_kernel(const T* inp, - T* out, - int rows, - int cols, - cudaStream_t stream); - -template <> -void launch_fuse_transpose_bias_kernel(const float* inp, - float* out, - int rows, - int cols, - cudaStream_t stream) -{ - // assert(rows % TILE_DIM == 0); - // assert(cols % TILE_DIM == 0); - - dim3 grid_dim((cols - 1) / TILE_DIM + 1); - dim3 block_dim(TILE_DIM, TILE_DIM); - - column_sum_reduce<<>>(inp, out, rows, cols); -} - -template <> -void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, - __half* out, - int rows, - int cols, - cudaStream_t stream) -{ - // assert(rows % TILE_DIM == 0); - // assert(cols % TILE_DIM == 0); - - dim3 grid_dim((cols - 1) / TILE_DIM + 1); - dim3 block_dim(TILE_DIM, TILE_DIM); - - column_sum_reduce<__half><<>>(inp, out, rows, cols); -} - -__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2) -{ - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - float4* out_4 = reinterpret_cast(out); - - CUDA_1D_KERNEL_LOOP(j, N) - { - float4 val; - float4 inp1_reg = inp1_4[j]; - float4 inp2_reg = inp2_4[j]; - - val.x = inp1_reg.x + inp2_reg.x; - val.y = inp1_reg.y + inp2_reg.y; - val.z = inp1_reg.z + inp2_reg.z; - val.w = inp1_reg.w + inp2_reg.w; - - out_4[j] = val; - } -} - -__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2) -{ - float2 inp1_4; - float2 inp2_4; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - - CUDA_1D_KERNEL_LOOP(j, N) - { - inp1_4 = inp1_arr[j]; - inp2_4 = inp2_arr[j]; - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - inp1_h_f_0.x += inp2_h_f_0.x; - inp1_h_f_0.y += inp2_h_f_0.y; - inp1_h_f_1.x += inp2_h_f_1.x; - inp1_h_f_1.y += inp2_h_f_1.y; - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[j] = val_f; - } -} - -template <> -void launch_fused_add2(float* out, - const float* inp1, - const float* inp2, - int batch_size, - int seq_length, - int hidden_dim, - cudaStream_t& stream) -{ - int total_count = batch_size * seq_length * hidden_dim / 4; - dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); - - fused_add2_kernel<<>>(total_count, out, inp1, inp2); -} - -template <> -void launch_fused_add2<__half>(__half* out, - const __half* inp1, - const __half* inp2, - int batch_size, - int seq_length, - int hidden_dim, - cudaStream_t& stream) -{ - int total_count = batch_size * seq_length * hidden_dim / 4; - dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); - - fused_add2_kernel<<>>(total_count, out, inp1, inp2); -} - -__global__ void fused_add3_kernel(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - const float4* inp3_4 = reinterpret_cast(inp3); - - float4* out_4 = reinterpret_cast(out); - - float4 val; - float4 inp1_reg = inp1_4[row * row_stride + id]; - float4 inp2_reg = inp2_4[row * row_stride + id]; - float4 inp3_reg = inp3_4[row * row_stride + id]; - - val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x; - val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y; - val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z; - val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w; - - out_4[row * row_stride + id] = val; -} - -__global__ void fused_add3_kernel(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - const float2* inp3_arr = reinterpret_cast(inp3); - - float2 inp1_4 = inp1_arr[row * row_stride + id]; - float2 inp2_4 = inp2_arr[row * row_stride + id]; - float2 inp3_4 = inp3_arr[row * row_stride + id]; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - float2 inp3_h_f_0 = __half22float2(inp3_h[0]); - float2 inp3_h_f_1 = __half22float2(inp3_h[1]); - - inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x); - inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y); - inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x); - inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y); - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[row * row_stride + id] = val_f; -} - -template <> -void launch_fused_add3(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add3_kernel<<>>( - out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -template <> -void launch_fused_add3<__half>(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add3_kernel<<>>( - out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -__global__ void fused_add4_kernel(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - const float* inp4, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - - const float4* inp1_4 = reinterpret_cast(inp1); - const float4* inp2_4 = reinterpret_cast(inp2); - const float4* inp3_4 = reinterpret_cast(inp3); - const float4* inp4_4 = reinterpret_cast(inp4); - float4* out_4 = reinterpret_cast(out); - - float4 val; - float4 inp1_reg = inp1_4[row * row_stride + id]; - float4 inp2_reg = inp2_4[row * row_stride + id]; - float4 inp3_reg = inp3_4[row * row_stride + id]; - float4 inp4_reg = inp4_4[row * row_stride + id]; - - val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x; - val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y; - val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z; - val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w; - - out_4[row * row_stride + id] = val; -} - -__global__ void fused_add4_kernel(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - const __half* inp4, - int size, - int row_stride) -{ - int row = blockIdx.x; - int id = threadIdx.x; - const float2* inp1_arr = reinterpret_cast(inp1); - const float2* inp2_arr = reinterpret_cast(inp2); - const float2* inp3_arr = reinterpret_cast(inp3); - const float2* inp4_arr = reinterpret_cast(inp4); - - float2 inp1_4 = inp1_arr[row * row_stride + id]; - float2 inp2_4 = inp2_arr[row * row_stride + id]; - float2 inp3_4 = inp3_arr[row * row_stride + id]; - float2 inp4_4 = inp4_arr[row * row_stride + id]; - - __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); - __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); - __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); - __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4); - - float2 inp1_h_f_0 = __half22float2(inp1_h[0]); - float2 inp1_h_f_1 = __half22float2(inp1_h[1]); - - float2 inp2_h_f_0 = __half22float2(inp2_h[0]); - float2 inp2_h_f_1 = __half22float2(inp2_h[1]); - - float2 inp3_h_f_0 = __half22float2(inp3_h[0]); - float2 inp3_h_f_1 = __half22float2(inp3_h[1]); - - float2 inp4_h_f_0 = __half22float2(inp4_h[0]); - float2 inp4_h_f_1 = __half22float2(inp4_h[1]); - - inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x); - inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y); - inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x); - inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y); - - float2 val_f; - __half2* val_h = reinterpret_cast<__half2*>(&val_f); - - val_h[0] = __float22half2_rn(inp1_h_f_0); - val_h[1] = __float22half2_rn(inp1_h_f_1); - - float2* out_4 = reinterpret_cast(out); - out_4[row * row_stride + id] = val_f; -} - -template <> -void launch_fused_add4(float* out, - const float* inp1, - const float* inp2, - const float* inp3, - const float* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add4_kernel<<>>( - out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); -} - -template <> -void launch_fused_add4<__half>(__half* out, - const __half* inp1, - const __half* inp2, - const __half* inp3, - const __half* inp4, - int batch_size, - int seq_length, - int hidden_size, - cudaStream_t& stream) -{ - dim3 grid_dim(batch_size * seq_length); - - dim3 block_dim(hidden_size / 4); - - fused_add4_kernel<<>>( - out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); -} +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +template +__global__ void column_sum_reduce(const T* __restrict__ inp, + T* __restrict__ out, + int rows, + int width) +{ + __shared__ float tile[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + int y_stride = width * TILE_DIM; + + float localSum = 0; + + // Loop across matrix height + if (idx < width) { + int offset = threadIdx.y * width + idx; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + localSum += (float)inp[offset]; + offset += y_stride; + } + } + + tile[threadIdx.x][threadIdx.y] = localSum; + + __syncthreads(); + + // Sum the shared buffer. + float sum = tile[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i); + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + if (pos < width) out[pos] = sum; + } +} + +template +void launch_fuse_transpose_bias_kernel(const T* inp, + T* out, + int rows, + int cols, + cudaStream_t stream); + +template <> +void launch_fuse_transpose_bias_kernel(const float* inp, + float* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<<>>(inp, out, rows, cols); +} + +template <> +void launch_fuse_transpose_bias_kernel<__half>(const __half* inp, + __half* out, + int rows, + int cols, + cudaStream_t stream) +{ + // assert(rows % TILE_DIM == 0); + // assert(cols % TILE_DIM == 0); + + dim3 grid_dim((cols - 1) / TILE_DIM + 1); + dim3 block_dim(TILE_DIM, TILE_DIM); + + column_sum_reduce<__half><<>>(inp, out, rows, cols); +} + +__global__ void fused_add2_kernel(const int N, float* out, const float* inp1, const float* inp2) +{ + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + float4* out_4 = reinterpret_cast(out); + + CUDA_1D_KERNEL_LOOP(j, N) + { + float4 val; + float4 inp1_reg = inp1_4[j]; + float4 inp2_reg = inp2_4[j]; + + val.x = inp1_reg.x + inp2_reg.x; + val.y = inp1_reg.y + inp2_reg.y; + val.z = inp1_reg.z + inp2_reg.z; + val.w = inp1_reg.w + inp2_reg.w; + + out_4[j] = val; + } +} + +__global__ void fused_add2_kernel(const int N, __half* out, const __half* inp1, const __half* inp2) +{ + float2 inp1_4; + float2 inp2_4; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + + CUDA_1D_KERNEL_LOOP(j, N) + { + inp1_4 = inp1_arr[j]; + inp2_4 = inp2_arr[j]; + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + inp1_h_f_0.x += inp2_h_f_0.x; + inp1_h_f_0.y += inp2_h_f_0.y; + inp1_h_f_1.x += inp2_h_f_1.x; + inp1_h_f_1.y += inp2_h_f_1.y; + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[j] = val_f; + } +} + +template <> +void launch_fused_add2(float* out, + const float* inp1, + const float* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +template <> +void launch_fused_add2<__half>(__half* out, + const __half* inp1, + const __half* inp2, + int batch_size, + int seq_length, + int hidden_dim, + cudaStream_t& stream) +{ + int total_count = batch_size * seq_length * hidden_dim / 4; + dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); + + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + + fused_add2_kernel<<>>(total_count, out, inp1, inp2); +} + +__global__ void fused_add3_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add3_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add3(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add3<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add3_kernel<<>>( + out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +__global__ void fused_add4_kernel(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + + const float4* inp1_4 = reinterpret_cast(inp1); + const float4* inp2_4 = reinterpret_cast(inp2); + const float4* inp3_4 = reinterpret_cast(inp3); + const float4* inp4_4 = reinterpret_cast(inp4); + float4* out_4 = reinterpret_cast(out); + + float4 val; + float4 inp1_reg = inp1_4[row * row_stride + id]; + float4 inp2_reg = inp2_4[row * row_stride + id]; + float4 inp3_reg = inp3_4[row * row_stride + id]; + float4 inp4_reg = inp4_4[row * row_stride + id]; + + val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x; + val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y; + val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z; + val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w; + + out_4[row * row_stride + id] = val; +} + +__global__ void fused_add4_kernel(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int size, + int row_stride) +{ + int row = blockIdx.x; + int id = threadIdx.x; + const float2* inp1_arr = reinterpret_cast(inp1); + const float2* inp2_arr = reinterpret_cast(inp2); + const float2* inp3_arr = reinterpret_cast(inp3); + const float2* inp4_arr = reinterpret_cast(inp4); + + float2 inp1_4 = inp1_arr[row * row_stride + id]; + float2 inp2_4 = inp2_arr[row * row_stride + id]; + float2 inp3_4 = inp3_arr[row * row_stride + id]; + float2 inp4_4 = inp4_arr[row * row_stride + id]; + + __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4); + __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4); + __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4); + __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4); + + float2 inp1_h_f_0 = __half22float2(inp1_h[0]); + float2 inp1_h_f_1 = __half22float2(inp1_h[1]); + + float2 inp2_h_f_0 = __half22float2(inp2_h[0]); + float2 inp2_h_f_1 = __half22float2(inp2_h[1]); + + float2 inp3_h_f_0 = __half22float2(inp3_h[0]); + float2 inp3_h_f_1 = __half22float2(inp3_h[1]); + + float2 inp4_h_f_0 = __half22float2(inp4_h[0]); + float2 inp4_h_f_1 = __half22float2(inp4_h[1]); + + inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x); + inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y); + inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x); + inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y); + + float2 val_f; + __half2* val_h = reinterpret_cast<__half2*>(&val_f); + + val_h[0] = __float22half2_rn(inp1_h_f_0); + val_h[1] = __float22half2_rn(inp1_h_f_1); + + float2* out_4 = reinterpret_cast(out); + out_4[row * row_stride + id] = val_f; +} + +template <> +void launch_fused_add4(float* out, + const float* inp1, + const float* inp2, + const float* inp3, + const float* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} + +template <> +void launch_fused_add4<__half>(__half* out, + const __half* inp1, + const __half* inp2, + const __half* inp3, + const __half* inp4, + int batch_size, + int seq_length, + int hidden_size, + cudaStream_t& stream) +{ + dim3 grid_dim(batch_size * seq_length); + + dim3 block_dim(hidden_size / 4); + + fused_add4_kernel<<>>( + out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4); +} diff --git a/csrc/transformer/inference/csrc/dequantize.cu b/csrc/transformer/inference/csrc/dequantize.cu index ddf7a958822a..0fc15d0fbeef 100644 --- a/csrc/transformer/inference/csrc/dequantize.cu +++ b/csrc/transformer/inference/csrc/dequantize.cu @@ -1,110 +1,110 @@ -#include "custom_cuda_layers.h" - -#define MAX_QUANTIZE_GROUPING 1024 - -#define loop_unroll 1 -#define loop_unroll_bits 1 - -__global__ void dequantize_kernel(float* output, - const int8_t* input, - const float* qscale, - int output_size, - int hidden_dim, - int groups, - int merge_count) -{ - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = (scale_data * (float)q); - tid += blockDim.x; - } -} - -__global__ void dequantize_kernel(__half* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count) -{ -#if __CUDA_ARCH__ >= 700 - - unsigned merge_hidden = hidden_dim >> merge_count; - unsigned quantization_stride = (merge_hidden * output_size) / groups; - - unsigned bid = blockIdx.x; - unsigned tid = threadIdx.x; - - while (tid < output_size) { - unsigned w_index = bid / merge_hidden; - unsigned q_index = tid + bid * output_size; - - auto q = input[q_index]; - - unsigned merge_hidden_total = w_index * merge_hidden; - unsigned scale_index = - ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) - << merge_count) + - w_index; - - float scale_data = qscale[scale_index]; - - output[q_index] = __float2half(scale_data * (float)q); - tid += blockDim.x; - } -#endif -} - -template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count, - cudaStream_t stream) -{ - unsigned threads = 1024; - dim3 block_dims(threads); - dim3 grid_dims(hidden_dim); - - dequantize_kernel<<>>( - output, input, qscale, output_size, hidden_dim, groups, merge_count); -} - -template void launch_dequantize(float*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); -template void launch_dequantize<__half>(__half*, - const int8_t*, - const float*, - unsigned, - unsigned, - unsigned, - unsigned, - cudaStream_t); +#include "custom_cuda_layers.h" + +#define MAX_QUANTIZE_GROUPING 1024 + +#define loop_unroll 1 +#define loop_unroll_bits 1 + +__global__ void dequantize_kernel(float* output, + const int8_t* input, + const float* qscale, + int output_size, + int hidden_dim, + int groups, + int merge_count) +{ + unsigned merge_hidden = hidden_dim >> merge_count; + unsigned quantization_stride = (merge_hidden * output_size) / groups; + + unsigned bid = blockIdx.x; + unsigned tid = threadIdx.x; + + while (tid < output_size) { + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; + + auto q = input[q_index]; + + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) + << merge_count) + + w_index; + + float scale_data = qscale[scale_index]; + + output[q_index] = (scale_data * (float)q); + tid += blockDim.x; + } +} + +__global__ void dequantize_kernel(__half* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count) +{ +#if __CUDA_ARCH__ >= 700 + + unsigned merge_hidden = hidden_dim >> merge_count; + unsigned quantization_stride = (merge_hidden * output_size) / groups; + + unsigned bid = blockIdx.x; + unsigned tid = threadIdx.x; + + while (tid < output_size) { + unsigned w_index = bid / merge_hidden; + unsigned q_index = tid + bid * output_size; + + auto q = input[q_index]; + + unsigned merge_hidden_total = w_index * merge_hidden; + unsigned scale_index = + ((((bid - merge_hidden_total) + tid * merge_hidden) / quantization_stride) + << merge_count) + + w_index; + + float scale_data = qscale[scale_index]; + + output[q_index] = __float2half(scale_data * (float)q); + tid += blockDim.x; + } +#endif +} + +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream) +{ + unsigned threads = 1024; + dim3 block_dims(threads); + dim3 grid_dims(hidden_dim); + + dequantize_kernel<<>>( + output, input, qscale, output_size, hidden_dim, groups, merge_count); +} + +template void launch_dequantize(float*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); +template void launch_dequantize<__half>(__half*, + const int8_t*, + const float*, + unsigned, + unsigned, + unsigned, + unsigned, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index fc3faacc54e8..10adaa6fe98f 100755 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -1,266 +1,266 @@ -#include "custom_cuda_layers.h" - -inline __device__ float gelu(const float x) -{ - const float sqrt_param = 0.79788456080286535587989211986876f; - const float mul_param = 0.044715; - return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); -} - -__global__ void fused_bias_gelu(float* input, - const float* bias, - int total_count, - int intermediate_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - data.x = gelu(data.x); - data.y = gelu(data.y); - data.z = gelu(data.z); - data.w = gelu(data.w); - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_gelu(__half* input, - const __half* bias, - int total_count, - int intermediate_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % intermediate_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - low_data.x = gelu(low_data.x); - low_data.y = gelu(low_data.y); - high_data.x = gelu(high_data.x); - high_data.y = gelu(high_data.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_gelu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream) -{ - int total_count = batch_size * (intermediate_size / 4); - int threads = 1024; // intermediate_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); - - fused_bias_gelu<<>>( - input, bias, total_count, intermediate_size / 4); -} - -template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); -template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); - -__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % hidden_size]; - - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % hidden_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) -{ - int total_count = batch_size * (hidden_size / 4); - int threads = 1024; // hidden_size / iterations / 4; - dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size); - - fused_bias_add<<>>(input, bias, total_count, hidden_size / 4); -} - -template void launch_bias_add(float*, const float*, int, int, cudaStream_t); -template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); - -__global__ void fused_bias_residual(float* input, - const float* residual, - const float* bias, - int total_count, - int intermediate_size) -{ - float4* input_cast = reinterpret_cast(input); - const float4* residual_cast = reinterpret_cast(residual); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float4 data = input_cast[offset]; - float4 res_vec = residual_cast[offset]; - float4 bias_data = bias_cast[offset % intermediate_size]; - - data.x += (res_vec.x + bias_data.x); - data.y += (res_vec.y + bias_data.y); - data.z += (res_vec.z + bias_data.z); - data.w += (res_vec.w + bias_data.w); - - input_cast[offset] = data; - } -} - -__global__ void fused_bias_residual(__half* input, - const __half* residual, - const __half* bias, - int total_count, - int intermediate_size) -{ -#if __CUDA_ARCH__ >= 700 - - float2* input_cast = reinterpret_cast(input); - const float2* residual_cast = reinterpret_cast(residual); - - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; - - if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 res_vec = residual_cast[offset]; - - float2 bias_vec = bias_cast[offset % intermediate_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* res_half = reinterpret_cast<__half2*>(&res_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_res = __half22float2(res_half[0]); - float2 high_res = __half22float2(res_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += (low_res.x + low_bias.x); - low_data.y += (low_res.y + low_bias.y); - high_data.x += (high_res.x + high_bias.x); - high_data.y += (high_res.y + high_bias.y); - - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); - - input_cast[offset] = vals_vec; - } -#endif -} - -template -void launch_bias_residual(T* input, - const T* residual, - const T* bias, - int batch, - int intermediate_size, - cudaStream_t stream) -{ - int total_count = batch * intermediate_size / 4; - dim3 block_dims(1024); - dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); - - fused_bias_residual<<>>( - input, residual, bias, total_count, intermediate_size / 4); -} - -template void launch_bias_residual(float*, - const float*, - const float*, - int, - int, - cudaStream_t); -template void launch_bias_residual<__half>(__half*, - const __half*, - const __half*, - int, - int, - cudaStream_t); +#include "custom_cuda_layers.h" + +inline __device__ float gelu(const float x) +{ + const float sqrt_param = 0.79788456080286535587989211986876f; + const float mul_param = 0.044715; + return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); +} + +__global__ void fused_bias_gelu(float* input, + const float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + data.x = gelu(data.x); + data.y = gelu(data.y); + data.z = gelu(data.z); + data.w = gelu(data.w); + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_gelu(__half* input, + const __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + low_data.x = gelu(low_data.x); + low_data.y = gelu(low_data.y); + high_data.x = gelu(high_data.x); + high_data.y = gelu(high_data.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream) +{ + int total_count = batch_size * (intermediate_size / 4); + int threads = 1024; // intermediate_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size); + + fused_bias_gelu<<>>( + input, bias, total_count, intermediate_size / 4); +} + +template void launch_bias_gelu(float*, const float*, int, int, cudaStream_t); +template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStream_t); + +__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 bias_data = bias_cast[offset % hidden_size]; + + data.x += bias_data.x; + data.y += bias_data.y; + data.z += bias_data.z; + data.w += bias_data.w; + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 bias_vec = bias_cast[offset % hidden_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += low_bias.x; + low_data.y += low_bias.y; + high_data.x += high_bias.x; + high_data.y += high_bias.y; + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) +{ + int total_count = batch_size * (hidden_size / 4); + int threads = 1024; // hidden_size / iterations / 4; + dim3 block_dims(threads); + dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size); + + fused_bias_add<<>>(input, bias, total_count, hidden_size / 4); +} + +template void launch_bias_add(float*, const float*, int, int, cudaStream_t); +template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); + +__global__ void fused_bias_residual(float* input, + const float* residual, + const float* bias, + int total_count, + int intermediate_size) +{ + float4* input_cast = reinterpret_cast(input); + const float4* residual_cast = reinterpret_cast(residual); + const float4* bias_cast = reinterpret_cast(bias); + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float4 data = input_cast[offset]; + float4 res_vec = residual_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + + data.x += (res_vec.x + bias_data.x); + data.y += (res_vec.y + bias_data.y); + data.z += (res_vec.z + bias_data.z); + data.w += (res_vec.w + bias_data.w); + + input_cast[offset] = data; + } +} + +__global__ void fused_bias_residual(__half* input, + const __half* residual, + const __half* bias, + int total_count, + int intermediate_size) +{ +#if __CUDA_ARCH__ >= 700 + + float2* input_cast = reinterpret_cast(input); + const float2* residual_cast = reinterpret_cast(residual); + + const float2* bias_cast = reinterpret_cast(bias); + + int offset = blockIdx.x * blockDim.x + threadIdx.x; + + if (offset < total_count) { + float2 vals_vec = input_cast[offset]; + float2 res_vec = residual_cast[offset]; + + float2 bias_vec = bias_cast[offset % intermediate_size]; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* res_half = reinterpret_cast<__half2*>(&res_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + + float2 low_data = __half22float2(vals_half[0]); + float2 high_data = __half22float2(vals_half[1]); + + float2 low_res = __half22float2(res_half[0]); + float2 high_res = __half22float2(res_half[1]); + + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + low_data.x += (low_res.x + low_bias.x); + low_data.y += (low_res.y + low_bias.y); + high_data.x += (high_res.x + high_bias.x); + high_data.y += (high_res.y + high_bias.y); + + vals_half[0] = __float22half2_rn(low_data); + vals_half[1] = __float22half2_rn(high_data); + + input_cast[offset] = vals_vec; + } +#endif +} + +template +void launch_bias_residual(T* input, + const T* residual, + const T* bias, + int batch, + int intermediate_size, + cudaStream_t stream) +{ + int total_count = batch * intermediate_size / 4; + dim3 block_dims(1024); + dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); + + fused_bias_residual<<>>( + input, residual, bias, total_count, intermediate_size / 4); +} + +template void launch_bias_residual(float*, + const float*, + const float*, + int, + int, + cudaStream_t); +template void launch_bias_residual<__half>(__half*, + const __half*, + const __half*, + int, + int, + cudaStream_t); diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu index ecd73154f37f..23cceff37166 100755 --- a/csrc/transformer/inference/csrc/normalize.cu +++ b/csrc/transformer/inference/csrc/normalize.cu @@ -1,426 +1,426 @@ -#include -#include "custom_cuda_layers.h" - -#include -#include -#include -#include - -#define NORM_REG (MAX_REGISTERS) - -namespace cg = cooperative_groups; - -__global__ void fused_bias_residual_layer_norm(float* output, - const float* vals, - const float* gamma, - const float* beta, - float epsilon, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - float inp_reg[NORM_REG]; - - int k = 0; - float sum = 0; - int input_id = id; - while (input_id < row_stride) { - inp_reg[k] = vals[input_id + row * row_stride]; - sum += inp_reg[k++]; - input_id += iteration_stride; - } - - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - - __shared__ float shr[MAX_WARP_NUM]; - - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - - float mean = sum / (row_stride); - sum = 0.f; - for (int f = 0; f < k; f++) { - inp_reg[f] -= mean; - sum += inp_reg[f] * inp_reg[f]; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride); - sum += epsilon; - sum = __frsqrt_rn(sum); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * sum; - inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; - output[out_id + row * row_stride] = inp_reg[f]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* output, - const __half* vals, - const __half* gamma, - const __half* beta, - float epsilon, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - __half2 inp_reg[NORM_REG]; - - const __half2* vals_cast = reinterpret_cast(vals); - __half2* out_cast = reinterpret_cast<__half2*>(output); - - int k = 0; - int input_id = id; - while (input_id < row_stride) { - inp_reg[k++] = vals_cast[input_id + row * row_stride]; - input_id += iteration_stride; - } - float sum = 0; - for (int f = k - 1; f >= 0; f--) { - float2 inp_f = __half22float2(inp_reg[f]); - sum += inp_f.x + inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride << 1); - sum = 0.f; - for (int f = 0; f < k; f++) { - float2 inp_f = __half22float2(inp_reg[f]); - inp_f.x -= mean; - inp_f.y -= mean; - inp_reg[f] = __float22half2_rn(inp_f); - sum += inp_f.x * inp_f.x; - sum += inp_f.y * inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride << 1); - sum += epsilon; - sum = __frsqrt_rn(sum); - __half2 variance_h = __float2half2_rn(sum); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * variance_h; - inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; - out_cast[out_id + row * row_stride] = inp_reg[f]; - } -#endif -} - -template -void launch_layer_norm(T* out, - T* vals, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream); - -template <> -void launch_layer_norm(float* out, - float* vals, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - out, vals, gamma, beta, epsilon, hidden_dim); -} - -template <> -void launch_layer_norm<__half>(__half* out, - __half* vals, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - out, vals, gamma, beta, epsilon, hidden_dim / 2); -} - -__global__ void fused_residual_layer_norm(float* norm, - float* res_add, - float* vals, - float* residual, - const float* bias, - const float* gamma, - const float* beta, - float epsilon, - int row_stride, - bool preLN) -{ - int iteration_stride = blockDim.x; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - float inp_reg[NORM_REG]; - - int k = 0; - int input_id = id; - - float sum = 0; - while (input_id < row_stride) { - inp_reg[k] = vals[input_id + row * row_stride]; - float res_f = (residual[input_id + row * row_stride]); - float bias_f = (bias[input_id]); - inp_reg[k] += res_f + bias_f; - if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; - sum += inp_reg[k++]; - input_id += iteration_stride; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride); - sum = 0.f; - for (int f = 0; f < k; f++) { - inp_reg[f] -= mean; - sum += inp_reg[f] * inp_reg[f]; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride); - sum += epsilon; - sum = __frsqrt_rn(sum); - - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * sum; - inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; - norm[out_id + row * row_stride] = inp_reg[f]; - } -} - -__global__ void fused_residual_layer_norm(__half* norm, - __half* res_add, - __half* vals, - __half* residual, - const __half* bias, - const __half* gamma, - const __half* beta, - float epsilon, - int row_stride, - bool preLN) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - int warp_num = iteration_stride >> 5; - - __half2 inp_reg[NORM_REG]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - __half2* norm_cast = reinterpret_cast<__half2*>(norm); - __half2* res_add_cast = reinterpret_cast<__half2*>(res_add); - __half2* residual_cast = reinterpret_cast<__half2*>(residual); - const __half2* bias_cast = reinterpret_cast(bias); - - int k = 0; - int input_id = id; - - float sum = 0; - while (input_id < row_stride) { - inp_reg[k] = vals_cast[input_id + row * row_stride]; - float2 inp_f = __half22float2(inp_reg[k]); - float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]); - float2 bias_f = __half22float2(bias_cast[input_id]); - inp_f.x += res_f.x + bias_f.x; - inp_f.y += res_f.y + bias_f.y; - inp_reg[k] = __float22half2_rn(inp_f); - - if (preLN) res_add_cast[input_id + row * row_stride] = inp_reg[k]; - sum += inp_f.x + inp_f.y; - input_id += iteration_stride; - k++; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - __shared__ float shr[MAX_WARP_NUM]; - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - float mean = sum / (row_stride << 1); - sum = 0.f; - for (int f = 0; f < k; f++) { - float2 inp_f = __half22float2(inp_reg[f]); - inp_f.x -= mean; - inp_f.y -= mean; - inp_reg[f] = __float22half2_rn(inp_f); - sum += inp_f.x * inp_f.x; - sum += inp_f.y * inp_f.y; - } - for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); - if (g.thread_rank() == 0) shr[gid] = sum; - b.sync(); - if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; - b.sync(); - for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= (row_stride << 1); - sum += epsilon; - sum = __frsqrt_rn(sum); - __half2 variance_h = __float2half2_rn(sum); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - for (int f = 0; f < k; f++) { - int out_id = f * iteration_stride + id; - inp_reg[f] = inp_reg[f] * variance_h; - inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; - norm_cast[out_id + row * row_stride] = inp_reg[f]; - } -#endif -} - -template -void launch_residual_layer_norm(T* norm, - T* res_add, - T* vals, - T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream); - -template <> -void launch_residual_layer_norm(float* norm, - float* res_add, - float* vals, - float* residual, - const float* bias, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - - dim3 block_dim(threads); - - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim, preLN); -} - -template <> -void launch_residual_layer_norm<__half>(__half* norm, - __half* res_add, - __half* vals, - __half* residual, - const __half* bias, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream) -{ - constexpr int threads = 1024; - - dim3 grid_dim(batch_size); - dim3 block_dim(threads); - - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim / 2, preLN); -} +#include +#include "custom_cuda_layers.h" + +#include +#include +#include +#include + +#define NORM_REG (MAX_REGISTERS) + +namespace cg = cooperative_groups; + +__global__ void fused_bias_residual_layer_norm(float* output, + const float* vals, + const float* gamma, + const float* beta, + float epsilon, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + float inp_reg[NORM_REG]; + + int k = 0; + float sum = 0; + int input_id = id; + while (input_id < row_stride) { + inp_reg[k] = vals[input_id + row * row_stride]; + sum += inp_reg[k++]; + input_id += iteration_stride; + } + + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + + __shared__ float shr[MAX_WARP_NUM]; + + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + + float mean = sum / (row_stride); + sum = 0.f; + for (int f = 0; f < k; f++) { + inp_reg[f] -= mean; + sum += inp_reg[f] * inp_reg[f]; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride); + sum += epsilon; + sum = __frsqrt_rn(sum); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * sum; + inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; + output[out_id + row * row_stride] = inp_reg[f]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* output, + const __half* vals, + const __half* gamma, + const __half* beta, + float epsilon, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + __half2 inp_reg[NORM_REG]; + + const __half2* vals_cast = reinterpret_cast(vals); + __half2* out_cast = reinterpret_cast<__half2*>(output); + + int k = 0; + int input_id = id; + while (input_id < row_stride) { + inp_reg[k++] = vals_cast[input_id + row * row_stride]; + input_id += iteration_stride; + } + float sum = 0; + for (int f = k - 1; f >= 0; f--) { + float2 inp_f = __half22float2(inp_reg[f]); + sum += inp_f.x + inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride << 1); + sum = 0.f; + for (int f = 0; f < k; f++) { + float2 inp_f = __half22float2(inp_reg[f]); + inp_f.x -= mean; + inp_f.y -= mean; + inp_reg[f] = __float22half2_rn(inp_f); + sum += inp_f.x * inp_f.x; + sum += inp_f.y * inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride << 1); + sum += epsilon; + sum = __frsqrt_rn(sum); + __half2 variance_h = __float2half2_rn(sum); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * variance_h; + inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; + out_cast[out_id + row * row_stride] = inp_reg[f]; + } +#endif +} + +template +void launch_layer_norm(T* out, + T* vals, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream); + +template <> +void launch_layer_norm(float* out, + float* vals, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + out, vals, gamma, beta, epsilon, hidden_dim); +} + +template <> +void launch_layer_norm<__half>(__half* out, + __half* vals, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + out, vals, gamma, beta, epsilon, hidden_dim / 2); +} + +__global__ void fused_residual_layer_norm(float* norm, + float* res_add, + float* vals, + float* residual, + const float* bias, + const float* gamma, + const float* beta, + float epsilon, + int row_stride, + bool preLN) +{ + int iteration_stride = blockDim.x; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + float inp_reg[NORM_REG]; + + int k = 0; + int input_id = id; + + float sum = 0; + while (input_id < row_stride) { + inp_reg[k] = vals[input_id + row * row_stride]; + float res_f = (residual[input_id + row * row_stride]); + float bias_f = (bias[input_id]); + inp_reg[k] += res_f + bias_f; + if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; + sum += inp_reg[k++]; + input_id += iteration_stride; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride); + sum = 0.f; + for (int f = 0; f < k; f++) { + inp_reg[f] -= mean; + sum += inp_reg[f] * inp_reg[f]; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride); + sum += epsilon; + sum = __frsqrt_rn(sum); + + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * sum; + inp_reg[f] = inp_reg[f] * gamma[out_id] + beta[out_id]; + norm[out_id + row * row_stride] = inp_reg[f]; + } +} + +__global__ void fused_residual_layer_norm(__half* norm, + __half* res_add, + __half* vals, + __half* residual, + const __half* bias, + const __half* gamma, + const __half* beta, + float epsilon, + int row_stride, + bool preLN) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + int warp_num = iteration_stride >> 5; + + __half2 inp_reg[NORM_REG]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + __half2* norm_cast = reinterpret_cast<__half2*>(norm); + __half2* res_add_cast = reinterpret_cast<__half2*>(res_add); + __half2* residual_cast = reinterpret_cast<__half2*>(residual); + const __half2* bias_cast = reinterpret_cast(bias); + + int k = 0; + int input_id = id; + + float sum = 0; + while (input_id < row_stride) { + inp_reg[k] = vals_cast[input_id + row * row_stride]; + float2 inp_f = __half22float2(inp_reg[k]); + float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]); + float2 bias_f = __half22float2(bias_cast[input_id]); + inp_f.x += res_f.x + bias_f.x; + inp_f.y += res_f.y + bias_f.y; + inp_reg[k] = __float22half2_rn(inp_f); + + if (preLN) res_add_cast[input_id + row * row_stride] = inp_reg[k]; + sum += inp_f.x + inp_f.y; + input_id += iteration_stride; + k++; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + __shared__ float shr[MAX_WARP_NUM]; + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + float mean = sum / (row_stride << 1); + sum = 0.f; + for (int f = 0; f < k; f++) { + float2 inp_f = __half22float2(inp_reg[f]); + inp_f.x -= mean; + inp_f.y -= mean; + inp_reg[f] = __float22half2_rn(inp_f); + sum += inp_f.x * inp_f.x; + sum += inp_f.y * inp_f.y; + } + for (int i = 1; i < 32; i *= 2) sum += g.shfl_down(sum, i); + if (g.thread_rank() == 0) shr[gid] = sum; + b.sync(); + if (g.thread_rank() < (warp_num)) sum = shr[g.thread_rank()]; + b.sync(); + for (int i = 1; i < (warp_num); i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= (row_stride << 1); + sum += epsilon; + sum = __frsqrt_rn(sum); + __half2 variance_h = __float2half2_rn(sum); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + for (int f = 0; f < k; f++) { + int out_id = f * iteration_stride + id; + inp_reg[f] = inp_reg[f] * variance_h; + inp_reg[f] = inp_reg[f] * gamma_cast[out_id] + beta_cast[out_id]; + norm_cast[out_id + row * row_stride] = inp_reg[f]; + } +#endif +} + +template +void launch_residual_layer_norm(T* norm, + T* res_add, + T* vals, + T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream); + +template <> +void launch_residual_layer_norm(float* norm, + float* res_add, + float* vals, + float* residual, + const float* bias, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + + dim3 block_dim(threads); + + fused_residual_layer_norm<<>>( + norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim, preLN); +} + +template <> +void launch_residual_layer_norm<__half>(__half* norm, + __half* res_add, + __half* vals, + __half* residual, + const __half* bias, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream) +{ + constexpr int threads = 1024; + + dim3 grid_dim(batch_size); + dim3 block_dim(threads); + + fused_residual_layer_norm<<>>( + norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim / 2, preLN); +} diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 1ebadaeb53b4..b587b6233173 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -1,630 +1,630 @@ - - -#include -#include -#include -#include "context.h" -#include "cublas_wrappers.h" -#include "custom_cuda_layers.h" - -std::array gemm_algos = std::array({99, 99, 99}); - -template -at::Tensor ds_softmax(at::Tensor& attn_scores, - at::Tensor& attn_mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size) -{ - auto attn_scores_c = attn_scores.contiguous(); - int bsz = attn_scores_c.size(0); - int seq_len = attn_scores_c.size(2); - int soft_len = attn_scores_c.size(3); - int heads = attn_scores_c.size(1); - launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), - (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), - triangular, - recompute, - local_attention, - window_size, - bsz, - heads, - seq_len, - soft_len, - 1.0, - at::cuda::getCurrentCUDAStream()); - - return attn_scores_c; -} - -template -void attention_unfused(at::Tensor& prev_key_cont, - at::Tensor& query_cont, - at::Tensor& attn_mask, - at::Tensor& prev_value_cont, - at::Tensor& output, - int& bsz, - int& seq_len, - int& soft_len, - int& heads, - float& norm_factor, - bool triangular, - bool recompute, - bool local_attention, - int window_size) -{ - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - float alpha = norm_factor; - float gemm_beta = 0.0; - auto attn_score = at::zeros({bsz, heads, seq_len, soft_len}, options); - int k = prev_value_cont.size(2) / heads; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), - soft_len, - seq_len, - k, - &alpha, - &gemm_beta, - (T*)prev_key_cont.data_ptr(), - (T*)query_cont.data_ptr(), - (T*)attn_score.data_ptr(), - CUBLAS_OP_N, - CUBLAS_OP_N, - soft_len * k, - seq_len * k, - seq_len * soft_len, - bsz * heads, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - attn_score = - ds_softmax(attn_score, attn_mask, triangular, recompute, local_attention, window_size); - alpha = 1.0; - cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), - k, - seq_len, - soft_len, - &alpha, - &gemm_beta, - (T*)prev_value_cont.data_ptr(), - (T*)attn_score.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_OP_N, - CUBLAS_OP_N, - soft_len * k, - seq_len * soft_len, - seq_len * k, - bsz * heads, - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -template -std::vector ds_softmax_context(at::Tensor& query, - at::Tensor& prev_key, - at::Tensor& new_key, - at::Tensor& attn_mask, - at::Tensor& prev_value, - at::Tensor& new_value, - int heads, - float norm_factor, - bool merging, - bool triangular, - bool local_attention, - int window_size, - bool no_masking) -{ - auto query_cont = query.contiguous(); - auto prev_key_cont = prev_key.contiguous(); - auto prev_value_cont = prev_value.contiguous(); - - int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); - - // Attn_Score [ batch Head Sequence-length Softmax-length] - - int bsz = query_cont.size(0); - int seq_len = query_cont.size(1); - int soft_len = prev_value.size(1); - - auto options = at::TensorOptions() - .dtype(query_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = - at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); - attention_unfused(prev_key_cont, - query_cont, - attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), - prev_value_cont, - output, - bsz, - seq_len, - soft_len, - heads, - norm_factor, - (triangular && (new_size == 0)), - (new_size == 0), - local_attention, - window_size); - - return {output, prev_key, prev_value}; -} - -template -at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - int intermediate_size = input_cont.size(2); - - launch_bias_gelu((T*)input_cont.data_ptr(), - (T*)bias.data_ptr(), - intermediate_size, - bsz, - Context::Instance().GetCurrentStream()); - return input_cont; -} - -template -at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - auto residual_cont = residual.contiguous(); - - int bsz = input_cont.size(0) * input_cont.size(1); - - launch_bias_residual((T*)input_cont.data_ptr(), - (T*)residual_cont.data_ptr(), - (T*)bias.data_ptr(), - bsz, - input_cont.size(2), - Context::Instance().GetCurrentStream()); - return input_cont; -} - -template -at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon) -{ - int bsz = input_cont.size(0) * input_cont.size(1); - auto inp_norm = at::empty_like(input_cont); - launch_layer_norm((T*)inp_norm.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)gamma.data_ptr(), - (T*)betta.data_ptr(), - epsilon, - bsz, - input_cont.size(2), - Context::Instance().GetCurrentStream()); - return inp_norm; -} - -template -void qkv_unfused_cublas(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) -{ - auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - int bsz = input.size(0) * input.size(1); - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); -} - -template -at::Tensor ds_qkv_gemm(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool add_bias) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); - - return output; -} - -template -void quantized_gemm(at::Tensor& output, - at::Tensor& input, - at::Tensor& weight, - at::Tensor& qscale, - int groups, - int merge_count) -{ - int bsz = input.size(0) * input.size(1); - auto options = at::TensorOptions() - .dtype(input.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); - - launch_dequantize((T*)weight16.data_ptr(), - (int8_t*)weight.data_ptr(), - (float*)qscale.data_ptr(), - weight.size(1), - weight.size(0), - groups, - merge_count, - Context::Instance().GetCurrentStream()); - - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight16.data_ptr(), - (T*)input.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); -} - -template -at::Tensor ds_qkv_gemm_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool add_bias) -{ - int bsz = input.size(0) * input.size(1); - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - auto inp_norm = ds_layernorm(input_cont, gamma, beta, epsilon); - - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - if (add_bias) - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return output; -} - -template -at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return output; -} - -template -at::Tensor ds_linear_layer_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& q_scale, - int groups) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - int bsz = input_cont.size(0) * input_cont.size(1); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, 0); - launch_bias_add((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - return output; -} - -template -at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input_cont.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - return output; -} - -template -at::Tensor ds_vector_matmul_int8(at::Tensor& input, - at::Tensor& weight, - at::Tensor& q_scale, - int groups, - int merge_count) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); - return output; -} - -template -void mlp_unfused_cublas(at::Tensor& output, - at::Tensor& residual_add, - at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm) -{ - int bsz = input.size(0) * input.size(1); - auto inp_norm = preLayerNorm ? at::empty_like(input) : residual_add; - - launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), - (T*)input.data_ptr(), - (T*)residual.data_ptr(), - (T*)input_bias.data_ptr(), - (T*)gamma.data_ptr(), - (T*)beta.data_ptr(), - epsilon, - bsz, - input.size(2), - preLayerNorm, - Context::Instance().GetCurrentStream()); - - float alpha = (T)1.0; - float gemm_beta = (T)0.0; - cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); - cublas_gemm_ex(Context::Instance().GetCublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - weight.size(1), - bsz, - input.size(2), - &alpha, - &gemm_beta, - (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), - (T*)output.data_ptr(), - CUBLAS_GEMM_DEFAULT_TENSOR_OP); - - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); -} -template -std::vector ds_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - auto residual_add = at::empty_like(input_cont); - int bsz = input_cont.size(0) * input_cont.size(1); - - mlp_unfused_cublas(output, - residual_add, - input, - residual, - input_bias, - weight, - bias, - gamma, - beta, - epsilon, - preLayerNorm); - - return {output, residual_add}; -} - -template -std::vector ds_mlp_gemm_int8(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - at::Tensor& q_scale, - int groups, - bool preLayerNorm) -{ - auto input_cont = input.contiguous(); - auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) - .layout(at::kStrided) - .device(at::kCUDA) - .requires_grad(false); - - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - - int bsz = input_cont.size(0) * input_cont.size(1); - auto inp_norm = at::empty_like(input_cont); - - auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); - // computing the blocking across K dimension - launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)residual.data_ptr(), - (T*)input_bias.data_ptr(), - (T*)gamma.data_ptr(), - (T*)beta.data_ptr(), - epsilon, - bsz, - input_cont.size(2), - preLayerNorm, - Context::Instance().GetCurrentStream()); - - quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); - launch_bias_gelu((T*)output.data_ptr(), - (T*)bias.data_ptr(), - weight.size(1), - bsz, - Context::Instance().GetCurrentStream()); - - return {output, residual_add}; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ - m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)"); - m.def( - "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); - m.def("softmax_context_fp16", - &ds_softmax_context<__half>, - "DeepSpeed attention with fp32 (CUDA)"); - m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); - m.def("bias_residual_fp32", - &ds_bias_residual, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("bias_residual_fp16", - &ds_bias_residual<__half>, - "DeepSpeed residual-bias add with fp32 (CUDA)"); - m.def("layer_norm_fp32", &ds_layernorm, "DeepSpeed layer-norm with fp32 (CUDA)"); - m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)"); - m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); - m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); - m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); - m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); - m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); - m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); - m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); - m.def("vector_matmul_int8", - &ds_vector_matmul_int8<__half>, - "DeepSpeed vector-MM with int8 (CUDA)"); - m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); - m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); - m.def("linear_layer_int8", - &ds_linear_layer_int8<__half>, - "DeepSpeed linear_layer with int8 (CUDA)"); -} + + +#include +#include +#include +#include "context.h" +#include "cublas_wrappers.h" +#include "custom_cuda_layers.h" + +std::array gemm_algos = std::array({99, 99, 99}); + +template +at::Tensor ds_softmax(at::Tensor& attn_scores, + at::Tensor& attn_mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + auto attn_scores_c = attn_scores.contiguous(); + int bsz = attn_scores_c.size(0); + int seq_len = attn_scores_c.size(2); + int soft_len = attn_scores_c.size(3); + int heads = attn_scores_c.size(1); + launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 1.0, + at::cuda::getCurrentCUDAStream()); + + return attn_scores_c; +} + +template +void attention_unfused(at::Tensor& prev_key_cont, + at::Tensor& query_cont, + at::Tensor& attn_mask, + at::Tensor& prev_value_cont, + at::Tensor& output, + int& bsz, + int& seq_len, + int& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + float alpha = norm_factor; + float gemm_beta = 0.0; + auto attn_score = at::zeros({bsz, heads, seq_len, soft_len}, options); + int k = prev_value_cont.size(2) / heads; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont.data_ptr(), + (T*)query_cont.data_ptr(), + (T*)attn_score.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + attn_score = + ds_softmax(attn_score, attn_mask, triangular, recompute, local_attention, window_size); + alpha = 1.0; + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont.data_ptr(), + (T*)attn_score.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_OP_N, + CUBLAS_OP_N, + soft_len * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +std::vector ds_softmax_context(at::Tensor& query, + at::Tensor& prev_key, + at::Tensor& new_key, + at::Tensor& attn_mask, + at::Tensor& prev_value, + at::Tensor& new_value, + int heads, + float norm_factor, + bool merging, + bool triangular, + bool local_attention, + int window_size, + bool no_masking) +{ + auto query_cont = query.contiguous(); + auto prev_key_cont = prev_key.contiguous(); + auto prev_value_cont = prev_value.contiguous(); + + int new_size = (new_value.sizes().size() > 1 ? new_value.size(1) : 0); + + // Attn_Score [ batch Head Sequence-length Softmax-length] + + int bsz = query_cont.size(0); + int seq_len = query_cont.size(1); + int soft_len = prev_value.size(1); + + auto options = at::TensorOptions() + .dtype(query_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = + at::empty({prev_value.size(0), heads, seq_len, prev_value.size(2) / heads}, options); + attention_unfused(prev_key_cont, + query_cont, + attn_mask, //(no_masking ? nullptr : (T*)attn_mask.data_ptr()), + prev_value_cont, + output, + bsz, + seq_len, + soft_len, + heads, + norm_factor, + (triangular && (new_size == 0)), + (new_size == 0), + local_attention, + window_size); + + return {output, prev_key, prev_value}; +} + +template +at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + int intermediate_size = input_cont.size(2); + + launch_bias_gelu((T*)input_cont.data_ptr(), + (T*)bias.data_ptr(), + intermediate_size, + bsz, + Context::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + auto residual_cont = residual.contiguous(); + + int bsz = input_cont.size(0) * input_cont.size(1); + + launch_bias_residual((T*)input_cont.data_ptr(), + (T*)residual_cont.data_ptr(), + (T*)bias.data_ptr(), + bsz, + input_cont.size(2), + Context::Instance().GetCurrentStream()); + return input_cont; +} + +template +at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& betta, float epsilon) +{ + int bsz = input_cont.size(0) * input_cont.size(1); + auto inp_norm = at::empty_like(input_cont); + launch_layer_norm((T*)inp_norm.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)gamma.data_ptr(), + (T*)betta.data_ptr(), + epsilon, + bsz, + input_cont.size(2), + Context::Instance().GetCurrentStream()); + return inp_norm; +} + +template +void qkv_unfused_cublas(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) +{ + auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + int bsz = input.size(0) * input.size(1); + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); +} + +template +at::Tensor ds_qkv_gemm(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool add_bias) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); + + return output; +} + +template +void quantized_gemm(at::Tensor& output, + at::Tensor& input, + at::Tensor& weight, + at::Tensor& qscale, + int groups, + int merge_count) +{ + int bsz = input.size(0) * input.size(1); + auto options = at::TensorOptions() + .dtype(input.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + auto weight16 = at::empty({weight.size(0), weight.size(1)}, options); + + launch_dequantize((T*)weight16.data_ptr(), + (int8_t*)weight.data_ptr(), + (float*)qscale.data_ptr(), + weight.size(1), + weight.size(0), + groups, + merge_count, + Context::Instance().GetCurrentStream()); + + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight16.data_ptr(), + (T*)input.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +template +at::Tensor ds_qkv_gemm_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + at::Tensor& q_scale, + int groups, + bool add_bias) +{ + int bsz = input.size(0) * input.size(1); + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + auto inp_norm = ds_layernorm(input_cont, gamma, beta, epsilon); + + quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); + if (add_bias) + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return output; +} + +template +at::Tensor ds_linear_layer(at::Tensor& input, at::Tensor& weight, at::Tensor& bias) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return output; +} + +template +at::Tensor ds_linear_layer_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& q_scale, + int groups) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + int bsz = input_cont.size(0) * input_cont.size(1); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, 0); + launch_bias_add((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + return output; +} + +template +at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + int bsz = input_cont.size(0) * input_cont.size(1); + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input_cont.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + return output; +} + +template +at::Tensor ds_vector_matmul_int8(at::Tensor& input, + at::Tensor& weight, + at::Tensor& q_scale, + int groups, + int merge_count) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + quantized_gemm(output, input_cont, weight, q_scale, groups, merge_count); + return output; +} + +template +void mlp_unfused_cublas(at::Tensor& output, + at::Tensor& residual_add, + at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm) +{ + int bsz = input.size(0) * input.size(1); + auto inp_norm = preLayerNorm ? at::empty_like(input) : residual_add; + + launch_residual_layer_norm((T*)inp_norm.data_ptr(), + (T*)residual_add.data_ptr(), + (T*)input.data_ptr(), + (T*)residual.data_ptr(), + (T*)input_bias.data_ptr(), + (T*)gamma.data_ptr(), + (T*)beta.data_ptr(), + epsilon, + bsz, + input.size(2), + preLayerNorm, + Context::Instance().GetCurrentStream()); + + float alpha = (T)1.0; + float gemm_beta = (T)0.0; + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_gemm_ex(Context::Instance().GetCublasHandle(), + CUBLAS_OP_N, + CUBLAS_OP_N, + weight.size(1), + bsz, + input.size(2), + &alpha, + &gemm_beta, + (T*)weight.data_ptr(), + (T*)inp_norm.data_ptr(), + (T*)output.data_ptr(), + CUBLAS_GEMM_DEFAULT_TENSOR_OP); + + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); +} +template +std::vector ds_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + auto residual_add = at::empty_like(input_cont); + int bsz = input_cont.size(0) * input_cont.size(1); + + mlp_unfused_cublas(output, + residual_add, + input, + residual, + input_bias, + weight, + bias, + gamma, + beta, + epsilon, + preLayerNorm); + + return {output, residual_add}; +} + +template +std::vector ds_mlp_gemm_int8(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + at::Tensor& q_scale, + int groups, + bool preLayerNorm) +{ + auto input_cont = input.contiguous(); + auto options = at::TensorOptions() + .dtype(input_cont.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); + + int bsz = input_cont.size(0) * input_cont.size(1); + auto inp_norm = at::empty_like(input_cont); + + auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); + // computing the blocking across K dimension + launch_residual_layer_norm((T*)inp_norm.data_ptr(), + (T*)residual_add.data_ptr(), + (T*)input_cont.data_ptr(), + (T*)residual.data_ptr(), + (T*)input_bias.data_ptr(), + (T*)gamma.data_ptr(), + (T*)beta.data_ptr(), + epsilon, + bsz, + input_cont.size(2), + preLayerNorm, + Context::Instance().GetCurrentStream()); + + quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); + launch_bias_gelu((T*)output.data_ptr(), + (T*)bias.data_ptr(), + weight.size(1), + bsz, + Context::Instance().GetCurrentStream()); + + return {output, residual_add}; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("softmax_fp32", &ds_softmax, "DeepSpeed SoftMax with fp32 (CUDA)"); + m.def("softmax_fp16", &ds_softmax<__half>, "DeepSpeed SoftMax with fp32 (CUDA)"); + m.def( + "softmax_context_fp32", &ds_softmax_context, "DeepSpeed attention with fp32 (CUDA)"); + m.def("softmax_context_fp16", + &ds_softmax_context<__half>, + "DeepSpeed attention with fp32 (CUDA)"); + m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); + m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); + m.def("bias_residual_fp32", + &ds_bias_residual, + "DeepSpeed residual-bias add with fp32 (CUDA)"); + m.def("bias_residual_fp16", + &ds_bias_residual<__half>, + "DeepSpeed residual-bias add with fp32 (CUDA)"); + m.def("layer_norm_fp32", &ds_layernorm, "DeepSpeed layer-norm with fp32 (CUDA)"); + m.def("layer_norm_fp16", &ds_layernorm<__half>, "DeepSpeed layer-norm with fp16 (CUDA)"); + m.def("qkv_gemm_fp32", &ds_qkv_gemm, "DeepSpeed qkv gemm with fp32 (CUDA)"); + m.def("qkv_gemm_fp16", &ds_qkv_gemm<__half>, "DeepSpeed qkv gemm with fp16 (CUDA)"); + m.def("qkv_gemm_int8", &ds_qkv_gemm_int8<__half>, "DeepSpeed qkv gemm with int8 (CUDA)"); + m.def("mlp_gemm_fp32", &ds_mlp_gemm, "DeepSpeed mlp with fp32 (CUDA)"); + m.def("mlp_gemm_fp16", &ds_mlp_gemm<__half>, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("mlp_gemm_int8", &ds_mlp_gemm_int8<__half>, "DeepSpeed mlp with int8 (CUDA)"); + m.def("vector_matmul_fp32", &ds_vector_matmul, "DeepSpeed vector-MM with fp32 (CUDA)"); + m.def("vector_matmul_fp16", &ds_vector_matmul<__half>, "DeepSpeed vector-MM with fp16 (CUDA)"); + m.def("vector_matmul_int8", + &ds_vector_matmul_int8<__half>, + "DeepSpeed vector-MM with int8 (CUDA)"); + m.def("linear_layer_fp32", &ds_linear_layer, "DeepSpeed linear_layer with fp32 (CUDA)"); + m.def("linear_layer_fp16", &ds_linear_layer<__half>, "DeepSpeed linear_layer with fp16 (CUDA)"); + m.def("linear_layer_int8", + &ds_linear_layer_int8<__half>, + "DeepSpeed linear_layer with int8 (CUDA)"); +} diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index 950ae6aeaafb..774e7ce6c2a7 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -1,432 +1,432 @@ -#include -#include "custom_cuda_layers.h" - -#include -#include -#include -#include - -#define ATTN_THREADS 1024 -#define MAX_REG_SIZE 8 - -#define minus_infinity -10000.0 - -void CheckCudaErrorAux(const char* file, unsigned line) -{ - cudaError_t err = cudaGetLastError(); - if (err == cudaSuccess) return; - std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line - << std::endl; - throw std::runtime_error("CUDA ERROR!!!\n"); -} - -#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) - -namespace cg = cooperative_groups; - -__global__ void attn_softmax_v2(__half* vals, - __half* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - float scale, - int iterations, - int reduceWidth) -{ -#if __CUDA_ARCH__ >= 700 - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - float2 low_data[MAX_REG_SIZE]; - float2 high_data[MAX_REG_SIZE]; - - __half2 h_scale = __float2half2_rn(scale); - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - - int reduce_blocks = reduceWidth >> 5; - int seq_lane = threadIdx.x % reduceWidth; - - __shared__ float partialSum[MAX_WARP_NUM]; - - int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); - - if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); - int seq_id = iter_offset % num_seq; - int seq_id4 = seq_id >> 2; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && - data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; - low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride) - ? __half2float(vals[data_id + 1]) - : minus_infinity; - high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride) - ? __half2float(vals[data_id + 2]) - : minus_infinity; - high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && - (data_id + 3) > window_stride) - ? __half2float(vals[data_id + 3]) - : minus_infinity; - if (mask && recompute) { - low_data[i].x += __half2float(mask[data_id + mask_offset]); - low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); - high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); - high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); - } - } else { - low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) - : minus_infinity; - low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && - (data_id + 1) > window_stride) && - (data_id + 1) < sequence_length) - ? __half2float(vals[data_id + 1]) - : minus_infinity; - high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && - (data_id + 2) > window_stride) && - (data_id + 2) < sequence_length) - ? __half2float(vals[data_id + 2]) - : minus_infinity; - high_data[i].y = minus_infinity; - if (mask && recompute) { - low_data[i].x += __half2float(mask[data_id + mask_offset]); - if ((data_id + 1) < sequence_length) - low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); - if ((data_id + 2) < sequence_length) - high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); - } - } - // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); - max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); - max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); - max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); - max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); - } else { - low_data[i].x = minus_infinity; - low_data[i].y = minus_infinity; - high_data[i].x = minus_infinity; - high_data[i].y = minus_infinity; - } - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); - } - float sum = 0; - for (int i = 0; i < iterations; i++) { - low_data[i].x = __expf(low_data[i].x - max_val); - low_data[i].y = __expf(low_data[i].y - max_val); - high_data[i].x = __expf(high_data[i].x - max_val); - high_data[i].y = __expf(high_data[i].y - max_val); - - sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / WARP_SIZE); - } - sum += 1e-6; - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - - if (data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - vals[data_id] = low_data[i].x / sum; - vals[data_id + 1] = low_data[i].y / sum; - vals[data_id + 2] = high_data[i].x / sum; - vals[data_id + 3] = high_data[i].y / sum; - } else { - vals[data_id] = low_data[i].x / sum; - if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum; - if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum; - } - } - } - } -#endif -} - -__global__ void attn_softmax_v2(float* vals, - float* attn_mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int total_count, - int heads, - int sequence_length, - int num_seq, - float scale, - int iterations, - int reduceWidth) -{ - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - float4 data[MAX_REG_SIZE]; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - int warp_num = blockDim.x >> 5; - - int reduce_blocks = reduceWidth >> 5; - int seq_lane = threadIdx.x % reduceWidth; - - __shared__ float partialSum[MAX_WARP_NUM]; - - int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); - if (iter_offset < total_count) { - vals += (iter_offset * sequence_length); - - int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); - int seq_id = iter_offset % num_seq; - int seq_id4 = seq_id >> 2; - - int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); - int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) - ? (real_seq_id >> 2) - (window_size >> 2) - : 0; - int window_stride = - (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && - data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); - data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride) - ? vals[data_id + 1] - : minus_infinity; - data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride) - ? vals[data_id + 2] - : minus_infinity; - data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && - (data_id + 3) > window_stride) - ? vals[data_id + 3] - : minus_infinity; - if (attn_mask && recompute) { - data[i].x += attn_mask[data_id + mask_offset]; - data[i].y += attn_mask[data_id + mask_offset + 1]; - data[i].z += attn_mask[data_id + mask_offset + 2]; - data[i].w += attn_mask[data_id + mask_offset + 3]; - } - } else { - data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; - data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && - (data_id + 1) > window_stride && (data_id + 1) < sequence_length) - ? (vals[data_id + 1]) - : minus_infinity; - data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && - (data_id + 2) > window_stride && (data_id + 2) < sequence_length) - ? (vals[data_id + 2]) - : minus_infinity; - data[i].w = minus_infinity; - if (attn_mask && recompute) { - data[i].x += attn_mask[data_id + mask_offset]; - if ((data_id + 1) < sequence_length) - data[i].y += attn_mask[data_id + mask_offset + 1]; - if ((data_id + 2) < sequence_length) - data[i].z += attn_mask[data_id + mask_offset + 2]; - } - } - max_val = (data[i].x > max_val ? data[i].x : max_val); - max_val = (data[i].y > max_val ? data[i].y : max_val); - max_val = (data[i].z > max_val ? data[i].z : max_val); - max_val = (data[i].w > max_val ? data[i].w : max_val); - } else { - data[i].x = minus_infinity; - data[i].y = minus_infinity; - data[i].z = minus_infinity; - data[i].w = minus_infinity; - } - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - data[i].x = __expf(data[i].x - max_val); - data[i].y = __expf(data[i].y - max_val); - data[i].z = __expf(data[i].z - max_val); - data[i].w = __expf(data[i].w - max_val); - - sum += (data[i].x + data[i].y + data[i].z + data[i].w); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); - - if (reduceWidth > WARP_SIZE) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - - b.sync(); - - for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / WARP_SIZE); - } - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - int data_id = i * (reduceWidth << 2) + (seq_lane << 2); - - if (data_id < sequence_length) { - if ((sequence_length - data_id) >= 4) { - vals[data_id] = data[i].x / sum; - vals[data_id + 1] = data[i].y / sum; - vals[data_id + 2] = data[i].z / sum; - vals[data_id + 3] = data[i].w / sum; - } else { - vals[data_id] = data[i].x / sum; - if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; - if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; - } - } - } - } -} - -template -void launch_attn_softmax_v2(T* vals, - T* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream) -{ - int total_count = batch_size * heads * num_seq; - dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); - dim3 block_dim(ATTN_THREADS); - - const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; - const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; - - if (sequence_length <= 32768) - attn_softmax_v2<<>>( - vals, - mask, - triangular, - recompute, - local_attention, - window_size, - total_count, - (triangular ? (heads * batch_size) : heads), - sequence_length, - num_seq, - scale, - iterations, - reduce_width); - else - throw std::runtime_error("Unsupport Seq_Length!"); -} - -template void launch_attn_softmax_v2(float* vals, - float* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); -template void launch_attn_softmax_v2(__half* vals, - __half* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); +#include +#include "custom_cuda_layers.h" + +#include +#include +#include +#include + +#define ATTN_THREADS 1024 +#define MAX_REG_SIZE 8 + +#define minus_infinity -10000.0 + +void CheckCudaErrorAux(const char* file, unsigned line) +{ + cudaError_t err = cudaGetLastError(); + if (err == cudaSuccess) return; + std::cerr << cudaGetErrorString(err) << "(" << err << ") at " << file << ":" << line + << std::endl; + throw std::runtime_error("CUDA ERROR!!!\n"); +} + +#define CUDA_CHECK_ERROR() CheckCudaErrorAux(__FILE__, __LINE__) + +namespace cg = cooperative_groups; + +__global__ void attn_softmax_v2(__half* vals, + __half* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + float scale, + int iterations, + int reduceWidth) +{ +#if __CUDA_ARCH__ >= 700 + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float2 low_data[MAX_REG_SIZE]; + float2 high_data[MAX_REG_SIZE]; + + __half2 h_scale = __float2half2_rn(scale); + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + int seq_id = iter_offset % num_seq; + int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) + ? __half2float(vals[data_id + 1]) + : minus_infinity; + high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) + ? __half2float(vals[data_id + 2]) + : minus_infinity; + high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) + ? __half2float(vals[data_id + 3]) + : minus_infinity; + if (mask && recompute) { + low_data[i].x += __half2float(mask[data_id + mask_offset]); + low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); + high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); + high_data[i].y += __half2float(mask[data_id + mask_offset + 3]); + } + } else { + low_data[i].x = data_id > window_stride ? __half2float(vals[data_id]) + : minus_infinity; + low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && + (data_id + 1) > window_stride) && + (data_id + 1) < sequence_length) + ? __half2float(vals[data_id + 1]) + : minus_infinity; + high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && + (data_id + 2) > window_stride) && + (data_id + 2) < sequence_length) + ? __half2float(vals[data_id + 2]) + : minus_infinity; + high_data[i].y = minus_infinity; + if (mask && recompute) { + low_data[i].x += __half2float(mask[data_id + mask_offset]); + if ((data_id + 1) < sequence_length) + low_data[i].y += __half2float(mask[data_id + mask_offset + 1]); + if ((data_id + 2) < sequence_length) + high_data[i].x += __half2float(mask[data_id + mask_offset + 2]); + } + } + // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } else { + low_data[i].x = minus_infinity; + low_data[i].y = minus_infinity; + high_data[i].x = minus_infinity; + high_data[i].y = minus_infinity; + } + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + float sum = 0; + for (int i = 0; i < iterations; i++) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + + if (data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + vals[data_id] = low_data[i].x / sum; + vals[data_id + 1] = low_data[i].y / sum; + vals[data_id + 2] = high_data[i].x / sum; + vals[data_id + 3] = high_data[i].y / sum; + } else { + vals[data_id] = low_data[i].x / sum; + if ((data_id + 1) < sequence_length) vals[data_id + 1] = low_data[i].y / sum; + if ((data_id + 2) < sequence_length) vals[data_id + 2] = high_data[i].x / sum; + } + } + } + } +#endif +} + +__global__ void attn_softmax_v2(float* vals, + float* attn_mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int total_count, + int heads, + int sequence_length, + int num_seq, + float scale, + int iterations, + int reduceWidth) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + float4 data[MAX_REG_SIZE]; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + int warp_num = blockDim.x >> 5; + + int reduce_blocks = reduceWidth >> 5; + int seq_lane = threadIdx.x % reduceWidth; + + __shared__ float partialSum[MAX_WARP_NUM]; + + int iter_offset = blockIdx.x * (warp_num / reduce_blocks) + (wid / reduce_blocks); + if (iter_offset < total_count) { + vals += (iter_offset * sequence_length); + + int mask_offset = (iter_offset / (heads * num_seq)) * (sequence_length); + int seq_id = iter_offset % num_seq; + int seq_id4 = seq_id >> 2; + + int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); + int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) + ? (real_seq_id >> 2) - (window_size >> 2) + : 0; + int window_stride = + (local_attention && real_seq_id >= window_size) ? real_seq_id - window_size : -1; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && + data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); + data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride) + ? vals[data_id + 1] + : minus_infinity; + data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride) + ? vals[data_id + 2] + : minus_infinity; + data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && + (data_id + 3) > window_stride) + ? vals[data_id + 3] + : minus_infinity; + if (attn_mask && recompute) { + data[i].x += attn_mask[data_id + mask_offset]; + data[i].y += attn_mask[data_id + mask_offset + 1]; + data[i].z += attn_mask[data_id + mask_offset + 2]; + data[i].w += attn_mask[data_id + mask_offset + 3]; + } + } else { + data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; + data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && + (data_id + 1) > window_stride && (data_id + 1) < sequence_length) + ? (vals[data_id + 1]) + : minus_infinity; + data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && + (data_id + 2) > window_stride && (data_id + 2) < sequence_length) + ? (vals[data_id + 2]) + : minus_infinity; + data[i].w = minus_infinity; + if (attn_mask && recompute) { + data[i].x += attn_mask[data_id + mask_offset]; + if ((data_id + 1) < sequence_length) + data[i].y += attn_mask[data_id + mask_offset + 1]; + if ((data_id + 2) < sequence_length) + data[i].z += attn_mask[data_id + mask_offset + 2]; + } + } + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } else { + data[i].x = minus_infinity; + data[i].y = minus_infinity; + data[i].z = minus_infinity; + data[i].w = minus_infinity; + } + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / WARP_SIZE); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) sum += g.shfl_xor(sum, i); + + if (reduceWidth > WARP_SIZE) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + + b.sync(); + + for (int i = 1; i < reduce_blocks; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / WARP_SIZE); + } + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * (reduceWidth << 2) + (seq_lane << 2); + + if (data_id < sequence_length) { + if ((sequence_length - data_id) >= 4) { + vals[data_id] = data[i].x / sum; + vals[data_id + 1] = data[i].y / sum; + vals[data_id + 2] = data[i].z / sum; + vals[data_id + 3] = data[i].w / sum; + } else { + vals[data_id] = data[i].x / sum; + if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; + if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; + } + } + } + } +} + +template +void launch_attn_softmax_v2(T* vals, + T* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream) +{ + int total_count = batch_size * heads * num_seq; + dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1); + dim3 block_dim(ATTN_THREADS); + + const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE; + const int iterations = (sequence_length - 1) / (reduce_width << 2) + 1; + + if (sequence_length <= 32768) + attn_softmax_v2<<>>( + vals, + mask, + triangular, + recompute, + local_attention, + window_size, + total_count, + (triangular ? (heads * batch_size) : heads), + sequence_length, + num_seq, + scale, + iterations, + reduce_width); + else + throw std::runtime_error("Unsupport Seq_Length!"); +} + +template void launch_attn_softmax_v2(float* vals, + float* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); +template void launch_attn_softmax_v2(__half* vals, + __half* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); diff --git a/csrc/transformer/inference/includes/context.h b/csrc/transformer/inference/includes/context.h index 65e464f57af7..4385bd7d5452 100755 --- a/csrc/transformer/inference/includes/context.h +++ b/csrc/transformer/inference/includes/context.h @@ -1,112 +1,112 @@ -#pragma once - -#include -#include -#include -#include -#include -#include "cublas_v2.h" -#include "cuda.h" -#include "curand.h" - -#define WARP_SIZE 32 - -#define CUDA_CHECK(callstr) \ - { \ - cudaError_t error_code = callstr; \ - if (error_code != cudaSuccess) { \ - std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ - assert(0); \ - } \ - } - -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) - -#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ - for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) - -#define DS_CUDA_NUM_THREADS 512 -#define DS_MAXIMUM_NUM_BLOCKS 262144 - -inline int DS_GET_BLOCKS(const int N) -{ - return std::max( - std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), - // Use at least 1 block, since CUDA does not allow empty block - 1); -} - -class Context { -public: - Context() : _workspace(nullptr), _seed(42), _curr_offset(0) - { - curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); - curandSetPseudoRandomGeneratorSeed(_gen, 123); - if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { - auto message = std::string("Fail to create cublas handle."); - std::cerr << message << std::endl; - throw std::runtime_error(message); - } - cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); - } - - virtual ~Context() - { - cublasDestroy(_cublasHandle); - cudaFree(_workspace); - } - - static Context& Instance() - { - static Context _ctx; - return _ctx; - } - - void GenWorkSpace(size_t size) - { - if (!_workspace) { - assert(_workspace == nullptr); - cudaMalloc(&_workspace, size); - } else if (_workSpaceSize < size) { - cudaFree(_workspace); - cudaMalloc(&_workspace, size); - } - - _workSpaceSize = size; - } - - void* GetWorkSpace() { return _workspace; } - - curandGenerator_t& GetRandGenerator() { return _gen; } - - cudaStream_t GetCurrentStream() - { - // get current pytorch stream. - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - return stream; - } - - cublasHandle_t GetCublasHandle() { return _cublasHandle; } - - std::pair IncrementOffset(uint64_t offset_inc) - { - uint64_t offset = _curr_offset; - _curr_offset += offset_inc; - return std::pair(_seed, offset); - } - - void SetSeed(uint64_t new_seed) { _seed = new_seed; } - - const std::vector>& GetGemmAlgos() const { return _gemm_algos; } - -private: - curandGenerator_t _gen; - cublasHandle_t _cublasHandle; - void* _workspace; - uint64_t _seed; - uint64_t _curr_offset; - size_t _workSpaceSize; - std::vector> _gemm_algos; -}; +#pragma once + +#include +#include +#include +#include +#include +#include "cublas_v2.h" +#include "cuda.h" +#include "curand.h" + +#define WARP_SIZE 32 + +#define CUDA_CHECK(callstr) \ + { \ + cudaError_t error_code = callstr; \ + if (error_code != cudaSuccess) { \ + std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ + assert(0); \ + } \ + } + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + +#define CUDA_2D_KERNEL_LOOP(i, n, j, m) \ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) \ + for (size_t j = blockIdx.y * blockDim.y + threadIdx.y; j < (m); j += blockDim.y * gridDim.y) + +#define DS_CUDA_NUM_THREADS 512 +#define DS_MAXIMUM_NUM_BLOCKS 262144 + +inline int DS_GET_BLOCKS(const int N) +{ + return std::max( + std::min((N + DS_CUDA_NUM_THREADS - 1) / DS_CUDA_NUM_THREADS, DS_MAXIMUM_NUM_BLOCKS), + // Use at least 1 block, since CUDA does not allow empty block + 1); +} + +class Context { +public: + Context() : _workspace(nullptr), _seed(42), _curr_offset(0) + { + curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); + curandSetPseudoRandomGeneratorSeed(_gen, 123); + if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) { + auto message = std::string("Fail to create cublas handle."); + std::cerr << message << std::endl; + throw std::runtime_error(message); + } + cublasSetMathMode(_cublasHandle, CUBLAS_TENSOR_OP_MATH); + } + + virtual ~Context() + { + cublasDestroy(_cublasHandle); + cudaFree(_workspace); + } + + static Context& Instance() + { + static Context _ctx; + return _ctx; + } + + void GenWorkSpace(size_t size) + { + if (!_workspace) { + assert(_workspace == nullptr); + cudaMalloc(&_workspace, size); + } else if (_workSpaceSize < size) { + cudaFree(_workspace); + cudaMalloc(&_workspace, size); + } + + _workSpaceSize = size; + } + + void* GetWorkSpace() { return _workspace; } + + curandGenerator_t& GetRandGenerator() { return _gen; } + + cudaStream_t GetCurrentStream() + { + // get current pytorch stream. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + return stream; + } + + cublasHandle_t GetCublasHandle() { return _cublasHandle; } + + std::pair IncrementOffset(uint64_t offset_inc) + { + uint64_t offset = _curr_offset; + _curr_offset += offset_inc; + return std::pair(_seed, offset); + } + + void SetSeed(uint64_t new_seed) { _seed = new_seed; } + + const std::vector>& GetGemmAlgos() const { return _gemm_algos; } + +private: + curandGenerator_t _gen; + cublasHandle_t _cublasHandle; + void* _workspace; + uint64_t _seed; + uint64_t _curr_offset; + size_t _workSpaceSize; + std::vector> _gemm_algos; +}; diff --git a/csrc/transformer/inference/includes/cublas_wrappers.h b/csrc/transformer/inference/includes/cublas_wrappers.h index 4be9e09a300a..0b37c51ca367 100755 --- a/csrc/transformer/inference/includes/cublas_wrappers.h +++ b/csrc/transformer/inference/includes/cublas_wrappers.h @@ -1,208 +1,208 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include "cublas_wrappers.h" - -int cublas_gemm_ex(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const float* A, - const float* B, - float* C, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - (const void*)alpha, - (const void*)A, - CUDA_R_32F, - (transa == CUBLAS_OP_N) ? m : k, - (const void*)B, - CUDA_R_32F, - (transb == CUBLAS_OP_N) ? k : n, - (const void*)beta, - C, - CUDA_R_32F, - m, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_gemm_ex(cublasHandle_t handle, - cublasOperation_t transa, - cublasOperation_t transb, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const __half* A, - const __half* B, - __half* C, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmEx(handle, - transa, - transb, - m, - n, - k, - (const void*)alpha, - (const void*)A, - CUDA_R_16F, - (transa == CUBLAS_OP_N) ? m : k, - (const void*)B, - CUDA_R_16F, - (transb == CUBLAS_OP_N) ? k : n, - (const void*)beta, - (void*)C, - CUDA_R_16F, - m, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const float* A, - const float* B, - float* C, - cublasOperation_t op_A, - cublasOperation_t op_B, - int stride_A, - int stride_B, - int stride_C, - int batch, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmStridedBatchedEx(handle, - op_A, - op_B, - m, - n, - k, - alpha, - A, - CUDA_R_32F, - (op_A == CUBLAS_OP_N) ? m : k, - stride_A, - B, - CUDA_R_32F, - (op_B == CUBLAS_OP_N) ? k : n, - stride_B, - beta, - C, - CUDA_R_32F, - m, - stride_C, - batch, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", - batch, - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - return 0; -} - -int cublas_strided_batched_gemm(cublasHandle_t handle, - int m, - int n, - int k, - const float* alpha, - const float* beta, - const __half* A, - const __half* B, - __half* C, - cublasOperation_t op_A, - cublasOperation_t op_B, - int stride_A, - int stride_B, - int stride_C, - int batch, - cublasGemmAlgo_t algo) -{ - cublasStatus_t status = cublasGemmStridedBatchedEx(handle, - op_A, - op_B, - m, - n, - k, - alpha, - A, - CUDA_R_16F, - (op_A == CUBLAS_OP_N) ? m : k, - stride_A, - B, - CUDA_R_16F, - (op_B == CUBLAS_OP_N) ? k : n, - stride_B, - beta, - C, - CUDA_R_16F, - m, - stride_C, - batch, - CUDA_R_32F, - algo); - - if (status != CUBLAS_STATUS_SUCCESS) { - fprintf(stderr, - "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", - m, - n, - k, - (int)status); - return EXIT_FAILURE; - } - - return 0; -} +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "cublas_wrappers.h" + +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + CUDA_R_32F, + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, + CUDA_R_32F, + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + C, + CUDA_R_32F, + m, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_gemm_ex(cublasHandle_t handle, + cublasOperation_t transa, + cublasOperation_t transb, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmEx(handle, + transa, + transb, + m, + n, + k, + (const void*)alpha, + (const void*)A, + CUDA_R_16F, + (transa == CUBLAS_OP_N) ? m : k, + (const void*)B, + CUDA_R_16F, + (transb == CUBLAS_OP_N) ? k : n, + (const void*)beta, + (void*)C, + CUDA_R_16F, + m, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const float* A, + const float* B, + float* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + CUDA_R_32F, + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, + CUDA_R_32F, + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, + CUDA_R_32F, + m, + stride_C, + batch, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (batch: %d, m: %d, n: %d, k: %d, error: %d) \n", + batch, + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + return 0; +} + +int cublas_strided_batched_gemm(cublasHandle_t handle, + int m, + int n, + int k, + const float* alpha, + const float* beta, + const __half* A, + const __half* B, + __half* C, + cublasOperation_t op_A, + cublasOperation_t op_B, + int stride_A, + int stride_B, + int stride_C, + int batch, + cublasGemmAlgo_t algo) +{ + cublasStatus_t status = cublasGemmStridedBatchedEx(handle, + op_A, + op_B, + m, + n, + k, + alpha, + A, + CUDA_R_16F, + (op_A == CUBLAS_OP_N) ? m : k, + stride_A, + B, + CUDA_R_16F, + (op_B == CUBLAS_OP_N) ? k : n, + stride_B, + beta, + C, + CUDA_R_16F, + m, + stride_C, + batch, + CUDA_R_32F, + algo); + + if (status != CUBLAS_STATUS_SUCCESS) { + fprintf(stderr, + "!!!! kernel execution error. (m: %d, n: %d, k: %d, error: %d) \n", + m, + n, + k, + (int)status); + return EXIT_FAILURE; + } + + return 0; +} diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index b544517fa657..dfb43a07ae33 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -1,79 +1,79 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#define MAX_WARP_NUM 32 -#define WARP_SIZE 32 -#define SMs 80 - -#define MAX_REGISTERS 256 -template -void launch_attn_softmax_v2(T* vals, - T* mask, - bool triangular, - bool recompute, - bool local_attention, - int window_size, - int batch_size, - int heads, - int num_seq, - int sequence_length, - float scale, - cudaStream_t stream); - -// Fused bias add with gelu activation -template -void launch_bias_gelu(T* input, - const T* bias, - int intermediate_size, - int batch_size, - cudaStream_t stream); -template -void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); - -template -void launch_bias_residual(T* input, - const T* residual, - const T* bias, - int size, - int intermediate_size, - cudaStream_t stream); - -template -void launch_layer_norm(T* out, - T* vals, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream); - -template -void launch_residual_layer_norm(T* norm, - T* res_add, - T* vals, - T* residual, - const T* bias, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - bool preLN, - cudaStream_t stream); -template -void launch_dequantize(T* output, - const int8_t* input, - const float* qscale, - unsigned output_size, - unsigned hidden_dim, - unsigned groups, - unsigned merge_count, - cudaStream_t stream); +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#define MAX_WARP_NUM 32 +#define WARP_SIZE 32 +#define SMs 80 + +#define MAX_REGISTERS 256 +template +void launch_attn_softmax_v2(T* vals, + T* mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int batch_size, + int heads, + int num_seq, + int sequence_length, + float scale, + cudaStream_t stream); + +// Fused bias add with gelu activation +template +void launch_bias_gelu(T* input, + const T* bias, + int intermediate_size, + int batch_size, + cudaStream_t stream); +template +void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream); + +template +void launch_bias_residual(T* input, + const T* residual, + const T* bias, + int size, + int intermediate_size, + cudaStream_t stream); + +template +void launch_layer_norm(T* out, + T* vals, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream); + +template +void launch_residual_layer_norm(T* norm, + T* res_add, + T* vals, + T* residual, + const T* bias, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + bool preLN, + cudaStream_t stream); +template +void launch_dequantize(T* output, + const int8_t* input, + const float* qscale, + unsigned output_size, + unsigned hidden_dim, + unsigned groups, + unsigned merge_count, + cudaStream_t stream); diff --git a/csrc/transformer/normalize_kernels.cu b/csrc/transformer/normalize_kernels.cu index ba8314c58b52..b7afa3589f2c 100644 --- a/csrc/transformer/normalize_kernels.cu +++ b/csrc/transformer/normalize_kernels.cu @@ -1,2103 +1,2103 @@ -#include "custom_cuda_layers.h" - -namespace cg = cooperative_groups; - -/* -Fused bias add, residual (elementwise) add, and normalization layer. - -For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for -__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). - -For specific launch constraints, see the launch functions. -*/ - -#define NORM_REG (MAX_REGISTERS / 4) - -__global__ void fused_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - bool preLayerNorm, - bool training, - float* vars, - float* means, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id / WARP_SIZE; - - float vals_arr[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - residual += (row * row_stride); - vals += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_arr[i] = residual[i * iteration_stride + id]; - sum += vals_arr[i]; - } - if (high_index < row_stride) { - vals_arr[iterations] = residual[high_index]; - sum += vals_arr[iterations]; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - - sum = g.shfl(sum, 0); - float mean = sum / row_stride; - if (training) - if (threadIdx.x == 0) means[row] = mean; - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_arr[i] -= mean; - variance += vals_arr[i] * vals_arr[i]; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= row_stride; - variance += epsilon; - if (training) - if (threadIdx.x == 0) vars[row] = variance; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr[i] = vals_arr[i] * rsqrtf(variance); - vals_arr[i] = - vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; - vals[i * iteration_stride + id] = vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); - vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; - vals[high_index] = vals_arr[iterations]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - bool preLayerNorm, - bool training, - __half* vars, - __half* means, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - - float2 vals_f[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - const __half2* residual_cast = reinterpret_cast(residual); - - residual_cast += (row * row_stride); - vals_cast += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); - sum += vals_f[i].x; - sum += vals_f[i].y; - } - if ((high_index) < row_stride) { - vals_f[iterations] = __half22float2(residual_cast[high_index]); - sum += vals_f[iterations].x; - sum += vals_f[iterations].y; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - sum = g.shfl(sum, 0); - float mean = sum / (row_stride * 2); - - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_f[i].x -= mean; - vals_f[i].y -= mean; - variance += vals_f[i].x * vals_f[i].x; - variance += vals_f[i].y * vals_f[i].y; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= (row_stride * 2); - variance += epsilon; - - __half2 variance_h = __float2half2_rn(variance); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - - if (training && threadIdx.x == 0) { - vars[row] = __float2half(variance); - means[row] = __float2half(mean); - } - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - __half2 vals_arr = __float22half2_rn(vals_f[i]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = - vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; - vals_cast[i * iteration_stride + id] = vals_arr; - } - if ((high_index) < row_stride) { - __half2 vals_arr = __float22half2_rn(vals_f[iterations]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; - vals_cast[high_index] = vals_arr; - } -#endif -} - -template -void launch_bias_residual_layer_norm(T* vals, - const T* residual, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - T* vars, - T* means); - -template <> -void launch_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - float* vars, - float* means) -{ - int threads = THREADS; - - dim3 grid_dim(batch_size); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim); -} - -template <> -void launch_bias_residual_layer_norm<__half>(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - __half* vars, - __half* means) -{ - int threads = 128; - - dim3 grid_dim(batch_size); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2); -} - -__global__ void fused_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - bool preLayerNorm, - bool training, - float* vars, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id / 32; - - float vals_arr[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - residual += (row * row_stride); - vals += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_arr[i] = residual[i * iteration_stride + id]; - sum += vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = residual[high_index]; - sum += vals_arr[iterations]; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - - sum = g.shfl(sum, 0); - float mean = sum / row_stride; - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_arr[i] -= mean; - variance += vals_arr[i] * vals_arr[i]; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= row_stride; - variance += epsilon; - if (training) - if (threadIdx.x == 0) vars[row] = variance; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr[i] = vals_arr[i] * rsqrtf(variance); - vals_arr[i] = - vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; - vals[i * iteration_stride + id] = vals_arr[i]; - } - if ((high_index) < row_stride) { - vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); - vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; - vals[high_index] = vals_arr[iterations]; - } -} - -__global__ void fused_bias_residual_layer_norm(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - bool preLayerNorm, - bool training, - __half* vars, - int row_stride) -{ -#if __CUDA_ARCH__ >= 700 - - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int gid = id >> 5; - - float2 vals_f[NORM_REG]; - __shared__ float shr[MAX_WARP_NUM]; - - __half2* vals_cast = reinterpret_cast<__half2*>(vals); - const __half2* residual_cast = reinterpret_cast(residual); - - residual_cast += (row * row_stride); - vals_cast += (row * row_stride); - - float sum = 0.f; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); - sum += vals_f[i].x; - sum += vals_f[i].y; - } - if ((high_index) < row_stride) { - vals_f[iterations] = __half22float2(residual_cast[high_index]); - sum += vals_f[iterations].x; - sum += vals_f[iterations].y; - iterations++; - } - - for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) shr[gid] = sum; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } - sum = g.shfl(sum, 0); - float mean = sum / (row_stride * 2); - - float variance = 0.f; - for (int i = 0; i < iterations; i++) { - vals_f[i].x -= mean; - vals_f[i].y -= mean; - variance += vals_f[i].x * vals_f[i].x; - variance += vals_f[i].y * vals_f[i].y; - } - - for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } - - if (g.thread_rank() == 0) shr[gid] = variance; - - b.sync(); - - if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } - variance = g.shfl(variance, 0); - variance /= (row_stride * 2); - variance += epsilon; - - __half2 variance_h = __float2half2_rn(variance); - const __half2* gamma_cast = reinterpret_cast(gamma); - const __half2* beta_cast = reinterpret_cast(beta); - - if (training && threadIdx.x == 0) vars[row] = __float2half(variance); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - __half2 vals_arr = __float22half2_rn(vals_f[i]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = - vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; - vals_cast[i * iteration_stride + id] = vals_arr; - } - if ((high_index) < row_stride) { - __half2 vals_arr = __float22half2_rn(vals_f[iterations]); - vals_arr = vals_arr * h2rsqrt(variance_h); - vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; - vals_cast[high_index] = vals_arr; - } -#endif -} - -template -void launch_bias_residual_layer_norm(T* vals, - const T* residual, - const T* gamma, - const T* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - T* vars); - -/* -To tune this launch the following restrictions must be met: - -For float: -row_stride == hidden_size -threads * iterations == row_stride -threads is in [32, 64, 128, 256, 512, 1024] - -For half: -row_stride == hidden_size / 2 -threads * iterations == row_stride -threads is in [32, 64, 128, 256, 512, 1024] - -*/ - -template <> -void launch_bias_residual_layer_norm(float* vals, - const float* residual, - const float* gamma, - const float* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - float* vars) -{ - int threads = THREADS; - - dim3 grid_dim(batch_size); - - // There are some limitations to call below functions, now just enumerate the situations. - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim); -} - -template <> -void launch_bias_residual_layer_norm<__half>(__half* vals, - const __half* residual, - const __half* gamma, - const __half* beta, - float epsilon, - int batch_size, - int hidden_dim, - cudaStream_t stream, - bool preLayerNorm, - bool training, - __half* vars) -{ - int threads = 128; - - dim3 grid_dim(batch_size); - - // There are some limitations to call below functions, now just enumerate the situations. - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim(threads); - fused_bias_residual_layer_norm<<>>( - vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2); -} - -/* Normalize Gamma & Betta gradients - * Compute gradients using either X_hat or - * normalize input (invertible). - * Combine transpose with gradients computation. - */ - -template -__global__ void LayerNormBackward1(const T* __restrict__ out_grad, - const T* __restrict__ vals_hat, - const T* __restrict__ gamma, - const T* __restrict__ betta, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width, - bool invertible) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - float betta_reg = (invertible ? (float)betta[idx] : 0.0f); - float gamma_reg = (float)gamma[idx]; - - // Loop across matrix height - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad[offset]; - float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg - : (float)vals_hat[offset]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -/* Normalize Gamma & Betta gradients - * Compute gradients using the input to - * the normalize. - * Combine transpose with gradients computation. - */ - -template -__global__ void LayerNormBackward1(const T* __restrict__ out_grad, - const T* __restrict__ X_data, - const T* __restrict__ vars, - const T* __restrict__ means, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - // Loop across matrix height - - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad[offset]; - float val = (float)X_data[offset]; - val = (val - (float)means[r]) * rsqrtf((float)vars[r]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} -/* - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is invertible! - * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. - */ - -__global__ void LayerNormBackward2(const float* out_grad, - const float* vals_hat, - const float* gamma, - const float* betta, - const float* vars, - float* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad += (row * row_stride); - vals_hat += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / - gamma_reg - : vals_hat[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg - : vals_hat[high_index]); - iterations++; - } - - float var_reg = vars[row]; - - float sum = 0; - for (int i = 0; i < iterations; i++) { - sum += vals_hat_arr[i] * vals_arr[i] * - sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad - vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); - if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); -} - -__global__ void LayerNormBackward2(const __half* out_grad, - const __half* vals_hat, - const __half* gamma, - const __half* betta, - const __half* vars, - __half* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h = reinterpret_cast(out_grad); - const __half2* vals_hat_h = reinterpret_cast(vals_hat); - - inp_grad_h += (row * row_stride); - out_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible - ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / - gamma_reg - : vals_hat_h[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg - : vals_hat_h[high_index]); - iterations++; - } - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 temp_f = __half22float2(temp); - vals_arr_f[i].x += temp_f.x; - vals_arr_f[i].y += temp_f.y; - } - sum = 0.f; - - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - - inp_grad_h[i * iteration_stride + id] = temp; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - - inp_grad_h[high_index] = temp; - } -} - -template <> -void launch_layerNorm_backward(const float* out_grad, - const float* vals_hat, - const float* vars, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const float* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - - LayerNormBackward2<<>>( - out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); -} - -template <> -void launch_layerNorm_backward<__half>(const __half* out_grad, - const __half* vals_hat, - const __half* vars, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const __half* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - - LayerNormBackward2<<>>( - out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); -} - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is not invertible! - * We do the backward using the input (X) - */ - -__global__ void LayerNormBackward2(const float* out_grad, - const float* X_vals, - const float* gamma, - const float* vars, - const float* means, - float* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad += (row * row_stride); - X_vals += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad[high_index]; - vals_arr[iterations] *= gamma_reg; - iterations++; - } - - float var_reg = vars[row]; - float mean_reg = means[row]; - - float sum = 0; - float xu[NORM_REG]; - for (int i = 0; i < iterations; i++) { - xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); - sum += vals_arr[i] * xu[i]; - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { - vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); - } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); - if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); -} - -__global__ void LayerNormBackward2(const __half* out_grad, - const __half* X_vals, - const __half* gamma, - const __half* vars, - const __half* means, - __half* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h = reinterpret_cast(out_grad); - const __half2* vals_hat_h = reinterpret_cast(X_vals); - - inp_grad_h += (row * row_stride); - out_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - iterations++; - } - __half mean_h = means[row]; - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - __half2 mean_reg = __halves2half2(mean_h, mean_h); - __half2 xu[NORM_REG]; - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); - __half2 result_h = (xu[i] * vals_arr[i]); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 xu_grad_f = __half22float2(xu_grad); - vals_arr_f[i].x += xu_grad_f.x; - vals_arr_f[i].y += xu_grad_f.y; - } - - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - inp_grad_h[i * iteration_stride + id] = temp; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - inp_grad_h[high_index] = temp; - } -} - -template <> -void launch_layerNorm_backward(const float* out_grad, - const float* X_data, - const float* vars, - const float* means, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2<<>>( - out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim); -} - -template <> -void launch_layerNorm_backward<__half>(const __half* out_grad, - const __half* X_data, - const __half* vars, - const __half* means, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2<<>>( - out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); -} - -template -__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, - const T* __restrict__ out_grad2, - const T* __restrict__ vals_hat, - const T* __restrict__ gamma, - const T* __restrict__ betta, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width, - bool invertible) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - float betta_reg = (invertible ? (float)betta[idx] : 0.0f); - float gamma_reg = (float)gamma[idx]; - - // Loop across matrix height - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; - float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg - : (float)vals_hat[offset]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -template -__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, - const T* __restrict__ out_grad2, - const T* __restrict__ X_data, - const T* __restrict__ vars, - const T* __restrict__ means, - T* __restrict__ gamma_grad, - T* __restrict__ betta_grad, - int rows, - int width) -{ - __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; - __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int idx = blockDim.x * blockIdx.x + threadIdx.x; - int offset = threadIdx.y * width + idx; - int y_stride = width * TILE_DIM; - - int pos = blockIdx.x * TILE_DIM + threadIdx.y; - // Loop across matrix height - - float betta_tmp = 0; - float gamma_tmp = 0; - for (int r = threadIdx.y; r < rows; r += TILE_DIM) { - float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; - float val = (float)X_data[offset]; - val = (val - (float)means[r]) * rsqrtf((float)vars[r]); - betta_tmp += grad; - gamma_tmp += (val * grad); - - offset += y_stride; - } - - betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; - gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; - - __syncthreads(); - - // Sum the shared buffer. - float s1 = betta_buffer[threadIdx.y][threadIdx.x]; - float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < TILE_DIM; i <<= 1) { - s1 += g.shfl_down(s1, i); - s2 += g.shfl_down(s2, i); - } - - if (threadIdx.x == 0) { - betta_grad[pos] = s1; - gamma_grad[pos] = s2; - } -} - -__global__ void LayerNormBackward2_fused_add(const float* out_grad1, - const float* out_grad2, - const float* vals_hat, - const float* gamma, - const float* betta, - const float* vars, - float* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - out_grad1 += (row * row_stride); - out_grad2 += (row * row_stride); - vals_hat += (row * row_stride); - inp_grad += (row * row_stride); - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = - (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / - gamma_reg - : vals_hat[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad1[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = - (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg - : vals_hat[high_index]); - iterations++; - } - - float var_reg = vars[row]; - - float sum = 0; - for (int i = 0; i < iterations; i++) { - sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg); - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) - inp_grad[i * iteration_stride + id] = - (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; - if ((high_index) < row_stride) - inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; -} - -__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, - const __half* out_grad2, - const __half* vals_hat, - const __half* gamma, - const __half* betta, - const __half* vars, - __half* inp_grad, - bool invertible, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - // float2 result[iterations]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h1 = reinterpret_cast(out_grad1); - const __half2* out_grad_h2 = reinterpret_cast(out_grad2); - const __half2* vals_hat_h = reinterpret_cast(vals_hat); - - inp_grad_h += (row * row_stride); - out_grad_h1 += (row * row_stride); - out_grad_h2 += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - vals_hat_arr[i] = - (invertible - ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / - gamma_reg - : vals_hat_h[i * iteration_stride + id]); - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h1[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - vals_hat_arr[iterations] = - (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg - : vals_hat_h[high_index]); - iterations++; - } - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 temp_f = __half22float2(temp); - vals_arr_f[i].x += temp_f.x; - vals_arr_f[i].y += temp_f.y; - } - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - - inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - - inp_grad_h[high_index] = temp + out_grad_h2[high_index]; - } -} - -template <> -void launch_layerNorm_backward_fused_add(const float* out_grad1, - const float* out_grad2, - const float* vals_hat, - const float* vars, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const float* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - LayerNormBackward1<<>>( - out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); -} - -template <> -void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, - const __half* out_grad2, - const __half* vals_hat, - const __half* vars, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2], - bool invertible, - const __half* betta) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); -} - -/* Backward Normalize (Input-Gradient) - * Using the means and variances from the input - * This type of backward is not invertible! - * We do the backward using the input (X) - */ - -__global__ void LayerNormBackward2_fused_add(const float* out_grad1, - const float* out_grad2, - const float* X_vals, - const float* gamma, - const float* vars, - const float* means, - float* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; - __shared__ float partialSum[MAX_WARP_NUM]; - - float vals_arr[NORM_REG]; - float vals_hat_arr[NORM_REG]; - - out_grad1 += (row * row_stride); - out_grad2 += (row * row_stride); - X_vals += (row * row_stride); - inp_grad += (row * row_stride); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - float gamma_reg = gamma[i * iteration_stride + id]; - vals_arr[i] = out_grad1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; - vals_hat_arr[i] = X_vals[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - float gamma_reg = gamma[high_index]; - vals_arr[iterations] = out_grad1[high_index]; - vals_arr[iterations] *= gamma_reg; - vals_hat_arr[iterations] = X_vals[high_index]; - iterations++; - } - - float var_reg = vars[row]; - float mean_reg = means[row]; - - float sum = 0; - float xu[NORM_REG]; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_arr[i] - mean_reg); - sum += vals_arr[i] * xu[i]; - vals_arr[i] *= rsqrtf(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= row_stride; - - for (int i = 0; i < iterations; i++) { - vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); - } - - sum = 0; - for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - sum = g.shfl(sum, 0); - sum /= row_stride; - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) - inp_grad[i * iteration_stride + id] = - (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; - if ((high_index) < row_stride) - inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; -} - -__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, - const __half* out_grad2, - const __half* X_vals, - const __half* gamma, - const __half* vars, - const __half* means, - __half* inp_grad, - int row_stride) -{ - int iteration_stride = blockDim.x; - int iterations = row_stride / iteration_stride; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - int wid = id / WARP_SIZE; - int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; - - __shared__ float partialSum[MAX_WARP_NUM]; - - __half2 vals_arr[NORM_REG]; - float2 vals_arr_f[NORM_REG]; - __half2 vals_hat_arr[NORM_REG]; - - __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); - const __half2* out_grad_h1 = reinterpret_cast(out_grad1); - const __half2* out_grad_h2 = reinterpret_cast(out_grad2); - const __half2* vals_hat_h = reinterpret_cast(X_vals); - - out_grad_h1 += (row * row_stride); - out_grad_h2 += (row * row_stride); - inp_grad_h += (row * row_stride); - vals_hat_h += (row * row_stride); - - const __half2* gamma_h = reinterpret_cast(gamma); - int high_index = iterations * iteration_stride + id; -#pragma unroll - for (int i = 0; i < iterations; i++) { - __half2 gamma_reg = gamma_h[i * iteration_stride + id]; - vals_arr[i] = out_grad_h1[i * iteration_stride + id]; - vals_arr[i] *= gamma_reg; // out_grad * gamma - vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - __half2 gamma_reg = gamma_h[high_index]; - vals_arr[iterations] = out_grad_h1[high_index]; - vals_arr[iterations] *= gamma_reg; // out_grad * gamma - vals_hat_arr[iterations] = vals_hat_h[high_index]; - iterations++; - } - - __half mean_h = means[row]; - __half var_h = vars[row]; - __half2 var_reg = __halves2half2(var_h, var_h); - __half2 mean_reg = __halves2half2(mean_h, mean_h); - __half2 xu[NORM_REG]; - - float sum = 0.f; - for (int i = 0; i < iterations; i++) { - xu[i] = (vals_hat_arr[i] - mean_reg); - __half2 result_h = (xu[i] * vals_arr[i]); - float2 result_f = __half22float2(result_h); - sum += result_f.x; - sum += result_f.y; - vals_arr[i] *= h2rsqrt(var_reg); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - __half2 sum_h = __float2half2_rn(sum); - - for (int i = 0; i < iterations; i++) { - __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); - vals_arr_f[i] = __half22float2(vals_arr[i]); - float2 xu_grad_f = __half22float2(xu_grad); - vals_arr_f[i].x += xu_grad_f.x; - vals_arr_f[i].y += xu_grad_f.y; - } - - sum = 0.f; - for (int i = 0; i < iterations; i++) { - sum += (vals_arr_f[i].x); - sum += (vals_arr_f[i].y); - } - - for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } - - if (g.thread_rank() == 0) partialSum[wid] = sum; - - __syncthreads(); - - if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; - -#ifndef __STOCHASTIC_MODE__ - __syncthreads(); -#endif - - for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); - - sum = g.shfl(sum, 0); - sum /= (2 * row_stride); - - iterations = row_stride / iteration_stride; - for (int i = 0; i < iterations; i++) { - vals_arr_f[i].x -= sum; - vals_arr_f[i].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[i]); - inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; - } - if ((high_index) < row_stride) { - vals_arr_f[iterations].x -= sum; - vals_arr_f[iterations].y -= sum; - __half2 temp = __float22half2_rn(vals_arr_f[iterations]); - inp_grad_h[high_index] = temp + out_grad_h2[high_index]; - } -} - -template <> -void launch_layerNorm_backward_fused_add(const float* out_grad1, - const float* out_grad2, - const float* X_data, - const float* vars, - const float* means, - const float* gamma, - float* gamma_grad, - float* betta_grad, - float* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<<>>( - out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 1; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 2; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim); -} - -template <> -void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, - const __half* out_grad2, - const __half* X_data, - const __half* vars, - const __half* means, - const __half* gamma, - __half* gamma_grad, - __half* betta_grad, - __half* inp_grad, - int batch, - int hidden_dim, - cudaStream_t stream[2]) -{ - int threads = THREADS; - - dim3 grid_dim(hidden_dim / TILE_DIM); - dim3 block_dim(TILE_DIM, TILE_DIM); - - LayerNormBackward1<__half><<>>( - out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); - - dim3 grid_dim2(batch); - - if (hidden_dim > 8192 && hidden_dim <= 16384) - threads <<= 1; - else if (hidden_dim > 16384 && hidden_dim <= 32768) - threads <<= 2; - else if (hidden_dim > 32768 && hidden_dim <= 65536) - threads <<= 3; - else if (hidden_dim > 65536) - throw std::runtime_error("Unsupport hidden_dim."); - - dim3 block_dim2(threads / 2); - LayerNormBackward2_fused_add<<>>( - out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); -} +#include "custom_cuda_layers.h" + +namespace cg = cooperative_groups; + +/* +Fused bias add, residual (elementwise) add, and normalization layer. + +For FP16, this kernel does not promote to FP32 in order to utilize the 2x throughput for +__half2 instructions, and avoid the conversion overhead (1/8 of __hal2 arithmetic). + +For specific launch constraints, see the launch functions. +*/ + +#define NORM_REG (MAX_REGISTERS / 4) + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + float* means, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / WARP_SIZE; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if (high_index < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + if (training) + if (threadIdx.x == 0) means[row] = mean; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + __half* means, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) { + vars[row] = __float2half(variance); + means[row] = __float2half(mean); + } + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars, + T* means); + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars, + float* means) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars, + __half* means) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, means, hidden_dim / 2); +} + +__global__ void fused_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + bool preLayerNorm, + bool training, + float* vars, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id / 32; + + float vals_arr[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + residual += (row * row_stride); + vals += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_arr[i] = residual[i * iteration_stride + id]; + sum += vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = residual[high_index]; + sum += vals_arr[iterations]; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#if !defined(__STOCHASTIC_MODE__) || __CUDA_ARCH__ < 700 + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + + sum = g.shfl(sum, 0); + float mean = sum / row_stride; + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_arr[i] -= mean; + variance += vals_arr[i] * vals_arr[i]; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= row_stride; + variance += epsilon; + if (training) + if (threadIdx.x == 0) vars[row] = variance; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr[i] = vals_arr[i] * rsqrtf(variance); + vals_arr[i] = + vals_arr[i] * gamma[i * iteration_stride + id] + beta[i * iteration_stride + id]; + vals[i * iteration_stride + id] = vals_arr[i]; + } + if ((high_index) < row_stride) { + vals_arr[iterations] = vals_arr[iterations] * rsqrtf(variance); + vals_arr[iterations] = vals_arr[iterations] * gamma[high_index] + beta[high_index]; + vals[high_index] = vals_arr[iterations]; + } +} + +__global__ void fused_bias_residual_layer_norm(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + bool preLayerNorm, + bool training, + __half* vars, + int row_stride) +{ +#if __CUDA_ARCH__ >= 700 + + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile<32> g = cg::tiled_partition<32>(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int gid = id >> 5; + + float2 vals_f[NORM_REG]; + __shared__ float shr[MAX_WARP_NUM]; + + __half2* vals_cast = reinterpret_cast<__half2*>(vals); + const __half2* residual_cast = reinterpret_cast(residual); + + residual_cast += (row * row_stride); + vals_cast += (row * row_stride); + + float sum = 0.f; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + vals_f[i] = __half22float2(residual_cast[i * iteration_stride + id]); + sum += vals_f[i].x; + sum += vals_f[i].y; + } + if ((high_index) < row_stride) { + vals_f[iterations] = __half22float2(residual_cast[high_index]); + sum += vals_f[iterations].x; + sum += vals_f[iterations].y; + iterations++; + } + + for (int i = 1; i < 32; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) shr[gid] = sum; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) sum = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { sum += g.shfl_down(sum, i); } + sum = g.shfl(sum, 0); + float mean = sum / (row_stride * 2); + + float variance = 0.f; + for (int i = 0; i < iterations; i++) { + vals_f[i].x -= mean; + vals_f[i].y -= mean; + variance += vals_f[i].x * vals_f[i].x; + variance += vals_f[i].y * vals_f[i].y; + } + + for (int i = 1; i < 32; i *= 2) { variance += g.shfl_down(variance, i); } + + if (g.thread_rank() == 0) shr[gid] = variance; + + b.sync(); + + if (g.thread_rank() < (iteration_stride >> 5)) variance = shr[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + for (int i = 1; i < (iteration_stride >> 5); i *= 2) { variance += g.shfl_down(variance, i); } + variance = g.shfl(variance, 0); + variance /= (row_stride * 2); + variance += epsilon; + + __half2 variance_h = __float2half2_rn(variance); + const __half2* gamma_cast = reinterpret_cast(gamma); + const __half2* beta_cast = reinterpret_cast(beta); + + if (training && threadIdx.x == 0) vars[row] = __float2half(variance); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + __half2 vals_arr = __float22half2_rn(vals_f[i]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = + vals_arr * gamma_cast[i * iteration_stride + id] + beta_cast[i * iteration_stride + id]; + vals_cast[i * iteration_stride + id] = vals_arr; + } + if ((high_index) < row_stride) { + __half2 vals_arr = __float22half2_rn(vals_f[iterations]); + vals_arr = vals_arr * h2rsqrt(variance_h); + vals_arr = vals_arr * gamma_cast[high_index] + beta_cast[high_index]; + vals_cast[high_index] = vals_arr; + } +#endif +} + +template +void launch_bias_residual_layer_norm(T* vals, + const T* residual, + const T* gamma, + const T* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + T* vars); + +/* +To tune this launch the following restrictions must be met: + +For float: +row_stride == hidden_size +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +For half: +row_stride == hidden_size / 2 +threads * iterations == row_stride +threads is in [32, 64, 128, 256, 512, 1024] + +*/ + +template <> +void launch_bias_residual_layer_norm(float* vals, + const float* residual, + const float* gamma, + const float* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + float* vars) +{ + int threads = THREADS; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim); +} + +template <> +void launch_bias_residual_layer_norm<__half>(__half* vals, + const __half* residual, + const __half* gamma, + const __half* beta, + float epsilon, + int batch_size, + int hidden_dim, + cudaStream_t stream, + bool preLayerNorm, + bool training, + __half* vars) +{ + int threads = 128; + + dim3 grid_dim(batch_size); + + // There are some limitations to call below functions, now just enumerate the situations. + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim(threads); + fused_bias_residual_layer_norm<<>>( + vals, residual, gamma, beta, epsilon, preLayerNorm, training, vars, hidden_dim / 2); +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using either X_hat or + * normalize input (invertible). + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +/* Normalize Gamma & Betta gradients + * Compute gradients using the input to + * the normalize. + * Combine transpose with gradients computation. + */ + +template +__global__ void LayerNormBackward1(const T* __restrict__ out_grad, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} +/* + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is invertible! + * We do the backward using the X_hat (X - u) / sqrt(variance) or the output of Normalization. + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * + sqrtf(var_reg); // dval_hat = gamma * (x - u) * out_grad + vals_arr[i] *= rsqrtf(var_reg); // dvar_inv = gamma * out_grad / sqrt(var) + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp; + } +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + + LayerNormBackward2<<>>( + out_grad, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2(const float* out_grad, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad[high_index]; + vals_arr[iterations] *= gamma_reg; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (X_vals[i * iteration_stride + id] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) inp_grad[i * iteration_stride + id] = (vals_arr[i] - sum); + if ((high_index) < row_stride) inp_grad[high_index] = (vals_arr[iterations] - sum); +} + +__global__ void LayerNormBackward2(const __half* out_grad, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h = reinterpret_cast(out_grad); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + inp_grad_h += (row * row_stride); + out_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + iterations++; + } + __half mean_h = means[row]; + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + __half2 mean_reg = __halves2half2(mean_h, mean_h); + __half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_h[i * iteration_stride + id] - mean_reg); + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp; + } +} + +template <> +void launch_layerNorm_backward(const float* out_grad, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward<__half>(const __half* out_grad, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2<<>>( + out_grad, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ vals_hat, + const T* __restrict__ gamma, + const T* __restrict__ betta, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width, + bool invertible) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + float betta_reg = (invertible ? (float)betta[idx] : 0.0f); + float gamma_reg = (float)gamma[idx]; + + // Loop across matrix height + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (invertible ? ((float)vals_hat[offset] - betta_reg) / gamma_reg + : (float)vals_hat[offset]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +template +__global__ void LayerNormBackward1_fused_add(const T* __restrict__ out_grad1, + const T* __restrict__ out_grad2, + const T* __restrict__ X_data, + const T* __restrict__ vars, + const T* __restrict__ means, + T* __restrict__ gamma_grad, + T* __restrict__ betta_grad, + int rows, + int width) +{ + __shared__ float betta_buffer[TILE_DIM][TILE_DIM + 1]; + __shared__ float gamma_buffer[TILE_DIM][TILE_DIM + 1]; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int offset = threadIdx.y * width + idx; + int y_stride = width * TILE_DIM; + + int pos = blockIdx.x * TILE_DIM + threadIdx.y; + // Loop across matrix height + + float betta_tmp = 0; + float gamma_tmp = 0; + for (int r = threadIdx.y; r < rows; r += TILE_DIM) { + float grad = (float)out_grad1[offset] + (float)out_grad2[offset]; + float val = (float)X_data[offset]; + val = (val - (float)means[r]) * rsqrtf((float)vars[r]); + betta_tmp += grad; + gamma_tmp += (val * grad); + + offset += y_stride; + } + + betta_buffer[threadIdx.x][threadIdx.y] = betta_tmp; + gamma_buffer[threadIdx.x][threadIdx.y] = gamma_tmp; + + __syncthreads(); + + // Sum the shared buffer. + float s1 = betta_buffer[threadIdx.y][threadIdx.x]; + float s2 = gamma_buffer[threadIdx.y][threadIdx.x]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < TILE_DIM; i <<= 1) { + s1 += g.shfl_down(s1, i); + s2 += g.shfl_down(s2, i); + } + + if (threadIdx.x == 0) { + betta_grad[pos] = s1; + gamma_grad[pos] = s2; + } +} + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* gamma, + const float* betta, + const float* vars, + float* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + vals_hat += (row * row_stride); + inp_grad += (row * row_stride); + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = + (invertible ? (vals_hat[i * iteration_stride + id] - betta[i * iteration_stride + id]) / + gamma_reg + : vals_hat[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = + (invertible ? (vals_hat[high_index] - betta[high_index]) / gamma_reg + : vals_hat[high_index]); + iterations++; + } + + float var_reg = vars[row]; + + float sum = 0; + for (int i = 0; i < iterations; i++) { + sum += vals_hat_arr[i] * vals_arr[i] * sqrtf(var_reg); + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { vals_arr[i] += ((-sum * vals_hat_arr[i]) / var_reg); } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* gamma, + const __half* betta, + const __half* vars, + __half* inp_grad, + bool invertible, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + // float2 result[iterations]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(vals_hat); + + inp_grad_h += (row * row_stride); + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + const __half2* betta_h = (invertible ? reinterpret_cast(betta) : nullptr); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = + (invertible + ? (vals_hat_h[i * iteration_stride + id] - betta_h[i * iteration_stride + id]) / + gamma_reg + : vals_hat_h[i * iteration_stride + id]); + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = + (invertible ? (vals_hat_h[high_index] - betta_h[high_index]) / gamma_reg + : vals_hat_h[high_index]); + iterations++; + } + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + __half2 result_h = (vals_hat_arr[i] * vals_arr[i] * h2sqrt(var_reg)); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 temp = ((-sum_h * vals_hat_arr[i]) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 temp_f = __half22float2(temp); + vals_arr_f[i].x += temp_f.x; + vals_arr_f[i].y += temp_f.y; + } + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* vals_hat, + const float* vars, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const float* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + LayerNormBackward1<<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* vals_hat, + const __half* vars, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2], + bool invertible, + const __half* betta) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, vals_hat, gamma, betta, gamma_grad, betta_grad, batch, hidden_dim, invertible); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, vals_hat, gamma, betta, vars, inp_grad, invertible, hidden_dim / 2); +} + +/* Backward Normalize (Input-Gradient) + * Using the means and variances from the input + * This type of backward is not invertible! + * We do the backward using the input (X) + */ + +__global__ void LayerNormBackward2_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_vals, + const float* gamma, + const float* vars, + const float* means, + float* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (THREADS < row_stride ? THREADS : row_stride) / WARP_SIZE; + __shared__ float partialSum[MAX_WARP_NUM]; + + float vals_arr[NORM_REG]; + float vals_hat_arr[NORM_REG]; + + out_grad1 += (row * row_stride); + out_grad2 += (row * row_stride); + X_vals += (row * row_stride); + inp_grad += (row * row_stride); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + float gamma_reg = gamma[i * iteration_stride + id]; + vals_arr[i] = out_grad1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; + vals_hat_arr[i] = X_vals[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + float gamma_reg = gamma[high_index]; + vals_arr[iterations] = out_grad1[high_index]; + vals_arr[iterations] *= gamma_reg; + vals_hat_arr[iterations] = X_vals[high_index]; + iterations++; + } + + float var_reg = vars[row]; + float mean_reg = means[row]; + + float sum = 0; + float xu[NORM_REG]; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + sum += vals_arr[i] * xu[i]; + vals_arr[i] *= rsqrtf(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= row_stride; + + for (int i = 0; i < iterations; i++) { + vals_arr[i] += (-sum * xu[i] * rsqrtf(var_reg) / (var_reg)); + } + + sum = 0; + for (int i = 0; i < iterations; i++) { sum += vals_arr[i]; } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + sum = g.shfl(sum, 0); + sum /= row_stride; + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) + inp_grad[i * iteration_stride + id] = + (vals_arr[i] - sum) + out_grad2[i * iteration_stride + id]; + if ((high_index) < row_stride) + inp_grad[high_index] = (vals_arr[iterations] - sum) + out_grad2[high_index]; +} + +__global__ void LayerNormBackward2_fused_add(const __half* out_grad1, + const __half* out_grad2, + const __half* X_vals, + const __half* gamma, + const __half* vars, + const __half* means, + __half* inp_grad, + int row_stride) +{ + int iteration_stride = blockDim.x; + int iterations = row_stride / iteration_stride; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + int wid = id / WARP_SIZE; + int warp_num = (iteration_stride < row_stride ? iteration_stride : row_stride) / WARP_SIZE; + + __shared__ float partialSum[MAX_WARP_NUM]; + + __half2 vals_arr[NORM_REG]; + float2 vals_arr_f[NORM_REG]; + __half2 vals_hat_arr[NORM_REG]; + + __half2* inp_grad_h = reinterpret_cast<__half2*>(inp_grad); + const __half2* out_grad_h1 = reinterpret_cast(out_grad1); + const __half2* out_grad_h2 = reinterpret_cast(out_grad2); + const __half2* vals_hat_h = reinterpret_cast(X_vals); + + out_grad_h1 += (row * row_stride); + out_grad_h2 += (row * row_stride); + inp_grad_h += (row * row_stride); + vals_hat_h += (row * row_stride); + + const __half2* gamma_h = reinterpret_cast(gamma); + int high_index = iterations * iteration_stride + id; +#pragma unroll + for (int i = 0; i < iterations; i++) { + __half2 gamma_reg = gamma_h[i * iteration_stride + id]; + vals_arr[i] = out_grad_h1[i * iteration_stride + id]; + vals_arr[i] *= gamma_reg; // out_grad * gamma + vals_hat_arr[i] = vals_hat_h[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + __half2 gamma_reg = gamma_h[high_index]; + vals_arr[iterations] = out_grad_h1[high_index]; + vals_arr[iterations] *= gamma_reg; // out_grad * gamma + vals_hat_arr[iterations] = vals_hat_h[high_index]; + iterations++; + } + + __half mean_h = means[row]; + __half var_h = vars[row]; + __half2 var_reg = __halves2half2(var_h, var_h); + __half2 mean_reg = __halves2half2(mean_h, mean_h); + __half2 xu[NORM_REG]; + + float sum = 0.f; + for (int i = 0; i < iterations; i++) { + xu[i] = (vals_hat_arr[i] - mean_reg); + __half2 result_h = (xu[i] * vals_arr[i]); + float2 result_f = __half22float2(result_h); + sum += result_f.x; + sum += result_f.y; + vals_arr[i] *= h2rsqrt(var_reg); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + __half2 sum_h = __float2half2_rn(sum); + + for (int i = 0; i < iterations; i++) { + __half2 xu_grad = ((-sum_h * xu[i] * h2rsqrt(var_reg)) / (var_reg)); + vals_arr_f[i] = __half22float2(vals_arr[i]); + float2 xu_grad_f = __half22float2(xu_grad); + vals_arr_f[i].x += xu_grad_f.x; + vals_arr_f[i].y += xu_grad_f.y; + } + + sum = 0.f; + for (int i = 0; i < iterations; i++) { + sum += (vals_arr_f[i].x); + sum += (vals_arr_f[i].y); + } + + for (int i = 1; i < WARP_SIZE; i *= 2) { sum += g.shfl_down(sum, i); } + + if (g.thread_rank() == 0) partialSum[wid] = sum; + + __syncthreads(); + + if (g.thread_rank() < warp_num) sum = partialSum[g.thread_rank()]; + +#ifndef __STOCHASTIC_MODE__ + __syncthreads(); +#endif + + for (int i = 1; i < warp_num; i *= 2) sum += g.shfl_down(sum, i); + + sum = g.shfl(sum, 0); + sum /= (2 * row_stride); + + iterations = row_stride / iteration_stride; + for (int i = 0; i < iterations; i++) { + vals_arr_f[i].x -= sum; + vals_arr_f[i].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[i]); + inp_grad_h[i * iteration_stride + id] = temp + out_grad_h2[i * iteration_stride + id]; + } + if ((high_index) < row_stride) { + vals_arr_f[iterations].x -= sum; + vals_arr_f[iterations].y -= sum; + __half2 temp = __float22half2_rn(vals_arr_f[iterations]); + inp_grad_h[high_index] = temp + out_grad_h2[high_index]; + } +} + +template <> +void launch_layerNorm_backward_fused_add(const float* out_grad1, + const float* out_grad2, + const float* X_data, + const float* vars, + const float* means, + const float* gamma, + float* gamma_grad, + float* betta_grad, + float* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 1; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 2; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim); +} + +template <> +void launch_layerNorm_backward_fused_add<__half>(const __half* out_grad1, + const __half* out_grad2, + const __half* X_data, + const __half* vars, + const __half* means, + const __half* gamma, + __half* gamma_grad, + __half* betta_grad, + __half* inp_grad, + int batch, + int hidden_dim, + cudaStream_t stream[2]) +{ + int threads = THREADS; + + dim3 grid_dim(hidden_dim / TILE_DIM); + dim3 block_dim(TILE_DIM, TILE_DIM); + + LayerNormBackward1<__half><<>>( + out_grad1, X_data, vars, means, gamma_grad, betta_grad, batch, hidden_dim); + + dim3 grid_dim2(batch); + + if (hidden_dim > 8192 && hidden_dim <= 16384) + threads <<= 1; + else if (hidden_dim > 16384 && hidden_dim <= 32768) + threads <<= 2; + else if (hidden_dim > 32768 && hidden_dim <= 65536) + threads <<= 3; + else if (hidden_dim > 65536) + throw std::runtime_error("Unsupport hidden_dim."); + + dim3 block_dim2(threads / 2); + LayerNormBackward2_fused_add<<>>( + out_grad1, out_grad2, X_data, gamma, vars, means, inp_grad, hidden_dim / 2); +} diff --git a/csrc/transformer/softmax_kernels.cu b/csrc/transformer/softmax_kernels.cu index 4c0d987e617a..d0f3dca5c3ca 100644 --- a/csrc/transformer/softmax_kernels.cu +++ b/csrc/transformer/softmax_kernels.cu @@ -1,595 +1,595 @@ -#include -#include "custom_cuda_layers.h" -#include "general_kernels.h" - -namespace cg = cooperative_groups; - -dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) -{ - int seq_length4 = sequence_length / 4; - int block_compute_size = - (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); - // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications - // The batch size is typically relatively small, while the sequence length could potentially be - // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. - unsigned x = heads * sequence_length / block_compute_size; - unsigned y = batch_size; - return {x, y}; -} - -// Fused attention + softmax -template -__global__ void attn_softmax(float* vals, - const float* attn_mask, - int heads, - int seq_length, - int iterations) -{ - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int batch = blockIdx.y; - int row = blockIdx.x; - int max_threads_in_sequence = std::max(seq_length, tbSeq); - int seq_lane = threadIdx.x % max_threads_in_sequence; - - int data_offset = batch * (gridDim.x * block_width) + row * block_width + - (threadIdx.x / max_threads_in_sequence) * seq_length; - int mask_offset = batch * seq_length; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - - float4* val_cast = reinterpret_cast(vals); - const float4* attn_mask_cast = reinterpret_cast(attn_mask); - - float4 data[MAX_THREAD_ITERATIONS]; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float4 mask = attn_mask_cast[mask_offset + data_id]; - data[i] = val_cast[data_offset + data_id]; - - data[i].x += mask.x; - data[i].y += mask.y; - data[i].z += mask.z; - data[i].w += mask.w; - - max_val = (data[i].x > max_val ? data[i].x : max_val); - max_val = (data[i].y > max_val ? data[i].y : max_val); - max_val = (data[i].z > max_val ? data[i].z : max_val); - max_val = (data[i].w > max_val ? data[i].w : max_val); - } else { - data[i].x = minus_infinity; - data[i].y = minus_infinity; - data[i].z = minus_infinity; - data[i].w = minus_infinity; - } - } - - for (int i = 1; i < tbSize; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / tbSize); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - data[i].x = __expf(data[i].x - max_val); - data[i].y = __expf(data[i].y - max_val); - data[i].z = __expf(data[i].z - max_val); - data[i].w = __expf(data[i].w - max_val); - - sum += (data[i].x + data[i].y + data[i].z + data[i].w); - } - - for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / tbSize); - } - - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - data[i].x /= sum; - data[i].y /= sum; - data[i].z /= sum; - data[i].w /= sum; - - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; - } -} - -template -__global__ void attn_softmax(__half* vals, - const __half* attn_mask, - int heads, - int seq_length, - int iterations) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int batch = blockIdx.y; - int row = blockIdx.x; - int max_threads_in_sequence = std::max(seq_length, tbSeq); - int seq_lane = threadIdx.x % max_threads_in_sequence; - - int data_offset = batch * (gridDim.x * block_width) + row * block_width + - (threadIdx.x / max_threads_in_sequence) * seq_length; - int mask_offset = batch * seq_length; - - int wid = threadIdx.x >> 5; - int lane = threadIdx.x & 0x1f; - - float2* val_cast = reinterpret_cast(vals); - const float2* attn_mask_cast = reinterpret_cast(attn_mask); - - val_cast += data_offset; - attn_mask_cast += mask_offset; - - float2 low_data[MAX_THREAD_ITERATIONS]; - float2 high_data[MAX_THREAD_ITERATIONS]; - - float max_val = minus_infinity; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float2 data = val_cast[data_id]; - float2 mask = attn_mask_cast[data_id]; - - __half2* data_arr = reinterpret_cast<__half2*>(&data); - __half2* mask_arr = reinterpret_cast<__half2*>(&mask); - - low_data[i] = __half22float2(data_arr[0]); - high_data[i] = __half22float2(data_arr[1]); - float2 low_mask = __half22float2(mask_arr[0]); - float2 high_mask = __half22float2(mask_arr[1]); - - low_data[i].x += low_mask.x; - low_data[i].y += low_mask.y; - high_data[i].x += high_mask.x; - high_data[i].y += high_mask.y; - - max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); - max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); - max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); - max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); - } - } - - for (int i = 1; i < tbSize; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = max_val; - b.sync(); - - if (lane < warp_num) max_val = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { - auto temp = g.shfl_xor(max_val, i); - max_val = (temp > max_val ? temp : max_val); - } - - max_val = g.shfl(max_val, threadIdx.x / tbSize); - } - - float sum = 0; - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - low_data[i].x = __expf(low_data[i].x - max_val); - low_data[i].y = __expf(low_data[i].y - max_val); - high_data[i].x = __expf(high_data[i].x - max_val); - high_data[i].y = __expf(high_data[i].y - max_val); - - sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); - } - } - - for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = sum; - b.sync(); - - if (lane < warp_num) sum = partialSum[lane]; - -#ifndef __STOCHASTIC_MODE__ - b.sync(); -#endif - - int iters = warp_num; - if (seq_length < iteration_stride) - iters = warp_num / (iteration_stride / max_threads_in_sequence); - - for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } - - sum = g.shfl(sum, threadIdx.x / tbSize); - } - - sum += 1e-6; - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + seq_lane; - if (data_id < seq_length) { - float2 result_f; - __half2* result_h = reinterpret_cast<__half2*>(&result_f); - - low_data[i].x /= sum; - low_data[i].y /= sum; - high_data[i].x /= sum; - high_data[i].y /= sum; - - result_h[0] = __float22half2_rn(low_data[i]); - result_h[1] = __float22half2_rn(high_data[i]); - - val_cast[data_id] = result_f; - } - } - -#endif -} - -template -void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); - -template <> -void launch_attn_softmax(float* vals, - const float* attn_mask, - int batch_size, - int heads, - int sequence_length, - cudaStream_t stream) -{ - const int threads = 128; - int seq_length4 = sequence_length / 4; - - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - int iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - - if (sequence_length <= 8) - attn_softmax<2, (threads / 2), 2> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 16) - attn_softmax<4, (threads / 4), 4> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 32) - attn_softmax<8, (threads / 8), 8> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 64) - attn_softmax<16, (threads / 16), 16> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 128) - attn_softmax<32, (threads / 32), 32> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 256) - attn_softmax<32, (threads / 64), 64> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else { - const int threads = 256; - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - if (sequence_length <= 512) - attn_softmax<32, (threads / 128), 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) - attn_softmax<32, 1, 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else - throw std::runtime_error( - "Unsupport Seq_Length! Check the restriction of the max_threads and " - "max_thread_iterations!"); - } -} - -template <> -void launch_attn_softmax<__half>(__half* vals, - const __half* attn_mask, - int batch_size, - int heads, - int sequence_length, - cudaStream_t stream) -{ - const int threads = 128; - int seq_length4 = sequence_length / 4; - - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - - int iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - - if (sequence_length <= 8) - attn_softmax<2, (threads / 2), 2> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 16) - attn_softmax<4, (threads / 4), 4> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 32) - attn_softmax<8, (threads / 8), 8> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 64) - attn_softmax<16, (threads / 16), 16> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 128) - attn_softmax<32, (threads / 32), 32> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length <= 256) - attn_softmax<32, (threads / 64), 64> - <<>>(vals, attn_mask, heads, seq_length4, iterations); - else { - const int threads = 256; - dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); - - int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; - - dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / - subblock_max_workload * threads) - : threads); - iterations = - (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads - : MAX_THREAD_ITERATIONS); - if (sequence_length <= 512) - attn_softmax<32, (threads / 128), 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) - attn_softmax<32, 1, 128><<>>( - vals, attn_mask, heads, seq_length4, iterations); - else - throw std::runtime_error( - "Unsupport Seq_Length! Check the restriction of the max_threads and " - "max_thread_iterations!"); - } -} - -template -__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length) -{ - __shared__ float partialSum[MAX_WARP_NUM]; - - int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32) - - int iteration_stride = blockDim.x; - int block_width = blockStride * seq_length; - - int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) - ? (seq_length + iteration_stride - 1) / iteration_stride - : MAX_THREAD_ITERATIONS); - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - int row = blockIdx.x; - int id = threadIdx.x; - - int wid = id >> 5; - int lane = id & 0x1f; - - T val_reg[MAX_THREAD_ITERATIONS]; - T soft_reg[MAX_THREAD_ITERATIONS]; - float grad_reg = 0.0f; - -#pragma unroll - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + id; - if (data_id < block_width) { - val_reg[i] = out_grad[row * block_width + data_id]; - soft_reg[i] = soft_inp[row * block_width + data_id]; - - grad_reg += ((float)val_reg[i] * - (float)soft_reg[i]); // if done in half, the multiplication, we may lose - // 2% of accuracy in computation!! - } - } - for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); - - if (seq_length > tbSize) { - if (lane == 0) partialSum[wid] = grad_reg; - b.sync(); - - if (lane < warp_num) grad_reg = partialSum[lane]; - - int iters = warp_num; - if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); - - for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); - - grad_reg = g.shfl(grad_reg, id / tbSize); - } - - for (int i = 0; i < iterations; i++) { - int data_id = i * iteration_stride + id; - if (data_id < block_width) { - float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); - out_grad[row * block_width + data_id] = (T)temp; - } - } -} - -template -__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/, - const T* output, - int softmax_length) -{ - int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; - int offset = batch_idx * softmax_length + threadIdx.x; - - grad += offset; - output += offset; - - T grad_reg[ITERATIONS]; - T output_reg[ITERATIONS]; - float sum = 0.0; - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) { - grad_reg[i] = grad[i * WARP_SIZE]; - output_reg[i] = output[i * WARP_SIZE]; - sum += (float)grad_reg[i] * (float)output_reg[i]; - } - } - - cg::thread_block b = cg::this_thread_block(); - cg::thread_block_tile g = cg::tiled_partition(b); - - for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); - -#pragma unroll - for (int i = 0; i < ITERATIONS; ++i) { - int curr_idx = threadIdx.x + i * WARP_SIZE; - if (curr_idx < softmax_length) - grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum); - } -} - -template -void launch_attn_softmax_backward_v2(T* out_grad, - const T* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream) -{ - const int warps_per_block = 4; - dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); - dim3 block_dim(WARP_SIZE, warps_per_block); - - if (seq_length <= 32) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 64) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 128) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 256) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 384) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 512) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 768) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 1024) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else if (seq_length <= 2048) - softmax_backward_kernel_v2 - <<>>(out_grad, soft_inp, seq_length); - else - throw std::runtime_error( - std::string("Special sequence length found in softmax backward, seq_length: ") + - std::to_string(seq_length)); -} - -template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, - const __half* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream); -template void launch_attn_softmax_backward_v2(float* out_grad, - const float* soft_inp, - int batch_size, - int heads, - int seq_length, - cudaStream_t stream); +#include +#include "custom_cuda_layers.h" +#include "general_kernels.h" + +namespace cg = cooperative_groups; + +dim3 get_attn_softmax_grid(int batch_size, int heads, int sequence_length, int threads) +{ + int seq_length4 = sequence_length / 4; + int block_compute_size = + (seq_length4 < threads ? (int)pow(2.0, floor(log2((float)(threads / seq_length4)))) : 1); + // Note that the Y and Z dimensions are limited to 65535, while X is basically unlimited: + // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications + // The batch size is typically relatively small, while the sequence length could potentially be + // arbitrarily large. We therefore place the batch size second to avoid hitting the Y limit. + unsigned x = heads * sequence_length / block_compute_size; + unsigned y = batch_size; + return {x, y}; +} + +// Fused attention + softmax +template +__global__ void attn_softmax(float* vals, + const float* attn_mask, + int heads, + int seq_length, + int iterations) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + + float4* val_cast = reinterpret_cast(vals); + const float4* attn_mask_cast = reinterpret_cast(attn_mask); + + float4 data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float4 mask = attn_mask_cast[mask_offset + data_id]; + data[i] = val_cast[data_offset + data_id]; + + data[i].x += mask.x; + data[i].y += mask.y; + data[i].z += mask.z; + data[i].w += mask.w; + + max_val = (data[i].x > max_val ? data[i].x : max_val); + max_val = (data[i].y > max_val ? data[i].y : max_val); + max_val = (data[i].z > max_val ? data[i].z : max_val); + max_val = (data[i].w > max_val ? data[i].w : max_val); + } else { + data[i].x = minus_infinity; + data[i].y = minus_infinity; + data[i].z = minus_infinity; + data[i].w = minus_infinity; + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + data[i].x = __expf(data[i].x - max_val); + data[i].y = __expf(data[i].y - max_val); + data[i].z = __expf(data[i].z - max_val); + data[i].w = __expf(data[i].w - max_val); + + sum += (data[i].x + data[i].y + data[i].z + data[i].w); + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + data[i].x /= sum; + data[i].y /= sum; + data[i].z /= sum; + data[i].w /= sum; + + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) val_cast[data_offset + data_id] = data[i]; + } +} + +template +__global__ void attn_softmax(__half* vals, + const __half* attn_mask, + int heads, + int seq_length, + int iterations) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int batch = blockIdx.y; + int row = blockIdx.x; + int max_threads_in_sequence = std::max(seq_length, tbSeq); + int seq_lane = threadIdx.x % max_threads_in_sequence; + + int data_offset = batch * (gridDim.x * block_width) + row * block_width + + (threadIdx.x / max_threads_in_sequence) * seq_length; + int mask_offset = batch * seq_length; + + int wid = threadIdx.x >> 5; + int lane = threadIdx.x & 0x1f; + + float2* val_cast = reinterpret_cast(vals); + const float2* attn_mask_cast = reinterpret_cast(attn_mask); + + val_cast += data_offset; + attn_mask_cast += mask_offset; + + float2 low_data[MAX_THREAD_ITERATIONS]; + float2 high_data[MAX_THREAD_ITERATIONS]; + + float max_val = minus_infinity; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 data = val_cast[data_id]; + float2 mask = attn_mask_cast[data_id]; + + __half2* data_arr = reinterpret_cast<__half2*>(&data); + __half2* mask_arr = reinterpret_cast<__half2*>(&mask); + + low_data[i] = __half22float2(data_arr[0]); + high_data[i] = __half22float2(data_arr[1]); + float2 low_mask = __half22float2(mask_arr[0]); + float2 high_mask = __half22float2(mask_arr[1]); + + low_data[i].x += low_mask.x; + low_data[i].y += low_mask.y; + high_data[i].x += high_mask.x; + high_data[i].y += high_mask.y; + + max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); + max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); + max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); + max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); + } + } + + for (int i = 1; i < tbSize; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = max_val; + b.sync(); + + if (lane < warp_num) max_val = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { + auto temp = g.shfl_xor(max_val, i); + max_val = (temp > max_val ? temp : max_val); + } + + max_val = g.shfl(max_val, threadIdx.x / tbSize); + } + + float sum = 0; + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + low_data[i].x = __expf(low_data[i].x - max_val); + low_data[i].y = __expf(low_data[i].y - max_val); + high_data[i].x = __expf(high_data[i].x - max_val); + high_data[i].y = __expf(high_data[i].y - max_val); + + sum += (low_data[i].x + low_data[i].y + high_data[i].x + high_data[i].y); + } + } + + for (int i = 1; i < tbSize; i *= 2) { sum += g.shfl_xor(sum, i); } + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = sum; + b.sync(); + + if (lane < warp_num) sum = partialSum[lane]; + +#ifndef __STOCHASTIC_MODE__ + b.sync(); +#endif + + int iters = warp_num; + if (seq_length < iteration_stride) + iters = warp_num / (iteration_stride / max_threads_in_sequence); + + for (int i = 1; i < iters; i *= 2) { sum += g.shfl_xor(sum, i); } + + sum = g.shfl(sum, threadIdx.x / tbSize); + } + + sum += 1e-6; + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + seq_lane; + if (data_id < seq_length) { + float2 result_f; + __half2* result_h = reinterpret_cast<__half2*>(&result_f); + + low_data[i].x /= sum; + low_data[i].y /= sum; + high_data[i].x /= sum; + high_data[i].y /= sum; + + result_h[0] = __float22half2_rn(low_data[i]); + result_h[1] = __float22half2_rn(high_data[i]); + + val_cast[data_id] = result_f; + } + } + +#endif +} + +template +void launch_attn_softmax(T*, const T*, int, int, int, cudaStream_t); + +template <> +void launch_attn_softmax(float* vals, + const float* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template <> +void launch_attn_softmax<__half>(__half* vals, + const __half* attn_mask, + int batch_size, + int heads, + int sequence_length, + cudaStream_t stream) +{ + const int threads = 128; + int seq_length4 = sequence_length / 4; + + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + + int iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + + if (sequence_length <= 8) + attn_softmax<2, (threads / 2), 2> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 16) + attn_softmax<4, (threads / 4), 4> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 32) + attn_softmax<8, (threads / 8), 8> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 64) + attn_softmax<16, (threads / 16), 16> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 128) + attn_softmax<32, (threads / 32), 32> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length <= 256) + attn_softmax<32, (threads / 64), 64> + <<>>(vals, attn_mask, heads, seq_length4, iterations); + else { + const int threads = 256; + dim3 grid_dim = get_attn_softmax_grid(batch_size, heads, sequence_length, threads); + + int subblock_max_workload = MAX_THREAD_ITERATIONS * 4 * threads; + + dim3 block_dim(seq_length4 > threads ? ((sequence_length + subblock_max_workload - 1) / + subblock_max_workload * threads) + : threads); + iterations = + (sequence_length < subblock_max_workload ? (seq_length4 + threads - 1) / threads + : MAX_THREAD_ITERATIONS); + if (sequence_length <= 512) + attn_softmax<32, (threads / 128), 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else if (sequence_length < (MAX_THREADS * MAX_THREAD_ITERATIONS * 4)) + attn_softmax<32, 1, 128><<>>( + vals, attn_mask, heads, seq_length4, iterations); + else + throw std::runtime_error( + "Unsupport Seq_Length! Check the restriction of the max_threads and " + "max_thread_iterations!"); + } +} + +template +__global__ void softmax_backward_kernel(T* out_grad, const T* soft_inp, int seq_length) +{ + __shared__ float partialSum[MAX_WARP_NUM]; + + int warp_num = blockDim.x >> 5; // warp-count = num_threads / WARP_SIZE (32) + + int iteration_stride = blockDim.x; + int block_width = blockStride * seq_length; + + int iterations = (seq_length < (MAX_THREAD_ITERATIONS * iteration_stride) + ? (seq_length + iteration_stride - 1) / iteration_stride + : MAX_THREAD_ITERATIONS); + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int row = blockIdx.x; + int id = threadIdx.x; + + int wid = id >> 5; + int lane = id & 0x1f; + + T val_reg[MAX_THREAD_ITERATIONS]; + T soft_reg[MAX_THREAD_ITERATIONS]; + float grad_reg = 0.0f; + +#pragma unroll + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + val_reg[i] = out_grad[row * block_width + data_id]; + soft_reg[i] = soft_inp[row * block_width + data_id]; + + grad_reg += ((float)val_reg[i] * + (float)soft_reg[i]); // if done in half, the multiplication, we may lose + // 2% of accuracy in computation!! + } + } + for (int i = 1; i < tbSize; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + if (seq_length > tbSize) { + if (lane == 0) partialSum[wid] = grad_reg; + b.sync(); + + if (lane < warp_num) grad_reg = partialSum[lane]; + + int iters = warp_num; + if (seq_length < iteration_stride) iters = warp_num / (iteration_stride / seq_length); + + for (int i = 1; i < iters; i *= 2) grad_reg += g.shfl_xor(grad_reg, i); + + grad_reg = g.shfl(grad_reg, id / tbSize); + } + + for (int i = 0; i < iterations; i++) { + int data_id = i * iteration_stride + id; + if (data_id < block_width) { + float temp = (float)soft_reg[i] * ((float)val_reg[i] - grad_reg); + out_grad[row * block_width + data_id] = (T)temp; + } + } +} + +template +__global__ void softmax_backward_kernel_v2(T* grad /* input & output*/, + const T* output, + int softmax_length) +{ + int batch_idx = blockIdx.x * blockDim.y + threadIdx.y; + int offset = batch_idx * softmax_length + threadIdx.x; + + grad += offset; + output += offset; + + T grad_reg[ITERATIONS]; + T output_reg[ITERATIONS]; + float sum = 0.0; + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) { + grad_reg[i] = grad[i * WARP_SIZE]; + output_reg[i] = output[i * WARP_SIZE]; + sum += (float)grad_reg[i] * (float)output_reg[i]; + } + } + + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); + +#pragma unroll + for (int i = 0; i < ITERATIONS; ++i) { + int curr_idx = threadIdx.x + i * WARP_SIZE; + if (curr_idx < softmax_length) + grad[i * WARP_SIZE] = (float)output_reg[i] * ((float)grad_reg[i] - sum); + } +} + +template +void launch_attn_softmax_backward_v2(T* out_grad, + const T* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream) +{ + const int warps_per_block = 4; + dim3 grid_dim(batch_size * heads * seq_length / warps_per_block); + dim3 block_dim(WARP_SIZE, warps_per_block); + + if (seq_length <= 32) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 64) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 128) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 256) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 384) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 512) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 768) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 1024) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else if (seq_length <= 2048) + softmax_backward_kernel_v2 + <<>>(out_grad, soft_inp, seq_length); + else + throw std::runtime_error( + std::string("Special sequence length found in softmax backward, seq_length: ") + + std::to_string(seq_length)); +} + +template void launch_attn_softmax_backward_v2<__half>(__half* out_grad, + const __half* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); +template void launch_attn_softmax_backward_v2(float* out_grad, + const float* soft_inp, + int batch_size, + int heads, + int seq_length, + cudaStream_t stream); diff --git a/csrc/transformer/transform_kernels.cu b/csrc/transformer/transform_kernels.cu index 7d8a27eeeb43..b7924c6f216f 100755 --- a/csrc/transformer/transform_kernels.cu +++ b/csrc/transformer/transform_kernels.cu @@ -1,575 +1,575 @@ -#include "custom_cuda_layers.h" - -#define rows_trans 16 -#define cols_trans 16 - -template -__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width) -{ - __shared__ T data_block[rows_trans * (cols_trans + 1)]; - - int r = threadIdx.x / cols_trans; - int c = threadIdx.x % cols_trans; - - int m = row_width / cols_trans; - - int i = blockIdx.x / m * rows_trans + r; - int j = blockIdx.x % m * cols_trans + c; - - int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); - - for (int k = 0; k < rows_trans; k += row_stride) - data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; - - __syncthreads(); - - i = blockIdx.x % m * rows_trans + r; - j = blockIdx.x / m * cols_trans + c; - - for (int k = 0; k < rows_trans; k += row_stride) - out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; -} - -template <> -void Transpose<__half>(const __half* inp_mat, - __half* out_mat, - int rows, - int cols, - cudaStream_t stream) -{ - int threads = THREADS; - - Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( - inp_mat, out_mat, cols, rows); -} - -template <> -void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream) -{ - int threads = THREADS; - - Transpose_Kernel<<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( - inp_mat, out_mat, cols, rows); -} - -template -__global__ void transform_0213(T* output, - const T* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext); - -template <> -__global__ void transform_0213(float* output, - const float* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) - int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - const float4* vals_vec = reinterpret_cast(vals); - float4* output_vec = reinterpret_cast(output); - - float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; - output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; -} - -template <> -__global__ void transform_0213<__half>(__half* output, - const __half* vals, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) - int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr[1]; - - const float4* vals_vec = reinterpret_cast(vals); - float4* output_vec = reinterpret_cast(output); - - vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; - output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; -#endif -} - -template <> -void launch_transform_0213(float* output, - const float* vals, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream) -{ - hidden_dim >>= 2; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, (seq_length * head_ext)); - - transform_0213 - <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); -} - -template <> -void launch_transform_0213<__half>(__half* output, - const __half* vals, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream) -{ - hidden_dim >>= 3; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, (seq_length * head_ext)); - transform_0213<__half> - <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); -} - -// Bias add -template -__global__ void bias_add_transform_0213(T* output, - const T* vals, - const T* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext); - -template <> -__global__ void bias_add_transform_0213(float* output, - const float* vals, - const float* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = blockIdx.z / head_ext; // Hidden count - int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + - d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; - float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; - - float4 outputs; - outputs.x = inputs.x + biases.x; - outputs.y = inputs.y + biases.y; - outputs.z = inputs.z + biases.z; - outputs.w = inputs.w + biases.w; - - output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + - d2 * d2_out_stride + d3] = outputs; -} - -#define ATTN_H 3 -#define MAX_SEQ_LINE 10 - -template <> -__global__ void bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = blockIdx.z / head_ext; // Hidden count - int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr; - float4 bias_arr; - float4 output_arr; - __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(&output_arr); - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); - vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); - vals_vec += (cnt * d1_stride); - vals_vec += (d2 * d2_stride); - - bias_vec += (cnt * d1_stride); - bias_vec += (d2 * d2_stride); - - output_vec += (cnt * d0_stride * gridDim.x); - output_vec += (d1 * d2_stride); - output_vec += (d0 * d0_stride); - output_vec += (d2 * d2_out_stride); - - bias_arr = bias_vec[d3]; - vals_arr = vals_vec[d3]; - -#if defined(__ACC_HALF__) - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; -#else - float2 bias_arr_f[4]; - float2 vals_arr_f[4]; -#pragma unroll - for (int l = 0; l < 4; l++) { - bias_arr_f[l] = __half22float2(bias_half[l]); - vals_arr_f[l] = __half22float2(vals_half[l]); - vals_arr_f[l].x += bias_arr_f[l].x; - vals_arr_f[l].y += bias_arr_f[l].y; - output_half[l] = __float22half2_rn(vals_arr_f[l]); - } -#endif - output_vec[d3] = output_arr; - -#endif -} - -__global__ void bias_add_transform_0213_v2(__half* output, - const __half* vals, - const __half* bias, - int hidden_dim, - int seq_length, - int heads) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float4 in_data[3072]; - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 - int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = d2_stride * seq_length; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y; // Sequence ID (0-127) - int cnt = threadIdx.z; // blockIdx.z; // Hidden count - int d2 = threadIdx.y; // Head (0-11) - int d3 = threadIdx.x; // Values (groups of 4) - - float4 vals_arr[1]; - float4 bias_arr[1]; - float4 output_arr[1]; - __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); - __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); - __half2* output_half = reinterpret_cast<__half2*>(output_arr); - - const float4* vals_vec = reinterpret_cast(vals); - const float4* bias_vec = reinterpret_cast(bias); - float4* output_vec = reinterpret_cast(output); - - int iter_index = cnt * d1_stride + d2 * d2_stride + d3; - int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); - bias_arr[0] = bias_vec[iter_index]; - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - vals_arr[0] = vals_vec[input_offset + iter_id]; - - output_half[0] = vals_half[0] + bias_half[0]; - output_half[1] = vals_half[1] + bias_half[1]; - output_half[2] = vals_half[2] + bias_half[2]; - output_half[3] = vals_half[3] + bias_half[3]; - - in_data[iter_id] = output_arr[0]; - } - __syncthreads(); - - iteration_stride = blockDim.z * (blockDim.y >> 1); - int matrix_stride = (d0_out_stride * gridDim.x); - int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); - - int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_row = (iter * iteration_stride) + head_count; - int iter_offset = - (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; - output_vec[out_index + iter_offset] = - in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; - } -#endif -} - -// [B S C*H] - > C * [B A S N] -template <> -void launch_bias_add_transform_0213(float* output, - const float* vals, - const float* bias, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 2; - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); - - bias_add_transform_0213<<>>( - output, vals, bias, hidden_dim, seq_length, heads, head_ext); -} - -template <> -void launch_bias_add_transform_0213<__half>(__half* output, - const __half* vals, - const __half* bias, - int batch_size, - int seq_length, - int hidden_dim, - int heads, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 3; - if (hidden_dim > 128 || hidden_dim < 16) { - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 block_dim(hidden_dim / heads, (heads / head_ext)); - dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); - bias_add_transform_0213<__half><<>>( - output, vals, bias, hidden_dim, seq_length, heads, head_ext); - } else { - dim3 block_dim(hidden_dim / heads, heads, trans_count); - dim3 grid_dim(batch_size, seq_length / 2); - bias_add_transform_0213_v2<<>>( - output, vals, bias, hidden_dim, seq_length, heads); - } -} - -template -__global__ void transform4d_0213(T* out, - const T* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext); - -template <> -__global__ void transform4d_0213(float* out, - const float* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ - int d0_stride = hidden_dim * seq_length; - int d1_stride = d0_stride / heads; - int d2_stride = hidden_dim / heads; - - int d0_out_stride = d0_stride; - int d1_out_stride = d2_stride; - int d2_out_stride = hidden_dim; - - int d0 = blockIdx.x; // Batch - int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head - int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; - int cnt = blockIdx.z; - int d3 = threadIdx.x; // Values (groups of 8) - - if (d2 < seq_length) { - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + - d2 * d2_stride + d3]; - out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + - d2 * d2_out_stride * gridDim.z + d3] = vals_vec; - } -} - -template <> -__global__ void transform4d_0213<__half>(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim, - int head_ext) -{ -#if __CUDA_ARCH__ >= 700 - - int d0_stride = hidden_dim * (seq_length / head_ext); - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0 = blockIdx.x; // Batch - int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head - int d2 = blockIdx.z / head_ext; // Sequence - int cnt = blockIdx.y; // Hidden count - int d3 = threadIdx.x; // Values (groups of 8) - - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - in_vec += (cnt * d0_stride * gridDim.x); - in_vec += (d0 * d0_stride); - in_vec += (d2 * d2_stride); - in_vec += (d1 * d2_stride * seq_length); - - out_vec += (cnt * d1_stride); - out_vec += (d1 * d2_stride); - out_vec += (d0 * d0_stride * gridDim.y); - out_vec += (d2 * d1_stride * gridDim.y); - - out_vec[d3] = in_vec[d3]; - -#endif -} - -__global__ void transform4d_0213_v2(__half* out, - const __half* in, - int heads, - int seq_length, - int hidden_dim) -{ -#if __CUDA_ARCH__ >= 700 - __shared__ float4 in_data[3072]; - - int d0_stride = hidden_dim * seq_length; - int d1_stride = hidden_dim; - int d2_stride = hidden_dim / heads; - - int d0 = blockIdx.x; // Batch - int d1 = threadIdx.y; // Head - int d2 = blockIdx.y; // Sequence - int cnt = threadIdx.z; // Hidden count - int d3 = threadIdx.x; // Values (groups of 8) - - const float4* in_vec = reinterpret_cast(in); - float4* out_vec = reinterpret_cast(out); - - int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; - int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); - int iteration_stride = blockDim.z * (blockDim.y >> 1); - int matrix_stride = (d0_stride * gridDim.x); - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_row = iter * iteration_stride + head_count; - int iter_offset = (iter_row % blockDim.y) * d2_stride; - - in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = - in_vec[input_offset + iter_offset * seq_length + - (iter_row / blockDim.y) * matrix_stride]; - } - __syncthreads(); - - iteration_stride = d1_stride * blockDim.z; - int iter_index = cnt * d1_stride + d1 * d2_stride + d3; - int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); - -#pragma unroll - for (int iter = 0; iter < 2; iter++) { - int iter_id = iter * iteration_stride + iter_index; - out_vec[output_offset + iter_id] = in_data[iter_id]; - } -#endif -} - -// 3 * [B A S N] - > [B S C*H] -template <> -void launch_transform4d_0213(float* out, - const float* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 2; - dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); - dim3 block_dims(hidden_dim / heads, 8); - transform4d_0213 - <<>>(out, in, heads, seq_length, hidden_dim, 1); -} - -template <> -void launch_transform4d_0213<__half>(__half* out, - const __half* in, - int batch_size, - int heads, - int seq_length, - int hidden_dim, - cudaStream_t stream, - int trans_count) -{ - hidden_dim >>= 3; - if (hidden_dim > 128 || hidden_dim < 16) { - int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; - dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); - dim3 block_dims(hidden_dim / heads, (heads / head_ext)); - transform4d_0213<__half><<>>( - out, in, heads, seq_length, hidden_dim, head_ext); - } else { - dim3 grid_dims(batch_size, seq_length / 2); - dim3 block_dims(hidden_dim / heads, heads, trans_count); - transform4d_0213_v2<<>>( - out, in, heads, seq_length, hidden_dim); - } -} +#include "custom_cuda_layers.h" + +#define rows_trans 16 +#define cols_trans 16 + +template +__global__ void Transpose_Kernel(const T* inp, T* out, int row_width, int col_width) +{ + __shared__ T data_block[rows_trans * (cols_trans + 1)]; + + int r = threadIdx.x / cols_trans; + int c = threadIdx.x % cols_trans; + + int m = row_width / cols_trans; + + int i = blockIdx.x / m * rows_trans + r; + int j = blockIdx.x % m * cols_trans + c; + + int row_stride = rows_trans / ((rows_trans * cols_trans + THREADS - 1) / THREADS); + + for (int k = 0; k < rows_trans; k += row_stride) + data_block[(k + r) * cols_trans + c] = inp[(i + k) * row_width + j]; + + __syncthreads(); + + i = blockIdx.x % m * rows_trans + r; + j = blockIdx.x / m * cols_trans + c; + + for (int k = 0; k < rows_trans; k += row_stride) + out[(i + k) * col_width + j] = data_block[c * cols_trans + r + k]; +} + +template <> +void Transpose<__half>(const __half* inp_mat, + __half* out_mat, + int rows, + int cols, + cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<__half><<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template <> +void Transpose(const float* inp_mat, float* out_mat, int rows, int cols, cudaStream_t stream) +{ + int threads = THREADS; + + Transpose_Kernel<<<(rows * cols + threads - 1) / threads, threads, 0, stream>>>( + inp_mat, out_mat, cols, rows); +} + +template +__global__ void transform_0213(T* output, + const T* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void transform_0213(float* output, + const float* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = inputs; +} + +template <> +__global__ void transform_0213<__half>(__half* output, + const __half* vals, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / head_ext; // Sequence ID (0-127) + int d2 = threadIdx.y + (blockIdx.y % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = reinterpret_cast(output); + + vals_arr[0] = vals_vec[d0 * d0_stride + d1 * d1_stride + d2 * d2_stride + d3]; + output_vec[d0 * d0_out_stride + d1 * d1_out_stride + d2 * d2_out_stride + d3] = vals_arr[0]; +#endif +} + +template <> +void launch_transform_0213(float* output, + const float* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + + transform_0213 + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_transform_0213<__half>(__half* output, + const __half* vals, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, (seq_length * head_ext)); + transform_0213<__half> + <<>>(output, vals, hidden_dim, seq_length, heads, head_ext); +} + +// Bias add +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + + d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; + float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + float4 outputs; + outputs.x = inputs.x + biases.x; + outputs.y = inputs.y + biases.y; + outputs.z = inputs.z + biases.z; + outputs.w = inputs.w + biases.w; + + output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride + d3] = outputs; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +template <> +__global__ void bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr; + float4 bias_arr; + float4 output_arr; + __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * gridDim.x); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + +#if defined(__ACC_HALF__) + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; +#else + float2 bias_arr_f[4]; + float2 vals_arr_f[4]; +#pragma unroll + for (int l = 0; l < 4; l++) { + bias_arr_f[l] = __half22float2(bias_half[l]); + vals_arr_f[l] = __half22float2(vals_half[l]); + vals_arr_f[l].x += bias_arr_f[l].x; + vals_arr_f[l].y += bias_arr_f[l].y; + output_half[l] = __float22half2_rn(vals_arr_f[l]); + } +#endif + output_vec[d3] = output_arr; + +#endif +} + +__global__ void bias_add_transform_0213_v2(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 + int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = threadIdx.z; // blockIdx.z; // Hidden count + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + float4 bias_arr[1]; + float4 output_arr[1]; + __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + __syncthreads(); + + iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_out_stride * gridDim.x); + int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = + (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; + } +#endif +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + + bias_add_transform_0213<<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); +} + +template <> +void launch_bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int batch_size, + int seq_length, + int hidden_dim, + int heads, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + bias_add_transform_0213<__half><<>>( + output, vals, bias, hidden_dim, seq_length, heads, head_ext); + } else { + dim3 block_dim(hidden_dim / heads, heads, trans_count); + dim3 grid_dim(batch_size, seq_length / 2); + bias_add_transform_0213_v2<<>>( + output, vals, bias, hidden_dim, seq_length, heads); + } +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext); + +template <> +__global__ void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head + int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; + int cnt = blockIdx.z; + int d3 = threadIdx.x; // Values (groups of 8) + + if (d2 < seq_length) { + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * gridDim.z + d3] = vals_vec; + } +} + +template <> +__global__ void transform4d_0213<__half>(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head + int d2 = blockIdx.z / head_ext; // Sequence + int cnt = blockIdx.y; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * gridDim.x); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * gridDim.y); + out_vec += (d2 * d1_stride * gridDim.y); + + out_vec[d3] = in_vec[d3]; + +#endif +} + +__global__ void transform4d_0213_v2(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y; // Head + int d2 = blockIdx.y; // Sequence + int cnt = threadIdx.z; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); + int iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_stride * gridDim.x); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % blockDim.y) * d2_stride; + + in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = + in_vec[input_offset + iter_offset * seq_length + + (iter_row / blockDim.y) * matrix_stride]; + } + __syncthreads(); + + iteration_stride = d1_stride * blockDim.z; + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +#endif +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); + dim3 block_dims(hidden_dim / heads, 8); + transform4d_0213 + <<>>(out, in, heads, seq_length, hidden_dim, 1); +} + +template <> +void launch_transform4d_0213<__half>(__half* out, + const __half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + if (hidden_dim > 128 || hidden_dim < 16) { + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); + dim3 block_dims(hidden_dim / heads, (heads / head_ext)); + transform4d_0213<__half><<>>( + out, in, heads, seq_length, hidden_dim, head_ext); + } else { + dim3 grid_dims(batch_size, seq_length / 2); + dim3 block_dims(hidden_dim / heads, heads, trans_count); + transform4d_0213_v2<<>>( + out, in, heads, seq_length, hidden_dim); + } +} diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index fd56facc4343..15b262342d47 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -1,9 +1,9 @@ -# Copyright 2020 The Microsoft DeepSpeed Team - -PDSH_LAUNCHER = 'pdsh' -PDSH_MAX_FAN_OUT = 1024 - -OPENMPI_LAUNCHER = 'openmpi' - -MVAPICH_LAUNCHER = 'mvapich' -MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' +# Copyright 2020 The Microsoft DeepSpeed Team + +PDSH_LAUNCHER = 'pdsh' +PDSH_MAX_FAN_OUT = 1024 + +OPENMPI_LAUNCHER = 'openmpi' + +MVAPICH_LAUNCHER = 'mvapich' +MVAPICH_TMP_HOSTFILE = '/tmp/deepspeed_mvapich_hostfile' diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 4c4ac3b490d0..a8fc6fcc14b2 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -1,227 +1,227 @@ -import os -import sys -import shutil -import subprocess -import warnings -from abc import ABC, abstractmethod - -from ..utils import logger -from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE - - -class MultiNodeRunner(ABC): - def __init__(self, args, world_info_base64): - self.args = args - self.validate_args() - self.user_arguments = self.parse_user_args() - self.user_script = args.user_script - self.world_info_base64 = world_info_base64 - self.exports = {} - - @abstractmethod - def backend_exists(self): - """Return whether the corresponding backend exists""" - - @abstractmethod - def get_cmd(self, environment, active_resources): - """Return the command to execute on node""" - - def add_export(self, key, var): - self.exports[key.strip()] = var.strip() - - def parse_user_args(self): - return self.args.user_args - - @property - def name(self): - """Return the name of the backend""" - return self.__class__.__name__ - - def validate_args(self): - """Validate self.args""" - - -class PDSHRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64): - super().__init__(args, world_info_base64) - - def backend_exists(self): - return shutil.which('pdsh') - - @property - def name(self): - return "pdsh" - - def parse_user_args(self): - return list( - map(lambda x: x if x.startswith("-") else f"'{x}'", - self.args.user_args)) - - def get_cmd(self, environment, active_resources): - environment['PDSH_RCMD_TYPE'] = 'ssh' - - active_workers = ",".join(active_resources.keys()) - logger.info("Running on the following workers: %s" % active_workers) - - # PDSH flags for max node fan out and specific hosts to launch on - # See https://linux.die.net/man/1/pdsh for flag details - pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] - - exports = "" - for key, val in self.exports.items(): - exports += f"export {key}={val}; " - - # https://linux.die.net/man/1/pdsh - # %n will be replaced by pdsh command - deepspeed_launch = [ - exports, - f"cd {os.path.abspath('.')};", - sys.executable, - "-u", - "-m", - "deepspeed.launcher.launch", - f'--world_info={self.world_info_base64}', - "--node_rank=%n", - f"--master_addr={self.args.master_addr}", - f"--master_port={self.args.master_port}" - ] - - return pdsh_cmd_args + deepspeed_launch + [self.user_script - ] + self.user_arguments - - -class OpenMPIRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64, resource_pool): - super().__init__(args, world_info_base64) - self.resource_pool = resource_pool - self.add_export('UCX_TLS', 'tcp') - - def backend_exists(self): - #TODO: if IB is available we should suggestion mvapich - return shutil.which('ompi_info') - - @property - def name(self): - return "openmpi" - - def validate_args(self): - super().validate_args() - #TODO: Allow for include/exclude at node-level but not gpu-level - if self.args.include != "" or self.args.exclude != "": - raise ValueError( - f"{self.name} backend does not support worker include/exclusion") - if self.args.num_nodes != -1 or self.args.num_gpus != -1: - raise ValueError( - f"{self.name} backend does not support limiting num nodes/gpus") - - def get_cmd(self, environment, active_resources): - total_process_count = sum(self.resource_pool.values()) - - mpirun_cmd = [ - 'mpirun', - '-n', - f'{total_process_count}', - '-hostfile', - f'{self.args.hostfile}', - '--mca', - 'btl', - '^openib', - '--mca', - 'btl_tcp_if_include', - 'eth0', - ] - - export_cmd = [] - for k, v in self.exports.items(): - export_cmd += ['-x', f'{k}={v}'] - - python_exec = [sys.executable, "-u"] - - return mpirun_cmd + export_cmd + python_exec + [self.user_script - ] + self.user_arguments - - -class MVAPICHRunner(MultiNodeRunner): - def __init__(self, args, world_info_base64, resource_pool): - super().__init__(args, world_info_base64) - self.resource_pool = resource_pool - - # Disable the CMA kernel module, not available on Ubuntu systems - self.add_export('MV2_SMP_USE_CMA', '0') - - # If we fail this will output more verbose logging - self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') - - # Enabled cuda-aware communication - self.add_export('MV2_USE_CUDA', '1') - - # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ - self.add_export('MV2_SUPPORT_DL', '1') - - # Support MPI_THREAD_MULTIPLE - self.add_export('MV2_ENABLE_AFFINITY', '0') - - # Performance tuning flags for allgather - self.add_export('MV2_INTER_ALLGATHER_TUNING', '5') - self.add_export('MV2_CUDA_USE_NAIVE', '0') - - def backend_exists(self): - #TODO: if IB is available we should suggestion mvapich - mpiname_exists = shutil.which('mpiname') - exists = False - if not mpiname_exists: - warnings.warn("mpiname does not exist, mvapich is not installed properly") - else: - results = subprocess.check_output('mpiname', shell=True) - mpiname_results = results.decode('utf-8').strip() - if "MVAPICH2-GDR" in mpiname_results: - exists = True - else: - warnings.warn( - f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}" - ) - return exists - - @property - def name(self): - return "mvapich" - - def validate_args(self): - super().validate_args() - #TODO: Allow for include/exclude at node-level but not gpu-level - if self.args.include != "" or self.args.exclude != "": - raise ValueError( - f"{self.name} backend does not support worker include/exclusion") - if self.args.num_nodes != -1 or self.args.num_gpus != -1: - raise ValueError( - f"{self.name} backend does not support limiting num nodes/gpus") - - def get_cmd(self, environment, active_resources): - devices_per_node = self.resource_pool.values() - total_process_count = sum(devices_per_node) - process_per_node = list(devices_per_node)[0] - if not all([n == process_per_node for n in devices_per_node]): - raise ValueError("mvapich requires same number of devices per node") - - with open(MVAPICH_TMP_HOSTFILE, 'w') as fd: - for host in self.resource_pool.keys(): - fd.write(f'{host}\n') - - mpirun_cmd = [ - 'mpirun', - '-np', - f'{total_process_count}', - '-ppn', - f'{process_per_node}', - '--hostfile', - f'{MVAPICH_TMP_HOSTFILE}', - ] - - export_cmd = [] - for k, v in self.exports.items(): - export_cmd += ['-env', f'{k}={v}'] - - python_exec = [sys.executable, "-u"] - - return mpirun_cmd + export_cmd + python_exec + [self.user_script - ] + self.user_arguments +import os +import sys +import shutil +import subprocess +import warnings +from abc import ABC, abstractmethod + +from ..utils import logger +from .constants import PDSH_MAX_FAN_OUT, MVAPICH_TMP_HOSTFILE + + +class MultiNodeRunner(ABC): + def __init__(self, args, world_info_base64): + self.args = args + self.validate_args() + self.user_arguments = self.parse_user_args() + self.user_script = args.user_script + self.world_info_base64 = world_info_base64 + self.exports = {} + + @abstractmethod + def backend_exists(self): + """Return whether the corresponding backend exists""" + + @abstractmethod + def get_cmd(self, environment, active_resources): + """Return the command to execute on node""" + + def add_export(self, key, var): + self.exports[key.strip()] = var.strip() + + def parse_user_args(self): + return self.args.user_args + + @property + def name(self): + """Return the name of the backend""" + return self.__class__.__name__ + + def validate_args(self): + """Validate self.args""" + + +class PDSHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64): + super().__init__(args, world_info_base64) + + def backend_exists(self): + return shutil.which('pdsh') + + @property + def name(self): + return "pdsh" + + def parse_user_args(self): + return list( + map(lambda x: x if x.startswith("-") else f"'{x}'", + self.args.user_args)) + + def get_cmd(self, environment, active_resources): + environment['PDSH_RCMD_TYPE'] = 'ssh' + + active_workers = ",".join(active_resources.keys()) + logger.info("Running on the following workers: %s" % active_workers) + + # PDSH flags for max node fan out and specific hosts to launch on + # See https://linux.die.net/man/1/pdsh for flag details + pdsh_cmd_args = ['pdsh', '-f', str(PDSH_MAX_FAN_OUT), '-w', active_workers] + + exports = "" + for key, val in self.exports.items(): + exports += f"export {key}={val}; " + + # https://linux.die.net/man/1/pdsh + # %n will be replaced by pdsh command + deepspeed_launch = [ + exports, + f"cd {os.path.abspath('.')};", + sys.executable, + "-u", + "-m", + "deepspeed.launcher.launch", + f'--world_info={self.world_info_base64}', + "--node_rank=%n", + f"--master_addr={self.args.master_addr}", + f"--master_port={self.args.master_port}" + ] + + return pdsh_cmd_args + deepspeed_launch + [self.user_script + ] + self.user_arguments + + +class OpenMPIRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + self.add_export('UCX_TLS', 'tcp') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + return shutil.which('ompi_info') + + @property + def name(self): + return "openmpi" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError( + f"{self.name} backend does not support worker include/exclusion") + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError( + f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + total_process_count = sum(self.resource_pool.values()) + + mpirun_cmd = [ + 'mpirun', + '-n', + f'{total_process_count}', + '-hostfile', + f'{self.args.hostfile}', + '--mca', + 'btl', + '^openib', + '--mca', + 'btl_tcp_if_include', + 'eth0', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-x', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments + + +class MVAPICHRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + + # Disable the CMA kernel module, not available on Ubuntu systems + self.add_export('MV2_SMP_USE_CMA', '0') + + # If we fail this will output more verbose logging + self.add_export('MV2_DEBUG_SHOW_BACKTRACE', '1') + + # Enabled cuda-aware communication + self.add_export('MV2_USE_CUDA', '1') + + # Support deep learning frameworks: http://hidl.cse.ohio-state.edu/userguide/horovod/ + self.add_export('MV2_SUPPORT_DL', '1') + + # Support MPI_THREAD_MULTIPLE + self.add_export('MV2_ENABLE_AFFINITY', '0') + + # Performance tuning flags for allgather + self.add_export('MV2_INTER_ALLGATHER_TUNING', '5') + self.add_export('MV2_CUDA_USE_NAIVE', '0') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + mpiname_exists = shutil.which('mpiname') + exists = False + if not mpiname_exists: + warnings.warn("mpiname does not exist, mvapich is not installed properly") + else: + results = subprocess.check_output('mpiname', shell=True) + mpiname_results = results.decode('utf-8').strip() + if "MVAPICH2-GDR" in mpiname_results: + exists = True + else: + warnings.warn( + f"Expected MVAPICH2-GDR as return for mpiname but received {mpiname_results}" + ) + return exists + + @property + def name(self): + return "mvapich" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError( + f"{self.name} backend does not support worker include/exclusion") + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError( + f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + devices_per_node = self.resource_pool.values() + total_process_count = sum(devices_per_node) + process_per_node = list(devices_per_node)[0] + if not all([n == process_per_node for n in devices_per_node]): + raise ValueError("mvapich requires same number of devices per node") + + with open(MVAPICH_TMP_HOSTFILE, 'w') as fd: + for host in self.resource_pool.keys(): + fd.write(f'{host}\n') + + mpirun_cmd = [ + 'mpirun', + '-np', + f'{total_process_count}', + '-ppn', + f'{process_per_node}', + '--hostfile', + f'{MVAPICH_TMP_HOSTFILE}', + ] + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-env', f'{k}={v}'] + + python_exec = [sys.executable, "-u"] + + return mpirun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments diff --git a/deepspeed/module_inject/module_quantize.py b/deepspeed/module_inject/module_quantize.py index 26c5422840fc..fde6990eba28 100755 --- a/deepspeed/module_inject/module_quantize.py +++ b/deepspeed/module_inject/module_quantize.py @@ -1,80 +1,80 @@ -import copy -import torch -import deepspeed - - -def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False): - """ Quantize bert-style transformer layers with DeepSpeed's transformer layer - Arguments: - orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, - e.g., transformers.modeling_bert.BertLayer. - model (torch.nn.Module): user's nn.module representing their model - - megatron (bool): megatron model-parallel implementation (this is supported for inference only) - preln (bool): does the original layer implementation do pre or post layer norm? - - Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag. - - Returns: - Updated nn.module with quantized transformer layers - """ - def quantize_weight(weight): - return weight.to(torch.int8) - - def megatron_layer_quantize(layer): - layer.attention.query_key_value.weight.data = quantize_weight( - layer.attention.query_key_value.weight.data) - layer.attention.dense.weight.data = quantize_weight( - layer.attention.dense.weight.data) - layer.mlp.dense_h_to_4h.weight.data = quantize_weight( - layer.mlp.dense_h_to_4h.weight.data) - layer.mlp.dense_4h_to_h.weight.data = quantize_weight( - layer.mlp.dense_4h_to_h.weight.data) - - def bert_layer_quantize(layer): - layer.attention.self.query.weight.data = quantize_weight( - layer.attention.self.query.weight.data) - layer.attention.self.key.weight.data = quantize_weight( - layer.attention.self.key.weight.data) - layer.attention.self.value.weight.data = quantize_weight( - layer.attention.self.value.weight.data) - layer.attention.output.dense.weight.data = quantize_weight( - layer.attention.output.dense.weight.data) - if preln: - layer.intermediate.dense_act.weight.data = quantize_weight( - layer.intermediate.dense_act.weight.data) - else: - layer.intermediate.dense.weight.data = quantize_weight( - layer.intermediate.dense.weight.data) - layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data) - - def quantize_fn(child): - if megatron: - # Quantize megatron GPT2 / GPT3 trained model - megatron_layer_quantize(child) - else: - # Quantize either DeepSpeed or HuggingFace trained model - bert_layer_quantize(child) - - return child - - return quantize_module(model=model, - orig_class=orig_layer_impl, - quantize_fn=quantize_fn) - - -def quantize_module(model, orig_class, quantize_fn): - policy = {orig_class: quantize_fn} - return _quantize_module(model, policy) - - -def _quantize_module(model, policies): - for name, child in model.named_children(): - if child.__class__ in policies: - orig = repr(child) - setattr(model, name, policies[child.__class__](child)) - new = getattr(model, name) - else: - _quantize_module(child, policies) - - return model +import copy +import torch +import deepspeed + + +def quantize_transformer_layer(orig_layer_impl, model, megatron=False, preln=False): + """ Quantize bert-style transformer layers with DeepSpeed's transformer layer + Arguments: + orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for, + e.g., transformers.modeling_bert.BertLayer. + model (torch.nn.Module): user's nn.module representing their model + + megatron (bool): megatron model-parallel implementation (this is supported for inference only) + preln (bool): does the original layer implementation do pre or post layer norm? + + Note: For Bert kind of models, we inject based on the DeepSpeed-Example models, if not setting huggingface flag. + + Returns: + Updated nn.module with quantized transformer layers + """ + def quantize_weight(weight): + return weight.to(torch.int8) + + def megatron_layer_quantize(layer): + layer.attention.query_key_value.weight.data = quantize_weight( + layer.attention.query_key_value.weight.data) + layer.attention.dense.weight.data = quantize_weight( + layer.attention.dense.weight.data) + layer.mlp.dense_h_to_4h.weight.data = quantize_weight( + layer.mlp.dense_h_to_4h.weight.data) + layer.mlp.dense_4h_to_h.weight.data = quantize_weight( + layer.mlp.dense_4h_to_h.weight.data) + + def bert_layer_quantize(layer): + layer.attention.self.query.weight.data = quantize_weight( + layer.attention.self.query.weight.data) + layer.attention.self.key.weight.data = quantize_weight( + layer.attention.self.key.weight.data) + layer.attention.self.value.weight.data = quantize_weight( + layer.attention.self.value.weight.data) + layer.attention.output.dense.weight.data = quantize_weight( + layer.attention.output.dense.weight.data) + if preln: + layer.intermediate.dense_act.weight.data = quantize_weight( + layer.intermediate.dense_act.weight.data) + else: + layer.intermediate.dense.weight.data = quantize_weight( + layer.intermediate.dense.weight.data) + layer.output.dense.weight.data = quantize_weight(layer.output.dense.weight.data) + + def quantize_fn(child): + if megatron: + # Quantize megatron GPT2 / GPT3 trained model + megatron_layer_quantize(child) + else: + # Quantize either DeepSpeed or HuggingFace trained model + bert_layer_quantize(child) + + return child + + return quantize_module(model=model, + orig_class=orig_layer_impl, + quantize_fn=quantize_fn) + + +def quantize_module(model, orig_class, quantize_fn): + policy = {orig_class: quantize_fn} + return _quantize_module(model, policy) + + +def _quantize_module(model, policies): + for name, child in model.named_children(): + if child.__class__ in policies: + orig = repr(child) + setattr(model, name, policies[child.__class__](child)) + new = getattr(model, name) + else: + _quantize_module(child, policies) + + return model diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 3758ffd9b522..cda2a685d43e 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -1,239 +1,239 @@ -from abc import ABC - -import torch - - -class DSPolicy(ABC): - def __init__(self, inference=True, linear_layer=True, scale_attention=True): - self.inference = inference - self.linear_layer = linear_layer - self.scale_attention = scale_attention - - def attention(self): - """ - Returns attention qkv and dense parameters - weight: (3*hidden, hidden) and (hidden, hidden) - bias: (3*hidden) and (hidden) - """ - raise NotImplementedError - - def get_hidden_heads(self): - """ - return hidden_size and number of heads - """ - raise NotImplementedError - - def mlp(self): - """ - Returns mlp intermediate and output - weight: (intermediate, hidden) and (hidden, intermediate) - bias: (intermediate) and (hidden) - """ - raise NotImplementedError - - def layerNorm(self): - """ - Returns LayerNorms used in transformer layer - Post-Attention and pre/post layer norm - gamma and beta with shape: (hidden) - """ - raise NotImplementedError - - -class HFBertLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=False, preln=False): - super().__init__(inference) - self.client_module = client_module - self.preln = preln - if HFBertLayerPolicy._orig_layer_class is None: - try: - import transformers - HFBertLayerPolicy._orig_layer_class = transformers.models.bert.modeling_bert.BertLayer - except: - HFBertLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attention.self.query.weight.data.shape[1], \ - self.client_module.attention.self.num_attention_heads - - def attention(self): - qw = self.client_module.attention.self.query.weight.data - qb = self.client_module.attention.self.query.bias.data - kw = self.client_module.attention.self.key.weight.data - kb = self.client_module.attention.self.key.bias.data - vw = self.client_module.attention.self.value.weight.data - vb = self.client_module.attention.self.value.bias.data - - qkvw = torch.cat((qw, kw, vw), dim=0) - qkvb = torch.cat((qb, kb, vb), dim=0) - - return self.linear_layer, \ - qkvw, \ - qkvb, \ - self.client_module.attention.output.dense.weight.data, \ - self.client_module.attention.output.dense.bias.data, \ - self.scale_attention - - def mlp(self): - if self.preln: - intermediate_ff = self.client_module.intermediate.dense_act - else: - intermediate_ff = self.client_module.intermediate.dense - - return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \ - self.client_module.output.dense.weight.data, \ - self.client_module.output.dense.bias.data - - def layerNorm(self): - if self.preln: - attention_layernorm = self.client_module.PostAttentionLayerNorm - transformer_layernorm = self.client_module.PreAttentionLayerNorm - else: - attention_layernorm = self.client_module.attention.output.LayerNorm - transformer_layernorm = self.client_module.output.LayerNorm - return attention_layernorm.weight.data, \ - attention_layernorm.bias.data, \ - transformer_layernorm.weight.data, \ - transformer_layernorm.bias.data - - -class HFGPTNEOLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=True): - super().__init__(inference, scale_attention=False) - self.client_module = client_module - try: - import transformers - HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock - except: - HFGPTNEOLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attn.attention.q_proj.weight.data.shape[1], \ - self.client_module.attn.attention.num_heads - - def attention(self): - qw = self.client_module.attn.attention.q_proj.weight.data - kw = self.client_module.attn.attention.k_proj.weight.data - vw = self.client_module.attn.attention.v_proj.weight.data - - qkvw = torch.cat((qw, kw, vw), dim=0) - - return self.linear_layer, \ - qkvw, \ - None, \ - self.client_module.attn.attention.out_proj.weight.data, \ - self.client_module.attn.attention.out_proj.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.c_fc.weight.data, \ - self.client_module.mlp.c_fc.bias.data, \ - self.client_module.mlp.c_proj.weight.data, \ - self.client_module.mlp.c_proj.bias.data - - def layerNorm(self): - return self.client_module.ln_2.weight.data, \ - self.client_module.ln_2.bias.data, \ - self.client_module.ln_1.weight.data, \ - self.client_module.ln_1.bias.data - - -class MegatronLayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, version=0, inference=True): - super().__init__(inference) - self.client_module = client_module - # we use megatron version to differentiate between the old and new - # megatron-lm source code - self.version = version - if MegatronLayerPolicy._orig_layer_class is None: - try: - import megatron - from megatron.model.transformer import ParallelTransformerLayer - MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer - except ImportError: - MegatronLayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attention.query_key_value.weight.data.shape[1], \ - self.client_module.attention.num_attention_heads - - def attention(self): - if self.inference: - if self.version == 0: - attention = self.client_module.attention - else: - attention = self.client_module.self_attention - - return self.linear_layer, \ - attention.query_key_value.weight.data, \ - attention.query_key_value.bias.data, \ - attention.dense.weight.data, \ - attention.dense.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.dense_h_to_4h.weight.data, \ - self.client_module.mlp.dense_h_to_4h.bias.data, \ - self.client_module.mlp.dense_4h_to_h.weight.data, \ - self.client_module.mlp.dense_4h_to_h.bias.data - - def layerNorm(self): - return self.client_module.post_attention_layernorm.weight.data, \ - self.client_module.post_attention_layernorm.bias.data, \ - self.client_module.input_layernorm.weight.data, \ - self.client_module.input_layernorm.bias.data - - -class HFGPT2LayerPolicy(DSPolicy): - _orig_layer_class = None - - def __init__(self, client_module, inference=True): - # HuggingFace GPT2 uses convolutional layer instead of linear layer - super().__init__(inference, linear_layer=False) - self.client_module = client_module - try: - import transformers - HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block - except ImportError: - HFGPT2LayerPolicy._orig_layer_class = None - - def get_hidden_heads(self): - return self.client_module.attn.embed_dim, \ - self.client_module.attn.num_heads - - def attention(self): - return self.linear_layer, \ - self.client_module.attn.c_attn.weight.data, \ - self.client_module.attn.c_attn.bias.data, \ - self.client_module.attn.c_proj.weight.data, \ - self.client_module.attn.c_proj.bias.data, \ - self.scale_attention - - def mlp(self): - return self.linear_layer, \ - self.client_module.mlp.c_fc.weight.data, \ - self.client_module.mlp.c_fc.bias.data, \ - self.client_module.mlp.c_proj.weight.data, \ - self.client_module.mlp.c_proj.bias.data - - def layerNorm(self): - return self.client_module.ln_2.weight.data, \ - self.client_module.ln_2.bias.data, \ - self.client_module.ln_1.weight.data, \ - self.client_module.ln_1.bias.data - - -replace_policies = [ - HFBertLayerPolicy, - HFGPTNEOLayerPolicy, - MegatronLayerPolicy, - HFGPT2LayerPolicy, -] +from abc import ABC + +import torch + + +class DSPolicy(ABC): + def __init__(self, inference=True, linear_layer=True, scale_attention=True): + self.inference = inference + self.linear_layer = linear_layer + self.scale_attention = scale_attention + + def attention(self): + """ + Returns attention qkv and dense parameters + weight: (3*hidden, hidden) and (hidden, hidden) + bias: (3*hidden) and (hidden) + """ + raise NotImplementedError + + def get_hidden_heads(self): + """ + return hidden_size and number of heads + """ + raise NotImplementedError + + def mlp(self): + """ + Returns mlp intermediate and output + weight: (intermediate, hidden) and (hidden, intermediate) + bias: (intermediate) and (hidden) + """ + raise NotImplementedError + + def layerNorm(self): + """ + Returns LayerNorms used in transformer layer + Post-Attention and pre/post layer norm + gamma and beta with shape: (hidden) + """ + raise NotImplementedError + + +class HFBertLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=False, preln=False): + super().__init__(inference) + self.client_module = client_module + self.preln = preln + if HFBertLayerPolicy._orig_layer_class is None: + try: + import transformers + HFBertLayerPolicy._orig_layer_class = transformers.models.bert.modeling_bert.BertLayer + except: + HFBertLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attention.self.query.weight.data.shape[1], \ + self.client_module.attention.self.num_attention_heads + + def attention(self): + qw = self.client_module.attention.self.query.weight.data + qb = self.client_module.attention.self.query.bias.data + kw = self.client_module.attention.self.key.weight.data + kb = self.client_module.attention.self.key.bias.data + vw = self.client_module.attention.self.value.weight.data + vb = self.client_module.attention.self.value.bias.data + + qkvw = torch.cat((qw, kw, vw), dim=0) + qkvb = torch.cat((qb, kb, vb), dim=0) + + return self.linear_layer, \ + qkvw, \ + qkvb, \ + self.client_module.attention.output.dense.weight.data, \ + self.client_module.attention.output.dense.bias.data, \ + self.scale_attention + + def mlp(self): + if self.preln: + intermediate_ff = self.client_module.intermediate.dense_act + else: + intermediate_ff = self.client_module.intermediate.dense + + return self.linear_layer, intermediate_ff.weight.data, intermediate_ff.bias.data, \ + self.client_module.output.dense.weight.data, \ + self.client_module.output.dense.bias.data + + def layerNorm(self): + if self.preln: + attention_layernorm = self.client_module.PostAttentionLayerNorm + transformer_layernorm = self.client_module.PreAttentionLayerNorm + else: + attention_layernorm = self.client_module.attention.output.LayerNorm + transformer_layernorm = self.client_module.output.LayerNorm + return attention_layernorm.weight.data, \ + attention_layernorm.bias.data, \ + transformer_layernorm.weight.data, \ + transformer_layernorm.bias.data + + +class HFGPTNEOLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + super().__init__(inference, scale_attention=False) + self.client_module = client_module + try: + import transformers + HFGPTNEOLayerPolicy._orig_layer_class = transformers.models.gpt_neo.modeling_gpt_neo.GPTNeoBlock + except: + HFGPTNEOLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attn.attention.q_proj.weight.data.shape[1], \ + self.client_module.attn.attention.num_heads + + def attention(self): + qw = self.client_module.attn.attention.q_proj.weight.data + kw = self.client_module.attn.attention.k_proj.weight.data + vw = self.client_module.attn.attention.v_proj.weight.data + + qkvw = torch.cat((qw, kw, vw), dim=0) + + return self.linear_layer, \ + qkvw, \ + None, \ + self.client_module.attn.attention.out_proj.weight.data, \ + self.client_module.attn.attention.out_proj.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.c_fc.weight.data, \ + self.client_module.mlp.c_fc.bias.data, \ + self.client_module.mlp.c_proj.weight.data, \ + self.client_module.mlp.c_proj.bias.data + + def layerNorm(self): + return self.client_module.ln_2.weight.data, \ + self.client_module.ln_2.bias.data, \ + self.client_module.ln_1.weight.data, \ + self.client_module.ln_1.bias.data + + +class MegatronLayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, version=0, inference=True): + super().__init__(inference) + self.client_module = client_module + # we use megatron version to differentiate between the old and new + # megatron-lm source code + self.version = version + if MegatronLayerPolicy._orig_layer_class is None: + try: + import megatron + from megatron.model.transformer import ParallelTransformerLayer + MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer + except ImportError: + MegatronLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attention.query_key_value.weight.data.shape[1], \ + self.client_module.attention.num_attention_heads + + def attention(self): + if self.inference: + if self.version == 0: + attention = self.client_module.attention + else: + attention = self.client_module.self_attention + + return self.linear_layer, \ + attention.query_key_value.weight.data, \ + attention.query_key_value.bias.data, \ + attention.dense.weight.data, \ + attention.dense.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.dense_h_to_4h.weight.data, \ + self.client_module.mlp.dense_h_to_4h.bias.data, \ + self.client_module.mlp.dense_4h_to_h.weight.data, \ + self.client_module.mlp.dense_4h_to_h.bias.data + + def layerNorm(self): + return self.client_module.post_attention_layernorm.weight.data, \ + self.client_module.post_attention_layernorm.bias.data, \ + self.client_module.input_layernorm.weight.data, \ + self.client_module.input_layernorm.bias.data + + +class HFGPT2LayerPolicy(DSPolicy): + _orig_layer_class = None + + def __init__(self, client_module, inference=True): + # HuggingFace GPT2 uses convolutional layer instead of linear layer + super().__init__(inference, linear_layer=False) + self.client_module = client_module + try: + import transformers + HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block + except ImportError: + HFGPT2LayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + return self.client_module.attn.embed_dim, \ + self.client_module.attn.num_heads + + def attention(self): + return self.linear_layer, \ + self.client_module.attn.c_attn.weight.data, \ + self.client_module.attn.c_attn.bias.data, \ + self.client_module.attn.c_proj.weight.data, \ + self.client_module.attn.c_proj.bias.data, \ + self.scale_attention + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.c_fc.weight.data, \ + self.client_module.mlp.c_fc.bias.data, \ + self.client_module.mlp.c_proj.weight.data, \ + self.client_module.mlp.c_proj.bias.data + + def layerNorm(self): + return self.client_module.ln_2.weight.data, \ + self.client_module.ln_2.bias.data, \ + self.client_module.ln_1.weight.data, \ + self.client_module.ln_1.bias.data + + +replace_policies = [ + HFBertLayerPolicy, + HFGPTNEOLayerPolicy, + MegatronLayerPolicy, + HFGPT2LayerPolicy, +] diff --git a/deepspeed/ops/adagrad/cpu_adagrad.py b/deepspeed/ops/adagrad/cpu_adagrad.py index 4c86a215893b..44b052fa3f67 100755 --- a/deepspeed/ops/adagrad/cpu_adagrad.py +++ b/deepspeed/ops/adagrad/cpu_adagrad.py @@ -1,135 +1,135 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team -''' - -import math -import torch -import time -from pathlib import Path -from ..op_builder import CPUAdagradBuilder -from deepspeed.utils.logging import should_log_le - - -class DeepSpeedCPUAdagrad(torch.optim.Optimizer): - optimizer_id = 0 - - def __init__(self, - model_params, - lr=1e-2, - eps=1e-10, - weight_decay=0, - amsgrad=False, - fp32_optimizer_states=True): - - default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) - super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args) - - self.opt_id = DeepSpeedCPUAdagrad.optimizer_id - DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 - self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adagrad = CPUAdagradBuilder().load() - - self.ds_opt_adagrad.create_adagrad(self.opt_id, - lr, - eps, - weight_decay, - should_log_le("info")) - - def __del__(self): - # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize - # is used multiple times in the same process (notebook or pytest worker) - self.ds_opt_adagrad.destroy_adagrad(self.opt_id) - - def __setstate__(self, state): - super(DeepSpeedCPUAdagrad, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - - if p.grad is None: - continue - - state = self.state[p] - # State initialization - if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - - #memory_format=torch.preserve_format) - # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - - state['step'] += 1 - - if p.grad.is_sparse == True: - sparse_param = p.sparse_mask(p.grad) - sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad) - self.ds_opt_adagrad.adagrad_update(self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - sparse_param.values(), - p.grad.values(), - sparse_exp_avg_sq.values()) - p[sparse_param.indices()] = sparse_param.values() - state['exp_avg_sq'][ - sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() - if fp16_param_groups is not None: - fp16_param_groups[group_id][param_id][ - sparse_param.indices()] = sparse_param.values() - else: - if fp16_param_groups is not None: - self.ds_opt_adagrad.adagrad_update_copy( - self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - p.data, - p.grad.data, - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adagrad.adagrad_update(self.opt_id, - state['step'], - group['lr'], - group['eps'], - group['weight_decay'], - p.data, - p.grad.data, - state['exp_avg_sq']) - return loss +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import math +import torch +import time +from pathlib import Path +from ..op_builder import CPUAdagradBuilder +from deepspeed.utils.logging import should_log_le + + +class DeepSpeedCPUAdagrad(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, + model_params, + lr=1e-2, + eps=1e-10, + weight_decay=0, + amsgrad=False, + fp32_optimizer_states=True): + + default_args = dict(lr=lr, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) + super(DeepSpeedCPUAdagrad, self).__init__(model_params, default_args) + + self.opt_id = DeepSpeedCPUAdagrad.optimizer_id + DeepSpeedCPUAdagrad.optimizer_id = DeepSpeedCPUAdagrad.optimizer_id + 1 + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_adagrad = CPUAdagradBuilder().load() + + self.ds_opt_adagrad.create_adagrad(self.opt_id, + lr, + eps, + weight_decay, + should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_adagrad.destroy_adagrad(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPUAdagrad, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None, fp16_param_groups=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + fp16_param_groups: FP16 GPU parameters to update. Performing the + copy here reduces communication time. Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + + state['step'] += 1 + + if p.grad.is_sparse == True: + sparse_param = p.sparse_mask(p.grad) + sparse_exp_avg_sq = state['exp_avg_sq'].sparse_mask(p.grad) + self.ds_opt_adagrad.adagrad_update(self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + sparse_param.values(), + p.grad.values(), + sparse_exp_avg_sq.values()) + p[sparse_param.indices()] = sparse_param.values() + state['exp_avg_sq'][ + sparse_exp_avg_sq.indices()] = sparse_exp_avg_sq.values() + if fp16_param_groups is not None: + fp16_param_groups[group_id][param_id][ + sparse_param.indices()] = sparse_param.values() + else: + if fp16_param_groups is not None: + self.ds_opt_adagrad.adagrad_update_copy( + self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + p.data, + p.grad.data, + state['exp_avg_sq'], + fp16_param_groups[group_id][param_id].data) + else: + self.ds_opt_adagrad.adagrad_update(self.opt_id, + state['step'], + group['lr'], + group['eps'], + group['weight_decay'], + p.data, + p.grad.data, + state['exp_avg_sq']) + return loss diff --git a/deepspeed/ops/adam/__init__.py b/deepspeed/ops/adam/__init__.py index 6e620b36bd8e..6ab6cbd37f35 100755 --- a/deepspeed/ops/adam/__init__.py +++ b/deepspeed/ops/adam/__init__.py @@ -1,2 +1,2 @@ -from .cpu_adam import DeepSpeedCPUAdam -from .fused_adam import FusedAdam +from .cpu_adam import DeepSpeedCPUAdam +from .fused_adam import FusedAdam diff --git a/deepspeed/ops/adam/cpu_adam.py b/deepspeed/ops/adam/cpu_adam.py index 5d6b597141c3..9304cdeacbde 100755 --- a/deepspeed/ops/adam/cpu_adam.py +++ b/deepspeed/ops/adam/cpu_adam.py @@ -1,186 +1,186 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team -''' - -import math -import torch -import time -from pathlib import Path -from ..op_builder import CPUAdamBuilder -from deepspeed.utils.logging import should_log_le - - -class DeepSpeedCPUAdam(torch.optim.Optimizer): - optimizer_id = 0 - - def __init__(self, - model_params, - lr=1e-3, - bias_correction=True, - betas=(0.9, - 0.999), - eps=1e-8, - weight_decay=0, - amsgrad=False, - adamw_mode=True, - fp32_optimizer_states=True): - """Fast vectorized implementation of two variations of Adam optimizer on CPU: - - * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); - * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101) - - DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W). - In order to apply this optimizer, the model requires to have its master parameter (in FP32) - reside on the CPU memory. - - To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers - the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory, - with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize - the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial - (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. - - For calling step function, there are two options available: (1) update optimizer's states and (2) update - optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second - option can bring 30% higher throughput than the doing the copy separately using option one. - - - .. note:: - We recommend using our `config - `_ - to allow :meth:`deepspeed.initialize` to build this optimizer - for you. - - - Arguments: - model_params (iterable): iterable of parameters to optimize or dicts defining - parameter groups. - lr (float, optional): learning rate. (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square. (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability. (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - amsgrad (boolean, optional): whether to use the AMSGrad variant of this - algorithm from the paper `On the Convergence of Adam and Beyond`_ - (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! - adamw_mode: select between Adam and AdamW implementations (default: AdamW) - full_precision_optimizer_states: creates momementum and variance in full precision regardless of - the precision of the parameters (default: True) - """ - - default_args = dict(lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - bias_correction=bias_correction, - amsgrad=amsgrad) - super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) - - self.opt_id = DeepSpeedCPUAdam.optimizer_id - DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 - self.adam_w_mode = adamw_mode - self.fp32_optimizer_states = fp32_optimizer_states - self.ds_opt_adam = CPUAdamBuilder().load() - - self.ds_opt_adam.create_adam(self.opt_id, - lr, - betas[0], - betas[1], - eps, - weight_decay, - adamw_mode, - should_log_le("info")) - - def __del__(self): - # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize - # is used multiple times in the same process (notebook or pytest worker) - self.ds_opt_adam.destroy_adam(self.opt_id) - - def __setstate__(self, state): - super(DeepSpeedCPUAdam, self).__setstate__(state) - for group in self.param_groups: - group.setdefault('amsgrad', False) - - @torch.no_grad() - def step(self, closure=None, fp16_param_groups=None): - """Update the model parameters. - - .. note:: - This method will be called internally by ZeRO-Offload. DeepSpeed - users should still use ``engine.step()`` as shown in the - `Getting Started - `_ guide. - - Args: - closure (callable, optional): closure to compute the loss. - Defaults to ``None``. - fp16_param_groups: FP16 GPU parameters to update. Performing the - copy here reduces communication time. Defaults to ``None``. - - Returns: - loss: if ``closure`` is provided. Otherwise ``None``. - """ - - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group_id, group in enumerate(self.param_groups): - for param_id, p in enumerate(group['params']): - - if p.grad is None: - continue - - state = self.state[p] - # State initialization - if len(state) == 0: - #print(f'group {group_id} param {param_id} = {p.numel()}') - state['step'] = 0 - - #use full precision by default unless self.fp32_optimizer_states is off - state_dtype = torch.float if self.fp32_optimizer_states else p.dtype - - # gradient momentums - state['exp_avg'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p.data, - dtype=state_dtype, - device='cpu') - #memory_format=torch.preserve_format) - - state['step'] += 1 - beta1, beta2 = group['betas'] - - if fp16_param_groups is not None: - self.ds_opt_adam.adam_update_copy( - self.opt_id, - state['step'], - group['lr'], - beta1, - beta2, - group['eps'], - group['weight_decay'], - group['bias_correction'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq'], - fp16_param_groups[group_id][param_id].data) - else: - self.ds_opt_adam.adam_update(self.opt_id, - state['step'], - group['lr'], - beta1, - beta2, - group['eps'], - group['weight_decay'], - group['bias_correction'], - p.data, - p.grad.data, - state['exp_avg'], - state['exp_avg_sq']) - return loss +''' +Copyright 2020 The Microsoft DeepSpeed Team +''' + +import math +import torch +import time +from pathlib import Path +from ..op_builder import CPUAdamBuilder +from deepspeed.utils.logging import should_log_le + + +class DeepSpeedCPUAdam(torch.optim.Optimizer): + optimizer_id = 0 + + def __init__(self, + model_params, + lr=1e-3, + bias_correction=True, + betas=(0.9, + 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + adamw_mode=True, + fp32_optimizer_states=True): + """Fast vectorized implementation of two variations of Adam optimizer on CPU: + + * Adam: A Method for Stochastic Optimization: (https://arxiv.org/abs/1412.6980); + * AdamW: Fixing Weight Decay Regularization in Adam (https://arxiv.org/abs/1711.05101) + + DeepSpeed CPU Adam(W) provides between 5x to 7x speedup over torch.optim.adam(W). + In order to apply this optimizer, the model requires to have its master parameter (in FP32) + reside on the CPU memory. + + To train on a heterogeneous system, such as coordinating CPU and GPU, DeepSpeed offers + the ZeRO-Offload technology which efficiently offloads the optimizer states into CPU memory, + with minimal impact on training throughput. DeepSpeedCPUAdam plays an important role to minimize + the overhead of the optimizer's latency on CPU. Please refer to ZeRO-Offload tutorial + (https://www.deepspeed.ai/tutorials/zero-offload/) for more information on how to enable this technology. + + For calling step function, there are two options available: (1) update optimizer's states and (2) update + optimizer's states and copy the parameters back to GPU at the same time. We have seen that the second + option can bring 30% higher throughput than the doing the copy separately using option one. + + + .. note:: + We recommend using our `config + `_ + to allow :meth:`deepspeed.initialize` to build this optimizer + for you. + + + Arguments: + model_params (iterable): iterable of parameters to optimize or dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED in DeepSpeed CPUAdam! + adamw_mode: select between Adam and AdamW implementations (default: AdamW) + full_precision_optimizer_states: creates momementum and variance in full precision regardless of + the precision of the parameters (default: True) + """ + + default_args = dict(lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + bias_correction=bias_correction, + amsgrad=amsgrad) + super(DeepSpeedCPUAdam, self).__init__(model_params, default_args) + + self.opt_id = DeepSpeedCPUAdam.optimizer_id + DeepSpeedCPUAdam.optimizer_id = DeepSpeedCPUAdam.optimizer_id + 1 + self.adam_w_mode = adamw_mode + self.fp32_optimizer_states = fp32_optimizer_states + self.ds_opt_adam = CPUAdamBuilder().load() + + self.ds_opt_adam.create_adam(self.opt_id, + lr, + betas[0], + betas[1], + eps, + weight_decay, + adamw_mode, + should_log_le("info")) + + def __del__(self): + # need to destroy the C++ object explicitly to avoid a memory leak when deepspeed.initialize + # is used multiple times in the same process (notebook or pytest worker) + self.ds_opt_adam.destroy_adam(self.opt_id) + + def __setstate__(self, state): + super(DeepSpeedCPUAdam, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None, fp16_param_groups=None): + """Update the model parameters. + + .. note:: + This method will be called internally by ZeRO-Offload. DeepSpeed + users should still use ``engine.step()`` as shown in the + `Getting Started + `_ guide. + + Args: + closure (callable, optional): closure to compute the loss. + Defaults to ``None``. + fp16_param_groups: FP16 GPU parameters to update. Performing the + copy here reduces communication time. Defaults to ``None``. + + Returns: + loss: if ``closure`` is provided. Otherwise ``None``. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group_id, group in enumerate(self.param_groups): + for param_id, p in enumerate(group['params']): + + if p.grad is None: + continue + + state = self.state[p] + # State initialization + if len(state) == 0: + #print(f'group {group_id} param {param_id} = {p.numel()}') + state['step'] = 0 + + #use full precision by default unless self.fp32_optimizer_states is off + state_dtype = torch.float if self.fp32_optimizer_states else p.dtype + + # gradient momentums + state['exp_avg'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + # gradient variances + state['exp_avg_sq'] = torch.zeros_like(p.data, + dtype=state_dtype, + device='cpu') + #memory_format=torch.preserve_format) + + state['step'] += 1 + beta1, beta2 = group['betas'] + + if fp16_param_groups is not None: + self.ds_opt_adam.adam_update_copy( + self.opt_id, + state['step'], + group['lr'], + beta1, + beta2, + group['eps'], + group['weight_decay'], + group['bias_correction'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq'], + fp16_param_groups[group_id][param_id].data) + else: + self.ds_opt_adam.adam_update(self.opt_id, + state['step'], + group['lr'], + beta1, + beta2, + group['eps'], + group['weight_decay'], + group['bias_correction'], + p.data, + p.grad.data, + state['exp_avg'], + state['exp_avg_sq']) + return loss diff --git a/deepspeed/ops/aio/__init__.py b/deepspeed/ops/aio/__init__.py index 50e6c9a3c988..d25f815739aa 100755 --- a/deepspeed/ops/aio/__init__.py +++ b/deepspeed/ops/aio/__init__.py @@ -1,6 +1,6 @@ -''' -Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -''' - -from ..op_builder import AsyncIOBuilder +''' +Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +''' + +from ..op_builder import AsyncIOBuilder diff --git a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py index 7a0bd4f4c0eb..6c134d71f2b5 100755 --- a/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py +++ b/deepspeed/ops/sparse_attention/bert_sparse_self_attention.py @@ -1,78 +1,78 @@ -""" -Copyright 2020 The Microsoft DeepSpeed Team -""" - -from torch import nn -from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig - - -class BertSparseSelfAttention(nn.Module): - """Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 - - For more information please see, TODO DeepSpeed Sparse Transformer. - - For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. - """ - def __init__( - self, - config, - # SparsityConfig parameters needs to be set accordingly - sparsity_config=FixedSparsityConfig(num_heads=4)): - """Initialize the bert sparse self attention layer. - - Note) you can use any of the provided sparsity configs or simply add yours! - - Arguments: - config: required: Bert model config - sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on FixedSparsityConfig class. - """ - - super(BertSparseSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, - config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.sparse_self_attention = SparseSelfAttention(sparsity_config) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask): - """Applies forward phase of bert sparse self attention - - Arguments: - hidden_states: required: hidden_states tensor of the bert model - attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported - - Return: - context_layer: a dense tensor containing attention context - """ - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - context_layer = self.sparse_self_attention(query_layer, - key_layer, - value_layer, - key_padding_mask=attention_mask) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer +""" +Copyright 2020 The Microsoft DeepSpeed Team +""" + +from torch import nn +from deepspeed.ops.sparse_attention import SparseSelfAttention, FixedSparsityConfig + + +class BertSparseSelfAttention(nn.Module): + """Implements Sparse Self Attention layer of Bert model based on https://github.com/microsoft/DeepSpeedExamples/blob/master/bing_bert/nvidia/modelingpreln.py#L373 + + For more information please see, TODO DeepSpeed Sparse Transformer. + + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial. + """ + def __init__( + self, + config, + # SparsityConfig parameters needs to be set accordingly + sparsity_config=FixedSparsityConfig(num_heads=4)): + """Initialize the bert sparse self attention layer. + + Note) you can use any of the provided sparsity configs or simply add yours! + + Arguments: + config: required: Bert model config + sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on FixedSparsityConfig class. + """ + + super(BertSparseSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, + config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.sparse_self_attention = SparseSelfAttention(sparsity_config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + """Applies forward phase of bert sparse self attention + + Arguments: + hidden_states: required: hidden_states tensor of the bert model + attn_mask: required: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported + + Return: + context_layer: a dense tensor containing attention context + """ + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + context_layer = self.sparse_self_attention(query_layer, + key_layer, + value_layer, + key_padding_mask=attention_mask) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, ) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer diff --git a/deepspeed/runtime/activation_checkpointing/checkpointing.py b/deepspeed/runtime/activation_checkpointing/checkpointing.py index 8bf3c5d586b1..9a5bee886629 100644 --- a/deepspeed/runtime/activation_checkpointing/checkpointing.py +++ b/deepspeed/runtime/activation_checkpointing/checkpointing.py @@ -1,912 +1,912 @@ -''' -Copyright (c) Microsoft Corporation -Licensed under the MIT license. - -Use to partition the activations stored for backward propagation -Therefore reduces the memory consumption -Also implements CPU checkpointing and contiguous memory checkpointing -Reduces memory consumption and memory fragmentation - -Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py -b886b7bb972afe72bac0f5de4f42a4a7bae8ebef -''' - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -import copy -import torch -import contextlib -import torch.distributed as dist - -import mmap -from torch import _C -from torch.cuda import _lazy_call, device as device_ctx_manager - -from deepspeed.runtime.config import DeepSpeedConfig -from deepspeed.utils import logger -from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank -from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers - -# DeepSpeed Checkpointing Enabled or Disabled -deepspeed_checkpointing_enabled = False - -# MP parameters -mpu = None -mp_rank = None -mp_size = None -mp_group = None - -# Model Parameters -num_layers = None - -# Checkpointing buffers -contiguous_data_buffers = [] -data_offsets = [] - -contiguous_size_buffers = [] -size_offsets = [] - -timers = None - -# optimization flags -PARTITION_ACTIVATIONS = False -CPU_CHECKPOINT = False -CONTIGUOUS_CHECKPOINTING = False -SYNCHRONIZE = False -PROFILE_TIME = False - -# Default name for the model parallel rng tracker. -_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' -transport_stream = None -cuda_device = None - - -def detach_variable(inputs, device=None): - if isinstance(inputs, tuple): - out = [] - for inp in inputs: - if not isinstance(inp, torch.Tensor): - out.append(inp) - continue - - requires_grad = inp.requires_grad - - if device is not None: - x = inp.to(device=device) - else: - x = inp - - x = x.detach() - x.requires_grad = requires_grad - out.append(x) - return tuple(out) - else: - raise RuntimeError( - "Only tuple of tensors is supported. Got Unsupported input type: ", - type(inputs).__name__) - - -def _set_cuda_rng_state(new_state, device=-1): - """Sets the random number generator state of the current GPU. - - Arguments: - new_state (torch.ByteTensor): The desired state - This function is adapted from PyTorch repo (torch.cuda.set_rng_state) - with a single change: the input state is not cloned. Cloning caused - major performance issues for +4 GPU cases. - """ - if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): - # older PyTorch - def cb(): - with device_ctx_manager(device): - _C._cuda_setRNGState(new_state) - else: - # newer PyTorch - if device == -1: - device = torch.device('cuda') - elif isinstance(device, str): - device = torch.device(device) - elif isinstance(device, int): - device = torch.device('cuda', device) - - def cb(): - idx = device.index - if idx is None: - idx = torch.cuda.current_device() - default_generator = torch.cuda.default_generators[idx] - default_generator.set_state(new_state) - - _lazy_call(cb) - - -class CudaRNGStatesTracker: - """Tracker for the cuda RNG states. - - Using the `add` method, a cuda rng state is initialized based on - the input `seed` and is assigned to `name`. Later, by forking the - rng state, we can perform operations and return to our starting - cuda state. - """ - def __init__(self): - # Map from a string name to the cuda rng state. - self.states_ = {} - # Seeds are just for book keeping and ensure no seed is set twice. - self.seeds_ = set() - - def reset(self): - """Set to the initial state (no tracker).""" - self.states_ = {} - self.seeds_ = set() - - def get_states(self): - """Get rng states. Copy the dictionary so we have direct - pointers to the states, not just a pointer to the dictionary.""" - return copy.copy(self.states_) - - def set_states(self, states): - """Set the rng states. For efficiency purposes, we do not check - the size of seed for compatibility.""" - self.states_ = states - - def add(self, name, seed): - """Track the rng state.""" - # Check seed is not already used. - if seed in self.seeds_: - raise Exception('seed {} already exists'.format(seed)) - self.seeds_.add(seed) - # Check that state is not already defined. - if name in self.states_: - raise Exception('cuda rng state {} already exists'.format(name)) - # Get the current rng state. - orig_rng_state = torch.cuda.get_rng_state() - # Set the new state and store it. - torch.cuda.manual_seed(seed) - self.states_[name] = torch.cuda.get_rng_state() - # Reset rng state to what it was. - _set_cuda_rng_state(orig_rng_state) - - @contextlib.contextmanager - def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): - """Fork the cuda rng state, perform operations, and exit with - the original state.""" - # Check if we have added the state - if name not in self.states_: - raise Exception('cuda rng state {} is not added'.format(name)) - # Store current rng state. - orig_cuda_rng_state = torch.cuda.get_rng_state() - # Set rng state to the desired one - _set_cuda_rng_state(self.states_[name]) - # Do the stuff we wanted to do. - try: - yield - finally: - # Update the current rng state for later use. - self.states_[name] = torch.cuda.get_rng_state() - # And set the state to the original state we started with. - _set_cuda_rng_state(orig_cuda_rng_state) - - -# RNG tracker object. -_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() - - -def get_cuda_rng_tracker(): - """Get cuda rng tracker.""" - return _CUDA_RNG_STATE_TRACKER - - -def model_parallel_cuda_manual_seed(seed): - """Initialize model parallel cuda seed. - - This function should be called after the model parallel is - initialized. Also, no torch.cuda.manual_seed should be called - after this function. Basically, this is replacement for that - function. - Two set of RNG states are tracked: - default state: This is for data parallelism and is the same among a - set of model parallel GPUs but different across - different model parallel groups. This is used for - example for dropout in the non-model-parallel regions. - model-parallel state: This state is different among a set of model - parallel GPUs, but the same across data parallel - groups. This is used for example for dropout in - model parallel regions. - """ - global mpu - - tp_rank = bwc_tensor_model_parallel_rank(mpu) - - # 2718 is just for fun and any POSITIVE value will work. - offset = seed + 2718 - model_parallel_seed = offset + tp_rank - # Data parallel gets the original seed. - data_parallel_seed = seed - - if torch.distributed.get_rank() == 0: - logger.info( - '> initializing model parallel cuda seeds on global rank {}, ' - 'model parallel rank {}, and data parallel rank {} with ' - 'model parallel seed: {} and data parallel seed: {}'.format( - torch.distributed.get_rank(), - tp_rank, - mpu.get_data_parallel_rank(), - model_parallel_seed, - data_parallel_seed), - ) - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) - - -def get_partition_start(item): - global mp_rank, mp_size, mp_group - size = item.numel() - partition_size = size / mp_size - start = partition_size * mp_rank - return int(start) - - -def get_partition_size(item): - global mp_rank, mp_size, mp_group - size = item.numel() - assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size" - partition_size = size / mp_size - return int(partition_size) - - -def gather_partitioned_activations(tensors, device=None): - global mp_rank, mp_size, mp_group - assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}' - inputs = [] - num_args = int(len(tensors) / 2) - for i in range(num_args): - - item = tensors[2 * i] - size = tensors[2 * i + 1] - - if not is_activation_to_checkpoint(item): - inputs.append(item) - continue - - partition_size = item.numel() - tensor_size = partition_size * mp_size - if device is not None: - flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) - else: - flat_tensor = torch.zeros([tensor_size], - dtype=item.dtype, - device=item.device) - partitions = [] - for i in range(mp_size): - part_i = flat_tensor.narrow(0, partition_size * i, partition_size) - if i == mp_rank: - part_i.copy_(item) - partitions.append(part_i) - if mp_group is not None: - dist.all_gather(partitions, partitions[mp_rank], group=mp_group) - input_tensor = flat_tensor.view(list(size.numpy())) - item.data = input_tensor.data - - inputs.append(item) - - return tuple(inputs) - - -def extract_tensors(all_objects): - """ - Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation. - The order of tensors and non-tensors is preserved in their respective output groups. - - Parameters: - all_objects (list/tuple): Objects containing tensors and non-tensors to be split. - - Returns: - tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor. - - """ - tensor_objects = [v for v in all_objects if torch.is_tensor(v)] - non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)] - tensor_flags = [torch.is_tensor(v) for v in all_objects] - if type(all_objects) is tuple: - return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags) - return tensor_objects, non_tensor_objects, tensor_flags - - -def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags): - """ - Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple). - - Parameters: - tensor_objects (list/tuple): Tensors to merge. - non_tensor_objects (list/tuple): Non-tensors to merge. - tensor_flags (list/tuple): Indicates whether each position in output is a tensor. - - Returns: - tuple: Merge of tensors and non-tensors - """ - merged_objects = [] - tensor_idx = 0 - non_tensor_idx = 0 - - real_tensor_flags = None - - # remove the flags that are assigned to the size of the flattened tensors - if PARTITION_ACTIVATIONS: - real_tensor_flags = [] - previous_flag = False - for flag in tensor_flags: - if previous_flag: - previous_flag = False - continue - previous_flag = flag - real_tensor_flags.append(flag) - else: - real_tensor_flags = tensor_flags - - for is_tensor in real_tensor_flags: - if is_tensor: - merged_objects.append(tensor_objects[tensor_idx]) - tensor_idx += 1 - else: - merged_objects.append(non_tensor_objects[non_tensor_idx]) - non_tensor_idx += 1 - - return tuple(merged_objects) - - -def is_activation_to_checkpoint(item): - """ - Is an activation to be checkpointed - """ - global mp_size - return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size - - -def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): - global contiguous_data_buffers, data_offsets - - inputs = [] - num_non_fp_tensors = 0 - - for arg_index, item in enumerate(args): - if not is_activation_to_checkpoint(item): - inputs.append(item) - num_non_fp_tensors += 1 - continue - - i = arg_index - num_non_fp_tensors - partition_size = get_partition_size(item) - partition = item.detach().contiguous().view(-1).narrow( - 0, - get_partition_start(item), - partition_size).clone() - - buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device - - if contiguous_checkpoint: - if i >= len(contiguous_data_buffers): - tensor_list = [ - torch.tensor(()).new_empty([partition_size], - dtype=partition.dtype, - device=buffer_device) - for _ in range(num_layers) - ] - contiguous_data_buffers.append(tensor_list) - data_offsets.append(0) - elif contiguous_data_buffers[i] is None: - tensor_list = [ - torch.tensor(()).new_empty([partition_size], - dtype=partition.dtype, - device=buffer_device) - for _ in range(num_layers) - ] - contiguous_data_buffers[i] = tensor_list - data_offsets[i] = 0 - - # Because the 'new_empty' returns uninitialized pages, - # the pages need to be populated during the cudaMemcpy time - # which increases the data copy time. To avoid this, we - # pre-populate these pages by simply writing 0 ahead of - # the actual cudaMemcpy operation time. Due to the - # previously launched GPU kernels, there is a small - # window of time here for CPUs to populate pages asynchronously. - contiguous_data_buffers[i][data_offsets[i]].data[range( - 0, - contiguous_data_buffers[i][data_offsets[i]].data.shape[0], - int(mmap.PAGESIZE / - contiguous_data_buffers[i][data_offsets[i]].data.element_size()) - )] = 0 - - contiguous_partition = contiguous_data_buffers[i][ - data_offsets[i]].data.copy_(partition.data) - data_offsets[i] = data_offsets[i] + 1 - inputs.append(contiguous_partition) - else: - partition = partition.cpu() if CPU_CHECKPOINT else partition - inputs.append(partition) - - return inputs - - -def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint): - global contiguous_size_buffers, size_offsets - - new_args = [] - num_non_fp_tensors = 0 - - for arg_index, (arg, inp) in enumerate(zip(args, inputs)): - size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None - if not is_activation_to_checkpoint(arg): - new_args.append(arg) - new_args.append(size) - num_non_fp_tensors += 1 - continue - - arg.data = inp.data - new_args.append(arg) - i = arg_index - num_non_fp_tensors - - if contiguous_checkpoint: - numel = size.numel() - if i >= len(contiguous_size_buffers): - tmp = torch.tensor(()) - contiguous_size_buffers.append( - tmp.new_empty([numel * num_layers], - dtype=size.dtype, - device=size.device)) - size_offsets.append(0) - elif contiguous_size_buffers[i] is None: - tmp = torch.tensor(()) - contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], - dtype=size.dtype, - device=size.device) - size_offsets[i] = 0 - - contiguous_size = contiguous_size_buffers[i].narrow( - 0, - size_offsets[i], - numel).data.copy_(size.data) - contiguous_size = contiguous_size.view_as(size) - size_offsets[i] = size_offsets[i] + numel - new_args.append(contiguous_size) - else: - new_args.append(size) - - return new_args - - -def get_cpu_activations_for_backward(args, inputs): - new_args = [] - for i, (arg, inp) in enumerate(zip(args, inputs)): - if not is_activation_to_checkpoint(arg): - new_args.append(arg) - continue - - arg.data = inp.data - new_args.append(arg) - - return new_args - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - 3) Performance activation partitioning, contiguous memory optimization - 4) CPU Checkpointing - 5) Profile forward and backward functions - """ - @staticmethod - def forward(ctx, run_function, all_outputs, *args): - global mpu, timers, SYNCHRONIZE, PROFILE_TIME - - def save_args_for_backward(*all_args): - tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args) - ctx.deepspeed_saved_tensors = tensor_args - ctx.non_tensor_args = non_tensor_args - ctx.tensor_flags = tensor_flags - - if SYNCHRONIZE: - torch.cuda.synchronize() - - if timers is None and PROFILE_TIME: - timers = Timers() - - if PROFILE_TIME: - timers('forward').start() - - ctx.run_function = run_function - global num_layers - global mp_rank, mp_size, mp_group - global contiguous_data_buffers, contiguous_size_buffers - global data_offsets, size_offsets - if mp_rank is None: - if mpu is not None: - if hasattr(mpu, 'get_tensor_model_parallel_rank'): - mp_rank = mpu.get_tensor_model_parallel_rank() - mp_size = mpu.get_tensor_model_parallel_world_size() - mp_group = mpu.get_tensor_model_parallel_group() - else: - mp_rank = mpu.get_model_parallel_rank() - mp_size = mpu.get_model_parallel_world_size() - mp_group = mpu.get_model_parallel_group() - else: - mp_rank = 0 - mp_size = 1 - mp_group = None - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset - - if cuda_device is None: - see_memory_usage("First Forward Beginning", force=False) - if dist.get_rank() == 0: - logger.info(f"Activation Checkpointing Information") - logger.info( - f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" - ) - logger.info( - f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" - ) - logger.info(f"----Synchronization {SYNCHRONIZE}") - logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") - - cuda_device = torch.cuda.current_device() - transport_stream = torch.cuda.Stream(device=cuda_device) - - if PARTITION_ACTIVATIONS: - inputs = partition_activations(args, - CPU_CHECKPOINT, - CONTIGUOUS_CHECKPOINTING) - elif CPU_CHECKPOINT: - inputs = copy_to_device(args, - device=torch.device('cpu'), - criterion_func=is_activation_to_checkpoint) - - # just in case something funky is happening such as reuse of inputs - inputs_cuda = copy_to_device(args, - device=cuda_device, - criterion_func=is_activation_to_checkpoint) - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - see_memory_usage("Before running forward on the layer", force=False) - # ctx.save_for_backward(*args) - with torch.no_grad(): - outputs = run_function(*inputs_cuda) - - see_memory_usage("After running forward on the layer", force=False) - del inputs_cuda - - if PARTITION_ACTIVATIONS: - new_args = get_partitioned_activations_for_backward( - args, - inputs, - CONTIGUOUS_CHECKPOINTING) - assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' - save_args_for_backward(*new_args) - elif CPU_CHECKPOINT: - new_args = get_cpu_activations_for_backward(args, inputs) - save_args_for_backward(*new_args) - else: - save_args_for_backward(*args) - - if PROFILE_TIME: - timers('forward').stop() - timers.log(['forward']) - if SYNCHRONIZE: - torch.cuda.synchronize() - - # Tensors returned from forward() may not be differentiable. - if torch.is_tensor(outputs): - non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] - else: - non_grad_outputs = [ - o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() - ] - ctx.mark_non_differentiable(*non_grad_outputs) - - if torch.is_tensor(outputs): - all_outputs += [outputs] - return outputs - else: - all_outputs += outputs - outputs, _, _ = extract_tensors(all_objects=outputs) - return tuple(outputs) - - @staticmethod - def backward(ctx, *grads): - global timers - see_memory_usage("In backward", force=False) - # removing pointers to the contiguous buffer memory - # so that they can be garbage collected once the checkpoints - # have been used - if SYNCHRONIZE: - torch.cuda.synchronize() - if PROFILE_TIME: - timers('backward').start() - - if CONTIGUOUS_CHECKPOINTING: - global data_offsets, size_offsets - global contiguous_data_buffers, contiguous_size_buffers - - for buffers in contiguous_data_buffers: - buffers = [] - - # frees up all the pointers to the checkpoints except for the ones - # stored by save for backward - contiguous_data_buffers = [] - contiguous_size_buffers = [] - data_offsets = [] - size_offsets = [] - - see_memory_usage("In backward checkpointing code", force=False) - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad(), " - "please use .backward() if possible") - - global cuda_device, transport_stream, PARTITION_ACTIVATIONS - - if PARTITION_ACTIVATIONS: - # with torch.cuda.stream(transport_stream): - inputs = gather_partitioned_activations( - ctx.deepspeed_saved_tensors, - device=cuda_device if CPU_CHECKPOINT else None) - detached_inputs = detach_variable(inputs) - elif CPU_CHECKPOINT: - inputs = move_to_device(ctx.deepspeed_saved_tensors, - cuda_device, - is_activation_to_checkpoint) - detached_inputs = detach_variable(inputs) - else: - inputs = ctx.deepspeed_saved_tensors - detached_inputs = detach_variable(inputs) - - # Add non tensor input args - detached_inputs = merge_tensors(tensor_objects=detached_inputs, - non_tensor_objects=ctx.non_tensor_args, - tensor_flags=ctx.tensor_flags) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # if PARTITION_ACTIVATIONS: - # current_stream=torch.cuda.current_stream() - # current_stream.wait_stream(transport_stream) - - see_memory_usage("In backward checkpointing code before forward", force=False) - - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - see_memory_usage("In backward checkpointing code after forward", force=False) - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs, ) - - # Filter out non tensor outputs - outputs, _, _ = extract_tensors(all_objects=outputs) - - # Construct arguments to autograd.backward(). - # This is usually just outputs and grads, but forward() can return tensors that - # are not differentiable. - output_tensors = [] - grad_tensors = [] - for out, grad in zip(outputs, grads): - if out.requires_grad: - output_tensors.append(out) - grad_tensors.append(grad) - - see_memory_usage("In backward checkpointing code before backward", force=False) - - torch.autograd.backward(output_tensors, grad_tensors) - - see_memory_usage("After backward checkpointing code after backward", force=False) - - if PROFILE_TIME: - timers('backward').stop() - timers.log(['backward']) - if SYNCHRONIZE: - torch.cuda.synchronize() - ret_list = [None, None] # first None for ctx - for inp in detached_inputs: - if torch.is_tensor(inp): - ret_list.append(inp.grad) - else: - ret_list.append(None) - - return tuple(ret_list) - - -def checkpoint(function, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint. """ - - all_outputs = [] - CheckpointFunction.apply(function, all_outputs, *args) - if len(all_outputs) == 1: - return all_outputs[0] - else: - return tuple(all_outputs) - - -def partition_activations_in_checkpoint(partition_activation): - global PARTITION_ACTIVATIONS - PARTITION_ACTIVATIONS = partition_activation - if dist.get_rank() == 0: - logger.info( - f"**************Partition Activations {PARTITION_ACTIVATIONS}************") - - -def set_num_layers(nlayers): - global num_layers - num_layers = nlayers - - -def reset(): - """Resets memory buffers related to contiguous memory optimizations. - Should be called during eval when multiple forward propagations are - computed without any backward propagation that usually clears these - buffers. - Arguments: - None - - Return: - None - """ - if CONTIGUOUS_CHECKPOINTING: - global data_offsets, size_offsets - global contiguous_data_buffers, contiguous_size_buffers - - for buffers in contiguous_data_buffers: - buffers = [] - - # frees up all the pointers to the checkpoints except for the ones - # stored by save for backward - contiguous_data_buffers = [] - contiguous_size_buffers = [] - data_offsets = [] - size_offsets = [] - - -def _configure_using_config_file(config, mpu=None): - global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config - if dist.get_rank() == 0: - logger.info(config.repr()) - PARTITION_ACTIVATIONS = config.partition_activations - CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization - num_layers = config.number_checkpoints - CPU_CHECKPOINT = config.cpu_checkpointing - SYNCHRONIZE = config.synchronize_checkpoint_boundary - PROFILE_TIME = config.profile - - -def _configure_defaults(): - - global mpu, num_layers, deepspeed_checkpointing_enabled - - global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - PARTITION_ACTIVATIONS = False - CONTIGUOUS_CHECKPOINTING = False - num_layers = False - CPU_CHECKPOINT = False - SYNCHRONIZE = False - PROFILE_TIME = False - deepspeed_checkpointing_enabled = True - - -def configure( - mpu_, - deepspeed_config=None, - partition_activations=None, - contiguous_checkpointing=None, - num_checkpoints=None, - checkpoint_in_cpu=None, - synchronize=None, - profile=None, -): - """Configure DeepSpeed Activation Checkpointing. - - Arguments: - mpu_: Optional: An object that implements the following methods - get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size - - deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to - configure DeepSpeed Activation Checkpointing - - partition_activations: Optional: Partitions activation checkpoint across model parallel - GPUs when enabled. By default False. Will overwrite deepspeed_config if provided - - contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory - buffer. Works only with homogeneous checkpoints when partition_activations is enabled. - Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if - provided - - num_checkpoints: Optional: Number of activation checkpoints stored during the forward - propagation of the model. Used to calculate the buffer size for contiguous_checkpointing - Will overwrite deepspeed_config if provided - - checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with - partition_activation. Default is false. Will overwrite deepspeed_config if provided - - synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of - each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. - By default false. Will overwrite deepspeed_config if provided - - profile: Optional: Logs the forward and backward time for each - deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config - if provided - - Returns: - None - """ - global mpu, num_layers, deepspeed_checkpointing_enabled - - global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ - CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME - - _configure_defaults() - - if mpu_ is not None: - mpu = mpu_ - - if deepspeed_config is not None: - _configure_using_config_file(deepspeed_config, mpu=mpu) - - if partition_activations is not None: - PARTITION_ACTIVATIONS = partition_activations - - if contiguous_checkpointing is not None: - CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing - - if num_checkpoints is not None: - num_layers = num_checkpoints - - if checkpoint_in_cpu is not None: - CPU_CHECKPOINT = checkpoint_in_cpu - - if synchronize is not None: - SYNCHRONIZE = synchronize - - if profile is not None: - PROFILE_TIME = profile - - if CONTIGUOUS_CHECKPOINTING: - assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config" - if CONTIGUOUS_CHECKPOINTING: - assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing" - - -def is_configured(): - """True if deepspeed activation checkpointing has been configured - by calling deepspeed.checkpointing.configure, else returns false - - Arguments: - None - - Return: - True of configured, else False - """ - return deepspeed_checkpointing_enabled +''' +Copyright (c) Microsoft Corporation +Licensed under the MIT license. + +Use to partition the activations stored for backward propagation +Therefore reduces the memory consumption +Also implements CPU checkpointing and contiguous memory checkpointing +Reduces memory consumption and memory fragmentation + +Code for rng checkpointing taken from NVIDIA Megatron-LM mpu/random.py +b886b7bb972afe72bac0f5de4f42a4a7bae8ebef +''' + +# Parts of the code here are adapted from PyTorch +# repo: https://github.com/pytorch/pytorch +import copy +import torch +import contextlib +import torch.distributed as dist + +import mmap +from torch import _C +from torch.cuda import _lazy_call, device as device_ctx_manager + +from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.utils import logger +from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank +from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers + +# DeepSpeed Checkpointing Enabled or Disabled +deepspeed_checkpointing_enabled = False + +# MP parameters +mpu = None +mp_rank = None +mp_size = None +mp_group = None + +# Model Parameters +num_layers = None + +# Checkpointing buffers +contiguous_data_buffers = [] +data_offsets = [] + +contiguous_size_buffers = [] +size_offsets = [] + +timers = None + +# optimization flags +PARTITION_ACTIVATIONS = False +CPU_CHECKPOINT = False +CONTIGUOUS_CHECKPOINTING = False +SYNCHRONIZE = False +PROFILE_TIME = False + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' +transport_stream = None +cuda_device = None + + +def detach_variable(inputs, device=None): + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + requires_grad = inp.requires_grad + + if device is not None: + x = inp.to(device=device) + else: + x = inp + + x = x.detach() + x.requires_grad = requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__) + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + + Arguments: + new_state (torch.ByteTensor): The desired state + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + """ + if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): + # older PyTorch + def cb(): + with device_ctx_manager(device): + _C._cuda_setRNGState(new_state) + else: + # newer PyTorch + if device == -1: + device = torch.device('cuda') + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device('cuda', device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + return copy.copy(self.states_) + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception('seed {} already exists'.format(seed)) + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception('cuda rng state {} already exists'.format(name)) + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise Exception('cuda rng state {} is not added'.format(name)) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed): + """Initialize model parallel cuda seed. + + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two set of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model parallel groups. This is used for + example for dropout in the non-model-parallel regions. + model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + """ + global mpu + + tp_rank = bwc_tensor_model_parallel_rank(mpu) + + # 2718 is just for fun and any POSITIVE value will work. + offset = seed + 2718 + model_parallel_seed = offset + tp_rank + # Data parallel gets the original seed. + data_parallel_seed = seed + + if torch.distributed.get_rank() == 0: + logger.info( + '> initializing model parallel cuda seeds on global rank {}, ' + 'model parallel rank {}, and data parallel rank {} with ' + 'model parallel seed: {} and data parallel seed: {}'.format( + torch.distributed.get_rank(), + tp_rank, + mpu.get_data_parallel_rank(), + model_parallel_seed, + data_parallel_seed), + ) + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) + + +def get_partition_start(item): + global mp_rank, mp_size, mp_group + size = item.numel() + partition_size = size / mp_size + start = partition_size * mp_rank + return int(start) + + +def get_partition_size(item): + global mp_rank, mp_size, mp_group + size = item.numel() + assert size % mp_size == 0, "Doesn't handle if partition activation if item is not divisible by mp size" + partition_size = size / mp_size + return int(partition_size) + + +def gather_partitioned_activations(tensors, device=None): + global mp_rank, mp_size, mp_group + assert len(tensors) % 2 == 0, f'Expected even count of tensors, instead got {len(tensors)}' + inputs = [] + num_args = int(len(tensors) / 2) + for i in range(num_args): + + item = tensors[2 * i] + size = tensors[2 * i + 1] + + if not is_activation_to_checkpoint(item): + inputs.append(item) + continue + + partition_size = item.numel() + tensor_size = partition_size * mp_size + if device is not None: + flat_tensor = torch.zeros([tensor_size], dtype=item.dtype, device=device) + else: + flat_tensor = torch.zeros([tensor_size], + dtype=item.dtype, + device=item.device) + partitions = [] + for i in range(mp_size): + part_i = flat_tensor.narrow(0, partition_size * i, partition_size) + if i == mp_rank: + part_i.copy_(item) + partitions.append(part_i) + if mp_group is not None: + dist.all_gather(partitions, partitions[mp_rank], group=mp_group) + input_tensor = flat_tensor.view(list(size.numpy())) + item.data = input_tensor.data + + inputs.append(item) + + return tuple(inputs) + + +def extract_tensors(all_objects): + """ + Separate objects in list/tuple into tensors and non-tensors and create a mapping to enable re-aggregation. + The order of tensors and non-tensors is preserved in their respective output groups. + + Parameters: + all_objects (list/tuple): Objects containing tensors and non-tensors to be split. + + Returns: + tuple: Containing tensors, non-tensors, and bools of whether each position in original list/tuple was a tensor. + + """ + tensor_objects = [v for v in all_objects if torch.is_tensor(v)] + non_tensor_objects = [v for v in all_objects if not torch.is_tensor(v)] + tensor_flags = [torch.is_tensor(v) for v in all_objects] + if type(all_objects) is tuple: + return tuple(tensor_objects), tuple(non_tensor_objects), tuple(tensor_flags) + return tensor_objects, non_tensor_objects, tensor_flags + + +def merge_tensors(tensor_objects, non_tensor_objects, tensor_flags): + """ + Merge two lists (or tuples) of tensors and non-tensors using a mapping of positions in merged list (or tuple). + + Parameters: + tensor_objects (list/tuple): Tensors to merge. + non_tensor_objects (list/tuple): Non-tensors to merge. + tensor_flags (list/tuple): Indicates whether each position in output is a tensor. + + Returns: + tuple: Merge of tensors and non-tensors + """ + merged_objects = [] + tensor_idx = 0 + non_tensor_idx = 0 + + real_tensor_flags = None + + # remove the flags that are assigned to the size of the flattened tensors + if PARTITION_ACTIVATIONS: + real_tensor_flags = [] + previous_flag = False + for flag in tensor_flags: + if previous_flag: + previous_flag = False + continue + previous_flag = flag + real_tensor_flags.append(flag) + else: + real_tensor_flags = tensor_flags + + for is_tensor in real_tensor_flags: + if is_tensor: + merged_objects.append(tensor_objects[tensor_idx]) + tensor_idx += 1 + else: + merged_objects.append(non_tensor_objects[non_tensor_idx]) + non_tensor_idx += 1 + + return tuple(merged_objects) + + +def is_activation_to_checkpoint(item): + """ + Is an activation to be checkpointed + """ + global mp_size + return torch.is_tensor(item) and item.is_floating_point() and item.numel() >= mp_size + + +def partition_activations(args, cpu_checkpoint, contiguous_checkpoint): + global contiguous_data_buffers, data_offsets + + inputs = [] + num_non_fp_tensors = 0 + + for arg_index, item in enumerate(args): + if not is_activation_to_checkpoint(item): + inputs.append(item) + num_non_fp_tensors += 1 + continue + + i = arg_index - num_non_fp_tensors + partition_size = get_partition_size(item) + partition = item.detach().contiguous().view(-1).narrow( + 0, + get_partition_start(item), + partition_size).clone() + + buffer_device = torch.device('cpu') if cpu_checkpoint else partition.device + + if contiguous_checkpoint: + if i >= len(contiguous_data_buffers): + tensor_list = [ + torch.tensor(()).new_empty([partition_size], + dtype=partition.dtype, + device=buffer_device) + for _ in range(num_layers) + ] + contiguous_data_buffers.append(tensor_list) + data_offsets.append(0) + elif contiguous_data_buffers[i] is None: + tensor_list = [ + torch.tensor(()).new_empty([partition_size], + dtype=partition.dtype, + device=buffer_device) + for _ in range(num_layers) + ] + contiguous_data_buffers[i] = tensor_list + data_offsets[i] = 0 + + # Because the 'new_empty' returns uninitialized pages, + # the pages need to be populated during the cudaMemcpy time + # which increases the data copy time. To avoid this, we + # pre-populate these pages by simply writing 0 ahead of + # the actual cudaMemcpy operation time. Due to the + # previously launched GPU kernels, there is a small + # window of time here for CPUs to populate pages asynchronously. + contiguous_data_buffers[i][data_offsets[i]].data[range( + 0, + contiguous_data_buffers[i][data_offsets[i]].data.shape[0], + int(mmap.PAGESIZE / + contiguous_data_buffers[i][data_offsets[i]].data.element_size()) + )] = 0 + + contiguous_partition = contiguous_data_buffers[i][ + data_offsets[i]].data.copy_(partition.data) + data_offsets[i] = data_offsets[i] + 1 + inputs.append(contiguous_partition) + else: + partition = partition.cpu() if CPU_CHECKPOINT else partition + inputs.append(partition) + + return inputs + + +def get_partitioned_activations_for_backward(args, inputs, contiguous_checkpoint): + global contiguous_size_buffers, size_offsets + + new_args = [] + num_non_fp_tensors = 0 + + for arg_index, (arg, inp) in enumerate(zip(args, inputs)): + size = torch.tensor(arg.size()) if torch.is_tensor(arg) else None + if not is_activation_to_checkpoint(arg): + new_args.append(arg) + new_args.append(size) + num_non_fp_tensors += 1 + continue + + arg.data = inp.data + new_args.append(arg) + i = arg_index - num_non_fp_tensors + + if contiguous_checkpoint: + numel = size.numel() + if i >= len(contiguous_size_buffers): + tmp = torch.tensor(()) + contiguous_size_buffers.append( + tmp.new_empty([numel * num_layers], + dtype=size.dtype, + device=size.device)) + size_offsets.append(0) + elif contiguous_size_buffers[i] is None: + tmp = torch.tensor(()) + contiguous_size_buffers[i] = tmp.new_empty([numel * num_layers], + dtype=size.dtype, + device=size.device) + size_offsets[i] = 0 + + contiguous_size = contiguous_size_buffers[i].narrow( + 0, + size_offsets[i], + numel).data.copy_(size.data) + contiguous_size = contiguous_size.view_as(size) + size_offsets[i] = size_offsets[i] + numel + new_args.append(contiguous_size) + else: + new_args.append(size) + + return new_args + + +def get_cpu_activations_for_backward(args, inputs): + new_args = [] + for i, (arg, inp) in enumerate(zip(args, inputs)): + if not is_activation_to_checkpoint(arg): + new_args.append(arg) + continue + + arg.data = inp.data + new_args.append(arg) + + return new_args + + +class CheckpointFunction(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + 3) Performance activation partitioning, contiguous memory optimization + 4) CPU Checkpointing + 5) Profile forward and backward functions + """ + @staticmethod + def forward(ctx, run_function, all_outputs, *args): + global mpu, timers, SYNCHRONIZE, PROFILE_TIME + + def save_args_for_backward(*all_args): + tensor_args, non_tensor_args, tensor_flags = extract_tensors(all_objects=all_args) + ctx.deepspeed_saved_tensors = tensor_args + ctx.non_tensor_args = non_tensor_args + ctx.tensor_flags = tensor_flags + + if SYNCHRONIZE: + torch.cuda.synchronize() + + if timers is None and PROFILE_TIME: + timers = Timers() + + if PROFILE_TIME: + timers('forward').start() + + ctx.run_function = run_function + global num_layers + global mp_rank, mp_size, mp_group + global contiguous_data_buffers, contiguous_size_buffers + global data_offsets, size_offsets + if mp_rank is None: + if mpu is not None: + if hasattr(mpu, 'get_tensor_model_parallel_rank'): + mp_rank = mpu.get_tensor_model_parallel_rank() + mp_size = mpu.get_tensor_model_parallel_world_size() + mp_group = mpu.get_tensor_model_parallel_group() + else: + mp_rank = mpu.get_model_parallel_rank() + mp_size = mpu.get_model_parallel_world_size() + mp_group = mpu.get_model_parallel_group() + else: + mp_rank = 0 + mp_size = 1 + mp_group = None + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset + + if cuda_device is None: + see_memory_usage("First Forward Beginning", force=False) + if dist.get_rank() == 0: + logger.info(f"Activation Checkpointing Information") + logger.info( + f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" + ) + logger.info( + f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" + ) + logger.info(f"----Synchronization {SYNCHRONIZE}") + logger.info(f"----Profiling time in checkpointing {PROFILE_TIME}") + + cuda_device = torch.cuda.current_device() + transport_stream = torch.cuda.Stream(device=cuda_device) + + if PARTITION_ACTIVATIONS: + inputs = partition_activations(args, + CPU_CHECKPOINT, + CONTIGUOUS_CHECKPOINTING) + elif CPU_CHECKPOINT: + inputs = copy_to_device(args, + device=torch.device('cpu'), + criterion_func=is_activation_to_checkpoint) + + # just in case something funky is happening such as reuse of inputs + inputs_cuda = copy_to_device(args, + device=cuda_device, + criterion_func=is_activation_to_checkpoint) + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + see_memory_usage("Before running forward on the layer", force=False) + # ctx.save_for_backward(*args) + with torch.no_grad(): + outputs = run_function(*inputs_cuda) + + see_memory_usage("After running forward on the layer", force=False) + del inputs_cuda + + if PARTITION_ACTIVATIONS: + new_args = get_partitioned_activations_for_backward( + args, + inputs, + CONTIGUOUS_CHECKPOINTING) + assert len(new_args) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' + save_args_for_backward(*new_args) + elif CPU_CHECKPOINT: + new_args = get_cpu_activations_for_backward(args, inputs) + save_args_for_backward(*new_args) + else: + save_args_for_backward(*args) + + if PROFILE_TIME: + timers('forward').stop() + timers.log(['forward']) + if SYNCHRONIZE: + torch.cuda.synchronize() + + # Tensors returned from forward() may not be differentiable. + if torch.is_tensor(outputs): + non_grad_outputs = [outputs] if not outputs.is_floating_point() else [] + else: + non_grad_outputs = [ + o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() + ] + ctx.mark_non_differentiable(*non_grad_outputs) + + if torch.is_tensor(outputs): + all_outputs += [outputs] + return outputs + else: + all_outputs += outputs + outputs, _, _ = extract_tensors(all_objects=outputs) + return tuple(outputs) + + @staticmethod + def backward(ctx, *grads): + global timers + see_memory_usage("In backward", force=False) + # removing pointers to the contiguous buffer memory + # so that they can be garbage collected once the checkpoints + # have been used + if SYNCHRONIZE: + torch.cuda.synchronize() + if PROFILE_TIME: + timers('backward').start() + + if CONTIGUOUS_CHECKPOINTING: + global data_offsets, size_offsets + global contiguous_data_buffers, contiguous_size_buffers + + for buffers in contiguous_data_buffers: + buffers = [] + + # frees up all the pointers to the checkpoints except for the ones + # stored by save for backward + contiguous_data_buffers = [] + contiguous_size_buffers = [] + data_offsets = [] + size_offsets = [] + + see_memory_usage("In backward checkpointing code", force=False) + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError("Checkpointing is not compatible with .grad(), " + "please use .backward() if possible") + + global cuda_device, transport_stream, PARTITION_ACTIVATIONS + + if PARTITION_ACTIVATIONS: + # with torch.cuda.stream(transport_stream): + inputs = gather_partitioned_activations( + ctx.deepspeed_saved_tensors, + device=cuda_device if CPU_CHECKPOINT else None) + detached_inputs = detach_variable(inputs) + elif CPU_CHECKPOINT: + inputs = move_to_device(ctx.deepspeed_saved_tensors, + cuda_device, + is_activation_to_checkpoint) + detached_inputs = detach_variable(inputs) + else: + inputs = ctx.deepspeed_saved_tensors + detached_inputs = detach_variable(inputs) + + # Add non tensor input args + detached_inputs = merge_tensors(tensor_objects=detached_inputs, + non_tensor_objects=ctx.non_tensor_args, + tensor_flags=ctx.tensor_flags) + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # if PARTITION_ACTIVATIONS: + # current_stream=torch.cuda.current_stream() + # current_stream.wait_stream(transport_stream) + + see_memory_usage("In backward checkpointing code before forward", force=False) + + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + see_memory_usage("In backward checkpointing code after forward", force=False) + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs, ) + + # Filter out non tensor outputs + outputs, _, _ = extract_tensors(all_objects=outputs) + + # Construct arguments to autograd.backward(). + # This is usually just outputs and grads, but forward() can return tensors that + # are not differentiable. + output_tensors = [] + grad_tensors = [] + for out, grad in zip(outputs, grads): + if out.requires_grad: + output_tensors.append(out) + grad_tensors.append(grad) + + see_memory_usage("In backward checkpointing code before backward", force=False) + + torch.autograd.backward(output_tensors, grad_tensors) + + see_memory_usage("After backward checkpointing code after backward", force=False) + + if PROFILE_TIME: + timers('backward').stop() + timers.log(['backward']) + if SYNCHRONIZE: + torch.cuda.synchronize() + ret_list = [None, None] # first None for ctx + for inp in detached_inputs: + if torch.is_tensor(inp): + ret_list.append(inp.grad) + else: + ret_list.append(None) + + return tuple(ret_list) + + +def checkpoint(function, *args): + """Checkpoint a model or part of the model. + This has been directly copied from torch.utils.checkpoint. """ + + all_outputs = [] + CheckpointFunction.apply(function, all_outputs, *args) + if len(all_outputs) == 1: + return all_outputs[0] + else: + return tuple(all_outputs) + + +def partition_activations_in_checkpoint(partition_activation): + global PARTITION_ACTIVATIONS + PARTITION_ACTIVATIONS = partition_activation + if dist.get_rank() == 0: + logger.info( + f"**************Partition Activations {PARTITION_ACTIVATIONS}************") + + +def set_num_layers(nlayers): + global num_layers + num_layers = nlayers + + +def reset(): + """Resets memory buffers related to contiguous memory optimizations. + Should be called during eval when multiple forward propagations are + computed without any backward propagation that usually clears these + buffers. + Arguments: + None + + Return: + None + """ + if CONTIGUOUS_CHECKPOINTING: + global data_offsets, size_offsets + global contiguous_data_buffers, contiguous_size_buffers + + for buffers in contiguous_data_buffers: + buffers = [] + + # frees up all the pointers to the checkpoints except for the ones + # stored by save for backward + contiguous_data_buffers = [] + contiguous_size_buffers = [] + data_offsets = [] + size_offsets = [] + + +def _configure_using_config_file(config, mpu=None): + global num_layers, PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + config = DeepSpeedConfig(config, mpu=mpu).activation_checkpointing_config + if dist.get_rank() == 0: + logger.info(config.repr()) + PARTITION_ACTIVATIONS = config.partition_activations + CONTIGUOUS_CHECKPOINTING = config.contiguous_memory_optimization + num_layers = config.number_checkpoints + CPU_CHECKPOINT = config.cpu_checkpointing + SYNCHRONIZE = config.synchronize_checkpoint_boundary + PROFILE_TIME = config.profile + + +def _configure_defaults(): + + global mpu, num_layers, deepspeed_checkpointing_enabled + + global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + PARTITION_ACTIVATIONS = False + CONTIGUOUS_CHECKPOINTING = False + num_layers = False + CPU_CHECKPOINT = False + SYNCHRONIZE = False + PROFILE_TIME = False + deepspeed_checkpointing_enabled = True + + +def configure( + mpu_, + deepspeed_config=None, + partition_activations=None, + contiguous_checkpointing=None, + num_checkpoints=None, + checkpoint_in_cpu=None, + synchronize=None, + profile=None, +): + """Configure DeepSpeed Activation Checkpointing. + + Arguments: + mpu_: Optional: An object that implements the following methods + get_model_parallel_rank/group/world_size, and get_data_parallel_rank/group/world_size + + deepspeed_config: Optional: DeepSpeed Config json file when provided will be used to + configure DeepSpeed Activation Checkpointing + + partition_activations: Optional: Partitions activation checkpoint across model parallel + GPUs when enabled. By default False. Will overwrite deepspeed_config if provided + + contiguous_checkpointing: Optional: Copies activation checkpoints to a contiguous memory + buffer. Works only with homogeneous checkpoints when partition_activations is enabled. + Must provide num_checkpoints. By default False. Will overwrite deepspeed_config if + provided + + num_checkpoints: Optional: Number of activation checkpoints stored during the forward + propagation of the model. Used to calculate the buffer size for contiguous_checkpointing + Will overwrite deepspeed_config if provided + + checkpoint_in_cpu: Optional: Moves the activation checkpoint to CPU. Only works with + partition_activation. Default is false. Will overwrite deepspeed_config if provided + + synchronize: Optional: Performs torch.cuda.synchronize() at the beginning and end of + each call to deepspeed.checkpointing.checkpoint for both forward and backward pass. + By default false. Will overwrite deepspeed_config if provided + + profile: Optional: Logs the forward and backward time for each + deepspeed.checkpointing.checkpoint invocation. Will overwrite deepspeed_config + if provided + + Returns: + None + """ + global mpu, num_layers, deepspeed_checkpointing_enabled + + global PARTITION_ACTIVATIONS, CONTIGUOUS_CHECKPOINTING, \ + CPU_CHECKPOINT, SYNCHRONIZE, PROFILE_TIME + + _configure_defaults() + + if mpu_ is not None: + mpu = mpu_ + + if deepspeed_config is not None: + _configure_using_config_file(deepspeed_config, mpu=mpu) + + if partition_activations is not None: + PARTITION_ACTIVATIONS = partition_activations + + if contiguous_checkpointing is not None: + CONTIGUOUS_CHECKPOINTING = contiguous_checkpointing + + if num_checkpoints is not None: + num_layers = num_checkpoints + + if checkpoint_in_cpu is not None: + CPU_CHECKPOINT = checkpoint_in_cpu + + if synchronize is not None: + SYNCHRONIZE = synchronize + + if profile is not None: + PROFILE_TIME = profile + + if CONTIGUOUS_CHECKPOINTING: + assert PARTITION_ACTIVATIONS, "Contiguous Checkpointing is only available with partitioned activations. Set partitioned activations to true in deepspeed config" + if CONTIGUOUS_CHECKPOINTING: + assert num_layers is not None, "Must specify the number of layers with contiguous memory checkpointing" + + +def is_configured(): + """True if deepspeed activation checkpointing has been configured + by calling deepspeed.checkpointing.configure, else returns false + + Arguments: + None + + Return: + True of configured, else False + """ + return deepspeed_checkpointing_enabled diff --git a/deepspeed/runtime/activation_checkpointing/config.py b/deepspeed/runtime/activation_checkpointing/config.py index 19e904980da7..0ab59ac64eea 100755 --- a/deepspeed/runtime/activation_checkpointing/config.py +++ b/deepspeed/runtime/activation_checkpointing/config.py @@ -1,103 +1,103 @@ -""" -Copyright (c) Microsoft Corporation -Licensed under the MIT license. -""" - -from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject - -######################################### -# DeepSpeed Activation Checkpointing -######################################### -# Activation Checkpointing Allows to save memory by only keeping a select few -#activations for the backpropagation. -ACTIVATION_CHKPT_FORMAT = ''' -Activation Checkpointing should be configured as: -"session_params": { - "activation_checkpointing": { - "partitioned_activations": [true|false], - "number_checkpoints": 100, - "contiguous_memory_optimization": [true|false], - "cpu_checkpointing": [true|false] - "profile": [true|false], - "synchronize_checkpoint_boundary": [true|false], - } -} -''' - -ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations' -ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False - -ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints' -ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None - -ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization' -ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False - -ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary' -ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False - -ACT_CHKPT_PROFILE = 'profile' -ACT_CHKPT_PROFILE_DEFAULT = False - -ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing' -ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False - -ACT_CHKPT = 'activation_checkpointing' - -ACT_CHKPT_DEFAULT = { - ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, - ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT, - ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, - ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT -} - - -class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): - def __init__(self, param_dict): - super(DeepSpeedActivationCheckpointingConfig, self).__init__() - - self.partition_activations = None - self.contiguous_memory_optimization = None - self.cpu_checkpointing = None - self.number_checkpoints = None - self.synchronize_checkpoint_boundary = None - self.profile = None - - if ACT_CHKPT in param_dict.keys(): - act_chkpt_config_dict = param_dict[ACT_CHKPT] - else: - act_chkpt_config_dict = ACT_CHKPT_DEFAULT - - self._initialize(act_chkpt_config_dict) - - def _initialize(self, act_chkpt_config_dict): - self.partition_activations = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_PARTITION_ACTIVATIONS, - ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) - - self.contiguous_memory_optimization = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, - ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) - - self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_CPU_CHECKPOINTING, - ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) - - self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_NUMBER_CHECKPOINTS, - ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) - - self.profile = get_scalar_param(act_chkpt_config_dict, - ACT_CHKPT_PROFILE, - ACT_CHKPT_PROFILE_DEFAULT) - - self.synchronize_checkpoint_boundary = get_scalar_param( - act_chkpt_config_dict, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, - ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT) +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" + +from deepspeed.runtime.config_utils import get_scalar_param, DeepSpeedConfigObject + +######################################### +# DeepSpeed Activation Checkpointing +######################################### +# Activation Checkpointing Allows to save memory by only keeping a select few +#activations for the backpropagation. +ACTIVATION_CHKPT_FORMAT = ''' +Activation Checkpointing should be configured as: +"session_params": { + "activation_checkpointing": { + "partitioned_activations": [true|false], + "number_checkpoints": 100, + "contiguous_memory_optimization": [true|false], + "cpu_checkpointing": [true|false] + "profile": [true|false], + "synchronize_checkpoint_boundary": [true|false], + } +} +''' + +ACT_CHKPT_PARTITION_ACTIVATIONS = 'partition_activations' +ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT = False + +ACT_CHKPT_NUMBER_CHECKPOINTS = 'number_checkpoints' +ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT = None + +ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION = 'contiguous_memory_optimization' +ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT = False + +ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY = 'synchronize_checkpoint_boundary' +ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT = False + +ACT_CHKPT_PROFILE = 'profile' +ACT_CHKPT_PROFILE_DEFAULT = False + +ACT_CHKPT_CPU_CHECKPOINTING = 'cpu_checkpointing' +ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT = False + +ACT_CHKPT = 'activation_checkpointing' + +ACT_CHKPT_DEFAULT = { + ACT_CHKPT_PARTITION_ACTIVATIONS: ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT, + ACT_CHKPT_NUMBER_CHECKPOINTS: ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION: + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY: + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT, + ACT_CHKPT_PROFILE: ACT_CHKPT_PROFILE_DEFAULT, + ACT_CHKPT_CPU_CHECKPOINTING: ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT +} + + +class DeepSpeedActivationCheckpointingConfig(DeepSpeedConfigObject): + def __init__(self, param_dict): + super(DeepSpeedActivationCheckpointingConfig, self).__init__() + + self.partition_activations = None + self.contiguous_memory_optimization = None + self.cpu_checkpointing = None + self.number_checkpoints = None + self.synchronize_checkpoint_boundary = None + self.profile = None + + if ACT_CHKPT in param_dict.keys(): + act_chkpt_config_dict = param_dict[ACT_CHKPT] + else: + act_chkpt_config_dict = ACT_CHKPT_DEFAULT + + self._initialize(act_chkpt_config_dict) + + def _initialize(self, act_chkpt_config_dict): + self.partition_activations = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_PARTITION_ACTIVATIONS, + ACT_CHKPT_PARTITION_ACTIVATIONS_DEFAULT) + + self.contiguous_memory_optimization = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION, + ACT_CHKPT_CONTIGUOUS_MEMORY_OPTIMIZATION_DEFAULT) + + self.cpu_checkpointing = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_CPU_CHECKPOINTING, + ACT_CHKPT_CPU_CHECKPOINTING_DEFAULT) + + self.number_checkpoints = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_NUMBER_CHECKPOINTS, + ACT_CHKPT_NUMBER_CHECKPOINTS_DEFAULT) + + self.profile = get_scalar_param(act_chkpt_config_dict, + ACT_CHKPT_PROFILE, + ACT_CHKPT_PROFILE_DEFAULT) + + self.synchronize_checkpoint_boundary = get_scalar_param( + act_chkpt_config_dict, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY, + ACT_CHKPT_SYNCHRONIZE_CHECKPOINT_BOUNDARY_DEFAULT) diff --git a/deepspeed/runtime/config_utils.py b/deepspeed/runtime/config_utils.py index 1d499cdcb3dd..199c773f4379 100755 --- a/deepspeed/runtime/config_utils.py +++ b/deepspeed/runtime/config_utils.py @@ -1,80 +1,80 @@ -""" -Copyright (c) Microsoft Corporation -Licensed under the MIT license. -""" -""" -Collection of DeepSpeed configuration utilities -""" -import json -import collections - - -# adapted from https://stackoverflow.com/a/50701137/9201239 -class ScientificNotationEncoder(json.JSONEncoder): - """ - This class overrides ``json.dumps`` default formatter. - - This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation. - - Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it - - """ - def iterencode(self, o, _one_shot=False, level=0): - indent = self.indent if self.indent is not None else 4 - prefix_close = " " * level * indent - level += 1 - prefix = " " * level * indent - if isinstance(o, bool): - return "true" if o else "false" - elif isinstance(o, float) or isinstance(o, int): - if o > 1e3: - return f"{o:e}" - else: - return f"{o}" - elif isinstance(o, collections.Mapping): - x = [ - f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, - v in o.items() - ] - return "{" + ', '.join(x) + f"\n{prefix_close}" + "}" - elif isinstance(o, collections.Sequence) and not isinstance(o, str): - return f"[{ f', '.join(map(self.iterencode, o)) }]" - return "\n, ".join(super().iterencode(o, _one_shot)) - - -class DeepSpeedConfigObject(object): - """ - For json serialization - """ - def repr(self): - return self.__dict__ - - def __repr__(self): - return json.dumps( - self.__dict__, - sort_keys=True, - indent=4, - cls=ScientificNotationEncoder, - ) - - -def get_scalar_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def get_list_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def get_dict_param(param_dict, param_name, param_default_value): - return param_dict.get(param_name, param_default_value) - - -def dict_raise_error_on_duplicate_keys(ordered_pairs): - """Reject duplicate keys.""" - d = dict((k, v) for k, v in ordered_pairs) - if len(d) != len(ordered_pairs): - counter = collections.Counter([pair[0] for pair in ordered_pairs]) - keys = [key for key, value in counter.items() if value > 1] - raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys)) - return d +""" +Copyright (c) Microsoft Corporation +Licensed under the MIT license. +""" +""" +Collection of DeepSpeed configuration utilities +""" +import json +import collections + + +# adapted from https://stackoverflow.com/a/50701137/9201239 +class ScientificNotationEncoder(json.JSONEncoder): + """ + This class overrides ``json.dumps`` default formatter. + + This version keeps everything as normal except formats numbers bigger than 1e3 using scientific notation. + + Just pass ``cls=ScientificNotationEncoder`` to ``json.dumps`` to activate it + + """ + def iterencode(self, o, _one_shot=False, level=0): + indent = self.indent if self.indent is not None else 4 + prefix_close = " " * level * indent + level += 1 + prefix = " " * level * indent + if isinstance(o, bool): + return "true" if o else "false" + elif isinstance(o, float) or isinstance(o, int): + if o > 1e3: + return f"{o:e}" + else: + return f"{o}" + elif isinstance(o, collections.Mapping): + x = [ + f'\n{prefix}"{k}": {self.iterencode(v, level=level)}' for k, + v in o.items() + ] + return "{" + ', '.join(x) + f"\n{prefix_close}" + "}" + elif isinstance(o, collections.Sequence) and not isinstance(o, str): + return f"[{ f', '.join(map(self.iterencode, o)) }]" + return "\n, ".join(super().iterencode(o, _one_shot)) + + +class DeepSpeedConfigObject(object): + """ + For json serialization + """ + def repr(self): + return self.__dict__ + + def __repr__(self): + return json.dumps( + self.__dict__, + sort_keys=True, + indent=4, + cls=ScientificNotationEncoder, + ) + + +def get_scalar_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def get_list_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def get_dict_param(param_dict, param_name, param_default_value): + return param_dict.get(param_name, param_default_value) + + +def dict_raise_error_on_duplicate_keys(ordered_pairs): + """Reject duplicate keys.""" + d = dict((k, v) for k, v in ordered_pairs) + if len(d) != len(ordered_pairs): + counter = collections.Counter([pair[0] for pair in ordered_pairs]) + keys = [key for key, value in counter.items() if value > 1] + raise ValueError("Duplicate keys in DeepSpeed config: {}".format(keys)) + return d diff --git a/deepspeed/runtime/eigenvalue.py b/deepspeed/runtime/eigenvalue.py index b1974d975f9f..490899bda576 100755 --- a/deepspeed/runtime/eigenvalue.py +++ b/deepspeed/runtime/eigenvalue.py @@ -1,152 +1,152 @@ -import torch -from deepspeed.utils import log_dist -import numpy as np -import logging - - -class Eigenvalue(object): - def __init__(self, - verbose=False, - max_iter=100, - tol=1e-2, - stability=0, - gas_boundary_resolution=1, - layer_name='', - layer_num=0): - super().__init__() - - self.verbose = verbose - self.max_iter = max_iter - self.tol = tol - self.stability = stability - self.gas_boundary_resolution = gas_boundary_resolution - self.layer_name = layer_name - self.layer_num = layer_num - - assert len(self.layer_name) > 0 and layer_num > 0 - - log_dist( - f'enabled eigenvalue with verbose={verbose}, max_iter={max_iter}, tol={tol}, stability={stability}, gas_boundary_resolution={gas_boundary_resolution}, layer_name={layer_name}, layer_num={layer_num}', - ranks=[0]) - - # Replace all nan/pos-inf/neg-inf to zero - # TODO: Pytorch new version may add this function, replace this one by then. - def nan_to_num(self, x): - device = x.device - x = x.cpu().numpy() - x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) - return torch.from_numpy(x).to(device) - - def normalize(self, v): - norm_squared = self.inner_product(v, v) - norm = norm_squared**0.5 + self.stability - normalized_vectors = [vector / norm for vector in v] - normalized_vectors = [self.nan_to_num(vector) for vector in normalized_vectors] - return normalized_vectors - - def inner_product(self, xs, ys): - return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)]) - - def get_layers(self, module): - scope_names = self.layer_name.split('.') - assert len(scope_names) > 0 - - m = module - for name in scope_names: - assert hasattr(m, name), "layer_name configuration is invalid." - m = getattr(m, name) - - return m - - def compute_eigenvalue(self, module, device=None, scale=1.0): - block_eigenvalue = [] - param_keys = [] - layers = self.get_layers(module) - - for block in range(self.layer_num): - model_block = layers[block] - - # We found this randn() has obvious accuracy impact in some cases, save/recover random state here. - rng_state = torch.random.get_rng_state() - if device is None: - v = [ - torch.randn(p.size()) for p in model_block.parameters() - if p.grad is not None and p.grad.grad_fn is not None - ] - else: - v = [ - torch.randn(p.size(), - device=device) for p in model_block.parameters() - if p.grad is not None and p.grad.grad_fn is not None - ] - torch.random.set_rng_state(rng_state) - - grads = [ - param.grad for param in model_block.parameters() - if param.grad is not None and param.grad.grad_fn is not None - ] - params = [ - param for param in model_block.parameters() - if param.grad is not None and param.grad.grad_fn is not None - ] - - layer_keys = [id(p) for p in model_block.parameters()] - param_keys.append(layer_keys) - - v = self.normalize(v) - - # Disable eigenvalue if the model doesn't support second order gradients computation, - # e.g. when enabling DS transformer kernel. - if len(grads) == 0 or len(params) == 0: - log_dist(f'The model does NOT support eigenvalue computation.', - ranks=[0], - level=logging.WARNING) - return [] - - i = 0 - eigenvalue_current, eigenvalue_previous = 1., 0. - - while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( - (eigenvalue_current - eigenvalue_previous) / - eigenvalue_current) >= self.tol): # test convergence criteria - eigenvalue_previous = eigenvalue_current - - Hv = torch.autograd.grad(grads, - params, - grad_outputs=v, - only_inputs=True, - retain_graph=True) - #Hv = [hv.float() for hv in Hv] - Hv = [self.nan_to_num(hv).float() for hv in Hv] - - eigenvalue_current = self.inner_product(Hv, v).item() - - v = self.normalize(Hv) - v = [x / scale for x in v] - i += 1 - - eigenvalue_current *= scale - block_eigenvalue.append(eigenvalue_current) - - if self.verbose: - log_dist( - f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', - ranks=[0]) - - block_eigenvalue = self.post_process(block_eigenvalue) - - if self.verbose: - log_dist(f'post processed block_eigenvalue: {block_eigenvalue}', ranks=[0]) - - # {param_id: (eigenvalue, layer_id)} - ev_dict = {} - for i, (layer_keys, value) in enumerate(zip(param_keys, block_eigenvalue)): - ev_dict.update(dict.fromkeys(layer_keys, (value, i))) - - return ev_dict - - # 1. Map all eigenvalues to [0, 1.0]. - # 2. Some layers can't generate valid eigenvalues on fp16 precision, use 1.0 instead. - def post_process(self, value_list): - max_value = abs(max(value_list, key=abs)) - return [abs(v) / max_value if v != 0.0 else 1.0 for v in value_list] +import torch +from deepspeed.utils import log_dist +import numpy as np +import logging + + +class Eigenvalue(object): + def __init__(self, + verbose=False, + max_iter=100, + tol=1e-2, + stability=0, + gas_boundary_resolution=1, + layer_name='', + layer_num=0): + super().__init__() + + self.verbose = verbose + self.max_iter = max_iter + self.tol = tol + self.stability = stability + self.gas_boundary_resolution = gas_boundary_resolution + self.layer_name = layer_name + self.layer_num = layer_num + + assert len(self.layer_name) > 0 and layer_num > 0 + + log_dist( + f'enabled eigenvalue with verbose={verbose}, max_iter={max_iter}, tol={tol}, stability={stability}, gas_boundary_resolution={gas_boundary_resolution}, layer_name={layer_name}, layer_num={layer_num}', + ranks=[0]) + + # Replace all nan/pos-inf/neg-inf to zero + # TODO: Pytorch new version may add this function, replace this one by then. + def nan_to_num(self, x): + device = x.device + x = x.cpu().numpy() + x = np.nan_to_num(x=x, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + return torch.from_numpy(x).to(device) + + def normalize(self, v): + norm_squared = self.inner_product(v, v) + norm = norm_squared**0.5 + self.stability + normalized_vectors = [vector / norm for vector in v] + normalized_vectors = [self.nan_to_num(vector) for vector in normalized_vectors] + return normalized_vectors + + def inner_product(self, xs, ys): + return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)]) + + def get_layers(self, module): + scope_names = self.layer_name.split('.') + assert len(scope_names) > 0 + + m = module + for name in scope_names: + assert hasattr(m, name), "layer_name configuration is invalid." + m = getattr(m, name) + + return m + + def compute_eigenvalue(self, module, device=None, scale=1.0): + block_eigenvalue = [] + param_keys = [] + layers = self.get_layers(module) + + for block in range(self.layer_num): + model_block = layers[block] + + # We found this randn() has obvious accuracy impact in some cases, save/recover random state here. + rng_state = torch.random.get_rng_state() + if device is None: + v = [ + torch.randn(p.size()) for p in model_block.parameters() + if p.grad is not None and p.grad.grad_fn is not None + ] + else: + v = [ + torch.randn(p.size(), + device=device) for p in model_block.parameters() + if p.grad is not None and p.grad.grad_fn is not None + ] + torch.random.set_rng_state(rng_state) + + grads = [ + param.grad for param in model_block.parameters() + if param.grad is not None and param.grad.grad_fn is not None + ] + params = [ + param for param in model_block.parameters() + if param.grad is not None and param.grad.grad_fn is not None + ] + + layer_keys = [id(p) for p in model_block.parameters()] + param_keys.append(layer_keys) + + v = self.normalize(v) + + # Disable eigenvalue if the model doesn't support second order gradients computation, + # e.g. when enabling DS transformer kernel. + if len(grads) == 0 or len(params) == 0: + log_dist(f'The model does NOT support eigenvalue computation.', + ranks=[0], + level=logging.WARNING) + return [] + + i = 0 + eigenvalue_current, eigenvalue_previous = 1., 0. + + while (i < self.max_iter) and abs(eigenvalue_current) > 0 and (abs( + (eigenvalue_current - eigenvalue_previous) / + eigenvalue_current) >= self.tol): # test convergence criteria + eigenvalue_previous = eigenvalue_current + + Hv = torch.autograd.grad(grads, + params, + grad_outputs=v, + only_inputs=True, + retain_graph=True) + #Hv = [hv.float() for hv in Hv] + Hv = [self.nan_to_num(hv).float() for hv in Hv] + + eigenvalue_current = self.inner_product(Hv, v).item() + + v = self.normalize(Hv) + v = [x / scale for x in v] + i += 1 + + eigenvalue_current *= scale + block_eigenvalue.append(eigenvalue_current) + + if self.verbose: + log_dist( + f'block: {block}, power iteration: {i}, eigenvalue: {eigenvalue_current}', + ranks=[0]) + + block_eigenvalue = self.post_process(block_eigenvalue) + + if self.verbose: + log_dist(f'post processed block_eigenvalue: {block_eigenvalue}', ranks=[0]) + + # {param_id: (eigenvalue, layer_id)} + ev_dict = {} + for i, (layer_keys, value) in enumerate(zip(param_keys, block_eigenvalue)): + ev_dict.update(dict.fromkeys(layer_keys, (value, i))) + + return ev_dict + + # 1. Map all eigenvalues to [0, 1.0]. + # 2. Some layers can't generate valid eigenvalues on fp16 precision, use 1.0 instead. + def post_process(self, value_list): + max_value = abs(max(value_list, key=abs)) + return [abs(v) / max_value if v != 0.0 else 1.0 for v in value_list] diff --git a/deepspeed/runtime/progressive_layer_drop.py b/deepspeed/runtime/progressive_layer_drop.py index 770978a940a0..41c08cfd9e7c 100755 --- a/deepspeed/runtime/progressive_layer_drop.py +++ b/deepspeed/runtime/progressive_layer_drop.py @@ -1,33 +1,33 @@ -import numpy as np -from deepspeed.utils import log_dist - - -class ProgressiveLayerDrop(object): - r""" Progressive Layer Dropping (PLD) for model training. - This implements the PLD technique for compressed model training - from this paper: https://arxiv.org/pdf/2010.13369.pdf - Args: - theta (float): a hyper-parameter that controls the trade-off between training time and robustness. - The lower the theta value, the faster the training speed. Default value: 0.5. - gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. - """ - def __init__(self, theta=0.5, gamma=0.001): - super().__init__() - - self.theta = theta - self.gamma = gamma - self.current_theta = 1.0 - log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) - - def get_state(self): - kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} - return kwargs - - def get_theta(self): - return self.current_theta - - def update_state(self, global_step): - def _prob(x, gamma, p): - return (1. - p) * np.exp(-gamma * x) + p - - self.current_theta = _prob(global_step, self.gamma, self.theta) +import numpy as np +from deepspeed.utils import log_dist + + +class ProgressiveLayerDrop(object): + r""" Progressive Layer Dropping (PLD) for model training. + This implements the PLD technique for compressed model training + from this paper: https://arxiv.org/pdf/2010.13369.pdf + Args: + theta (float): a hyper-parameter that controls the trade-off between training time and robustness. + The lower the theta value, the faster the training speed. Default value: 0.5. + gamma (float): a hyper-parameter that controls how fast the drop ratio increases. Default value: 0.001. + """ + def __init__(self, theta=0.5, gamma=0.001): + super().__init__() + + self.theta = theta + self.gamma = gamma + self.current_theta = 1.0 + log_dist(f'Enabled progressive layer dropping (theta = {self.theta})', ranks=[0]) + + def get_state(self): + kwargs = {'progressive_layer_drop': True, 'pld_theta': self.get_theta()} + return kwargs + + def get_theta(self): + return self.current_theta + + def update_state(self, global_step): + def _prob(x, gamma, p): + return (1. - p) * np.exp(-gamma * x) + p + + self.current_theta = _prob(global_step, self.gamma, self.theta) diff --git a/deepspeed/runtime/quantize.py b/deepspeed/runtime/quantize.py index a23d189aaab8..05fc50201b77 100755 --- a/deepspeed/runtime/quantize.py +++ b/deepspeed/runtime/quantize.py @@ -1,224 +1,224 @@ -import torch -import math -from deepspeed.utils import log_dist -from deepspeed.utils import logger -from deepspeed.ops.quantizer import ds_quantizer - -# number of 2-dimensional parameters in a layer -# this is set for transformer-based models -TWO_D_PARAMS = 6 - - -class Quantizer(object): - def __init__(self, - q_target_bits=8, - q_start_bits=16, - q_period=100, - q_offset=100, - q_groups=1, - q_mixed_fp16=False, - q_change_ratio=0.01, - q_type=0, - q_rounding=0, - q_verbose=False, - q_eigenvalue=False, - use_quantizer_kernel=False, - layer_num=0): - - self.q_target_bits = q_target_bits - - self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1) - self.q_period = [q_period] * (layer_num if layer_num != 0 else 1) - self.q_offset = q_offset - self.q_groups = q_groups - self.q_mixed_fp16 = q_mixed_fp16 - self.q_change_ratio = q_change_ratio - self.q_type = q_type - self.qsteps = 0 - self.q_init_period = q_period - self.quantize_real_ratio = 1.000 - self.q_verbose = q_verbose - self.q_eigenvalue = q_eigenvalue - self.use_quantizer_kernel = use_quantizer_kernel - self.q_rounding = q_rounding - self.layer_num = layer_num - - def any_precision_switch(self): - if self.layer_num == 0: - return True - result = False - for index in range(self.layer_num): - if self.q_start_bits[index] != self.q_target_bits: - next_step = self.qsteps + ( - TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) - if next_step >= self.q_period[index]: - result = True - return result - - def quantize(self, - parameter_group, - overflow, - eigenvalue_enabled, - block_eigenvalue={}): - - if overflow and not eigenvalue_enabled: - return - - self.step() - - self.update_fp16_ratio() - - for i in range(len(parameter_group)): - for p in parameter_group[i]: - if len(p.size()) > 1: - param_id = id(p) - eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0) - if eigenvalue is not None: - factor = 1 + math.floor(eigenvalue * 4) - p.data = self.compute_quantization(p.data, layer_id, factor) - else: - p.data = self.compute_quantization(p.data, layer_id) - - def step(self): - self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) - - def sr_quantize(self, input_flat, input_g, scale): - # Random number generator (Uniform) - p = torch.cuda.FloatTensor(input_flat.size(), - device=input_flat.device).uniform_() - p = torch.split(p, p.size(0) // self.q_groups) - add_s = torch.zeros_like(input_flat) - add_s = torch.split(add_s, add_s.size(0) // self.q_groups) - - scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] - # Quantize with INT rounding - input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)] - # Compute the error - error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)] - # Stochastic Rounding - add_s = [ - a_s.masked_fill_(pg < err_g, - 1 / s) for (a_s, - pg, - err_g, - s) in zip(add_s, - p, - error, - scale) - ] - add_s = [ - a_s * (g > 0).float() - a_s * (g < 0).float() for a_s, - g in zip(add_s, - input_flat) - ] - input_flat = [((q + a_s) * s).clamp(-(q_range >> 1), - (q_range >> 1) - 1) / s for q, - a_s, - s in zip(input_flat, - add_s, - scale)] - return input_flat - - def mixed_fp16_quantize(self, input, input_q, index): - if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1): - input_q = input * self.quantize_real_ratio + ( - 1 - self.quantize_real_ratio) * input_q - return input_q - return input_q - - def compute_quantization(self, input, index=0, factor=1): - # fixing the quantization bits based on the training steps - # when reducing 1 bit at each period, we increase the period - # to go slowly toward the target quantization bits - # the period and starting bit can be configured - if self.q_offset > 0: - if self.qsteps >= self.q_offset: - self.q_offset = 0 - self.qsteps = 0 - else: - return input - - if self.q_start_bits[index] != self.q_target_bits: - if self.qsteps >= self.q_period[index]: - self.quantize_real_ratio = 1.0 - if self.q_eigenvalue: - self.q_period[index] <<= 1 - self.q_period[index] *= factor - self.q_start_bits[index] -= 1 - else: - for i in range(len(self.q_start_bits)): - self.q_start_bits[i] -= 1 - self.q_period[i] <<= 1 - if self.q_verbose: - logger.info( - f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}' - ) - assert (self.q_start_bits[index] >= self.q_target_bits), \ - 'Quantization bit is lower than target precision bits!' - - # quantize the weights base on the selected bits and the value-range - if not self.use_quantizer_kernel: - q_range = 2**self.q_start_bits[index] - input_flat = input.view(-1) - input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups) - if self.q_type == 0: #symmetric - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index]) - else: - scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] - if self.q_rounding == 0: # Nearest value rounding - input_flat = [(g * s).round().clamp(-(q_range >> 1), - (q_range >> 1) - 1) / s for g, - s in zip(input_g, - scale)] - else: # Stochastic Rounding - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - sr=True) - else: - input_flat = self.sr_quantize(input_flat, input_g) - else: #asymmetric - if self.q_rounding == 0: - if self.use_quantizer_kernel: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - asym=True) - else: - scale = [(g.max() - g.min()) / q_range for g in input_g] - input_flat = [ - ((g - g.min()) / s).round().clamp(0, - (q_range - 1)) * s + g.min() - for g, - s in zip(input_g, - scale) - ] - else: - input_q = ds_quantizer(input.clone(), - self.q_groups, - self.q_start_bits[index], - asym=True) - - if self.use_quantizer_kernel or (self.q_type and self.q_rounding): - return self.mixed_fp16_quantize(input, input_q, index) - else: - if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - - 1): - input_flat = [(self.quantize_real_ratio * g) + - ((1 - self.quantize_real_ratio) * g_q) for g, - g_q in zip(input_g, - input_flat)] - input_q = torch.cat(input_flat) - input_q = input_q.reshape(input.size()) - return input_q - - def update_fp16_ratio(self): - if self.q_mixed_fp16: - if self.quantize_real_ratio > 0: - self.quantize_real_ratio -= self.q_change_ratio - else: - self.quantize_real_ratio = 0.000 +import torch +import math +from deepspeed.utils import log_dist +from deepspeed.utils import logger +from deepspeed.ops.quantizer import ds_quantizer + +# number of 2-dimensional parameters in a layer +# this is set for transformer-based models +TWO_D_PARAMS = 6 + + +class Quantizer(object): + def __init__(self, + q_target_bits=8, + q_start_bits=16, + q_period=100, + q_offset=100, + q_groups=1, + q_mixed_fp16=False, + q_change_ratio=0.01, + q_type=0, + q_rounding=0, + q_verbose=False, + q_eigenvalue=False, + use_quantizer_kernel=False, + layer_num=0): + + self.q_target_bits = q_target_bits + + self.q_start_bits = [q_start_bits] * (layer_num if layer_num != 0 else 1) + self.q_period = [q_period] * (layer_num if layer_num != 0 else 1) + self.q_offset = q_offset + self.q_groups = q_groups + self.q_mixed_fp16 = q_mixed_fp16 + self.q_change_ratio = q_change_ratio + self.q_type = q_type + self.qsteps = 0 + self.q_init_period = q_period + self.quantize_real_ratio = 1.000 + self.q_verbose = q_verbose + self.q_eigenvalue = q_eigenvalue + self.use_quantizer_kernel = use_quantizer_kernel + self.q_rounding = q_rounding + self.layer_num = layer_num + + def any_precision_switch(self): + if self.layer_num == 0: + return True + result = False + for index in range(self.layer_num): + if self.q_start_bits[index] != self.q_target_bits: + next_step = self.qsteps + ( + TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) + if next_step >= self.q_period[index]: + result = True + return result + + def quantize(self, + parameter_group, + overflow, + eigenvalue_enabled, + block_eigenvalue={}): + + if overflow and not eigenvalue_enabled: + return + + self.step() + + self.update_fp16_ratio() + + for i in range(len(parameter_group)): + for p in parameter_group[i]: + if len(p.size()) > 1: + param_id = id(p) + eigenvalue, layer_id = block_eigenvalue[param_id] if param_id in block_eigenvalue else (None, 0) + if eigenvalue is not None: + factor = 1 + math.floor(eigenvalue * 4) + p.data = self.compute_quantization(p.data, layer_id, factor) + else: + p.data = self.compute_quantization(p.data, layer_id) + + def step(self): + self.qsteps += (TWO_D_PARAMS * (self.layer_num if self.layer_num != 0 else 1)) + + def sr_quantize(self, input_flat, input_g, scale): + # Random number generator (Uniform) + p = torch.cuda.FloatTensor(input_flat.size(), + device=input_flat.device).uniform_() + p = torch.split(p, p.size(0) // self.q_groups) + add_s = torch.zeros_like(input_flat) + add_s = torch.split(add_s, add_s.size(0) // self.q_groups) + + scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] + # Quantize with INT rounding + input_flat = [(g * s).int().float() / s for (g, s) in zip(input_g, scale)] + # Compute the error + error = [((g - q).abs() / s) for (g, s, q) in zip(input_g, scale, input_flat)] + # Stochastic Rounding + add_s = [ + a_s.masked_fill_(pg < err_g, + 1 / s) for (a_s, + pg, + err_g, + s) in zip(add_s, + p, + error, + scale) + ] + add_s = [ + a_s * (g > 0).float() - a_s * (g < 0).float() for a_s, + g in zip(add_s, + input_flat) + ] + input_flat = [((q + a_s) * s).clamp(-(q_range >> 1), + (q_range >> 1) - 1) / s for q, + a_s, + s in zip(input_flat, + add_s, + scale)] + return input_flat + + def mixed_fp16_quantize(self, input, input_q, index): + if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - 1): + input_q = input * self.quantize_real_ratio + ( + 1 - self.quantize_real_ratio) * input_q + return input_q + return input_q + + def compute_quantization(self, input, index=0, factor=1): + # fixing the quantization bits based on the training steps + # when reducing 1 bit at each period, we increase the period + # to go slowly toward the target quantization bits + # the period and starting bit can be configured + if self.q_offset > 0: + if self.qsteps >= self.q_offset: + self.q_offset = 0 + self.qsteps = 0 + else: + return input + + if self.q_start_bits[index] != self.q_target_bits: + if self.qsteps >= self.q_period[index]: + self.quantize_real_ratio = 1.0 + if self.q_eigenvalue: + self.q_period[index] <<= 1 + self.q_period[index] *= factor + self.q_start_bits[index] -= 1 + else: + for i in range(len(self.q_start_bits)): + self.q_start_bits[i] -= 1 + self.q_period[i] <<= 1 + if self.q_verbose: + logger.info( + f'Quantization settings: current bit-precision = {self.q_start_bits[index]}, step = {self.qsteps}, quantization period = {self.q_period[index]}, index = {index}' + ) + assert (self.q_start_bits[index] >= self.q_target_bits), \ + 'Quantization bit is lower than target precision bits!' + + # quantize the weights base on the selected bits and the value-range + if not self.use_quantizer_kernel: + q_range = 2**self.q_start_bits[index] + input_flat = input.view(-1) + input_g = torch.split(input_flat, input_flat.size(0) // self.q_groups) + if self.q_type == 0: #symmetric + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index]) + else: + scale = [q_range / (2 * max(g.max(), g.min().abs())) for g in input_g] + if self.q_rounding == 0: # Nearest value rounding + input_flat = [(g * s).round().clamp(-(q_range >> 1), + (q_range >> 1) - 1) / s for g, + s in zip(input_g, + scale)] + else: # Stochastic Rounding + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + sr=True) + else: + input_flat = self.sr_quantize(input_flat, input_g) + else: #asymmetric + if self.q_rounding == 0: + if self.use_quantizer_kernel: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + asym=True) + else: + scale = [(g.max() - g.min()) / q_range for g in input_g] + input_flat = [ + ((g - g.min()) / s).round().clamp(0, + (q_range - 1)) * s + g.min() + for g, + s in zip(input_g, + scale) + ] + else: + input_q = ds_quantizer(input.clone(), + self.q_groups, + self.q_start_bits[index], + asym=True) + + if self.use_quantizer_kernel or (self.q_type and self.q_rounding): + return self.mixed_fp16_quantize(input, input_q, index) + else: + if self.q_mixed_fp16 and self.q_start_bits[index] >= (self.q_target_bits - + 1): + input_flat = [(self.quantize_real_ratio * g) + + ((1 - self.quantize_real_ratio) * g_q) for g, + g_q in zip(input_g, + input_flat)] + input_q = torch.cat(input_flat) + input_q = input_q.reshape(input.size()) + return input_q + + def update_fp16_ratio(self): + if self.q_mixed_fp16: + if self.quantize_real_ratio > 0: + self.quantize_real_ratio -= self.q_change_ratio + else: + self.quantize_real_ratio = 0.000 diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 2134020d52c0..d8f79f69a577 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1,3351 +1,3351 @@ -""" -"Copyright 2020 The Microsoft DeepSpeed Team. -Licensed under the MIT license. -""" - -import gc -from dataclasses import dataclass -import functools -import os -import collections -from collections import OrderedDict, UserDict -import itertools -from typing import Deque, Dict, Iterable, Set, Tuple -import torch -from torch.cuda import Event, Stream -from torch.nn import Module, Parameter -import torch.distributed as dist -import math -from torch._six import inf -from torch.nn import Module -from torch.nn.parameter import Parameter - -from deepspeed.utils.logging import logger -from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler -from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced -from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim -from deepspeed.runtime.zero.partition_parameters import * -from deepspeed.runtime.zero.partition_parameters import _init_external_params -from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS -from deepspeed.ops.adam import DeepSpeedCPUAdam -from deepspeed.ops.op_builder import UtilsBuilder -from deepspeed.runtime.zero.offload_constants import * -from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus -from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper -from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper - -# Toggle this to true to enable correctness test -# with gradient partitioning and without -pg_correctness_test = False - -FWD_MODULE_STACK = list() - - -def print_rank_0(message, debug=False, force=False): - rank = torch.distributed.get_rank() - if rank == 0 and (debug or force): - print(message) - # other variations - # - print for all ranks w/o interleaving - # printflock(f"[{rank}] {message}") - # - print to log file per rank - # log_rank_file(rank, message) - - -def input(msg): - return - - -def isclose(a, b, rtol=1e-09, atol=0.0): - return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) - - -def lcm(x, y): - from fractions import gcd # or can import gcd from `math` in Python 3 - return x * y // gcd(x, y) - - -def debug_rank0(message: str) -> None: - if dist.get_rank() == 0: - logger.debug(message) - - -def get_cuda_mem_allocated_str() -> str: - # this is really slow. when enabled the python process becomes slow - # to the point where it can't keep the GPU fed with work, so only enable - # for memory debugging. - # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" - return "xGB" - - -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - -@instrument_w_nvtx -def get_all_parameters(sub_module, recurse=False): - return itertools.chain(sub_module.named_parameters(recurse=recurse), - sub_module.ds_external_parameters()) - - -def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: - return map(lambda pair: pair[1], get_all_parameters(module, recurse)) - - -#apply torch.autograd.Function that calls a backward_function to tensors in output -def _apply_to_tensors_only(module, functional, backward_function, outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_to_tensors_only(module, - functional, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - return functional.apply(module, backward_function, outputs) - else: - return outputs - - -#for each tensor in outputs run the forward_function and register backward_function as hook -def _apply_forward_and_backward_to_tensors_only(module, - forward_function, - backward_function, - outputs): - if type(outputs) is tuple: - touched_outputs = [] - for output in outputs: - touched_output = _apply_forward_and_backward_to_tensors_only( - module, - forward_function, - backward_function, - output) - touched_outputs.append(touched_output) - return tuple(touched_outputs) - elif type(outputs) is torch.Tensor: - forward_function(outputs) - if outputs.requires_grad: - outputs.register_hook(backward_function) - return outputs - else: - return outputs - - -class ZeROOrderedDict(OrderedDict): - def __init__(self, parent_module, *args, **kwargs): - """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. - - Args: - parent_module (``collections.OrderedDict``): the collection to replace - """ - - super().__init__(*args, **kwargs) - self._parent_module = parent_module - self._in_forward = False - - def __getitem__(self, key): - param = super().__getitem__(key) - - # Params can be registered as None (e.g., bias) - if param is None: - return param - - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - if self._parent_module._parameters._in_forward: - print_rank_0(f'Registering external parameter from getter {key}', - force=False) - register_external_parameter(FWD_MODULE_STACK[-1], param) - param.all_gather() - - return param - - -def _inject_parameters(module, cls): - for module in module.modules(): - if cls == ZeROOrderedDict: - new_param = cls(parent_module=module) - else: - new_param = cls() - - for key, param in module._parameters.items(): - new_param[key] = param - module._parameters = new_param - - -class PartitionedParameterCoordinator: - """Handles partitioning and gathering of parameters.""" - class __InflightParamRegistry(UserDict): - """registry for parameters in flight""" - def __setitem__(self, - param: Parameter, - handle: AllGatherCoalescedHandle) -> None: - if param in self.data: - raise RuntimeError(f"{param.ds_summary()} already in registry") - if param.ds_status != ZeroParamStatus.INFLIGHT: - raise RuntimeError( - f"attempted to add non-inflight parameter to registry {param.ds_summary()}" - ) - self.data[param] = handle - - @dataclass - class __ParamInTrace: - param: Parameter - step_id_last_used_at: int - - def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: Stream, - prefetch_nvme: bool = False, - ) -> None: - # mapping of param -> handle for each param that is currently in flight - self.__inflight_param_registry = __class__.__InflightParamRegistry() - # keeps track of the number of submodules invoked so far. - self.__step_id: int = 0 - # whether or not we have completed a trace of the entire network. This should - # always be true after the first forward pass + backward pass. - self.trace_complete: bool = False - # sequence of submodules/parameters in forward pass + backward pass - self.__submodule_order: Iterable[Module] = [] - self.__param_order: Iterable[__class__.__ParamInTrace] = [] - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - # number of available params, and max number of available params - self.__n_available_params: int = 0 - self.__max_n_available_params: int = max_available_parameters_in_numel - # max distance between two use of the module beyond which module is released - self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel - # queue for parameters to fetch. parameters will be popped off the left - # side of the dequeue as they are fetched - self.__param_queue: Deque[__class__.__ParamInTrace] = None - self.__prefetch_bucket_sz: int = prefetch_bucket_sz - self.__prefetch_nvme: bool = prefetch_nvme - self.hierarchy: int = 0 - - # stream that will be used for allgather operations - self.__allgather_stream: Stream = allgather_stream - - # limit the number of fetch events that can be queued at once - # otherwise, what happens is memory is allocated by the host thread at the - # time of the call, but not used until later by the asynchronous cuda stream. - # allowing an infinite number of these to queue up causes a lot of memory - # pressure that then becomes detrimental to performance. - # this is a much less elegant way of fixing this vs something like using - # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now - # because ideally in the future its replaced by an async allocation - # mechanism which doesnt require any configuration by the user. - self.__ongoing_fetch_events: Deque[Event] = collections.deque() - self.__max_ongoing_fetch_events: int = 2 - - """Tracing and Tracking - TODO. consider performing trace before initializing PartitionedParameterCoordinator - and passing trace results into constructor. This way all the code in here can - just assume that the trace is complete and the results can be entirely - immutable. - - Bookkeeping operations used to track where we are in the forward/backward pass - """ - - def record_trace(self, sub_module: Module) -> None: - """adds sub module to trace""" - if self.trace_complete: - raise RuntimeError( - "attemted to record trace when trace was already complete") - - self.__submodule_order.append(sub_module) - for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): - self.__param_order.append( - __class__.__ParamInTrace(param=param, - step_id_last_used_at=self.__step_id)) - - def reset_step(self) -> None: - """indicate that we have completed one fwd+bwd for the model""" - if self.__inflight_param_registry: - raise RuntimeError( - f"still have inflight params " - f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") - - if not self.trace_complete: - # make sure that recorded parameter and submodule orders are - # identical across ranks - assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) - assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) - assert_ints_same_as_other_ranks( - [p.step_id_last_used_at for p in self.__param_order]) - - self.__submodule_order = tuple(self.__submodule_order) # freeze - self.__param_order = tuple(self.__param_order) # freeze - self.trace_complete = True - print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", - force=True) - - self.__param_queue = collections.deque(self.__param_order) # reset fetch queue - self.__most_recent_step_id_param_fetched_for = collections.defaultdict( - lambda: int(-1e10)) - self.__step_id = 0 - self.__n_available_params = 0 - - """Fetch and Release - Fetching, prefetching, and releasing parameters - """ - - @instrument_w_nvtx - @torch.no_grad() - def fetch_sub_module(self, current_submodule: Module) -> None: - """This method does the following (in order): - 1. kick off fetch for parameters in immediately required sub module - 2. kick off fetch for next few parameters we will need later (prefetch) - 3. block on parameters in immediately required sub module - """ - debug_rank0( - f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " - + str({ - "avail": f"{self.__n_available_params:.1e}", - "queue_sz": f"{len(self.__param_queue or [])}", - "inflight": [p.ds_id for p in self.__inflight_param_registry], - "allocated": get_cuda_mem_allocated_str() - })) - - params_to_fetch = frozenset(iter_params(current_submodule)) - - # kick off all gather for params in the immediately required submodule - for param in params_to_fetch: - debug_rank0(f"-fetch: {param.ds_summary()}") - self.__all_gather_params(params_to_fetch) - - # wait for parameters in the immediately needed submodule to become available - for param in iter_params(current_submodule): - param.ds_active_sub_modules.add(current_submodule.id) - debug_rank0(f"-wait: {param.ds_summary()}") - if param in self.__inflight_param_registry: - with torch.cuda.stream(self.__allgather_stream): - while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ - 0].query(): - self.__ongoing_fetch_events.popleft() - if len(self.__ongoing_fetch_events - ) > self.__max_ongoing_fetch_events: - self.__ongoing_fetch_events.popleft().synchronize() - - self.__inflight_param_registry.pop(param).wait() - - event = Event() - event.record() - self.__ongoing_fetch_events.append(event) - - assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() - torch.cuda.current_stream().wait_stream(self.__allgather_stream) - - # kick off parameter prefetches for upcoming modules - # don't prefetch if we dont have a completed model trace, or if we aren't - # training (throws off the tracing and don't want to prefetch modules for bwd) - if self.trace_complete and current_submodule.training: - # go through the parameters we need for the current module and pop them - # off the fetch queue so that they aren't prefetched later. - # if params have already been popped off the fetch queue by earlier - # prefetches we won't look for them here - discarded_from_prefetch_queue = set() - params_not_already_fetched = set( - filter( - lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. - __step_id, - params_to_fetch)) - while self.__param_queue and len(discarded_from_prefetch_queue) < len( - params_not_already_fetched): - param_in_trace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - discarded_from_prefetch_queue.add(param_in_trace.param) - if discarded_from_prefetch_queue != params_not_already_fetched: - raise RuntimeError( - f"tracing error at step {self.__step_id}: " - f"expected the next {len(params_not_already_fetched)} parameters in the " - f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " - f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." - ) - - # kick off all gather for params in the next few submodules (prefetch) - max_params_to_prefetch = min( - self.__max_n_available_params - self.__n_available_params, - self.__prefetch_bucket_sz) - params_to_prefetch = set() - numel_prefetching = 0 - while self.__param_queue and numel_prefetching < max_params_to_prefetch: - param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() - self.__most_recent_step_id_param_fetched_for[ - param_in_trace.param] = param_in_trace.step_id_last_used_at - if param_in_trace.param not in params_to_prefetch: - params_to_prefetch.add(param_in_trace.param) - numel_prefetching += param_in_trace.param.ds_numel - for param in params_to_prefetch: - debug_rank0(f"-prefetch: {param.ds_summary()}") - self.__all_gather_params(params_to_prefetch) - - if self.__prefetch_nvme: - self.__prefetch_nvme_param_partitions() - - self.__step_id += 1 - - @instrument_w_nvtx - @torch.no_grad() - def release_sub_module(self, submodule: Module) -> None: - """release the parameters of a sub module, assuming they meet conditions to - be released.""" - params_to_release = (self.__params_to_release(submodule, - self.__step_id) - if self.trace_complete else set( - p.ds_id for p in iter_params(submodule))) - - for param in iter_params(submodule): - param.ds_active_sub_modules.discard(submodule.id) - if param.ds_id in params_to_release and not param.is_external_param: - self.__release_param(param) - - @instrument_w_nvtx - @torch.no_grad() - def release_and_reset_all(self) -> None: - """release all module parameters""" - for param in map(lambda p: p.param, self.__param_order): - if param in self.__inflight_param_registry: - raise RuntimeError(f"param {param.ds_summary()} still in flight") - - # TODO. make this throw if if there are still active submodules. currently - # there's a hook execution issue - param.ds_active_sub_modules.clear() - self.__release_param(param) - - for param_in_trace in self.__param_order: - if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError( - f"{param_in_trace.param.ds_summary()} expected to be released") - - @instrument_w_nvtx - def __all_gather_params(self, params: Set[Parameter]) -> None: - """for each partitioned parameter, kick off an async allgather and store - the work handle for the in flight parameters.""" - partitioned_params = [] - for param in params: - if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: - partitioned_params.append(param) - self.__n_available_params += param.ds_numel - - if partitioned_params: - with torch.cuda.stream(self.__allgather_stream): - handle = partitioned_params[0].all_gather_coalesced(partitioned_params) - - for param in partitioned_params: - assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() - self.__inflight_param_registry[param] = handle - - @instrument_w_nvtx - def __release_param(self, param: Parameter) -> None: - if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: - debug_rank0(f"-release: {param.ds_summary()}") - param.partition() - self.__n_available_params -= param.ds_numel - - @instrument_w_nvtx - @functools.lru_cache(maxsize=None) - def __params_to_release(self, - submodule_to_release: Module, - step_id: int) -> Set[int]: - if not self.trace_complete: - raise RuntimeError("expected trace to be complete") - - params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) - if not p.ds_persist) - - # examine all modules within `max_reuse_dist_in_numel` of the current step, - # if we see any of the candidate parameters to be released reoccur while - # doing this, remove them from the set of parameters to release. - params_traversed = 0 - for module in self.__submodule_order[step_id:]: - if params_traversed > self.__max_reuse_dist_in_numel: - break - for param in iter_params(module): - params_to_release.discard(param.ds_id) - params_traversed += param.ds_numel - - return params_to_release - - @instrument_w_nvtx - def __prefetch_nvme_param_partitions(self) -> None: - """swap in parameter partitions from nvme for those parameters that will be used - after the ones that are already being prefetched into full parameters - """ - if not self.trace_complete: - return - - numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) - - numel_considered = 0 - swap_in_params = [] - for param_in_trace in self.__param_queue: - param = param_in_trace.param - if param.nvme_swapper is None: - continue - if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= - param.nvme_swapper.available_swap_in_buffers()): - break - if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_in_params.append(param) - numel_considered += param.ds_numel - - if swap_in_params: - swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - - -class PreBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, outputs): - ctx.module = module - ctx.pre_backward_function = pre_backward_function - if not hasattr(module, "applied_pre_backward_ref_cnt"): - module.applied_pre_backward_ref_cnt = 0 - module.applied_pre_backward_ref_cnt += 1 - #print(f"After Forward: {ctx.module.__class__.__name__}") - outputs = outputs.detach() - return outputs - - @staticmethod - def backward(ctx, *args): - #print(f"Before Backward: {ctx.module.__class__.__name__}") - ctx.pre_backward_function(ctx.module) - return (None, None) + args - - -class PostBackwardFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, module, pre_backward_function, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.pre_backward_function = pre_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.pre_backward_function(ctx.module) - #print(f"After Backward: {ctx.module.__class__.__name__}") - return (None, None) + args - - -class FP16_DeepSpeedZeroOptimizer_Stage3(object): - """ - DeepSpeedZeroOptimizer designed to reduce the memory footprint - required for training large deep learning models. - - For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models - https://arxiv.org/abs/1910.02054 - - For usage examples, refer to TODO: DeepSpeed Tutorial - - """ - def __init__(self, - module, - init_optimizer, - timers, - ds_config, - static_loss_scale=1.0, - dynamic_loss_scale=False, - dynamic_loss_args=None, - verbose=True, - contiguous_gradients=True, - reduce_bucket_size=500000000, - prefetch_bucket_size=50000000, - max_reuse_distance=1000000000, - max_live_parameters=1000000000, - param_persistence_threshold=100000, - dp_process_group=None, - reduce_scatter=True, - overlap_comm=False, - offload_optimizer_config=None, - offload_param_config=None, - sub_group_size=1000000000000, - mpu=None, - clip_grad=0.0, - allreduce_always_fp32=False, - postscale_gradients=True, - gradient_predivide_factor=1.0, - gradient_accumulation_steps=1, - elastic_checkpoint=False, - aio_config=None): - - see_memory_usage("Stage 3 initialize beginning", force=False) - - if dist.get_rank() == 0: - logger.info(f"initialized {__class__.__name__} with args: {locals()}") - logger.info(f"Reduce bucket size {reduce_bucket_size}") - logger.info(f"Allgather bucket size {prefetch_bucket_size}") - # The fused optimizer does all the work. We need this layer for two reason: - # 1. maintain same user API from apex.fp16_utils - # 2. keep common stuff here in case we need to add ne552w fused optimizer later - - # differences from apex.fp16_utils: - # - assume all model params in fp16 - # - assume all params requires grad - # - flat by groups, not keeping state. TODO: remove state explicitly? - # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? - if not torch.cuda.is_available: - raise SystemError("Cannot use fp16 without CUDA.") - self.optimizer = init_optimizer - - # Load pre-built or JIT compile (un)flatten ops - util_ops = UtilsBuilder().load() - self.flatten = util_ops.flatten - self.unflatten = util_ops.unflatten - self.dtype = self.optimizer.param_groups[0]['params'][0].dtype - self._global_grad_norm = 0. - - self.optimizer_swapper = None - self.swap_optimizer = False - - self.offload_optimizer = False - self.offload_optimizer_pin_memory = False - self.offload_optimizer_fast_init = False - self.offload_param = False - self.offload_param_pin_memory = False - self.params_in_nvme_and_cpu = False - self.max_params_in_cpu = 0 - - self._configure_offloading(offload_optimizer_config, offload_param_config) - - self._convert_to_zero_parameters(ds_config, module, mpu) - - for m in module.modules(): - _init_external_params(m) - - self.module = module - self.elastic_checkpoint = elastic_checkpoint - - # Replace ._parameters with a new class to enable auto-registration of - # external parameters - _inject_parameters(module, ZeROOrderedDict) - - self.__inf_or_nan_tracker: Tensor = torch.zeros( - 1, - dtype=torch.bool, - device=torch.cuda.current_device(), - requires_grad=False) - - self.deepspeed_adam_offload = (self.offload_optimizer - and type(init_optimizer) == DeepSpeedCPUAdam) - - self.device = torch.cuda.current_device( - ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE - ### streams used for overlapping computation with communication - self.__allgather_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - self.__reduce_and_partition_stream = Stream( - ) if overlap_comm else torch.cuda.default_stream() - - ############################################################################ - - see_memory_usage("Before Partitioned Parameter Coordinator", force=False) - self.param_coordinator = PartitionedParameterCoordinator( - prefetch_bucket_sz=int(prefetch_bucket_size), - max_reuse_distance_in_numel=int(max_reuse_distance), - max_available_parameters_in_numel=int(max_live_parameters), - allgather_stream=self.__allgather_stream, - prefetch_nvme=self.params_in_nvme_and_cpu, - ) - see_memory_usage("After Partitioned Parameter Coordinator", force=False) - - self.__n_caching_allocator_flushes = 0 - - #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) - #-------------Stage 3 Setup-------------------# - # parameters smaller than the threshold will be collectively gathered at the - # end of the optimizer step and will be kept till the end of the backward pass - # TODO maybe worth just replicating these parameters and doing all reduce for them - self.persistence_threshold = int(param_persistence_threshold) - - self.persistent_parameters = self.persistent_parameters() - - self.setup_zero_stage3_hooks() - - #resetting ds_tensor just in case parameters have been changed after initialization - #example .half() or .to() - #self.reset_ds_tensor() - #---------------------------------------------# - - self.timers = timers - - self.dp_process_group = dp_process_group - - self.partition_count = dist.get_world_size(group=self.dp_process_group) - - if mpu is None: - self.model_parallel_group = None - self.model_parallel_rank = 0 - else: - self.model_parallel_group = mpu.get_model_parallel_group() - self.model_parallel_rank = mpu.get_model_parallel_rank() - - self.overflow = False - self.clip_grad = clip_grad - self.allreduce_always_fp32 = allreduce_always_fp32 - self.gradient_predivide_factor = gradient_predivide_factor - self.postscale_gradients = postscale_gradients - self.gradient_accumulation_steps = gradient_accumulation_steps - self.micro_step_id = 0 - - # Holds the mode parameter - # The param.data may not hold any meaningful data - # when param's status is NOT_AVAILABLE or IN_FLGHT - self.fp16_groups = [] - - # Hold partitioned parameters - self.fp16_partitioned_groups = [] - - # Holds a fused and flattened copy of the parameters - self.fp16_partitioned_groups_flat = [] - self.fp16_partitioned_groups_flat_numel = [] - - #defragmented pinned memory - self.param_groups_fp16_flat_cpu_memory = [] - - #a single 32-bit partition of the parallel partitioned parameters - #that this process will update - self.fp32_partitioned_groups_flat = [] - self.next_swappable_fp32_partitioned_groups = [] - - # number of elements per partition in each group - self.partition_size = [] - - self.all_reduce_print = False - - self.prefetch_elements = int(prefetch_bucket_size) - - # padding on each partition for alignment purposes - self.groups_padding = [] - - self.sub_group_size = sub_group_size - - self.sub_group_to_group_id = {} - see_memory_usage("Before creating fp16 partitions", force=False) - self._create_fp16_partitions_with_defragmentation() - num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) - see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", - force=False) - - # Optimizer tensor swapping - if self.swap_optimizer: - self._configure_tensor_swapping(offload_optimizer_config, aio_config) - - see_memory_usage("Before creating fp32 partitions", force=False) - if not isinstance(self.optimizer, DummyOptim): - self._create_fp32_partitions() - see_memory_usage("After creating fp32 partitions", force=False) - dist.barrier() - - # To support pipelined optimizer swapping - if not isinstance(init_optimizer, DummyOptim): - self._create_next_swappable_fp32_groups() - - see_memory_usage("Before initializing optimizer states", force=False) - if not isinstance(init_optimizer, DummyOptim): - self.initialize_optimizer_states() - see_memory_usage("After initializing optimizer states", force=False) - dist.barrier() - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.reduce_bucket_size = int(reduce_bucket_size) - - # IPG - if contiguous_gradients: - self.__ipg_bucket_flat_buffer: Tensor = torch.empty( - int(reduce_bucket_size), - dtype=self.dtype, - device=torch.cuda.current_device()) - - self.__param_id_to_grad_partition: Dict[int, Tensor] = {} - - all_params = list(itertools.chain.from_iterable(self.fp16_groups)) - - grad_partitions_flat_buffer: Tensor = torch.zeros( - sum(p.ds_tensor.ds_numel for p in all_params), - dtype=self.dtype, - device=self.device, - pin_memory=self.offload_optimizer_pin_memory) - - offset = 0 - for param in all_params: - self.__param_id_to_grad_partition[ - param.ds_id] = grad_partitions_flat_buffer.narrow( - 0, - offset, - param.ds_tensor.numel()) - offset += param.ds_tensor.numel() - - self.__params_in_ipg_bucket: List[Parameter] = [] - self.is_gradient_accumulation_boundary: bool = True - - self.__param_reduce_events: Deque[Event] = collections.deque() - self.__max_param_reduce_events: int = 2 - - if dist.get_rank() == 0: - logger.info(f"optimizer state initialized") - - self.param_dict = {} - - # map between param_id and bool to specify if a param is in this partition - self.is_param_in_current_partition = {} - - self.contiguous_gradients = contiguous_gradients - self.extra_large_param_to_reduce = None - self.grads_in_ipg_bucket = [] - self.params_in_ipg_bucket = [] - self.params_already_reduced = [] - self.is_gradient_accumulation_boundary = True - self._release_ipg_buffers() - self.previous_reduced_grads = None - - # simplified param id - self.param_id = {} - - count = 0 - for i, params_group in enumerate(self.fp16_groups): - for param in params_group: - unique_id = id(param) - self.param_id[unique_id] = count - self.param_dict[count] = param - self.params_already_reduced.append(False) - count = count + 1 - - #Largest partitioned param - largest_partitioned_param_numel = max([ - max([tensor.numel() for tensor in fp16_partitioned_group]) - for fp16_partitioned_group in self.fp16_partitioned_groups - ]) - print_rank_0( - f'Largest partitioned param numel = {largest_partitioned_param_numel}', - force=False) - - see_memory_usage(f"Before Set Grad positions", force=False) - - self.grad_position = {} - self.set_grad_positions() - see_memory_usage(f"Before CPU Offload initialization", force=False) - - self.grads_in_partition = None - - if self.offload_optimizer: - self.norm_for_param_grads = {} - self.local_overflow = False - - see_memory_usage(f"After CPU Offload initialization", force=False) - - # stores if a partition has been reduced in this step - self.is_partition_reduced = {} - - # stores if a grad in a partition has been computed or not - self.is_grad_computed = {} - - # will store the averaged gradients required by this paritition - self.averaged_gradients = {} - - #creates backward hooks for gradient partitioning - self.create_reduce_and_remove_grad_hooks() - - #exit(0) - - # we may have a way of fusing dynamic scale. Do not support for now - if self.dtype == torch.float or not dynamic_loss_scale: - loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale - - self.dynamic_loss_scale = False - self.loss_scaler = LossScaler(scale=loss_scale_value) - cur_iter = 0 - else: - if dynamic_loss_args is None: - self.loss_scaler = DynamicLossScaler() - else: - self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) - - self.dynamic_loss_scale = True - - self.debug_fp16_grads = [{} for _ in self.fp16_groups] - - if dist.get_rank(group=self.dp_process_group) == 0: - see_memory_usage(f"After initializing ZeRO optimizer", force=False) - - @staticmethod - def defragment(tensors: List[Tensor]) -> Tensor: - """move provided tensors into a contiguous flat buffer, with some additional - measures taken to reduce memory fragmentation""" - assert len(set(t.dtype for t in tensors)) == 1 - assert len(set(t.device for t in tensors)) == 1 - - cpu_buffer = torch.empty(sum(p.numel() for p in tensors), - dtype=get_only_unique_item(t.dtype for t in tensors), - device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] - orig_device = get_only_unique_item(t.device for t in tensors) - - offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() - # move the tensor from device memory to host memory - cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) - tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - - gc.collect() - torch.cuda.empty_cache() - - # copy tensors (now flattened and contiguous) back to GPU - device_buffer = cpu_buffer.to(orig_device) - - # restore device tensors - for tensor, offset, tensor_numel in tensor_infos: - tensor.data = device_buffer.narrow(0, offset, tensor_numel) - - return device_buffer - - def _configure_offloading(self, offload_optimizer_config, offload_param_config): - ###################### offload optimizer setup ################################## - if offload_optimizer_config is not None: - self.offload_optimizer = True - self.offload_optimizer_pin_memory = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIN_MEMORY] - self.swap_optimizer = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE - self.offload_optimizer_fast_init = offload_optimizer_config[ - OFFLOAD_OPTIMIZER_FAST_INIT] - - ###################### offload param setup ################################## - if offload_param_config is not None: - if not isinstance(self.optimizer, DummyOptim): - assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" - self.offload_param = True - self.offload_param_pin_memory = offload_param_config[ - OFFLOAD_PARAM_PIN_MEMORY] - self.params_in_nvme_and_cpu = offload_param_config[ - OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE - self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] - print_rank_0( - f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", - force=False) - - def _convert_to_zero_parameters(self, ds_config, module, mpu): - non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] - if non_zero_params: - zero_params = [p for p in module.parameters() if is_zero_param(p)] - if zero_params: - zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) - else: - group = None - if mpu: - group = mpu.get_data_parallel_group() - - if self.params_in_nvme_and_cpu: - remote_device = OFFLOAD_NVME_DEVICE - elif self.offload_param: - remote_device = OFFLOAD_CPU_DEVICE - else: - remote_device = None - - Init(module=module, - data_parallel_group=group, - dtype=self.dtype, - config_dict_or_path=ds_config, - remote_device=remote_device, - pin_memory=self.offload_param_pin_memory, - mpu=mpu) - - def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): - nvme_swap_folder = os.path.join( - offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], - 'zero_stage_3') - os.makedirs(nvme_swap_folder, exist_ok=True) - if torch.distributed.get_rank() == 0: - logger.info(f'Tensor Swapping: Adding optimizer tensors') - - swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ - OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper - - self.optimizer_swapper = swapper_type( - swap_config=offload_optimizer_config, - aio_config=aio_config, - base_folder=nvme_swap_folder, - optimizer=self.optimizer, - largest_numel=max(self.fp16_partitioned_groups_flat_numel), - device=self.device, - dtype=torch.float32, - timers=self.timers) - - @property - def elements_in_ipg_bucket(self): - return sum(p.ds_numel for p in self.__params_in_ipg_bucket) - - def _create_fp16_partitions(self): - dist.barrier() - partition_id = dist.get_rank(group=self.dp_process_group) - - # loop to deal with groups - for j, param_group in enumerate(self.optimizer.param_groups): - - sub_groups = self._create_fp16_sub_groups(param_group['params']) - for sub_group in sub_groups: - i = len(self.fp16_groups) - - # push this group to list before modify - self.fp16_groups.append(sub_group) - self.sub_group_to_group_id[i] = j - - #These are the list of the partitioned parameters - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in self.fp16_groups[i]]) - - print_rank_0( - f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" - ) - - # Record padding required to align group to world size (only applies to last rank) - if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: - padding = [p.padding_size() for p in self.fp16_groups[i]] - else: - padding = [0] * len(self.fp16_groups[i]) - self.groups_padding.append(padding) - - #not sure why apex was cloning the weights before flattening - #removing cloning here - see_memory_usage(f"Before Flattening param group {i}", force=False) - - if not self.offload_param: - see_memory_usage(f"Before moving param group {i} to CPU", - force=False) - #move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.fp16_partitioned_groups[i]) - see_memory_usage(f"After moving param group {i} to CPU", force=False) - - #create flat buffer in CPU and move to GPU - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size(group=self.dp_process_group)).cuda( - torch.cuda.current_device())) - see_memory_usage( - f"After flattening and moving param group {i} to GPU", - force=False) - else: - #Without the detach, seems like the flattening becomes part of the - #model graph causing errors downstream - self.fp16_partitioned_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.fp16_partitioned_groups[i], - dist.get_world_size( - group=self.dp_process_group)).detach().pin_memory()) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - see_memory_usage(f"After Flattening param group {i}", force=False) - - #set model fp16 weight to slices of flattened buffer - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], - self.fp16_partitioned_groups[i]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): - partitioned_param.data = q.data - - def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): - '''If flat buffer is None then the parameters in the param_list are - not copied to the flat buffer. This is because they excede the number of max_params_in_cpu - Some of these parameters may aready be in CPU in unflattened buffers - or they maybe in GPU, or they maybe in NVME. If they are in NVME, then - they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are - needed during training.''' - if flat_buffer is None: - # this dst buffer is on NVMe, so skip this - return - - start = 0 - for param in param_list: - src = param.ds_tensor - dest = flat_buffer.narrow(0, start, src.ds_numel) - start = start + src.ds_numel - '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' - if src.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" - ) - param.nvme_swapper.swap_into_buffer(param, dest) - src.data = dest.data - src.status = PartitionedParamStatus.AVAILABLE - else: - assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" - if not avoid_copy: - dest.data.copy_(src.data) - src.data = dest.data - - # Final location must be gpu/cpu in this case - param.ds_tensor.final_location = 'not-nvme' - - def _create_param_groups_fp16_flat_cpu_memory(self): - - aggregate_params_count = 0 - - for j, param_group in enumerate(self.optimizer.param_groups): - params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) - - flat_buffer_size = params_in_group - - if self.params_in_nvme_and_cpu and \ - aggregate_params_count + params_in_group > self.max_params_in_cpu: - - flat_buffer_size = max(0, - self.max_params_in_cpu - aggregate_params_count) - - aggregate_params_count += params_in_group - - if flat_buffer_size > 0: - print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", - force=False) - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(int(flat_buffer_size), - dtype=self.dtype, - pin_memory=True)) - else: - print_rank_0( - f"No flat buffer size. Param group size was {params_in_group}", - force=False) - - self.param_groups_fp16_flat_cpu_memory.append( - torch.empty(1, - dtype=self.dtype)) - - def _create_fp16_partitions_with_defragmentation(self): - dist.barrier() - param_groups: List[List[Parameter]] = tuple( - self._create_fp16_sub_groups(param_group["params"]) - for param_group in self.optimizer.param_groups) - - # bookkeeping related to param groups - for param_group_idx, param_group in enumerate(param_groups): - for sub_group in param_group: - sub_group_idx = len(self.fp16_groups) - - # record sub group and partitions - self.fp16_groups.append(sub_group) - self.fp16_partitioned_groups.append( - [param.ds_tensor for param in sub_group]) - - # record sub group -> group mapping - self.sub_group_to_group_id[sub_group_idx] = param_group_idx - - # record total elements of parameter partitions in sub group - self.fp16_partitioned_groups_flat_numel.append( - sum(p.ds_tensor.ds_numel for p in sub_group)) - - # record padding required to align group to world size (only applies to last rank) - rank_requires_padding = dist.get_rank( - self.dp_process_group) == dist.get_world_size( - self.dp_process_group) - 1 - self.groups_padding.append([ - p.padding_size() if rank_requires_padding else 0 for p in sub_group - ]) - - # move parameters to flattened buffer - if not self.offload_param: # partitioned params remain in GPU during training - # move parameter partitions into a single contiguous flat buffer - parameter_partitions: List[Tensor] = [] - for sub_group in self.fp16_groups: - for param in sub_group: - parameter_partitions.append(param.ds_tensor) - device_buffer = __class__.defragment(parameter_partitions) - - # setup flat buffers per subgroup, these are each just sections of the - # contiguous flat buffer for all parameters that we created earlier - offset = 0 - for sub_group in self.fp16_groups: - sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) - self.fp16_partitioned_groups_flat.append( - device_buffer.narrow(0, - offset, - sub_group_numel)) - offset += sub_group_numel - else: # partitioned params offloaded to CPU when not in use - # create a flat CPU memory allocation for each param group - self._create_param_groups_fp16_flat_cpu_memory() - for param_group_idx, param_group in enumerate(param_groups): - flat_offset = 0 - for i, sub_group in enumerate(param_group): - total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) - print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") - #Flat buffer may not be available for parameters that reside in NVME - if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].numel(): - fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ - param_group_idx].narrow(0, - flat_offset, - total_elements) - print_rank_0( - f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", - force=False) - elif self.params_in_nvme_and_cpu: - fp16_partitioned_group_flat = None - print_rank_0( - f"No flat buffer for sub group {i} of {total_elements} elements", - force=False) - else: - assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" - - self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) - flat_offset += total_elements - - self._move_to_flat_buffer(sub_group, - fp16_partitioned_group_flat, - avoid_copy=not self.offload_param) - - # if necessary, create a pinned memory buffer to be used for swapping out - # params to NVME after optimizer step - should_create_fp16_flat_reuse_buffer = any( - flattened_partition_group is None - for flattened_partition_group in self.fp16_partitioned_groups_flat) - if should_create_fp16_flat_reuse_buffer: - max_partition_numel, largest_partition_numel = 0, None - for sub_group in self.fp16_groups: - total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) - if total_elements > max_partition_numel: - largest_partition_numel = [t.ds_numel for t in sub_group] - max_partition_numel = total_elements - - assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' - self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( - largest_partition_numel) - - def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): - offset = 0 - elements_in_sub_group = sum( - [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) - assert (flat_buffer.numel() == elements_in_sub_group) - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - print_rank_0( - f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" - ) - param.nvme_swapper.swap_in([param], async_op=False) - dest.data.copy_(partitioned_param.data) - param.nvme_swapper.remove_partition_and_release_buffers([param]) - print_rank_0(f"Swapping in {param.ds_id} done") - else: - dest.data.copy_(partitioned_param.data) - offset += partitioned_param.ds_numel - - def _create_next_swappable_fp32_groups(self): - reverse_order_indices = [ - i for i in range(len(self.fp32_partitioned_groups_flat)) - ] - reverse_order_indices.reverse() - - next_group = None - for i in reverse_order_indices: - self.next_swappable_fp32_partitioned_groups.append(next_group) - if self._swappable_optimizer_subgroup(i): - next_group = self.fp32_partitioned_groups_flat[i] - - self.next_swappable_fp32_partitioned_groups.reverse() - - def _get_sub_group_partitions(self, sub_group_id): - sub_group_partitions = [] - for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): - if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: - swap_path = param.nvme_swapper.get_path(param, True) - sub_group_partitions.append((partitioned_param, - param.ds_tensor.ds_numel, - swap_path)) - else: - sub_group_partitions.append((partitioned_param, - partitioned_param.ds_numel, - None)) - - return sub_group_partitions - - def _create_fp32_partitions(self): - cpu_memory_usage = 0 - cpu_memory_sub_groups = 0 - nvme_memory_usage = 0 - num_swappable_partitions = 0 - num_swap_from_nvme_partitions = 0 - num_swap_from_cpu_partitions = 0 - swap_from_nvme_memory_usage = 0 - swap_from_cpu_memory_usage = 0 - GIGA_BYTES = (1024**3) - - swappable_fp32_tensors = [] - swappable_fp16_src_tensors = [] - nvme_fp16_partitions_info = [] - nvme_fp16_num_elems = [] - nvme_fp32_dest_tensors = [] - fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() - - for i, tensor in enumerate(self.fp16_partitioned_groups_flat): - num_elements = self.fp16_partitioned_groups_flat_numel[i] - - # a partition of the fp32 master weights that will be updated by this process - if self._swappable_optimizer_subgroup(i): - self.fp32_partitioned_groups_flat.append(torch.Tensor()) - nvme_memory_usage += (fp32_element_size * num_elements) - num_swappable_partitions += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - num_swap_from_nvme_partitions += 1 - swap_from_nvme_memory_usage += (fp32_element_size * num_elements) - if self.offload_optimizer_fast_init: - sub_group_partitions = self._get_sub_group_partitions(i) - nvme_fp16_partitions_info.append(sub_group_partitions) - nvme_fp16_num_elems.append(num_elements) - nvme_fp32_dest_tensors.append( - self.fp32_partitioned_groups_flat[i]) - else: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.optimizer_swapper.initialize_parameters( - parameters=[self.fp32_partitioned_groups_flat[i]], - src_tensors=[unpinned_fp32_buffer]) - else: - num_swap_from_cpu_partitions += 1 - swap_from_cpu_memory_usage += (fp32_element_size * num_elements) - swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) - swappable_fp16_src_tensors.append( - self.fp16_partitioned_groups_flat[i]) - else: - cpu_memory_usage += (fp32_element_size * num_elements) - cpu_memory_sub_groups += 1 - - if self.params_in_nvme_and_cpu and tensor is None: - unpinned_fp32_buffer = torch.empty(num_elements, - device=self.device, - dtype=torch.float) - self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) - self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) - else: - self.fp32_partitioned_groups_flat.append( - self.fp16_partitioned_groups_flat[i].to( - self.device).clone().float().detach()) - - self.fp32_partitioned_groups_flat[ - i].requires_grad = True # keep this in case internal optimizer uses it - - if len(swappable_fp32_tensors) > 0: - self.optimizer_swapper.initialize_parameters( - parameters=swappable_fp32_tensors, - src_tensors=swappable_fp16_src_tensors) - - if len(nvme_fp32_dest_tensors) > 0: - fp16_pinned_buffers = self.fp16_groups[0][ - 0].nvme_swapper.reserve_available_buffers() - assert len(fp16_pinned_buffers) > 0 - self.optimizer_swapper.initialize_from_swapped_fp16_params( - fp16_partitions_info=nvme_fp16_partitions_info, - fp16_num_elems=nvme_fp16_num_elems, - fp16_pinned_buffers=fp16_pinned_buffers, - fp32_parameters=nvme_fp32_dest_tensors) - self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() - - nvme_gigabytes = nvme_memory_usage / GIGA_BYTES - print_rank_0( - f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', - force=False) - if self.params_in_nvme_and_cpu: - print_rank_0( - f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - print_rank_0( - f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', - force=False) - - cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES - print_rank_0( - f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', - force=False) - - # Clear for on-the-fly population before the optimizer step - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _create_fp16_sub_groups(self, params_group): - - params_group_numel = sum([param.partitioned_size() for param in params_group]) - sub_group_size = self.sub_group_size - - if sub_group_size is None or sub_group_size >= params_group_numel: - return [params_group] - - sub_groups = [] - sub_group = [] - local_sub_group_size = 0 - for param in params_group: - - sub_group.append(param) - local_sub_group_size += param.partitioned_size() - - if local_sub_group_size >= sub_group_size or id(param) == id( - params_group[-1]): - - sub_groups.append(sub_group) - - sub_group = [] - local_sub_group_size = 0 - - return sub_groups - - # def reset_ds_tensor(self): - # for name, param in self.module.named_parameters(recurse=True): - # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" - # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" - # param.ds_tensor.data = param.data - - def setup_zero_stage3_hooks(self): - self.hierarchy = 0 - - #reset step if in inference mode - @instrument_w_nvtx - def _end_of_forward_hook(module, *args): - - if not torch._C.is_grad_enabled(): - self.param_coordinator.reset_step() - - #likely one of them should be enough but just to be safe - self._register_hooks_recursively(self.module) - self.module.register_forward_hook(_end_of_forward_hook) - - # Add top module to stack trace - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(self.module) - - def persistent_parameters(self): - persistent_params = [] - total_persistent_parameters = 0 - params_count = 0 - for _, param in self.module.named_parameters(recurse=True): - if param.ds_numel < self.persistence_threshold: - params_count += 1 - param.ds_persist = True - persistent_params.append(param) - total_persistent_parameters += param.ds_numel - - print_rank_0( - f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", - force=False) - return persistent_params - - def _register_hooks_recursively(self, module, count=[0]): - my_count = count[0] - module.id = my_count - - #print(f"{module.__class__} : {module.id}") - - for child in module.children(): - count[0] = count[0] + 1 - self._register_hooks_recursively(child, count=count) - - @instrument_w_nvtx - def _pre_forward_module_hook(module, *args): - self.pre_sub_module_forward_function(module) - - @instrument_w_nvtx - def _post_forward_module_hook(module, input, output): - global FWD_MODULE_STACK - FWD_MODULE_STACK.pop() - if output is None: - output = [] - elif not isinstance(output, (list, tuple)): - if torch.is_tensor(output): - output = [output] - else: - #print(f'got UNKNOWN type {type(output)}') - outputs = [] - output = output if isinstance(output, dict) else vars(output) - for name, val in output.items(): - if not name.startswith('__') and torch.is_tensor(val): - outputs.append(val) - output = outputs - #print(f'convert output to {output}') - - for item in filter(lambda item: is_zero_param(item), output): - if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): - item.is_external_param = True - module_to_register = FWD_MODULE_STACK[-1] - print_rank_0( - f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', - force=False) - register_external_parameter(module_to_register, item) - - # It's possible that the parameter was already external to the completed module. If so, remove it the - # registration as it will be covered by the outer module instead. - if id(item) in module._external_params: - print_rank_0( - f' Unregistering nested dangling parameter from module {module.__class__.__name__}', - force=False) - unregister_external_parameter(module, item) - - item.all_gather() - - self.post_sub_module_forward_function(module) - - def _pre_backward_module_hook(module, inputs, output): - @instrument_w_nvtx - def _run_before_backward_function(sub_module): - # some models (e.g. Albert) may run multiple forwards on the same layer in a loop - # before doing backwards, so each backward will need a pre-fetch - using reference - # counting to support this scenario - #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") - if sub_module.applied_pre_backward_ref_cnt > 0: - self.pre_sub_module_backward_function(sub_module) - sub_module.applied_pre_backward_ref_cnt -= 1 - #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - - return _apply_to_tensors_only(module, - PreBackwardFunction, - _run_before_backward_function, - output) - - #This is an alternate to doing _post_backward_module_hook - #it uses tensor.register_hook instead of using torch.autograd.Function - def _alternate_post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - #print(f"Before Forward {module.__class__.__name__}") - - def _run_after_backward_hook(*unused): - module.ds_grads_remaining = module.ds_grads_remaining - 1 - if module.ds_grads_remaining == 0: - #print(f"After backward {module.__class__.__name__}") - self.post_sub_module_backward_function(module) - - def _run_before_forward_function(input): - if input.requires_grad: - module.ds_grads_remaining += 1 - - return _apply_forward_and_backward_to_tensors_only( - module, - _run_before_forward_function, - _run_after_backward_hook, - inputs) - - def _post_backward_module_hook(module, inputs): - module.ds_grads_remaining = 0 - - @instrument_w_nvtx - def _run_after_backward_function(sub_module): - if sub_module.ds_grads_remaining == 0: - self.post_sub_module_backward_function(sub_module) - - return _apply_to_tensors_only(module, - PostBackwardFunction, - _run_after_backward_function, - inputs) - - # Pre forward hook - module.register_forward_pre_hook(_pre_forward_module_hook) - # Post forward hook - module.register_forward_hook(_post_forward_module_hook) - - # Pre backward hook - module.register_forward_hook(_pre_backward_module_hook) - - # post backward hook - module.register_forward_pre_hook(_post_backward_module_hook) - - @torch.no_grad() - def pre_sub_module_forward_function(self, sub_module): - see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", - force=False) - - global FWD_MODULE_STACK - FWD_MODULE_STACK.append(sub_module) - - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - - self.param_coordinator.fetch_sub_module(sub_module) - see_memory_usage( - f"Before sub module function {sub_module.__class__.__name__} after fetch", - force=False) - - @torch.no_grad() - def post_sub_module_forward_function(self, sub_module): - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - - self.param_coordinator.release_sub_module(sub_module) - - see_memory_usage( - f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - @torch.no_grad() - def pre_sub_module_backward_function(self, sub_module): - if not self.param_coordinator.trace_complete: - self.param_coordinator.record_trace(sub_module) - self.param_coordinator.fetch_sub_module(sub_module) - - @torch.no_grad() - def post_sub_module_backward_function(self, sub_module): - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", - force=False) - self.param_coordinator.release_sub_module(sub_module) - see_memory_usage( - f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", - force=False) - - def _release_ipg_buffers(self): - if self.contiguous_gradients: - self.ipg_buffer = None - if not self.offload_optimizer and self.is_gradient_accumulation_boundary: - self.grads_in_partition = None - - self.grads_in_partition_offset = 0 - - def _optimizer_step(self, sub_group_id): - param_group_id = self.sub_group_to_group_id[sub_group_id] - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] - - self.optimizer.step() - self.optimizer.param_groups[param_group_id]['params'] = [] - - def _swappable_optimizer_subgroup(self, sub_group_id): - if not self.swap_optimizer: - return False - - return self.optimizer_swapper.swappable_tensor( - None, - numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) - - def _partitioned_params_swap_out(self, i): - offset = 0 - fp32_param = self.fp32_partitioned_groups_flat[i] - assert fp32_param is not None, \ - f'fp32 parameters of sub_group {i} is None' - - swap_fp16_params = [] - swap_fp32_params = [] - for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): - src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) - if partitioned_param.status == PartitionedParamStatus.AVAILABLE: - partitioned_param.data.copy_(src.data) - else: - swap_fp32_params.append(src) - swap_fp16_params.append(param) - offset += partitioned_param.ds_numel - - if len(swap_fp16_params): - swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( - dst_fp16_params=swap_fp16_params, - src_fp32_params=swap_fp32_params) - - def initialize_optimizer_states(self): - num_subgroups = len(self.fp16_groups) - - largest_numel = max( - [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) - gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype - gradient_buffer = torch.zeros(int(largest_numel), - dtype=gradient_dtype, - device=self.device) - - timers = self.timers - timer_names = set() - - if self.swap_optimizer: - self.optimizer_swapper.init_timers() - - INIT_OPTIMIZER_TIMER = 'init_optimizer_state' - timer_names.add(INIT_OPTIMIZER_TIMER) - self.start_timers([INIT_OPTIMIZER_TIMER]) - - for i, group in enumerate(self.fp16_groups): - swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) - swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None - - num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) - - see_memory_usage( - f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_in(i, timer_names) - - if self.offload_optimizer and not swappable_optimizer_subgroup: - subgroup_gradient_buffer = torch.zeros(num_elements, - dtype=gradient_dtype, - device=self.device) - if self.offload_optimizer_pin_memory: - subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() - - self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer - else: - self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( - 0, - 0, - num_elements) - - self._optimizer_step(i) - - if swappable_param_subgroup: - self._partitioned_params_swap_out(i) - - if swappable_optimizer_subgroup: - self._optimizer_states_and_gradient_swap_out(i, timer_names) - - see_memory_usage( - f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', - force=False) - - self.stop_timers([INIT_OPTIMIZER_TIMER]) - self.log_timers(timer_names) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - if not self.offload_optimizer: - for group in self.fp32_partitioned_groups_flat: - group.grad = None - - # Reset steps - return - - ######################################################################### - #########################ZeRO Partition Gradients######################## - ######################################################################### - - def get_first_param_index(self, group_id, param_group, partition_id): - for index, param in enumerate(param_group): - param_id = self.get_param_id(param) - if partition_id in self.param_to_partition_ids[group_id][param_id]: - return index - return None - - def initialize_gradient_partitioning_data_structures(self): - - total_partitions = dist.get_world_size(group=self.dp_process_group) - - for i, param_group in enumerate(self.fp16_groups): - - self.param_to_partition_ids[i] = {} - self.is_partition_reduced[i] = {} - self.total_grads_in_partition[i] = {} - self.remaining_grads_in_partition[i] = {} - self.is_grad_computed[i] = {} - self.grad_partition_insertion_offset[i] = {} - self.grad_start_offset[i] = {} - self.first_param_index_in_partition[i] = {} - - for partition_id in range(total_partitions): - self.is_grad_computed[i][partition_id] = {} - self.grad_partition_insertion_offset[i][partition_id] = {} - self.grad_start_offset[i][partition_id] = {} - self.initialize_gradient_partition(i, param_group, partition_id) - self.is_partition_reduced[i][partition_id] = False - self.first_param_index_in_partition[i][ - partition_id] = self.get_first_param_index( - i, - param_group, - partition_id) - - @instrument_w_nvtx - def independent_gradient_partition_epilogue(self): - self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) - self.__reduce_and_partition_ipg_grads() - self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) - - self.__reduce_and_partition_stream.synchronize() - - # if dist.get_rank() == 0: - # logger.info("Params already reduced %s", self.params_already_reduced) - for i in range(len(self.params_already_reduced)): - self.params_already_reduced[i] = False - - #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad - #TODO: use a similar code path for both cpu_offload and non-cpu offload - if not self.offload_optimizer: - for i, sub_group in enumerate(self.fp16_groups): - self.averaged_gradients[i] = [ - self.__param_id_to_grad_partition[param.ds_id] - if param.requires_grad else torch.zeros_like(param.ds_tensor) - for param in sub_group - ] - # self.averaged_gradients[i] = self.get_flat_partition( - # self.fp16_groups[i], - # 0, - # self.fp32_partitioned_groups_flat[i].numel(), - # return_tensor_list=True) - - # this method gets called after every backward. need to increment - # here because if it gets incremented in backward() the micro step - # id will be off by one when we do the reduce and partition at the. - # start of this method. - # TODO. make this less error prone - self.micro_step_id += 1 - - def overlapping_partition_gradients_reduce_epilogue(self): - self.independent_gradient_partition_epilogue() - - def create_reduce_and_remove_grad_hooks(self): - print_rank_0(f'[Begin] Create gradient reduction hooks') - self.grad_accs = [] - for i, param_group in enumerate(self.fp16_groups): - for param in param_group: - if param.requires_grad: - #print_rank_0(f" Before all gather {param.device}, {param.shape}") - - # The hook must be created in un-partitioned parameter - param.all_gather() - - #print(f"After all gather {param.device}, {param.shape}") - def wrapper(param, i): - param_tmp = param.expand_as(param) - grad_acc = param_tmp.grad_fn.next_functions[0][0] - - @instrument_w_nvtx - def reduce_partition_and_remove_grads(*notneeded): - self.reduce_ready_partitions_and_remove_grads(param, i) - - grad_acc.register_hook(reduce_partition_and_remove_grads) - self.grad_accs.append(grad_acc) - - #print(f"param grad fn {param.expand_as(param).grad_fn}") - wrapper(param, i) - - # Partition the parameter after creating the hook - param.partition() - print_rank_0(f'[End] Create gradient reduction hooks') - - def get_param_id(self, param): - unique_id = id(param) - return self.param_id[unique_id] - - def report_ipg_memory_usage(self, tag, param_elems): - elem_count = self.elements_in_ipg_bucket + param_elems - percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size - see_memory_usage( - f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", - force=False) - - ###############Idependent Partition Gradient ######################## - def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): - #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) - - # Because the ipg bucket is initialized with a random place holder tensor, we must - # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > - # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a - # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be - # empty, while reduction_list will have that garbage data. - if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: - self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", - param.ds_numel) - - self.__reduce_and_partition_ipg_grads() - - param_id = self.get_param_id(param) - assert self.params_already_reduced[param_id] == False, \ - f"The parameter {param_id} has already been reduced. \ - Gradient computed twice for this partition. \ - Multiple gradient reduction is currently not supported" - - self.__add_grad_to_ipg_bucket(param) - - @instrument_w_nvtx - @torch.no_grad() - def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: - self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) - - if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( - ) < self.reduce_bucket_size: - # move the gradient to a contiguous buffer - with torch.cuda.stream(self.__reduce_and_partition_stream): - # move the parameter's gradient to the contiguous flat buffer - new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( - 0, - self.elements_in_ipg_bucket, - param.grad.numel()).view_as(param.grad) - new_grad_tensor.copy_(param.grad, non_blocking=True) - param.grad.record_stream(torch.cuda.current_stream()) - param.grad.data = new_grad_tensor - - self.__params_in_ipg_bucket.append(param) - - @instrument_w_nvtx - @torch.no_grad() - def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: - if not self.__params_in_ipg_bucket: - return - - for param in self.__params_in_ipg_bucket: - if param.grad.numel() != param.ds_numel: - raise RuntimeError( - f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " - f"gradients whose size is not same as the params") - - self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) - - assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( - self.__params_in_ipg_bucket) - - while self.__param_reduce_events and self.__param_reduce_events[0].query(): - self.__param_reduce_events.popleft() - if len(self.__param_reduce_events) > self.__max_param_reduce_events: - self.__param_reduce_events.popleft().synchronize() - - with torch.cuda.stream(self.__reduce_and_partition_stream): - if safe_mode: - assert_ints_same_as_other_ranks( - [p.ds_id for p in self.__params_in_ipg_bucket]) - - grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) - self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) - - self.__params_in_ipg_bucket.clear() - - event = Event() - event.record() - self.__param_reduce_events.append(event) - - @instrument_w_nvtx - def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: - """average gradients and scatter partitions across ranks""" - dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) - - full_grads_for_rank = [p.grad for p in params_to_reduce] - if self.allreduce_always_fp32: - full_grads_for_rank = [g.float() for g in full_grads_for_rank] - - if self.postscale_gradients and self.gradient_predivide_factor != 1.0: - full_grads_for_rank = [ - g.div(self.gradient_predivide_factor) for g in full_grads_for_rank - ] - - grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, - self.dp_process_group) - - if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( - self.dp_process_group): - grad_partitions_for_rank = [ - g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank - ] - - if self.allreduce_always_fp32: - grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] - - return grad_partitions_for_rank - - def set_grad_positions(self): - for i, group in enumerate(self.fp16_groups): - current_offset = 0 - for param in group: - param_id = self.get_param_id(param) - num_elements = param.ds_tensor.ds_numel - - self.grad_position[param_id] = [ - int(i), - int(current_offset), - int(num_elements) - ] - #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") - current_offset += num_elements - - def _constant_buffered_norm2(self, input, buffer_size=250000000): - norm = None - for part in input.view(-1).split(buffer_size): - if norm is None: - norm = part.data.double().norm(2)**2.0 - else: - norm += part.data.double().norm(2)**2.0 - return norm**0.5 - - def set_norm_for_param_grad_in_gpu(self, param): - param_id = self.get_param_id(param) - #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) - #Using a more memory efficient version - self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) - - def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): - with torch.cuda.stream(self.copy_grad_stream): - param_id = self.get_param_id(param) - src_tensor = param.grad.view(-1).float() - #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") - fp32_grad_tensor.copy_(src_tensor, non_blocking=True) - param.grad = None - - def complete_grad_norm_calculation_for_cpu_offload(self, params): - total_norm = 0.0 - norm_type = 2.0 - for p in params: - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - param_id = self.get_param_id(p) - if param_id in self.norm_for_param_grads.keys(): - param_norm = self.norm_for_param_grads[param_id] - total_norm += param_norm.item()**2 - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda[0].item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - @instrument_w_nvtx - def __partition_grads(self, - params_to_release: List[Parameter], - grad_partitions: List[Tensor]) -> None: - for param, grad_partition in zip(params_to_release, grad_partitions): - if param.ds_tensor.ds_numel * dist.get_rank( - self.dp_process_group) > param.ds_numel: - # this grad partition is empty - don't need to do anything - continue - - # move or accumulate gradient partition to target buffer - grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( - 0, - 0, - grad_partition.numel()) - if self.micro_step_id == 0: # don't accumulate - grad_buffer.copy_(grad_partition, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) - elif grad_buffer.is_cuda: - grad_buffer.add_(grad_partition) - else: - # if dst is CPU, copy first to src device, do the addition - # there, then move back to dst. adding directly to cpu is very slow - cuda_grad_buffer = grad_buffer.to(grad_partition.device, - non_blocking=True) - cuda_grad_buffer.add_(grad_partition) - grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) - # ensure grad buffer is a CUDA buffer to speed up the next few - # operations and so it can be used asynchronously - grad_buffer = cuda_grad_buffer - - if hasattr(self.__inf_or_nan_tracker, "logical_or_"): - self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) - self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) - else: - # logical_or_ not available in older versions of pytorch - self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() - self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() - self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 - - # offload the gradient partition if applicable - if self.offload_optimizer: - i, dest_offset, _ = self.grad_position[self.get_param_id(param)] - offload_fp32_gradients = {} - offload_fp32_offsets = {} - - if self.is_gradient_accumulation_boundary: - self.norm_for_param_grads[self.get_param_id( - param)] = self._constant_buffered_norm2(grad_buffer) - - if self._swappable_optimizer_subgroup(i): - if not i in offload_fp32_gradients.keys(): - offload_fp32_gradients[i] = [] - offload_fp32_offsets[i] = [] - - offload_fp32_gradients[i].append(grad_buffer.float()) - offload_fp32_offsets[i].append(dest_offset) - else: - fp32_grad_tensor = self.fp32_partitioned_groups_flat[ - i].grad.narrow(0, - dest_offset, - grad_buffer.numel()) - fp32_grad_tensor.copy_(grad_buffer) - - # free the gradient - param.grad.record_stream(torch.cuda.current_stream()) - param.grad = None - - if self.offload_optimizer and self.swap_optimizer: - for i in offload_fp32_gradients.keys(): - self.optimizer_swapper.swap_out_gradients( - parameter=self.fp32_partitioned_groups_flat[i], - gradient_offsets=offload_fp32_offsets[i], - gradient_tensors=offload_fp32_gradients[i]) - - def reduce_ready_partitions_and_remove_grads(self, param, i): - #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) - self.reduce_independent_p_g_buckets_and_remove_grads(param, i) - - def zero_reduced_gradients(self, partition_id, i): - def are_all_related_partitions_reduced(params_id): - for partition_id in self.param_to_partition_ids[i][params_id]: - if not self.is_partition_reduced[i][partition_id]: - return False - return True - - for params_id in self.is_grad_computed[i][partition_id]: - if are_all_related_partitions_reduced(params_id): - self.param_dict[params_id].grad = None - - def flatten_and_print(self, message, tensors, start=0, n=5): - flatten_tensor = self.flatten(tensors) - - def print_func(): - logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) - - self.sequential_execution(print_func, message) - - def get_grads_to_reduce(self, i, partition_id): - def get_reducible_portion(key): - grad = self.param_dict[key].grad - total_elements = grad.numel() - start = self.grad_start_offset[i][partition_id][key] - num_elements = min( - total_elements - start, - self.partition_size[i] - - self.grad_partition_insertion_offset[i][partition_id][key]) - if not pg_correctness_test: - if num_elements == total_elements: - return grad - else: - return grad.contiguous().view(-1).narrow(0, - int(start), - int(num_elements)) - else: - if num_elements == total_elements: - return grad.clone() - else: - return grad.clone().contiguous().view(-1).narrow( - 0, - int(start), - int(num_elements)) - - grads_to_reduce = [] - for key in self.is_grad_computed[i][partition_id]: - grad = get_reducible_portion(key) - grads_to_reduce.append(grad) - return grads_to_reduce - - def sequential_execution(self, function, message, group=None): - if group is None: - group = self.dp_process_group - if dist.get_rank(group=group) == 0: - logger.info(message) - for id in range(dist.get_world_size(group=group)): - if id == dist.get_rank(group=group): - function() - dist.barrier(group=group) - - def set_none_gradients_to_zero(self, i, partition_id): - for param_id in self.is_grad_computed[i][partition_id]: - param = self.param_dict[param_id] - if param.grad is None: - param.grad = torch.zero_like(param) - - ######################Reduction Related Methods############################## - - def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): - rank = None - tensor = self.flatten(bucket) - - tensor_to_allreduce = tensor - - if pg_correctness_test: - allreduce_always_fp32 = True - - if allreduce_always_fp32: - tensor_to_allreduce = tensor.float() - - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) - - if rank is None: - # "All Reducing" - dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) - else: - global_rank = _get_global_rank(self.dp_process_group, rank) - dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) - - if allreduce_always_fp32 and tensor is not tensor_to_allreduce: - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - tensor.copy_(tensor_to_allreduce) - - return tensor - - # if rank is specified do a reduction instead of an allreduce - def allreduce_and_copy(self, small_bucket, rank=None, log=None): - with torch.cuda.stream(self.reduction_stream): - allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) - if rank is None or rank == dist.get_rank(group=self.dp_process_group): - for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): - buf.copy_(synced) - - def allreduce_no_retain(self, - bucket, - numel_per_bucket=500000000, - rank=None, - log=None): - small_bucket = [] - numel = 0 - for tensor in bucket: - small_bucket.append(tensor) - numel = numel + tensor.numel() - if numel > numel_per_bucket: - self.allreduce_and_copy(small_bucket, rank=rank, log=None) - small_bucket = [] - if len(small_bucket) > 0: - self.allreduce_and_copy(small_bucket, rank=rank, log=log) - - ############################################################################# - ############################################################################# - ############################################################################# - - # views the tensor as multiple partitions and returns - # those partitions - def get_data_parallel_partitions(self, tensor): - partitions = [] - - dp = dist.get_world_size(group=self.dp_process_group) - dp_id = dist.get_rank(group=self.dp_process_group) - - total_num_elements = tensor.numel() - - base_size = total_num_elements // dp - remaining = total_num_elements % dp - - start = 0 - for id in range(dp): - partition_size = base_size - if id < remaining: - partition_size = partition_size + 1 - partitions.append(tensor.narrow(0, start, partition_size)) - start = start + partition_size - return partitions - - def get_partition_info(self, tensor_list, partition_size, partition_id): - params_in_partition = [] - params_not_in_partition = [] - - start_index = partition_size * partition_id - end_index = partition_size * (partition_id + 1) - - current_index = 0 - first_offset = 0 - - for tensor in tensor_list: - - tensor_size = tensor.numel() - - if (current_index >= start_index and current_index < end_index): - params_in_partition.append(tensor) - - elif start_index > current_index and start_index < (current_index + - tensor_size): - params_in_partition.append(tensor) - - assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" - first_offset = start_index - current_index - - else: - params_not_in_partition.append(tensor) - - current_index = current_index + tensor_size - - return params_in_partition, params_not_in_partition, first_offset - - @instrument_w_nvtx - def zero_grad(self, set_grads_to_None=True): - """ - Zero FP16 parameter grads. - """ - self.micro_step_id = 0 - - # FP32 grad should never exist. - # For speed, set model fp16 grad to None by default - for group in self.fp16_groups: - for p in group: - if set_grads_to_None: - if p.grad is not None and p.grad.is_cuda: - p.grad.record_stream(torch.cuda.current_stream()) - p.grad = None - else: - if p.grad is not None: - p.grad.detach_() - p.grad.zero_() - - def _model_parallel_all_reduce(self, tensor, op): - """ Perform all reduce within model parallel group, if any. - """ - if self.model_parallel_group is None: - pass - else: - torch.distributed.all_reduce(tensor=tensor, - op=op, - group=self.model_parallel_group) - - @instrument_w_nvtx - def get_grad_norm_direct(self, gradients, params, norm_type=2): - """Clips gradient norm of an iterable of parameters. - - This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and - added functionality to handle model parallel parameters. Note that - the gradients are modified in place. - - Arguments: - parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a - single Tensor that will have gradients normalized - max_norm (float or int): max norm of the gradients - norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for - infinity norm. - - Returns: - Total norm of the parameters (viewed as a single vector). - """ - norm_type = float(norm_type) - if norm_type == inf: - total_norm = max(g.data.abs().max() for g in gradients) - total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - # Take max across all GPUs. - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) - total_norm = total_norm_cuda[0].item() - else: - # if dist.get_rank() == 0: - # logger.info(f"Total Norm beginning {total_norm}") - grad_norms = [] - for g, p in zip(gradients, params): - if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): - grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) - - # Sum across all model parallel GPUs. - total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) - - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=self.dp_process_group) - - self._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) - - total_norm = total_norm_cuda.item()**(1. / norm_type) - - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: - total_norm = -1 - - return total_norm - - # creates a flat fused tensor from the tensor list starting at the first_offset - # in the first tensor of the list. If there are not enough elements in the tensor - # list then the flat tensor will be padded with zeros - def get_flat_partition(self, - tensor_list, - first_offset, - partition_size, - return_tensor_list=False): - flat_tensor_list = [] - current_size = 0 - for i, tensor in enumerate(tensor_list): - if tensor.grad is None: - tensor.grad = torch.zeros_like(tensor) - - tensor = tensor.grad - num_elements = tensor.numel() - tensor_offset = 0 - - # we need to offset to get to the right element - if i == 0 and first_offset > 0: - tensor_offset = first_offset - num_elements = num_elements - tensor_offset - - # we dont need all elements of the tensor - if num_elements > (partition_size - current_size): - num_elements = partition_size - current_size - - # we need a narrow view of the tensor based on the tensor offset and number of elements that - # we need from this tensor - if tensor_offset > 0 or num_elements < tensor.numel(): - flat_tensor_list.append(tensor.contiguous().view(-1).narrow( - 0, - int(tensor_offset), - int(num_elements))) - else: - flat_tensor_list.append(tensor) - - current_size = current_size + num_elements - - # this means its the last partition and does not align with the dp boundary. We need to pad before flattening - if current_size < partition_size: - flat_tensor_list.append( - torch.zeros(int(partition_size - current_size), - dtype=tensor_list[0].dtype, - device=tensor_list[0].device)) - - if return_tensor_list: - return flat_tensor_list - - return self.flatten(flat_tensor_list) - - def free_grad_in_param_list(self, param_list): - for p in param_list: - p.grad = None - - def reset_cpu_buffers(self): - self.norm_for_param_grads = {} - self.local_overflow = False - - def log_timers(self, timer_names): - if self.timers is None: - return - - self.timers.log(names=list(timer_names)) - - def start_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).start() - - def stop_timers(self, timer_names): - if self.timers is None: - return - - for name in timer_names: - self.timers(name).stop() - - def _pre_step(self): - self.micro_step_id = 0 - - print_rank_0(f"Inside Step function") - see_memory_usage(f"In step before checking overflow", force=False) - - print_rank_0("Finished Tracing at Beginning of Step") - self.param_coordinator.hierarchy = 0 - - print_rank_0("Finished Tracing at Beginning of Step") - - @instrument_w_nvtx - def _get_norm_groups(self): - norm_groups = [] - for i, group in enumerate(self.fp16_groups): - if self.offload_optimizer: - norm_groups.append( - self.complete_grad_norm_calculation_for_cpu_offload( - self.fp16_groups[i])) - else: - norm_groups.append( - self.get_grad_norm_direct(self.averaged_gradients[i], - self.fp16_groups[i])) - return norm_groups - - @instrument_w_nvtx - def _prepare_fp32_grad_for_sub_group(self, sub_group_id): - partition_id = dist.get_rank(group=self.dp_process_group) - - single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( - self.fp32_partitioned_groups_flat[sub_group_id].dtype) - - assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ - "averaged gradients have different number of elements that partition size {} {} {} {}".format( - single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) - - self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition - - # release all the gradient since we have already created a necessary copy in dp_grad_partition - self.zero_grad() - - for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): - grad.record_stream(torch.cuda.current_stream()) - - self.averaged_gradients[sub_group_id] = None - - @instrument_w_nvtx - def _prepare_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', - force=False) - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) - elif not self.offload_optimizer: - self._prepare_fp32_grad_for_sub_group(sub_group_id) - see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', - force=False) - - def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' - see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_IN_STATE]) - - self.optimizer_swapper.swap_in_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) - - self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) - timer_names.add(OPTIMIZER_SWAP_IN_STATE) - see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', - force=False) - - @instrument_w_nvtx - def _release_sub_group(self, sub_group_id, timer_names=set()): - see_memory_usage(f'Before release optimizer sub group {sub_group_id}', - force=False) - # get rid of the fp32 gradients. Not needed anymore - if not self.offload_optimizer: - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - if self._swappable_optimizer_subgroup(sub_group_id): - self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) - see_memory_usage(f'After release optimizer sub group {sub_group_id}', - force=False) - - # create a flat tensor aligned at the alignment boundary - @instrument_w_nvtx - def flatten_dense_tensors_aligned(self, tensor_list, alignment): - num_elements = 0 - for tens in tensor_list: - num_elements = num_elements + tens.numel() - - remaining = num_elements % alignment - - if remaining: - elements_to_add = alignment - remaining - pad_tensor = torch.zeros(elements_to_add, - device=tensor_list[0].device, - dtype=tensor_list[0].dtype) - padded_tensor_list = tensor_list + [pad_tensor] - - num_elements = num_elements + elements_to_add - else: - padded_tensor_list = tensor_list - - return self.flatten(padded_tensor_list) - - def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): - param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] - fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) - assert self._swappable_optimizer_subgroup(sub_group_id), \ - f'Parameter {fp32_param_id} of numel={param_length} is not swappable' - - OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' - see_memory_usage( - f'post-step Before swapping out optimizer tensors {sub_group_id}', - force=False) - self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) - - self.optimizer_swapper.swap_out_optimizer_state( - parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is - not None) - - self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) - see_memory_usage( - f'post-step After swapping out optimizer tensors {sub_group_id}', - force=False) - timer_names.add(OPTIMIZER_SWAP_OUT_STATE) - - # get rid of the fp32 gradients. Not needed anymore - self.fp32_partitioned_groups_flat[sub_group_id].grad = None - - def _unflatten_partitioned_parameters(self, sub_group_id): - updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - def _overflow_clean_up(self, prev_scale): - see_memory_usage('After overflow before clearing gradients', force=False) - self.zero_grad() - - if self.offload_optimizer: - self.reset_cpu_buffers() - else: - self.averaged_gradients = {} - - see_memory_usage('After overflow after clearing gradients', force=False) - - if torch.distributed.get_rank() == 0: - logger.info( - "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " - "reducing to {}".format(dist.get_rank(), - prev_scale, - self.loss_scale)) - - @instrument_w_nvtx - def _overflow_check_and_loss_scale_update(self): - - # First compute norm for all group so we know if there is overflow - self.check_overflow() - - #loss scaling related computation - prev_scale = self.loss_scale - self._update_scale(self.overflow) - - if self.overflow: - self._overflow_clean_up(prev_scale) - - return self.overflow - - @instrument_w_nvtx - def _post_step(self, timer_names=set()): - if self.offload_optimizer: - self.reset_cpu_buffers() - - #Gathering persisting parameters - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - - self.log_timers(timer_names) - - see_memory_usage('After zero_optimizer step', force=False) - print_rank_0(f"------------------Finishing Step-----------------------") - - @instrument_w_nvtx - def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): - if self.fp16_partitioned_groups_flat[sub_group_id] is not None: - self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( - self.fp32_partitioned_groups_flat[sub_group_id].data) - - #unflatten fp16 parameter subgroup - self._unflatten_partitioned_parameters(sub_group_id) - else: - self._partitioned_params_swap_out(sub_group_id) - - @instrument_w_nvtx - def step(self, closure=None): - """ - Not supporting closure. - """ - self._pre_step() - self._partition_all_parameters() - - #checks for overflow, adjust the loss scale accordingly - if self._overflow_check_and_loss_scale_update(): - if self.swap_optimizer: - self.optimizer_swapper.log_timers() - return - - norm_groups = self._get_norm_groups() - self._global_grad_norm = get_global_norm(norm_list=norm_groups) - - timer_names = set() - - timer_names.add('optimizer_step') - self.start_timers(['optimizer_step']) - - #update parameters one sub group at a time - for sub_group_id, group in enumerate(self.fp16_groups): - - #prepare optimizer states, gradients and fp32 parameters for update - self._prepare_sub_group(sub_group_id, timer_names) - - #scale the fp32 gradients - self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) - - #apply the optimizer step on the sub group and copy fp32 parameters to fp16 - self._optimizer_step(sub_group_id) - - #put fp16 parameters in appropriate location - self._reassign_or_swap_out_partitioned_parameters(sub_group_id) - - #release memory or swap out optimizer states of fp32 parameters - self._release_sub_group(sub_group_id, timer_names) - - self.stop_timers(['optimizer_step']) - - self._post_step(timer_names) - - # warn user about caching allocator flushes - alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( - torch.cuda, - "memory_stats") else 0 - if alloc_retries > self.__n_caching_allocator_flushes: - if dist.get_rank() == 0: - logger.warning( - "%d pytorch allocator cache flushes since last step. this happens " - "when there is high memory pressure and is detrimental to " - "performance. if this is happening frequently consider adjusting " - "settings to reduce memory consumption. If you are unable to " - "make the cache flushes go away consider adding " - "torch.cuda.empty_cache() calls in your training loop to ensure " - "that all ranks flush their caches at the same time", - alloc_retries - self.__n_caching_allocator_flushes) - self.__n_caching_allocator_flushes = alloc_retries - - def dump_pre_step_gradients(self, debug_fp32_grads): - # Dump gradient norms for debugging - for i, _ in enumerate(self.fp16_groups): - print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') - for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): - param_id = self.get_param_id(fp16_param) - fp16_grad_norm = self.debug_fp16_grads[i][param_id] - - fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] - norm_list = [fp16_grad_norm, fp32_grad_norm] - print(f'Pre-Step Norms {i} {param_id} = {norm_list}') - - def dump_post_step_gradients(self): - # Dump gradient norms for debugging - for i, group in enumerate(self.fp16_groups): - print( - f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') - unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) - unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], - self.fp16_groups[i]) - for j, p in enumerate(self.fp16_groups[i]): - param_id = self.get_param_id(p) - param_norm = float(p.data.float().norm(2)) - ds_norm = float(p.ds_tensor.data.float().norm(2)) - - unflat_norm = [ - float(t.data.float().norm(2)) - for t in [unflat_fp16[j], - unflat_fp32[j]] - ] - norm_list = [param_norm, ds_norm] + unflat_norm - print(f'Post-Step Norms {i} {param_id} = {norm_list}') - - @instrument_w_nvtx - def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - - # compute combined scale factor for this group - combined_scale = self.loss_scale - if self.clip_grad > 0.: - # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad - if clip > 1: - combined_scale = clip * self.loss_scale - # to maintain behavior of averaging over accumulation steps - combined_scale *= self.micro_step_id + 1 - - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) - - def _check_overflow(self, partition_gradients=True): - self.overflow = self.has_overflow(partition_gradients) - - # `params` is a list / generator of torch.Variable - def has_overflow_serial(self, params, is_grad_list=False): - for p in params: - if p.grad is not None and self._has_inf_or_nan(p.grad.data): - return True - - return False - - def has_overflow_partitioned_grads_serial(self): - for i in range(len(self.fp16_groups)): - for j, grad in enumerate(self.averaged_gradients[i]): - if grad is not None and self._has_inf_or_nan(grad.data, j): - return True - return False - - @instrument_w_nvtx - def has_overflow(self, partition_gradients=True): - if partition_gradients: - with torch.cuda.stream(self.__reduce_and_partition_stream): - self.local_overflow = bool(self.__inf_or_nan_tracker.item()) - self.__inf_or_nan_tracker.zero_() - - overflow = self.local_overflow - #overflow = self.has_overflow_partitioned_grads_serial() - overflow_gpu = torch.cuda.ByteTensor([overflow]) - torch.distributed.all_reduce(overflow_gpu, - op=torch.distributed.ReduceOp.MAX, - group=self.dp_process_group) - - else: - params = [] - for group in self.fp16_groups: - for param in group: - params.append(param) - - overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) - overflow_gpu = torch.cuda.ByteTensor([overflow]) - - # Since each model parallel GPU carries only part of the model, - # make sure overflow flag is synced across all the model parallel GPUs - self._model_parallel_all_reduce(tensor=overflow_gpu, - op=torch.distributed.ReduceOp.MAX) - - overflow = overflow_gpu[0].item() - return bool(overflow) - - # `x` is a torch.Tensor - @staticmethod - def _has_inf_or_nan(x, j=None): - try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). - cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) - except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. - if "value cannot be converted" not in instance.args[0]: - raise - return True - else: - if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: - return True - return False - - @instrument_w_nvtx - def backward(self, loss, retain_graph=False): - """ - :attr:`backward` performs the following steps: - - 1. fp32_loss = loss.float() - 2. scaled_loss = fp32_loss*loss_scale - 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves - """ - if self.swap_optimizer: - self.optimizer_swapper.pre_backward() - - see_memory_usage(f"Before backward", force=False) - - self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) - - self.param_coordinator.reset_step() - - if self.swap_optimizer: - self.optimizer_swapper.post_backward() - - def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: - """get fp32 gradient partition dictionary - accessed as grad_dict[parameter_group_index][parameter_index] - """ - self.__reduce_and_partition_stream.synchronize() - grad_dict = collections.defaultdict(dict) - if self.offload_optimizer: - for group in self.fp16_groups: - for param_idx, param in enumerate(group): - group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] - fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( - 0, - dest_offset, - num_elements) - grad_dict[group_idx][param_idx] = fp32_grad - else: - for group_idx, group in self.averaged_gradients.items(): - for param_idx, gradient in enumerate(group): - grad_dict[group_idx][param_idx] = gradient.float() - - return grad_dict - - @instrument_w_nvtx - def _partition_all_parameters(self): - """Partitioning Parameters that were not partitioned usually if parameters - of modules whose input parameters do not require grad computation do not - trigger post call and will therefore will remain unpartitioned""" - self.param_coordinator.release_and_reset_all() - for param in iter_params(self.module, recurse=True): - if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: - raise RuntimeError(f"{param.ds_summary()} expected to be released") - - def check_overflow(self, partition_gradients=True): - self._check_overflow(partition_gradients) - - def _update_scale(self, has_overflow=False): - self.loss_scaler.update_scale(has_overflow) - - # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" - def _get_state(self): - return self.optimizer.state - - def _set_state(self, value): - self.optimizer.state = value - - state = property(_get_state, _set_state) - - # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" - # (for example, to adjust the learning rate) - def _get_param_groups(self): - return self.optimizer.param_groups - - def _set_param_groups(self, value): - self.optimizer.param_groups = value - - param_groups = property(_get_param_groups, _set_param_groups) - - # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" - def _get_loss_scale(self): - return self.loss_scaler.loss_scale - - def _set_loss_scale(self, value): - self.loss_scaler.cur_scale = value - - loss_scale = property(_get_loss_scale, _set_loss_scale) - cur_scale = property(_get_loss_scale, _set_loss_scale) - - def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): - # Remove paddings from flattened tensor - individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) - lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] - lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] - #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') - return lean_tensors - - #TODO REVISIT this for stage 3 - def get_lean_optimizer_state(self): - # Return optimizer states after removing paddings. - # This method assumes that each param group contains a single flattened tensor. - optimizer_groups_state = [] - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - lean_state = {} - for key, value in self.optimizer.state[p].items(): - if torch.is_tensor(value): - padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] - lean_state[key] = self._get_lean_tensors( - value, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - lean_flat_len = sum([t.numel() for t in lean_state[key]]) - else: - lean_state[key] = value - - optimizer_groups_state.append(lean_state) - - return optimizer_groups_state - - def get_groups_without_padding(self, groups_with_padding): - # Return group tensor after removing paddings added for alignment to DP world size. - groups_without_padding = [] - for i, group in enumerate(groups_with_padding): - lean_group = self._get_lean_tensors(group, - self.fp16_partitioned_groups[i], - self.groups_padding[i]) - groups_without_padding.append(lean_group) - - return groups_without_padding - - def _set_fp32_optimizer_param_groups(self): - for sub_group_id, _ in enumerate(self.fp16_groups): - param_group_id = self.sub_group_to_group_id[sub_group_id] - self.optimizer.param_groups[param_group_id]['params'].append( - self.fp32_partitioned_groups_flat[sub_group_id]) - - def _clear_fp32_optimizer_param_groups(self): - for param_group in self.optimizer.param_groups: - param_group['params'] = [] - - def _rigid_state_dict(self): - state_dict = {} - state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS - state_dict['loss_scaler'] = self.loss_scaler - state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale - state_dict['overflow'] = self.overflow - state_dict['partition_count'] = self.partition_count - - self._set_fp32_optimizer_param_groups() - state_dict['optimizer_state_dict'] = self.optimizer.state_dict() - state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat - self._clear_fp32_optimizer_param_groups() - - return state_dict - - def state_dict(self): - """ - Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. - This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict - of the contained Pytorch optimizer. - Example:: - checkpoint = {} - checkpoint['model'] = model.state_dict() - checkpoint['optimizer'] = optimizer.state_dict() - torch.save(checkpoint, "saved.pth") - """ - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - return self._rigid_state_dict() - - -# Restore base optimizer fp32 weights from checkpoint by: -# 1) Merging fp32 weights from checkpoints of all partitions -# 2) Extracting fp32 weights for current partition from merged weights -# 3) Using extracted weights to update base optimizer weights directly. - - def _restore_from_fp32_weights(self, all_state_dict): - - flat_local_partition = [] - for i in range(len(self.fp32_partitioned_groups_flat)): - merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] - flat_local_partition.append(self._get_flattened_partition(merged_partitions)) - - for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): - current.data.copy_(saved.data) - - # Restore base optimizer fp32 weights from ZeRO fp16 weights - def _restore_from_fp16_weights(self): - for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): - fp32_partition.data.copy_(fp16_partitions.data) - - # Refresh the fp32 master params from the fp16 copies. - def refresh_fp32_params(self): - self._restore_from_fp16_weights() - - # Extract flattened partition for current rank from all partitions - def _get_flattened_partition(self, all_partition_states): - partition_id = dist.get_rank(group=self.dp_process_group) - alignment = dist.get_world_size(group=self.dp_process_group) - - param_partitions = [[] for _ in range(len(all_partition_states[0]))] - for i, partition in enumerate(all_partition_states): - for j, param in enumerate(partition): - param_partitions[j].append(param) - - local_state_partitions = [] - for param_index, param_slices in enumerate(param_partitions): - flattened_merged_tensor = self.flatten_dense_tensors_aligned( - param_slices, - alignment) - new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) - local_state_partitions.append(new_partitions[partition_id]) - - if torch.is_tensor(local_state_partitions[0]): - return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) - - # Assume non-tensor states are not partitioned and equal across ranks, so return first one - return local_state_partitions[0] - - # Restore base optimizer state from checkpoint by - # 1) Merging optimizer state from checkpoints of all partitions - # 2) Extracting optimizer state for current partition from the merged state - # 3) Using the extracted value to directly update the base optimizer. - def _restore_base_optimizer_state(self, all_state_dict): - base_optimizer_group_states = [] - for i in range(len(self.optimizer.param_groups)): - partition_states = {} - all_partition_group_states = [ - sd['base_optimizer_state'][i] for sd in all_state_dict - ] - for key in all_partition_group_states[0].keys(): - all_partition_states = [ - all_states[key] for all_states in all_partition_group_states - ] - partition_states[key] = self._get_flattened_partition( - all_partition_states) - base_optimizer_group_states.append(partition_states) - - for i, group in enumerate(self.optimizer.param_groups): - p = group['params'][0] - for key, saved in base_optimizer_group_states[i].items(): - if torch.is_tensor(self.optimizer.state[p][key]): - self.optimizer.state[p][key].data.copy_(saved.data) - else: - self.optimizer.state[p][key] = saved - - def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): - # I think it should actually be ok to reload the optimizer before the model. - self.loss_scaler = state_dict['loss_scaler'] - self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] - self.overflow = state_dict['overflow'] - - if load_optimizer_states: - self._set_fp32_optimizer_param_groups() - self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) - self._clear_fp32_optimizer_param_groups() - - # restore fp32 partitions - for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): - curr_param.data.copy_(saved_param.data) - - # restore fp16 partitions from fp32 - for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): - fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] - fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] - fp16_param.data.copy_(fp32_param.data) - - # update fp16 unflattened params - for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): - updated_params = self.unflatten( - self.fp16_partitioned_groups_flat[sub_group_id], - self.fp16_partitioned_groups[sub_group_id]) - - for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): - partitioned_param.data = q.data - - # TODO: Support different/changing load/save DP degree. - def load_state_dict(self, - state_dict_list, - load_optimizer_states=True, - load_from_fp32_weights=False): - r"""Loading a ZeRO checkpoint - Arguments: - state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. - Note that the number of saved partitions may differ from number of loading partitions to support - changing GPU count, specifically DP world size, between saving and loading checkpoints. - load_optimizer_states: Boolean indicating whether or not to load base optimizer states - load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 - copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). - """ - """ - Loads a state_dict created by an earlier call to state_dict(). - If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, - whose parameters in turn came from ``model``, it is expected that the user - will call ``model.load_state_dict()`` before - ``fp16_optimizer_instance.load_state_dict()`` is called. - Example:: - model = torch.nn.Linear(D_in, D_out).cuda().half() - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) - ... - checkpoint = torch.load("saved.pth") - model.load_state_dict(checkpoint['model']) - optimizer.load_state_dict(checkpoint['optimizer']) - """ - - if self.elastic_checkpoint: - raise NotImplementedError( - "ZeRO-3 does not yet support elastic checkpointing, please disable for now." - ) - - if self.swap_optimizer or self.params_in_nvme_and_cpu: - raise NotImplementedError( - "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." - ) - - self._rigid_load_state_dict( - state_dict_list[dist.get_rank(group=self.dp_process_group)], - load_optimizer_states=load_optimizer_states) - - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].partition(self.persistent_parameters) - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - def save_checkpoint_prologue(self): - self._partition_all_parameters() - - def save_checkpoint_epilogue(self): - if len(self.persistent_parameters) > 0: - self.persistent_parameters[0].all_gather(self.persistent_parameters) - - -def _handle_overflow(cpu_sum, x, i): - import math - rank = torch.distributed.get_rank() - if rank == 0: - t_i = -1 - for v_i, v in enumerate(x.data.contiguous().view(-1)): - if not math.isfinite(float(v)): - t_i = v_i - break - logger.info( - f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" - ) - - -def estimate_zero3_model_states_mem_needs(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - cpu_offload=True, - cpu_offload_params=True, - zero_init=True, - additional_buffer_factor=1.5): - - total_gpus = num_nodes * num_gpus_per_node - gpus_factor = 1 / num_nodes - largest_layer_memory = (4 * largest_layer_params) - - if cpu_offload: - if cpu_offload_params: - gpu_mem = largest_layer_memory - - if zero_init: - cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 18 * gpus_factor) * additional_buffer_factor - - else: - gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) - - if zero_init: - cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor - else: - cpu_mem = total_params * max(4 * num_gpus_per_node, - 16 * gpus_factor) * additional_buffer_factor - else: - gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) - if zero_init: - cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor - else: - cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor - - return int(cpu_mem), int(gpu_mem), largest_layer_memory - - -def model_to_params(model): - # shared params calculated only once - total_params = sum( - dict((p.data_ptr(), - p.numel()) for p in model.parameters()).values()) - - largest_layer_params = 0 - for m in model.modules(): - # assuming no shared params within a single layer - layer_params = sum(p.numel() for p in m.parameters(recurse=False)) - largest_layer_params = max(largest_layer_params, layer_params) - - return total_params, largest_layer_params - - -import math - - -def estimate_zero3_model_states_mem_needs_all_live(model, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If you have an actual model object, use this function and everything will be derived - automatically. - - If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - Args: - - ``model``: ``nn.Module`` object - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - - total_params, largest_layer_params = model_to_params(model) - - estimate_zero3_model_states_mem_needs_all_cold( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - additional_buffer_factor=additional_buffer_factor) - - -def estimate_zero3_model_states_mem_needs_all_cold(total_params, - largest_layer_params, - num_gpus_per_node=1, - num_nodes=1, - additional_buffer_factor=1.5): - """ - Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients - for a given ``model`` and hardware setup. - - If it's a hypothetical model, use this function where you have to pass - the ``total_params`` and ``largest_layer_params`` explicitly. - - If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything - will be derived automatically. - - Args: - - ``total_params``: total model params - - ``largest_layer_params``: largest layer's params - - ``num_gpus_per_node``: how many gpus per node (defaults to 1) - - ``num_nodes``: how many nodes (defaults to 1), - - ``additional_buffer_factor``: estimation factor (defaults to 1.5): - - """ - def format_options(cpu_offload, cpu_offload_params, zero_init): - enabled = [] - padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' - param_device = padded_cpu_str if cpu_offload_params else "none" - enabled.append(f"{OFFLOAD_PARAM}={param_device}") - optimizer_device = padded_cpu_str if cpu_offload else "none" - enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") - enabled.append(f"zero_init={1 if zero_init else 0}") - return ", ".join(enabled) - - nodes_str = "nodes" if num_nodes > 1 else "node" - gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" - print( - "Estimated memory needed for params, optim states and gradients for a:\n" - f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" - f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." - ) - print(" per CPU | per GPU | Options") - for cpu_offload in [True, False]: - for cpu_offload_params in [True, False]: - if not cpu_offload and cpu_offload_params: - continue - for zero_init in [True, False]: - cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( - total_params=total_params, - largest_layer_params=largest_layer_params, - num_gpus_per_node=num_gpus_per_node, - num_nodes=num_nodes, - cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init, - additional_buffer_factor=additional_buffer_factor - ) - - options_str = format_options(cpu_offload=cpu_offload, - cpu_offload_params=cpu_offload_params, - zero_init=zero_init) - print( - f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") +""" +"Copyright 2020 The Microsoft DeepSpeed Team. +Licensed under the MIT license. +""" + +import gc +from dataclasses import dataclass +import functools +import os +import collections +from collections import OrderedDict, UserDict +import itertools +from typing import Deque, Dict, Iterable, Set, Tuple +import torch +from torch.cuda import Event, Stream +from torch.nn import Module, Parameter +import torch.distributed as dist +import math +from torch._six import inf +from torch.nn import Module +from torch.nn.parameter import Parameter + +from deepspeed.utils.logging import logger +from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler +from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced +from deepspeed.runtime.utils import get_global_norm, see_memory_usage, is_model_parallel_parameter, DummyOptim +from deepspeed.runtime.zero.partition_parameters import * +from deepspeed.runtime.zero.partition_parameters import _init_external_params +from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS +from deepspeed.ops.adam import DeepSpeedCPUAdam +from deepspeed.ops.op_builder import UtilsBuilder +from deepspeed.runtime.zero.offload_constants import * +from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus +from deepspeed.runtime.swap_tensor.partitioned_optimizer_swapper import PartitionedOptimizerSwapper +from deepspeed.runtime.swap_tensor.pipelined_optimizer_swapper import PipelinedOptimizerSwapper + +# Toggle this to true to enable correctness test +# with gradient partitioning and without +pg_correctness_test = False + +FWD_MODULE_STACK = list() + + +def print_rank_0(message, debug=False, force=False): + rank = torch.distributed.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def input(msg): + return + + +def isclose(a, b, rtol=1e-09, atol=0.0): + return abs(a - b) <= max(rtol * max(abs(a), abs(b)), atol) + + +def lcm(x, y): + from fractions import gcd # or can import gcd from `math` in Python 3 + return x * y // gcd(x, y) + + +def debug_rank0(message: str) -> None: + if dist.get_rank() == 0: + logger.debug(message) + + +def get_cuda_mem_allocated_str() -> str: + # this is really slow. when enabled the python process becomes slow + # to the point where it can't keep the GPU fed with work, so only enable + # for memory debugging. + # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" + return "xGB" + + +def move_to_cpu(tensor_list): + for tensor in tensor_list: + tensor.data = tensor.data.cpu() + + +@instrument_w_nvtx +def get_all_parameters(sub_module, recurse=False): + return itertools.chain(sub_module.named_parameters(recurse=recurse), + sub_module.ds_external_parameters()) + + +def iter_params(module: Module, recurse=False) -> Iterable[Parameter]: + return map(lambda pair: pair[1], get_all_parameters(module, recurse)) + + +#apply torch.autograd.Function that calls a backward_function to tensors in output +def _apply_to_tensors_only(module, functional, backward_function, outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_to_tensors_only(module, + functional, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + return functional.apply(module, backward_function, outputs) + else: + return outputs + + +#for each tensor in outputs run the forward_function and register backward_function as hook +def _apply_forward_and_backward_to_tensors_only(module, + forward_function, + backward_function, + outputs): + if type(outputs) is tuple: + touched_outputs = [] + for output in outputs: + touched_output = _apply_forward_and_backward_to_tensors_only( + module, + forward_function, + backward_function, + output) + touched_outputs.append(touched_output) + return tuple(touched_outputs) + elif type(outputs) is torch.Tensor: + forward_function(outputs) + if outputs.requires_grad: + outputs.register_hook(backward_function) + return outputs + else: + return outputs + + +class ZeROOrderedDict(OrderedDict): + def __init__(self, parent_module, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + self._in_forward = False + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if self._parent_module._parameters._in_forward: + print_rank_0(f'Registering external parameter from getter {key}', + force=False) + register_external_parameter(FWD_MODULE_STACK[-1], param) + param.all_gather() + + return param + + +def _inject_parameters(module, cls): + for module in module.modules(): + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + +class PartitionedParameterCoordinator: + """Handles partitioning and gathering of parameters.""" + class __InflightParamRegistry(UserDict): + """registry for parameters in flight""" + def __setitem__(self, + param: Parameter, + handle: AllGatherCoalescedHandle) -> None: + if param in self.data: + raise RuntimeError(f"{param.ds_summary()} already in registry") + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError( + f"attempted to add non-inflight parameter to registry {param.ds_summary()}" + ) + self.data[param] = handle + + @dataclass + class __ParamInTrace: + param: Parameter + step_id_last_used_at: int + + def __init__( + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, + ) -> None: + # mapping of param -> handle for each param that is currently in flight + self.__inflight_param_registry = __class__.__InflightParamRegistry() + # keeps track of the number of submodules invoked so far. + self.__step_id: int = 0 + # whether or not we have completed a trace of the entire network. This should + # always be true after the first forward pass + backward pass. + self.trace_complete: bool = False + # sequence of submodules/parameters in forward pass + backward pass + self.__submodule_order: Iterable[Module] = [] + self.__param_order: Iterable[__class__.__ParamInTrace] = [] + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + # number of available params, and max number of available params + self.__n_available_params: int = 0 + self.__max_n_available_params: int = max_available_parameters_in_numel + # max distance between two use of the module beyond which module is released + self.__max_reuse_dist_in_numel: int = max_reuse_distance_in_numel + # queue for parameters to fetch. parameters will be popped off the left + # side of the dequeue as they are fetched + self.__param_queue: Deque[__class__.__ParamInTrace] = None + self.__prefetch_bucket_sz: int = prefetch_bucket_sz + self.__prefetch_nvme: bool = prefetch_nvme + self.hierarchy: int = 0 + + # stream that will be used for allgather operations + self.__allgather_stream: Stream = allgather_stream + + # limit the number of fetch events that can be queued at once + # otherwise, what happens is memory is allocated by the host thread at the + # time of the call, but not used until later by the asynchronous cuda stream. + # allowing an infinite number of these to queue up causes a lot of memory + # pressure that then becomes detrimental to performance. + # this is a much less elegant way of fixing this vs something like using + # cudaMallocAsync/cudaFreeAsync. Choosing to not expose this to the user now + # because ideally in the future its replaced by an async allocation + # mechanism which doesnt require any configuration by the user. + self.__ongoing_fetch_events: Deque[Event] = collections.deque() + self.__max_ongoing_fetch_events: int = 2 + + """Tracing and Tracking + TODO. consider performing trace before initializing PartitionedParameterCoordinator + and passing trace results into constructor. This way all the code in here can + just assume that the trace is complete and the results can be entirely + immutable. + + Bookkeeping operations used to track where we are in the forward/backward pass + """ + + def record_trace(self, sub_module: Module) -> None: + """adds sub module to trace""" + if self.trace_complete: + raise RuntimeError( + "attemted to record trace when trace was already complete") + + self.__submodule_order.append(sub_module) + for param in sorted(set(iter_params(sub_module)), key=lambda p: p.ds_id): + self.__param_order.append( + __class__.__ParamInTrace(param=param, + step_id_last_used_at=self.__step_id)) + + def reset_step(self) -> None: + """indicate that we have completed one fwd+bwd for the model""" + if self.__inflight_param_registry: + raise RuntimeError( + f"still have inflight params " + f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}") + + if not self.trace_complete: + # make sure that recorded parameter and submodule orders are + # identical across ranks + assert_ints_same_as_other_ranks([m.id for m in self.__submodule_order]) + assert_ints_same_as_other_ranks([p.param.ds_id for p in self.__param_order]) + assert_ints_same_as_other_ranks( + [p.step_id_last_used_at for p in self.__param_order]) + + self.__submodule_order = tuple(self.__submodule_order) # freeze + self.__param_order = tuple(self.__param_order) # freeze + self.trace_complete = True + print_rank_0(f"completed trace: {[m.id for m in self.__submodule_order]}", + force=True) + + self.__param_queue = collections.deque(self.__param_order) # reset fetch queue + self.__most_recent_step_id_param_fetched_for = collections.defaultdict( + lambda: int(-1e10)) + self.__step_id = 0 + self.__n_available_params = 0 + + """Fetch and Release + Fetching, prefetching, and releasing parameters + """ + + @instrument_w_nvtx + @torch.no_grad() + def fetch_sub_module(self, current_submodule: Module) -> None: + """This method does the following (in order): + 1. kick off fetch for parameters in immediately required sub module + 2. kick off fetch for next few parameters we will need later (prefetch) + 3. block on parameters in immediately required sub module + """ + debug_rank0( + f"{self.__step_id}: M{current_submodule.id}({type(current_submodule).__name__}) P{[p.ds_id for p in iter_params(current_submodule)]} " + + str({ + "avail": f"{self.__n_available_params:.1e}", + "queue_sz": f"{len(self.__param_queue or [])}", + "inflight": [p.ds_id for p in self.__inflight_param_registry], + "allocated": get_cuda_mem_allocated_str() + })) + + params_to_fetch = frozenset(iter_params(current_submodule)) + + # kick off all gather for params in the immediately required submodule + for param in params_to_fetch: + debug_rank0(f"-fetch: {param.ds_summary()}") + self.__all_gather_params(params_to_fetch) + + # wait for parameters in the immediately needed submodule to become available + for param in iter_params(current_submodule): + param.ds_active_sub_modules.add(current_submodule.id) + debug_rank0(f"-wait: {param.ds_summary()}") + if param in self.__inflight_param_registry: + with torch.cuda.stream(self.__allgather_stream): + while self.__ongoing_fetch_events and self.__ongoing_fetch_events[ + 0].query(): + self.__ongoing_fetch_events.popleft() + if len(self.__ongoing_fetch_events + ) > self.__max_ongoing_fetch_events: + self.__ongoing_fetch_events.popleft().synchronize() + + self.__inflight_param_registry.pop(param).wait() + + event = Event() + event.record() + self.__ongoing_fetch_events.append(event) + + assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary() + torch.cuda.current_stream().wait_stream(self.__allgather_stream) + + # kick off parameter prefetches for upcoming modules + # don't prefetch if we dont have a completed model trace, or if we aren't + # training (throws off the tracing and don't want to prefetch modules for bwd) + if self.trace_complete and current_submodule.training: + # go through the parameters we need for the current module and pop them + # off the fetch queue so that they aren't prefetched later. + # if params have already been popped off the fetch queue by earlier + # prefetches we won't look for them here + discarded_from_prefetch_queue = set() + params_not_already_fetched = set( + filter( + lambda p: self.__most_recent_step_id_param_fetched_for[p] < self. + __step_id, + params_to_fetch)) + while self.__param_queue and len(discarded_from_prefetch_queue) < len( + params_not_already_fetched): + param_in_trace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + discarded_from_prefetch_queue.add(param_in_trace.param) + if discarded_from_prefetch_queue != params_not_already_fetched: + raise RuntimeError( + f"tracing error at step {self.__step_id}: " + f"expected the next {len(params_not_already_fetched)} parameters in the " + f"parameter fetch queue to be {tuple(p.ds_summary() for p in params_not_already_fetched)} " + f"but got {tuple(p.ds_summary() for p in discarded_from_prefetch_queue)}." + ) + + # kick off all gather for params in the next few submodules (prefetch) + max_params_to_prefetch = min( + self.__max_n_available_params - self.__n_available_params, + self.__prefetch_bucket_sz) + params_to_prefetch = set() + numel_prefetching = 0 + while self.__param_queue and numel_prefetching < max_params_to_prefetch: + param_in_trace: __class__.__ParamInTrace = self.__param_queue.popleft() + self.__most_recent_step_id_param_fetched_for[ + param_in_trace.param] = param_in_trace.step_id_last_used_at + if param_in_trace.param not in params_to_prefetch: + params_to_prefetch.add(param_in_trace.param) + numel_prefetching += param_in_trace.param.ds_numel + for param in params_to_prefetch: + debug_rank0(f"-prefetch: {param.ds_summary()}") + self.__all_gather_params(params_to_prefetch) + + if self.__prefetch_nvme: + self.__prefetch_nvme_param_partitions() + + self.__step_id += 1 + + @instrument_w_nvtx + @torch.no_grad() + def release_sub_module(self, submodule: Module) -> None: + """release the parameters of a sub module, assuming they meet conditions to + be released.""" + params_to_release = (self.__params_to_release(submodule, + self.__step_id) + if self.trace_complete else set( + p.ds_id for p in iter_params(submodule))) + + for param in iter_params(submodule): + param.ds_active_sub_modules.discard(submodule.id) + if param.ds_id in params_to_release and not param.is_external_param: + self.__release_param(param) + + @instrument_w_nvtx + @torch.no_grad() + def release_and_reset_all(self) -> None: + """release all module parameters""" + for param in map(lambda p: p.param, self.__param_order): + if param in self.__inflight_param_registry: + raise RuntimeError(f"param {param.ds_summary()} still in flight") + + # TODO. make this throw if if there are still active submodules. currently + # there's a hook execution issue + param.ds_active_sub_modules.clear() + self.__release_param(param) + + for param_in_trace in self.__param_order: + if param_in_trace.param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError( + f"{param_in_trace.param.ds_summary()} expected to be released") + + @instrument_w_nvtx + def __all_gather_params(self, params: Set[Parameter]) -> None: + """for each partitioned parameter, kick off an async allgather and store + the work handle for the in flight parameters.""" + partitioned_params = [] + for param in params: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + partitioned_params.append(param) + self.__n_available_params += param.ds_numel + + if partitioned_params: + with torch.cuda.stream(self.__allgather_stream): + handle = partitioned_params[0].all_gather_coalesced(partitioned_params) + + for param in partitioned_params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, param.ds_summary() + self.__inflight_param_registry[param] = handle + + @instrument_w_nvtx + def __release_param(self, param: Parameter) -> None: + if param.ds_status == ZeroParamStatus.AVAILABLE and not param.ds_active_sub_modules: + debug_rank0(f"-release: {param.ds_summary()}") + param.partition() + self.__n_available_params -= param.ds_numel + + @instrument_w_nvtx + @functools.lru_cache(maxsize=None) + def __params_to_release(self, + submodule_to_release: Module, + step_id: int) -> Set[int]: + if not self.trace_complete: + raise RuntimeError("expected trace to be complete") + + params_to_release = set(p.ds_id for p in iter_params(submodule_to_release) + if not p.ds_persist) + + # examine all modules within `max_reuse_dist_in_numel` of the current step, + # if we see any of the candidate parameters to be released reoccur while + # doing this, remove them from the set of parameters to release. + params_traversed = 0 + for module in self.__submodule_order[step_id:]: + if params_traversed > self.__max_reuse_dist_in_numel: + break + for param in iter_params(module): + params_to_release.discard(param.ds_id) + params_traversed += param.ds_numel + + return params_to_release + + @instrument_w_nvtx + def __prefetch_nvme_param_partitions(self) -> None: + """swap in parameter partitions from nvme for those parameters that will be used + after the ones that are already being prefetched into full parameters + """ + if not self.trace_complete: + return + + numel_in_flight = sum(param.ds_numel for param in self.__inflight_param_registry) + + numel_considered = 0 + swap_in_params = [] + for param_in_trace in self.__param_queue: + param = param_in_trace.param + if param.nvme_swapper is None: + continue + if (numel_considered > 2 * numel_in_flight or len(swap_in_params) >= + param.nvme_swapper.available_swap_in_buffers()): + break + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_in_params.append(param) + numel_considered += param.ds_numel + + if swap_in_params: + swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) + + +class PreBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, outputs): + ctx.module = module + ctx.pre_backward_function = pre_backward_function + if not hasattr(module, "applied_pre_backward_ref_cnt"): + module.applied_pre_backward_ref_cnt = 0 + module.applied_pre_backward_ref_cnt += 1 + #print(f"After Forward: {ctx.module.__class__.__name__}") + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + #print(f"Before Backward: {ctx.module.__class__.__name__}") + ctx.pre_backward_function(ctx.module) + return (None, None) + args + + +class PostBackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, pre_backward_function, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.pre_backward_function = pre_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.pre_backward_function(ctx.module) + #print(f"After Backward: {ctx.module.__class__.__name__}") + return (None, None) + args + + +class FP16_DeepSpeedZeroOptimizer_Stage3(object): + """ + DeepSpeedZeroOptimizer designed to reduce the memory footprint + required for training large deep learning models. + + For more details please see ZeRO: Memory Optimization Towards Training A Trillion Parameter Models + https://arxiv.org/abs/1910.02054 + + For usage examples, refer to TODO: DeepSpeed Tutorial + + """ + def __init__(self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + mpu=None, + clip_grad=0.0, + allreduce_always_fp32=False, + postscale_gradients=True, + gradient_predivide_factor=1.0, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None): + + see_memory_usage("Stage 3 initialize beginning", force=False) + + if dist.get_rank() == 0: + logger.info(f"initialized {__class__.__name__} with args: {locals()}") + logger.info(f"Reduce bucket size {reduce_bucket_size}") + logger.info(f"Allgather bucket size {prefetch_bucket_size}") + # The fused optimizer does all the work. We need this layer for two reason: + # 1. maintain same user API from apex.fp16_utils + # 2. keep common stuff here in case we need to add ne552w fused optimizer later + + # differences from apex.fp16_utils: + # - assume all model params in fp16 + # - assume all params requires grad + # - flat by groups, not keeping state. TODO: remove state explicitly? + # - master gard and unflat master weight never exist. TODO: a way to save out unflat master? + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + self.optimizer = init_optimizer + + # Load pre-built or JIT compile (un)flatten ops + util_ops = UtilsBuilder().load() + self.flatten = util_ops.flatten + self.unflatten = util_ops.unflatten + self.dtype = self.optimizer.param_groups[0]['params'][0].dtype + self._global_grad_norm = 0. + + self.optimizer_swapper = None + self.swap_optimizer = False + + self.offload_optimizer = False + self.offload_optimizer_pin_memory = False + self.offload_optimizer_fast_init = False + self.offload_param = False + self.offload_param_pin_memory = False + self.params_in_nvme_and_cpu = False + self.max_params_in_cpu = 0 + + self._configure_offloading(offload_optimizer_config, offload_param_config) + + self._convert_to_zero_parameters(ds_config, module, mpu) + + for m in module.modules(): + _init_external_params(m) + + self.module = module + self.elastic_checkpoint = elastic_checkpoint + + # Replace ._parameters with a new class to enable auto-registration of + # external parameters + _inject_parameters(module, ZeROOrderedDict) + + self.__inf_or_nan_tracker: Tensor = torch.zeros( + 1, + dtype=torch.bool, + device=torch.cuda.current_device(), + requires_grad=False) + + self.deepspeed_adam_offload = (self.offload_optimizer + and type(init_optimizer) == DeepSpeedCPUAdam) + + self.device = torch.cuda.current_device( + ) if not self.offload_optimizer else OFFLOAD_CPU_DEVICE + ### streams used for overlapping computation with communication + self.__allgather_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + self.__reduce_and_partition_stream = Stream( + ) if overlap_comm else torch.cuda.default_stream() + + ############################################################################ + + see_memory_usage("Before Partitioned Parameter Coordinator", force=False) + self.param_coordinator = PartitionedParameterCoordinator( + prefetch_bucket_sz=int(prefetch_bucket_size), + max_reuse_distance_in_numel=int(max_reuse_distance), + max_available_parameters_in_numel=int(max_live_parameters), + allgather_stream=self.__allgather_stream, + prefetch_nvme=self.params_in_nvme_and_cpu, + ) + see_memory_usage("After Partitioned Parameter Coordinator", force=False) + + self.__n_caching_allocator_flushes = 0 + + #self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream()) + #-------------Stage 3 Setup-------------------# + # parameters smaller than the threshold will be collectively gathered at the + # end of the optimizer step and will be kept till the end of the backward pass + # TODO maybe worth just replicating these parameters and doing all reduce for them + self.persistence_threshold = int(param_persistence_threshold) + + self.persistent_parameters = self.persistent_parameters() + + self.setup_zero_stage3_hooks() + + #resetting ds_tensor just in case parameters have been changed after initialization + #example .half() or .to() + #self.reset_ds_tensor() + #---------------------------------------------# + + self.timers = timers + + self.dp_process_group = dp_process_group + + self.partition_count = dist.get_world_size(group=self.dp_process_group) + + if mpu is None: + self.model_parallel_group = None + self.model_parallel_rank = 0 + else: + self.model_parallel_group = mpu.get_model_parallel_group() + self.model_parallel_rank = mpu.get_model_parallel_rank() + + self.overflow = False + self.clip_grad = clip_grad + self.allreduce_always_fp32 = allreduce_always_fp32 + self.gradient_predivide_factor = gradient_predivide_factor + self.postscale_gradients = postscale_gradients + self.gradient_accumulation_steps = gradient_accumulation_steps + self.micro_step_id = 0 + + # Holds the mode parameter + # The param.data may not hold any meaningful data + # when param's status is NOT_AVAILABLE or IN_FLGHT + self.fp16_groups = [] + + # Hold partitioned parameters + self.fp16_partitioned_groups = [] + + # Holds a fused and flattened copy of the parameters + self.fp16_partitioned_groups_flat = [] + self.fp16_partitioned_groups_flat_numel = [] + + #defragmented pinned memory + self.param_groups_fp16_flat_cpu_memory = [] + + #a single 32-bit partition of the parallel partitioned parameters + #that this process will update + self.fp32_partitioned_groups_flat = [] + self.next_swappable_fp32_partitioned_groups = [] + + # number of elements per partition in each group + self.partition_size = [] + + self.all_reduce_print = False + + self.prefetch_elements = int(prefetch_bucket_size) + + # padding on each partition for alignment purposes + self.groups_padding = [] + + self.sub_group_size = sub_group_size + + self.sub_group_to_group_id = {} + see_memory_usage("Before creating fp16 partitions", force=False) + self._create_fp16_partitions_with_defragmentation() + num_fp16_subgroups = len(self.fp16_partitioned_groups_flat) + see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}", + force=False) + + # Optimizer tensor swapping + if self.swap_optimizer: + self._configure_tensor_swapping(offload_optimizer_config, aio_config) + + see_memory_usage("Before creating fp32 partitions", force=False) + if not isinstance(self.optimizer, DummyOptim): + self._create_fp32_partitions() + see_memory_usage("After creating fp32 partitions", force=False) + dist.barrier() + + # To support pipelined optimizer swapping + if not isinstance(init_optimizer, DummyOptim): + self._create_next_swappable_fp32_groups() + + see_memory_usage("Before initializing optimizer states", force=False) + if not isinstance(init_optimizer, DummyOptim): + self.initialize_optimizer_states() + see_memory_usage("After initializing optimizer states", force=False) + dist.barrier() + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.reduce_bucket_size = int(reduce_bucket_size) + + # IPG + if contiguous_gradients: + self.__ipg_bucket_flat_buffer: Tensor = torch.empty( + int(reduce_bucket_size), + dtype=self.dtype, + device=torch.cuda.current_device()) + + self.__param_id_to_grad_partition: Dict[int, Tensor] = {} + + all_params = list(itertools.chain.from_iterable(self.fp16_groups)) + + grad_partitions_flat_buffer: Tensor = torch.zeros( + sum(p.ds_tensor.ds_numel for p in all_params), + dtype=self.dtype, + device=self.device, + pin_memory=self.offload_optimizer_pin_memory) + + offset = 0 + for param in all_params: + self.__param_id_to_grad_partition[ + param.ds_id] = grad_partitions_flat_buffer.narrow( + 0, + offset, + param.ds_tensor.numel()) + offset += param.ds_tensor.numel() + + self.__params_in_ipg_bucket: List[Parameter] = [] + self.is_gradient_accumulation_boundary: bool = True + + self.__param_reduce_events: Deque[Event] = collections.deque() + self.__max_param_reduce_events: int = 2 + + if dist.get_rank() == 0: + logger.info(f"optimizer state initialized") + + self.param_dict = {} + + # map between param_id and bool to specify if a param is in this partition + self.is_param_in_current_partition = {} + + self.contiguous_gradients = contiguous_gradients + self.extra_large_param_to_reduce = None + self.grads_in_ipg_bucket = [] + self.params_in_ipg_bucket = [] + self.params_already_reduced = [] + self.is_gradient_accumulation_boundary = True + self._release_ipg_buffers() + self.previous_reduced_grads = None + + # simplified param id + self.param_id = {} + + count = 0 + for i, params_group in enumerate(self.fp16_groups): + for param in params_group: + unique_id = id(param) + self.param_id[unique_id] = count + self.param_dict[count] = param + self.params_already_reduced.append(False) + count = count + 1 + + #Largest partitioned param + largest_partitioned_param_numel = max([ + max([tensor.numel() for tensor in fp16_partitioned_group]) + for fp16_partitioned_group in self.fp16_partitioned_groups + ]) + print_rank_0( + f'Largest partitioned param numel = {largest_partitioned_param_numel}', + force=False) + + see_memory_usage(f"Before Set Grad positions", force=False) + + self.grad_position = {} + self.set_grad_positions() + see_memory_usage(f"Before CPU Offload initialization", force=False) + + self.grads_in_partition = None + + if self.offload_optimizer: + self.norm_for_param_grads = {} + self.local_overflow = False + + see_memory_usage(f"After CPU Offload initialization", force=False) + + # stores if a partition has been reduced in this step + self.is_partition_reduced = {} + + # stores if a grad in a partition has been computed or not + self.is_grad_computed = {} + + # will store the averaged gradients required by this paritition + self.averaged_gradients = {} + + #creates backward hooks for gradient partitioning + self.create_reduce_and_remove_grad_hooks() + + #exit(0) + + # we may have a way of fusing dynamic scale. Do not support for now + if self.dtype == torch.float or not dynamic_loss_scale: + loss_scale_value = 1.0 if self.dtype == torch.float else static_loss_scale + + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(scale=loss_scale_value) + cur_iter = 0 + else: + if dynamic_loss_args is None: + self.loss_scaler = DynamicLossScaler() + else: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + + self.dynamic_loss_scale = True + + self.debug_fp16_grads = [{} for _ in self.fp16_groups] + + if dist.get_rank(group=self.dp_process_group) == 0: + see_memory_usage(f"After initializing ZeRO optimizer", force=False) + + @staticmethod + def defragment(tensors: List[Tensor]) -> Tensor: + """move provided tensors into a contiguous flat buffer, with some additional + measures taken to reduce memory fragmentation""" + assert len(set(t.dtype for t in tensors)) == 1 + assert len(set(t.device for t in tensors)) == 1 + + cpu_buffer = torch.empty(sum(p.numel() for p in tensors), + dtype=get_only_unique_item(t.dtype for t in tensors), + device="cpu") + tensor_infos: List[Tuple[Tensor, int, int]] = [] + orig_device = get_only_unique_item(t.device for t in tensors) + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # move the tensor from device memory to host memory + cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) + tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) + + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + + offset += tensor_numel + + gc.collect() + torch.cuda.empty_cache() + + # copy tensors (now flattened and contiguous) back to GPU + device_buffer = cpu_buffer.to(orig_device) + + # restore device tensors + for tensor, offset, tensor_numel in tensor_infos: + tensor.data = device_buffer.narrow(0, offset, tensor_numel) + + return device_buffer + + def _configure_offloading(self, offload_optimizer_config, offload_param_config): + ###################### offload optimizer setup ################################## + if offload_optimizer_config is not None: + self.offload_optimizer = True + self.offload_optimizer_pin_memory = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIN_MEMORY] + self.swap_optimizer = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_DEVICE] == OFFLOAD_NVME_DEVICE + self.offload_optimizer_fast_init = offload_optimizer_config[ + OFFLOAD_OPTIMIZER_FAST_INIT] + + ###################### offload param setup ################################## + if offload_param_config is not None: + if not isinstance(self.optimizer, DummyOptim): + assert self.offload_optimizer, "parameter offload is only available with optimizer state offload" + self.offload_param = True + self.offload_param_pin_memory = offload_param_config[ + OFFLOAD_PARAM_PIN_MEMORY] + self.params_in_nvme_and_cpu = offload_param_config[ + OFFLOAD_PARAM_DEVICE] == OFFLOAD_NVME_DEVICE + self.max_params_in_cpu = offload_param_config[OFFLOAD_PARAM_MAX_IN_CPU] + print_rank_0( + f"FP16 params swapping is {self.params_in_nvme_and_cpu}, Max params in CPU is {self.max_params_in_cpu}", + force=False) + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + if self.params_in_nvme_and_cpu: + remote_device = OFFLOAD_NVME_DEVICE + elif self.offload_param: + remote_device = OFFLOAD_CPU_DEVICE + else: + remote_device = None + + Init(module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=remote_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu) + + def _configure_tensor_swapping(self, offload_optimizer_config, aio_config): + nvme_swap_folder = os.path.join( + offload_optimizer_config[OFFLOAD_OPTIMIZER_NVME_PATH], + 'zero_stage_3') + os.makedirs(nvme_swap_folder, exist_ok=True) + if torch.distributed.get_rank() == 0: + logger.info(f'Tensor Swapping: Adding optimizer tensors') + + swapper_type = PipelinedOptimizerSwapper if offload_optimizer_config[ + OFFLOAD_OPTIMIZER_PIPELINE] else PartitionedOptimizerSwapper + + self.optimizer_swapper = swapper_type( + swap_config=offload_optimizer_config, + aio_config=aio_config, + base_folder=nvme_swap_folder, + optimizer=self.optimizer, + largest_numel=max(self.fp16_partitioned_groups_flat_numel), + device=self.device, + dtype=torch.float32, + timers=self.timers) + + @property + def elements_in_ipg_bucket(self): + return sum(p.ds_numel for p in self.__params_in_ipg_bucket) + + def _create_fp16_partitions(self): + dist.barrier() + partition_id = dist.get_rank(group=self.dp_process_group) + + # loop to deal with groups + for j, param_group in enumerate(self.optimizer.param_groups): + + sub_groups = self._create_fp16_sub_groups(param_group['params']) + for sub_group in sub_groups: + i = len(self.fp16_groups) + + # push this group to list before modify + self.fp16_groups.append(sub_group) + self.sub_group_to_group_id[i] = j + + #These are the list of the partitioned parameters + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in self.fp16_groups[i]]) + + print_rank_0( + f"fp16 group {i} partitioned_param norms : {[param.ds_tensor.norm().item() for param in self.fp16_groups[i]]}" + ) + + # Record padding required to align group to world size (only applies to last rank) + if partition_id == dist.get_world_size(group=self.dp_process_group) - 1: + padding = [p.padding_size() for p in self.fp16_groups[i]] + else: + padding = [0] * len(self.fp16_groups[i]) + self.groups_padding.append(padding) + + #not sure why apex was cloning the weights before flattening + #removing cloning here + see_memory_usage(f"Before Flattening param group {i}", force=False) + + if not self.offload_param: + see_memory_usage(f"Before moving param group {i} to CPU", + force=False) + #move all the parameters to cpu to free up GPU space for creating flat buffer + move_to_cpu(self.fp16_partitioned_groups[i]) + see_memory_usage(f"After moving param group {i} to CPU", force=False) + + #create flat buffer in CPU and move to GPU + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size(group=self.dp_process_group)).cuda( + torch.cuda.current_device())) + see_memory_usage( + f"After flattening and moving param group {i} to GPU", + force=False) + else: + #Without the detach, seems like the flattening becomes part of the + #model graph causing errors downstream + self.fp16_partitioned_groups_flat.append( + self.flatten_dense_tensors_aligned( + self.fp16_partitioned_groups[i], + dist.get_world_size( + group=self.dp_process_group)).detach().pin_memory()) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + see_memory_usage(f"After Flattening param group {i}", force=False) + + #set model fp16 weight to slices of flattened buffer + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[i], + self.fp16_partitioned_groups[i]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[i], updated_params): + partitioned_param.data = q.data + + def _move_to_flat_buffer(self, param_list, flat_buffer, avoid_copy=False): + '''If flat buffer is None then the parameters in the param_list are + not copied to the flat buffer. This is because they excede the number of max_params_in_cpu + Some of these parameters may aready be in CPU in unflattened buffers + or they maybe in GPU, or they maybe in NVME. If they are in NVME, then + they will be marked as NOT_AVAILABLE, and will be moved to CPU when they are + needed during training.''' + if flat_buffer is None: + # this dst buffer is on NVMe, so skip this + return + + start = 0 + for param in param_list: + src = param.ds_tensor + dest = flat_buffer.narrow(0, start, src.ds_numel) + start = start + src.ds_numel + '''if the parameter was initialized in nvme then bring it to the destination buffer directly''' + if src.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with partition size {param.ds_tensor.ds_numel} permanently to CPU" + ) + param.nvme_swapper.swap_into_buffer(param, dest) + src.data = dest.data + src.status = PartitionedParamStatus.AVAILABLE + else: + assert src.status == PartitionedParamStatus.AVAILABLE, "Partitioned Param must be available here" + if not avoid_copy: + dest.data.copy_(src.data) + src.data = dest.data + + # Final location must be gpu/cpu in this case + param.ds_tensor.final_location = 'not-nvme' + + def _create_param_groups_fp16_flat_cpu_memory(self): + + aggregate_params_count = 0 + + for j, param_group in enumerate(self.optimizer.param_groups): + params_in_group = sum([p.ds_tensor.ds_numel for p in param_group['params']]) + + flat_buffer_size = params_in_group + + if self.params_in_nvme_and_cpu and \ + aggregate_params_count + params_in_group > self.max_params_in_cpu: + + flat_buffer_size = max(0, + self.max_params_in_cpu - aggregate_params_count) + + aggregate_params_count += params_in_group + + if flat_buffer_size > 0: + print_rank_0(f"group {j} flat buffer size {flat_buffer_size}", + force=False) + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(int(flat_buffer_size), + dtype=self.dtype, + pin_memory=True)) + else: + print_rank_0( + f"No flat buffer size. Param group size was {params_in_group}", + force=False) + + self.param_groups_fp16_flat_cpu_memory.append( + torch.empty(1, + dtype=self.dtype)) + + def _create_fp16_partitions_with_defragmentation(self): + dist.barrier() + param_groups: List[List[Parameter]] = tuple( + self._create_fp16_sub_groups(param_group["params"]) + for param_group in self.optimizer.param_groups) + + # bookkeeping related to param groups + for param_group_idx, param_group in enumerate(param_groups): + for sub_group in param_group: + sub_group_idx = len(self.fp16_groups) + + # record sub group and partitions + self.fp16_groups.append(sub_group) + self.fp16_partitioned_groups.append( + [param.ds_tensor for param in sub_group]) + + # record sub group -> group mapping + self.sub_group_to_group_id[sub_group_idx] = param_group_idx + + # record total elements of parameter partitions in sub group + self.fp16_partitioned_groups_flat_numel.append( + sum(p.ds_tensor.ds_numel for p in sub_group)) + + # record padding required to align group to world size (only applies to last rank) + rank_requires_padding = dist.get_rank( + self.dp_process_group) == dist.get_world_size( + self.dp_process_group) - 1 + self.groups_padding.append([ + p.padding_size() if rank_requires_padding else 0 for p in sub_group + ]) + + # move parameters to flattened buffer + if not self.offload_param: # partitioned params remain in GPU during training + # move parameter partitions into a single contiguous flat buffer + parameter_partitions: List[Tensor] = [] + for sub_group in self.fp16_groups: + for param in sub_group: + parameter_partitions.append(param.ds_tensor) + device_buffer = __class__.defragment(parameter_partitions) + + # setup flat buffers per subgroup, these are each just sections of the + # contiguous flat buffer for all parameters that we created earlier + offset = 0 + for sub_group in self.fp16_groups: + sub_group_numel = sum(param.ds_tensor.ds_numel for param in sub_group) + self.fp16_partitioned_groups_flat.append( + device_buffer.narrow(0, + offset, + sub_group_numel)) + offset += sub_group_numel + else: # partitioned params offloaded to CPU when not in use + # create a flat CPU memory allocation for each param group + self._create_param_groups_fp16_flat_cpu_memory() + for param_group_idx, param_group in enumerate(param_groups): + flat_offset = 0 + for i, sub_group in enumerate(param_group): + total_elements = sum(p.ds_tensor.ds_numel for p in sub_group) + print_rank_0(f"Params in nvme and cpu {self.params_in_nvme_and_cpu}") + #Flat buffer may not be available for parameters that reside in NVME + if not self.params_in_nvme_and_cpu or flat_offset + total_elements <= self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].numel(): + fp16_partitioned_group_flat = self.param_groups_fp16_flat_cpu_memory[ + param_group_idx].narrow(0, + flat_offset, + total_elements) + print_rank_0( + f"Creating a flat buffer for subgroup {i} requiring {total_elements} elements, and cumulative CPU elements {flat_offset + total_elements}", + force=False) + elif self.params_in_nvme_and_cpu: + fp16_partitioned_group_flat = None + print_rank_0( + f"No flat buffer for sub group {i} of {total_elements} elements", + force=False) + else: + assert False, "Either params are in nvme, or they are in CPU memory. This code path should not be triggered. Please see you max_params_in_cpu and params_in_nvme configs" + + self.fp16_partitioned_groups_flat.append(fp16_partitioned_group_flat) + flat_offset += total_elements + + self._move_to_flat_buffer(sub_group, + fp16_partitioned_group_flat, + avoid_copy=not self.offload_param) + + # if necessary, create a pinned memory buffer to be used for swapping out + # params to NVME after optimizer step + should_create_fp16_flat_reuse_buffer = any( + flattened_partition_group is None + for flattened_partition_group in self.fp16_partitioned_groups_flat) + if should_create_fp16_flat_reuse_buffer: + max_partition_numel, largest_partition_numel = 0, None + for sub_group in self.fp16_groups: + total_elements = sum(t.ds_tensor.ds_numel for t in sub_group) + if total_elements > max_partition_numel: + largest_partition_numel = [t.ds_numel for t in sub_group] + max_partition_numel = total_elements + + assert len(largest_partition_numel) > 0, f'Unexpected that largest partition is empty' + self.fp16_groups[0][0].nvme_swapper.reserve_partitioned_swap_space( + largest_partition_numel) + + def _swap_in_sub_group_to_flat_buffer(self, flat_buffer, sub_group_id): + offset = 0 + elements_in_sub_group = sum( + [t.ds_numel for t in self.fp16_partitioned_groups[sub_group_id]]) + assert (flat_buffer.numel() == elements_in_sub_group) + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + dest = flat_buffer.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + print_rank_0( + f"Swapping in {param.ds_id} with elements {param.ds_numel} and partition {param.ds_tensor.ds_numel}" + ) + param.nvme_swapper.swap_in([param], async_op=False) + dest.data.copy_(partitioned_param.data) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0(f"Swapping in {param.ds_id} done") + else: + dest.data.copy_(partitioned_param.data) + offset += partitioned_param.ds_numel + + def _create_next_swappable_fp32_groups(self): + reverse_order_indices = [ + i for i in range(len(self.fp32_partitioned_groups_flat)) + ] + reverse_order_indices.reverse() + + next_group = None + for i in reverse_order_indices: + self.next_swappable_fp32_partitioned_groups.append(next_group) + if self._swappable_optimizer_subgroup(i): + next_group = self.fp32_partitioned_groups_flat[i] + + self.next_swappable_fp32_partitioned_groups.reverse() + + def _get_sub_group_partitions(self, sub_group_id): + sub_group_partitions = [] + for param, partitioned_param in zip(self.fp16_groups[sub_group_id], self.fp16_partitioned_groups[sub_group_id]): + if partitioned_param.status == PartitionedParamStatus.NOT_AVAILABLE: + swap_path = param.nvme_swapper.get_path(param, True) + sub_group_partitions.append((partitioned_param, + param.ds_tensor.ds_numel, + swap_path)) + else: + sub_group_partitions.append((partitioned_param, + partitioned_param.ds_numel, + None)) + + return sub_group_partitions + + def _create_fp32_partitions(self): + cpu_memory_usage = 0 + cpu_memory_sub_groups = 0 + nvme_memory_usage = 0 + num_swappable_partitions = 0 + num_swap_from_nvme_partitions = 0 + num_swap_from_cpu_partitions = 0 + swap_from_nvme_memory_usage = 0 + swap_from_cpu_memory_usage = 0 + GIGA_BYTES = (1024**3) + + swappable_fp32_tensors = [] + swappable_fp16_src_tensors = [] + nvme_fp16_partitions_info = [] + nvme_fp16_num_elems = [] + nvme_fp32_dest_tensors = [] + fp32_element_size = torch.tensor([], dtype=torch.float32).element_size() + + for i, tensor in enumerate(self.fp16_partitioned_groups_flat): + num_elements = self.fp16_partitioned_groups_flat_numel[i] + + # a partition of the fp32 master weights that will be updated by this process + if self._swappable_optimizer_subgroup(i): + self.fp32_partitioned_groups_flat.append(torch.Tensor()) + nvme_memory_usage += (fp32_element_size * num_elements) + num_swappable_partitions += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + num_swap_from_nvme_partitions += 1 + swap_from_nvme_memory_usage += (fp32_element_size * num_elements) + if self.offload_optimizer_fast_init: + sub_group_partitions = self._get_sub_group_partitions(i) + nvme_fp16_partitions_info.append(sub_group_partitions) + nvme_fp16_num_elems.append(num_elements) + nvme_fp32_dest_tensors.append( + self.fp32_partitioned_groups_flat[i]) + else: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.optimizer_swapper.initialize_parameters( + parameters=[self.fp32_partitioned_groups_flat[i]], + src_tensors=[unpinned_fp32_buffer]) + else: + num_swap_from_cpu_partitions += 1 + swap_from_cpu_memory_usage += (fp32_element_size * num_elements) + swappable_fp32_tensors.append(self.fp32_partitioned_groups_flat[i]) + swappable_fp16_src_tensors.append( + self.fp16_partitioned_groups_flat[i]) + else: + cpu_memory_usage += (fp32_element_size * num_elements) + cpu_memory_sub_groups += 1 + + if self.params_in_nvme_and_cpu and tensor is None: + unpinned_fp32_buffer = torch.empty(num_elements, + device=self.device, + dtype=torch.float) + self._swap_in_sub_group_to_flat_buffer(unpinned_fp32_buffer, i) + self.fp32_partitioned_groups_flat.append(unpinned_fp32_buffer) + else: + self.fp32_partitioned_groups_flat.append( + self.fp16_partitioned_groups_flat[i].to( + self.device).clone().float().detach()) + + self.fp32_partitioned_groups_flat[ + i].requires_grad = True # keep this in case internal optimizer uses it + + if len(swappable_fp32_tensors) > 0: + self.optimizer_swapper.initialize_parameters( + parameters=swappable_fp32_tensors, + src_tensors=swappable_fp16_src_tensors) + + if len(nvme_fp32_dest_tensors) > 0: + fp16_pinned_buffers = self.fp16_groups[0][ + 0].nvme_swapper.reserve_available_buffers() + assert len(fp16_pinned_buffers) > 0 + self.optimizer_swapper.initialize_from_swapped_fp16_params( + fp16_partitions_info=nvme_fp16_partitions_info, + fp16_num_elems=nvme_fp16_num_elems, + fp16_pinned_buffers=fp16_pinned_buffers, + fp32_parameters=nvme_fp32_dest_tensors) + self.fp16_groups[0][0].nvme_swapper.release_reserved_buffers() + + nvme_gigabytes = nvme_memory_usage / GIGA_BYTES + print_rank_0( + f'Swappable FP32 Partitions: count={num_swappable_partitions} size={nvme_gigabytes:5.2f} GB', + force=False) + if self.params_in_nvme_and_cpu: + print_rank_0( + f'Swap from NVMe Partitions: count = {num_swap_from_nvme_partitions}, size = {swap_from_nvme_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + print_rank_0( + f'Swap from CPU Partitions: count = {num_swap_from_cpu_partitions}, size = {swap_from_cpu_memory_usage/GIGA_BYTES:5.2f}GB', + force=False) + + cpu_memory_gigabytes = cpu_memory_usage / GIGA_BYTES + print_rank_0( + f'In-Memory FP32 Partitions: count={cpu_memory_sub_groups} size={cpu_memory_gigabytes:5.2f} GB', + force=False) + + # Clear for on-the-fly population before the optimizer step + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _create_fp16_sub_groups(self, params_group): + + params_group_numel = sum([param.partitioned_size() for param in params_group]) + sub_group_size = self.sub_group_size + + if sub_group_size is None or sub_group_size >= params_group_numel: + return [params_group] + + sub_groups = [] + sub_group = [] + local_sub_group_size = 0 + for param in params_group: + + sub_group.append(param) + local_sub_group_size += param.partitioned_size() + + if local_sub_group_size >= sub_group_size or id(param) == id( + params_group[-1]): + + sub_groups.append(sub_group) + + sub_group = [] + local_sub_group_size = 0 + + return sub_groups + + # def reset_ds_tensor(self): + # for name, param in self.module.named_parameters(recurse=True): + # assert hasattr(param,'ds_id'), "Parameters have not been converted to be Zero 3 compatible" + # assert (param.ds_status == ZeroParamStatus.NOT_AVAILABLE), "All the parameters must have been partitioned by now" + # param.ds_tensor.data = param.data + + def setup_zero_stage3_hooks(self): + self.hierarchy = 0 + + #reset step if in inference mode + @instrument_w_nvtx + def _end_of_forward_hook(module, *args): + + if not torch._C.is_grad_enabled(): + self.param_coordinator.reset_step() + + #likely one of them should be enough but just to be safe + self._register_hooks_recursively(self.module) + self.module.register_forward_hook(_end_of_forward_hook) + + # Add top module to stack trace + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(self.module) + + def persistent_parameters(self): + persistent_params = [] + total_persistent_parameters = 0 + params_count = 0 + for _, param in self.module.named_parameters(recurse=True): + if param.ds_numel < self.persistence_threshold: + params_count += 1 + param.ds_persist = True + persistent_params.append(param) + total_persistent_parameters += param.ds_numel + + print_rank_0( + f"ZeRO 3: Total persistent parameters: {total_persistent_parameters} in {params_count} params", + force=False) + return persistent_params + + def _register_hooks_recursively(self, module, count=[0]): + my_count = count[0] + module.id = my_count + + #print(f"{module.__class__} : {module.id}") + + for child in module.children(): + count[0] = count[0] + 1 + self._register_hooks_recursively(child, count=count) + + @instrument_w_nvtx + def _pre_forward_module_hook(module, *args): + self.pre_sub_module_forward_function(module) + + @instrument_w_nvtx + def _post_forward_module_hook(module, input, output): + global FWD_MODULE_STACK + FWD_MODULE_STACK.pop() + if output is None: + output = [] + elif not isinstance(output, (list, tuple)): + if torch.is_tensor(output): + output = [output] + else: + #print(f'got UNKNOWN type {type(output)}') + outputs = [] + output = output if isinstance(output, dict) else vars(output) + for name, val in output.items(): + if not name.startswith('__') and torch.is_tensor(val): + outputs.append(val) + output = outputs + #print(f'convert output to {output}') + + for item in filter(lambda item: is_zero_param(item), output): + if not any(id(item) in m._external_params for m in FWD_MODULE_STACK): + item.is_external_param = True + module_to_register = FWD_MODULE_STACK[-1] + print_rank_0( + f'Registering dangling parameter for module {module_to_register.__class__.__name__}.', + force=False) + register_external_parameter(module_to_register, item) + + # It's possible that the parameter was already external to the completed module. If so, remove it the + # registration as it will be covered by the outer module instead. + if id(item) in module._external_params: + print_rank_0( + f' Unregistering nested dangling parameter from module {module.__class__.__name__}', + force=False) + unregister_external_parameter(module, item) + + item.all_gather() + + self.post_sub_module_forward_function(module) + + def _pre_backward_module_hook(module, inputs, output): + @instrument_w_nvtx + def _run_before_backward_function(sub_module): + # some models (e.g. Albert) may run multiple forwards on the same layer in a loop + # before doing backwards, so each backward will need a pre-fetch - using reference + # counting to support this scenario + #print(f"COUNTER before: {sub_module.applied_pre_backward_ref_cnt}") + if sub_module.applied_pre_backward_ref_cnt > 0: + self.pre_sub_module_backward_function(sub_module) + sub_module.applied_pre_backward_ref_cnt -= 1 + #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") + + return _apply_to_tensors_only(module, + PreBackwardFunction, + _run_before_backward_function, + output) + + #This is an alternate to doing _post_backward_module_hook + #it uses tensor.register_hook instead of using torch.autograd.Function + def _alternate_post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + #print(f"Before Forward {module.__class__.__name__}") + + def _run_after_backward_hook(*unused): + module.ds_grads_remaining = module.ds_grads_remaining - 1 + if module.ds_grads_remaining == 0: + #print(f"After backward {module.__class__.__name__}") + self.post_sub_module_backward_function(module) + + def _run_before_forward_function(input): + if input.requires_grad: + module.ds_grads_remaining += 1 + + return _apply_forward_and_backward_to_tensors_only( + module, + _run_before_forward_function, + _run_after_backward_hook, + inputs) + + def _post_backward_module_hook(module, inputs): + module.ds_grads_remaining = 0 + + @instrument_w_nvtx + def _run_after_backward_function(sub_module): + if sub_module.ds_grads_remaining == 0: + self.post_sub_module_backward_function(sub_module) + + return _apply_to_tensors_only(module, + PostBackwardFunction, + _run_after_backward_function, + inputs) + + # Pre forward hook + module.register_forward_pre_hook(_pre_forward_module_hook) + # Post forward hook + module.register_forward_hook(_post_forward_module_hook) + + # Pre backward hook + module.register_forward_hook(_pre_backward_module_hook) + + # post backward hook + module.register_forward_pre_hook(_post_backward_module_hook) + + @torch.no_grad() + def pre_sub_module_forward_function(self, sub_module): + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", + force=False) + + global FWD_MODULE_STACK + FWD_MODULE_STACK.append(sub_module) + + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + + self.param_coordinator.fetch_sub_module(sub_module) + see_memory_usage( + f"Before sub module function {sub_module.__class__.__name__} after fetch", + force=False) + + @torch.no_grad() + def post_sub_module_forward_function(self, sub_module): + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + + self.param_coordinator.release_sub_module(sub_module) + + see_memory_usage( + f"After sub module function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + @torch.no_grad() + def pre_sub_module_backward_function(self, sub_module): + if not self.param_coordinator.trace_complete: + self.param_coordinator.record_trace(sub_module) + self.param_coordinator.fetch_sub_module(sub_module) + + @torch.no_grad() + def post_sub_module_backward_function(self, sub_module): + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", + force=False) + self.param_coordinator.release_sub_module(sub_module) + see_memory_usage( + f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", + force=False) + + def _release_ipg_buffers(self): + if self.contiguous_gradients: + self.ipg_buffer = None + if not self.offload_optimizer and self.is_gradient_accumulation_boundary: + self.grads_in_partition = None + + self.grads_in_partition_offset = 0 + + def _optimizer_step(self, sub_group_id): + param_group_id = self.sub_group_to_group_id[sub_group_id] + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'] = [fp32_param] + + self.optimizer.step() + self.optimizer.param_groups[param_group_id]['params'] = [] + + def _swappable_optimizer_subgroup(self, sub_group_id): + if not self.swap_optimizer: + return False + + return self.optimizer_swapper.swappable_tensor( + None, + numel=self.fp16_partitioned_groups_flat_numel[sub_group_id]) + + def _partitioned_params_swap_out(self, i): + offset = 0 + fp32_param = self.fp32_partitioned_groups_flat[i] + assert fp32_param is not None, \ + f'fp32 parameters of sub_group {i} is None' + + swap_fp16_params = [] + swap_fp32_params = [] + for param, partitioned_param in zip(self.fp16_groups[i], self.fp16_partitioned_groups[i]): + src = fp32_param.narrow(0, offset, partitioned_param.ds_numel) + if partitioned_param.status == PartitionedParamStatus.AVAILABLE: + partitioned_param.data.copy_(src.data) + else: + swap_fp32_params.append(src) + swap_fp16_params.append(param) + offset += partitioned_param.ds_numel + + if len(swap_fp16_params): + swap_fp16_params[0].nvme_swapper.swap_out_partitioned_params( + dst_fp16_params=swap_fp16_params, + src_fp32_params=swap_fp32_params) + + def initialize_optimizer_states(self): + num_subgroups = len(self.fp16_groups) + + largest_numel = max( + [sum([p.ds_numel for p in psg]) for psg in self.fp16_partitioned_groups]) + gradient_dtype = self.fp32_partitioned_groups_flat[0].dtype + gradient_buffer = torch.zeros(int(largest_numel), + dtype=gradient_dtype, + device=self.device) + + timers = self.timers + timer_names = set() + + if self.swap_optimizer: + self.optimizer_swapper.init_timers() + + INIT_OPTIMIZER_TIMER = 'init_optimizer_state' + timer_names.add(INIT_OPTIMIZER_TIMER) + self.start_timers([INIT_OPTIMIZER_TIMER]) + + for i, group in enumerate(self.fp16_groups): + swappable_optimizer_subgroup = self._swappable_optimizer_subgroup(i) + swappable_param_subgroup = self.fp16_partitioned_groups_flat[i] is None + + num_elements = int(self.fp16_partitioned_groups_flat_numel[i]) + + see_memory_usage( + f'[Begin] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_in(i, timer_names) + + if self.offload_optimizer and not swappable_optimizer_subgroup: + subgroup_gradient_buffer = torch.zeros(num_elements, + dtype=gradient_dtype, + device=self.device) + if self.offload_optimizer_pin_memory: + subgroup_gradient_buffer = subgroup_gradient_buffer.pin_memory() + + self.fp32_partitioned_groups_flat[i].grad = subgroup_gradient_buffer + else: + self.fp32_partitioned_groups_flat[i].grad = gradient_buffer.narrow( + 0, + 0, + num_elements) + + self._optimizer_step(i) + + if swappable_param_subgroup: + self._partitioned_params_swap_out(i) + + if swappable_optimizer_subgroup: + self._optimizer_states_and_gradient_swap_out(i, timer_names) + + see_memory_usage( + f'[End] Initialize optimizer states {i} / {num_subgroups} subgroups, num_elems: {num_elements}, swappable opt/param:{swappable_optimizer_subgroup}/{swappable_param_subgroup}', + force=False) + + self.stop_timers([INIT_OPTIMIZER_TIMER]) + self.log_timers(timer_names) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + if not self.offload_optimizer: + for group in self.fp32_partitioned_groups_flat: + group.grad = None + + # Reset steps + return + + ######################################################################### + #########################ZeRO Partition Gradients######################## + ######################################################################### + + def get_first_param_index(self, group_id, param_group, partition_id): + for index, param in enumerate(param_group): + param_id = self.get_param_id(param) + if partition_id in self.param_to_partition_ids[group_id][param_id]: + return index + return None + + def initialize_gradient_partitioning_data_structures(self): + + total_partitions = dist.get_world_size(group=self.dp_process_group) + + for i, param_group in enumerate(self.fp16_groups): + + self.param_to_partition_ids[i] = {} + self.is_partition_reduced[i] = {} + self.total_grads_in_partition[i] = {} + self.remaining_grads_in_partition[i] = {} + self.is_grad_computed[i] = {} + self.grad_partition_insertion_offset[i] = {} + self.grad_start_offset[i] = {} + self.first_param_index_in_partition[i] = {} + + for partition_id in range(total_partitions): + self.is_grad_computed[i][partition_id] = {} + self.grad_partition_insertion_offset[i][partition_id] = {} + self.grad_start_offset[i][partition_id] = {} + self.initialize_gradient_partition(i, param_group, partition_id) + self.is_partition_reduced[i][partition_id] = False + self.first_param_index_in_partition[i][ + partition_id] = self.get_first_param_index( + i, + param_group, + partition_id) + + @instrument_w_nvtx + def independent_gradient_partition_epilogue(self): + self.report_ipg_memory_usage(f"In ipg_epilogue before reduce_ipg_grads", 0) + self.__reduce_and_partition_ipg_grads() + self.report_ipg_memory_usage(f"In ipg_epilogue after reduce_ipg_grads", 0) + + self.__reduce_and_partition_stream.synchronize() + + # if dist.get_rank() == 0: + # logger.info("Params already reduced %s", self.params_already_reduced) + for i in range(len(self.params_already_reduced)): + self.params_already_reduced[i] = False + + #in case of cpu offload, averaged gradients are already in fp32_partitioned_groups_flat.grad + #TODO: use a similar code path for both cpu_offload and non-cpu offload + if not self.offload_optimizer: + for i, sub_group in enumerate(self.fp16_groups): + self.averaged_gradients[i] = [ + self.__param_id_to_grad_partition[param.ds_id] + if param.requires_grad else torch.zeros_like(param.ds_tensor) + for param in sub_group + ] + # self.averaged_gradients[i] = self.get_flat_partition( + # self.fp16_groups[i], + # 0, + # self.fp32_partitioned_groups_flat[i].numel(), + # return_tensor_list=True) + + # this method gets called after every backward. need to increment + # here because if it gets incremented in backward() the micro step + # id will be off by one when we do the reduce and partition at the. + # start of this method. + # TODO. make this less error prone + self.micro_step_id += 1 + + def overlapping_partition_gradients_reduce_epilogue(self): + self.independent_gradient_partition_epilogue() + + def create_reduce_and_remove_grad_hooks(self): + print_rank_0(f'[Begin] Create gradient reduction hooks') + self.grad_accs = [] + for i, param_group in enumerate(self.fp16_groups): + for param in param_group: + if param.requires_grad: + #print_rank_0(f" Before all gather {param.device}, {param.shape}") + + # The hook must be created in un-partitioned parameter + param.all_gather() + + #print(f"After all gather {param.device}, {param.shape}") + def wrapper(param, i): + param_tmp = param.expand_as(param) + grad_acc = param_tmp.grad_fn.next_functions[0][0] + + @instrument_w_nvtx + def reduce_partition_and_remove_grads(*notneeded): + self.reduce_ready_partitions_and_remove_grads(param, i) + + grad_acc.register_hook(reduce_partition_and_remove_grads) + self.grad_accs.append(grad_acc) + + #print(f"param grad fn {param.expand_as(param).grad_fn}") + wrapper(param, i) + + # Partition the parameter after creating the hook + param.partition() + print_rank_0(f'[End] Create gradient reduction hooks') + + def get_param_id(self, param): + unique_id = id(param) + return self.param_id[unique_id] + + def report_ipg_memory_usage(self, tag, param_elems): + elem_count = self.elements_in_ipg_bucket + param_elems + percent_of_bucket_size = (100.0 * elem_count) // self.reduce_bucket_size + see_memory_usage( + f"{tag}: elems in_bucket {self.elements_in_ipg_bucket} param {param_elems} max_percent {percent_of_bucket_size}", + force=False) + + ###############Idependent Partition Gradient ######################## + def reduce_independent_p_g_buckets_and_remove_grads(self, param, i): + #print_rank_0(f"Inside reduce ipg buckets. {debug_param2name_id_shape(param)}, ipg elements {self.elements_in_ipg_bucket}, reduce bucket size {self.reduce_bucket_size}", force=True) + + # Because the ipg bucket is initialized with a random place holder tensor, we must + # explicitly check that the bucket has any real data in it (self.elements_in_ipg_bucket > + # 0). Otherwise if the incoming param.ds_numel is large, this branch may get triggered on a + # garbage data and `self.average_tensor()` will crash because its params_to_reduce will be + # empty, while reduction_list will have that garbage data. + if self.elements_in_ipg_bucket > 0 and self.elements_in_ipg_bucket + param.ds_numel > self.reduce_bucket_size: + self.report_ipg_memory_usage("In ipg_remove_grads before reduce_ipg_grads", + param.ds_numel) + + self.__reduce_and_partition_ipg_grads() + + param_id = self.get_param_id(param) + assert self.params_already_reduced[param_id] == False, \ + f"The parameter {param_id} has already been reduced. \ + Gradient computed twice for this partition. \ + Multiple gradient reduction is currently not supported" + + self.__add_grad_to_ipg_bucket(param) + + @instrument_w_nvtx + @torch.no_grad() + def __add_grad_to_ipg_bucket(self, param: Parameter) -> None: + self.__reduce_and_partition_stream.wait_stream(torch.cuda.default_stream()) + + if self.contiguous_gradients and self.elements_in_ipg_bucket + param.grad.numel( + ) < self.reduce_bucket_size: + # move the gradient to a contiguous buffer + with torch.cuda.stream(self.__reduce_and_partition_stream): + # move the parameter's gradient to the contiguous flat buffer + new_grad_tensor = self.__ipg_bucket_flat_buffer.narrow( + 0, + self.elements_in_ipg_bucket, + param.grad.numel()).view_as(param.grad) + new_grad_tensor.copy_(param.grad, non_blocking=True) + param.grad.record_stream(torch.cuda.current_stream()) + param.grad.data = new_grad_tensor + + self.__params_in_ipg_bucket.append(param) + + @instrument_w_nvtx + @torch.no_grad() + def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: + if not self.__params_in_ipg_bucket: + return + + for param in self.__params_in_ipg_bucket: + if param.grad.numel() != param.ds_numel: + raise RuntimeError( + f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter " + f"gradients whose size is not same as the params") + + self.__params_in_ipg_bucket.sort(key=lambda p: p.ds_id) + + assert len(set(p.ds_id for p in self.__params_in_ipg_bucket)) == len( + self.__params_in_ipg_bucket) + + while self.__param_reduce_events and self.__param_reduce_events[0].query(): + self.__param_reduce_events.popleft() + if len(self.__param_reduce_events) > self.__max_param_reduce_events: + self.__param_reduce_events.popleft().synchronize() + + with torch.cuda.stream(self.__reduce_and_partition_stream): + if safe_mode: + assert_ints_same_as_other_ranks( + [p.ds_id for p in self.__params_in_ipg_bucket]) + + grad_partitions = self.__avg_scatter_grads(self.__params_in_ipg_bucket) + self.__partition_grads(self.__params_in_ipg_bucket, grad_partitions) + + self.__params_in_ipg_bucket.clear() + + event = Event() + event.record() + self.__param_reduce_events.append(event) + + @instrument_w_nvtx + def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor]: + """average gradients and scatter partitions across ranks""" + dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) + + full_grads_for_rank = [p.grad for p in params_to_reduce] + if self.allreduce_always_fp32: + full_grads_for_rank = [g.float() for g in full_grads_for_rank] + + if self.postscale_gradients and self.gradient_predivide_factor != 1.0: + full_grads_for_rank = [ + g.div(self.gradient_predivide_factor) for g in full_grads_for_rank + ] + + grad_partitions_for_rank = reduce_scatter_coalesced(full_grads_for_rank, + self.dp_process_group) + + if self.postscale_gradients and self.gradient_predivide_factor != dist.get_world_size( + self.dp_process_group): + grad_partitions_for_rank = [ + g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank + ] + + if self.allreduce_always_fp32: + grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] + + return grad_partitions_for_rank + + def set_grad_positions(self): + for i, group in enumerate(self.fp16_groups): + current_offset = 0 + for param in group: + param_id = self.get_param_id(param) + num_elements = param.ds_tensor.ds_numel + + self.grad_position[param_id] = [ + int(i), + int(current_offset), + int(num_elements) + ] + #print(f"param id {param_id} i:{i}, ds_tensor {num_elements} numel {param.numel()}") + current_offset += num_elements + + def _constant_buffered_norm2(self, input, buffer_size=250000000): + norm = None + for part in input.view(-1).split(buffer_size): + if norm is None: + norm = part.data.double().norm(2)**2.0 + else: + norm += part.data.double().norm(2)**2.0 + return norm**0.5 + + def set_norm_for_param_grad_in_gpu(self, param): + param_id = self.get_param_id(param) + #self.norm_for_param_grads[param_id] = param.grad.data.double().norm(2) + #Using a more memory efficient version + self.norm_for_param_grads[param_id] = self._constant_buffered_norm2(param.grad) + + def async_inplace_copy_grad_to_fp32_buffer_from_gpu(self, param, fp32_grad_tensor): + with torch.cuda.stream(self.copy_grad_stream): + param_id = self.get_param_id(param) + src_tensor = param.grad.view(-1).float() + #print(f"src_tensor {src_tensor.size()} and fp32 grad {fp32_grad_tensor.size()}") + fp32_grad_tensor.copy_(src_tensor, non_blocking=True) + param.grad = None + + def complete_grad_norm_calculation_for_cpu_offload(self, params): + total_norm = 0.0 + norm_type = 2.0 + for p in params: + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_id = self.get_param_id(p) + if param_id in self.norm_for_param_grads.keys(): + param_norm = self.norm_for_param_grads[param_id] + total_norm += param_norm.item()**2 + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + @instrument_w_nvtx + def __partition_grads(self, + params_to_release: List[Parameter], + grad_partitions: List[Tensor]) -> None: + for param, grad_partition in zip(params_to_release, grad_partitions): + if param.ds_tensor.ds_numel * dist.get_rank( + self.dp_process_group) > param.ds_numel: + # this grad partition is empty - don't need to do anything + continue + + # move or accumulate gradient partition to target buffer + grad_buffer = self.__param_id_to_grad_partition[param.ds_id].narrow( + 0, + 0, + grad_partition.numel()) + if self.micro_step_id == 0: # don't accumulate + grad_buffer.copy_(grad_partition, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = grad_buffer.to(grad_partition.device, non_blocking=True) + elif grad_buffer.is_cuda: + grad_buffer.add_(grad_partition) + else: + # if dst is CPU, copy first to src device, do the addition + # there, then move back to dst. adding directly to cpu is very slow + cuda_grad_buffer = grad_buffer.to(grad_partition.device, + non_blocking=True) + cuda_grad_buffer.add_(grad_partition) + grad_buffer.copy_(cuda_grad_buffer, non_blocking=True) + # ensure grad buffer is a CUDA buffer to speed up the next few + # operations and so it can be used asynchronously + grad_buffer = cuda_grad_buffer + + if hasattr(self.__inf_or_nan_tracker, "logical_or_"): + self.__inf_or_nan_tracker.logical_or_(torch.isinf(grad_buffer).any()) + self.__inf_or_nan_tracker.logical_or_(torch.isnan(grad_buffer).any()) + else: + # logical_or_ not available in older versions of pytorch + self.__inf_or_nan_tracker += torch.isinf(grad_buffer).any() + self.__inf_or_nan_tracker += torch.isnan(grad_buffer).any() + self.__inf_or_nan_tracker = self.__inf_or_nan_tracker > 0 + + # offload the gradient partition if applicable + if self.offload_optimizer: + i, dest_offset, _ = self.grad_position[self.get_param_id(param)] + offload_fp32_gradients = {} + offload_fp32_offsets = {} + + if self.is_gradient_accumulation_boundary: + self.norm_for_param_grads[self.get_param_id( + param)] = self._constant_buffered_norm2(grad_buffer) + + if self._swappable_optimizer_subgroup(i): + if not i in offload_fp32_gradients.keys(): + offload_fp32_gradients[i] = [] + offload_fp32_offsets[i] = [] + + offload_fp32_gradients[i].append(grad_buffer.float()) + offload_fp32_offsets[i].append(dest_offset) + else: + fp32_grad_tensor = self.fp32_partitioned_groups_flat[ + i].grad.narrow(0, + dest_offset, + grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer) + + # free the gradient + param.grad.record_stream(torch.cuda.current_stream()) + param.grad = None + + if self.offload_optimizer and self.swap_optimizer: + for i in offload_fp32_gradients.keys(): + self.optimizer_swapper.swap_out_gradients( + parameter=self.fp32_partitioned_groups_flat[i], + gradient_offsets=offload_fp32_offsets[i], + gradient_tensors=offload_fp32_gradients[i]) + + def reduce_ready_partitions_and_remove_grads(self, param, i): + #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) + self.reduce_independent_p_g_buckets_and_remove_grads(param, i) + + def zero_reduced_gradients(self, partition_id, i): + def are_all_related_partitions_reduced(params_id): + for partition_id in self.param_to_partition_ids[i][params_id]: + if not self.is_partition_reduced[i][partition_id]: + return False + return True + + for params_id in self.is_grad_computed[i][partition_id]: + if are_all_related_partitions_reduced(params_id): + self.param_dict[params_id].grad = None + + def flatten_and_print(self, message, tensors, start=0, n=5): + flatten_tensor = self.flatten(tensors) + + def print_func(): + logger.info(flatten_tensor.contiguous().view(-1).narrow(0, start, n)) + + self.sequential_execution(print_func, message) + + def get_grads_to_reduce(self, i, partition_id): + def get_reducible_portion(key): + grad = self.param_dict[key].grad + total_elements = grad.numel() + start = self.grad_start_offset[i][partition_id][key] + num_elements = min( + total_elements - start, + self.partition_size[i] - + self.grad_partition_insertion_offset[i][partition_id][key]) + if not pg_correctness_test: + if num_elements == total_elements: + return grad + else: + return grad.contiguous().view(-1).narrow(0, + int(start), + int(num_elements)) + else: + if num_elements == total_elements: + return grad.clone() + else: + return grad.clone().contiguous().view(-1).narrow( + 0, + int(start), + int(num_elements)) + + grads_to_reduce = [] + for key in self.is_grad_computed[i][partition_id]: + grad = get_reducible_portion(key) + grads_to_reduce.append(grad) + return grads_to_reduce + + def sequential_execution(self, function, message, group=None): + if group is None: + group = self.dp_process_group + if dist.get_rank(group=group) == 0: + logger.info(message) + for id in range(dist.get_world_size(group=group)): + if id == dist.get_rank(group=group): + function() + dist.barrier(group=group) + + def set_none_gradients_to_zero(self, i, partition_id): + for param_id in self.is_grad_computed[i][partition_id]: + param = self.param_dict[param_id] + if param.grad is None: + param.grad = torch.zero_like(param) + + ######################Reduction Related Methods############################## + + def allreduce_bucket(self, bucket, allreduce_always_fp32=False, rank=None, log=None): + rank = None + tensor = self.flatten(bucket) + + tensor_to_allreduce = tensor + + if pg_correctness_test: + allreduce_always_fp32 = True + + if allreduce_always_fp32: + tensor_to_allreduce = tensor.float() + + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + + if rank is None: + # "All Reducing" + dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) + else: + global_rank = _get_global_rank(self.dp_process_group, rank) + dist.reduce(tensor_to_allreduce, global_rank, group=self.dp_process_group) + + if allreduce_always_fp32 and tensor is not tensor_to_allreduce: + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + tensor.copy_(tensor_to_allreduce) + + return tensor + + # if rank is specified do a reduction instead of an allreduce + def allreduce_and_copy(self, small_bucket, rank=None, log=None): + with torch.cuda.stream(self.reduction_stream): + allreduced = self.allreduce_bucket(small_bucket, rank=rank, log=log) + if rank is None or rank == dist.get_rank(group=self.dp_process_group): + for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): + buf.copy_(synced) + + def allreduce_no_retain(self, + bucket, + numel_per_bucket=500000000, + rank=None, + log=None): + small_bucket = [] + numel = 0 + for tensor in bucket: + small_bucket.append(tensor) + numel = numel + tensor.numel() + if numel > numel_per_bucket: + self.allreduce_and_copy(small_bucket, rank=rank, log=None) + small_bucket = [] + if len(small_bucket) > 0: + self.allreduce_and_copy(small_bucket, rank=rank, log=log) + + ############################################################################# + ############################################################################# + ############################################################################# + + # views the tensor as multiple partitions and returns + # those partitions + def get_data_parallel_partitions(self, tensor): + partitions = [] + + dp = dist.get_world_size(group=self.dp_process_group) + dp_id = dist.get_rank(group=self.dp_process_group) + + total_num_elements = tensor.numel() + + base_size = total_num_elements // dp + remaining = total_num_elements % dp + + start = 0 + for id in range(dp): + partition_size = base_size + if id < remaining: + partition_size = partition_size + 1 + partitions.append(tensor.narrow(0, start, partition_size)) + start = start + partition_size + return partitions + + def get_partition_info(self, tensor_list, partition_size, partition_id): + params_in_partition = [] + params_not_in_partition = [] + + start_index = partition_size * partition_id + end_index = partition_size * (partition_id + 1) + + current_index = 0 + first_offset = 0 + + for tensor in tensor_list: + + tensor_size = tensor.numel() + + if (current_index >= start_index and current_index < end_index): + params_in_partition.append(tensor) + + elif start_index > current_index and start_index < (current_index + + tensor_size): + params_in_partition.append(tensor) + + assert (first_offset == 0), "This can happen either zero or only once as this must be the first tensor in the partition" + first_offset = start_index - current_index + + else: + params_not_in_partition.append(tensor) + + current_index = current_index + tensor_size + + return params_in_partition, params_not_in_partition, first_offset + + @instrument_w_nvtx + def zero_grad(self, set_grads_to_None=True): + """ + Zero FP16 parameter grads. + """ + self.micro_step_id = 0 + + # FP32 grad should never exist. + # For speed, set model fp16 grad to None by default + for group in self.fp16_groups: + for p in group: + if set_grads_to_None: + if p.grad is not None and p.grad.is_cuda: + p.grad.record_stream(torch.cuda.current_stream()) + p.grad = None + else: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def _model_parallel_all_reduce(self, tensor, op): + """ Perform all reduce within model parallel group, if any. + """ + if self.model_parallel_group is None: + pass + else: + torch.distributed.all_reduce(tensor=tensor, + op=op, + group=self.model_parallel_group) + + @instrument_w_nvtx + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + grad_norms = [] + for g, p in zip(gradients, params): + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + grad_norms.append(g.cuda(non_blocking=True).double().norm(2)) + + # Sum across all model parallel GPUs. + total_norm_cuda = torch.sum(torch.pow(torch.stack(grad_norms), 2)) + + torch.distributed.all_reduce(total_norm_cuda, + op=torch.distributed.ReduceOp.SUM, + group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, + op=torch.distributed.ReduceOp.SUM) + + total_norm = total_norm_cuda.item()**(1. / norm_type) + + if total_norm == float( + 'inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + # creates a flat fused tensor from the tensor list starting at the first_offset + # in the first tensor of the list. If there are not enough elements in the tensor + # list then the flat tensor will be padded with zeros + def get_flat_partition(self, + tensor_list, + first_offset, + partition_size, + return_tensor_list=False): + flat_tensor_list = [] + current_size = 0 + for i, tensor in enumerate(tensor_list): + if tensor.grad is None: + tensor.grad = torch.zeros_like(tensor) + + tensor = tensor.grad + num_elements = tensor.numel() + tensor_offset = 0 + + # we need to offset to get to the right element + if i == 0 and first_offset > 0: + tensor_offset = first_offset + num_elements = num_elements - tensor_offset + + # we dont need all elements of the tensor + if num_elements > (partition_size - current_size): + num_elements = partition_size - current_size + + # we need a narrow view of the tensor based on the tensor offset and number of elements that + # we need from this tensor + if tensor_offset > 0 or num_elements < tensor.numel(): + flat_tensor_list.append(tensor.contiguous().view(-1).narrow( + 0, + int(tensor_offset), + int(num_elements))) + else: + flat_tensor_list.append(tensor) + + current_size = current_size + num_elements + + # this means its the last partition and does not align with the dp boundary. We need to pad before flattening + if current_size < partition_size: + flat_tensor_list.append( + torch.zeros(int(partition_size - current_size), + dtype=tensor_list[0].dtype, + device=tensor_list[0].device)) + + if return_tensor_list: + return flat_tensor_list + + return self.flatten(flat_tensor_list) + + def free_grad_in_param_list(self, param_list): + for p in param_list: + p.grad = None + + def reset_cpu_buffers(self): + self.norm_for_param_grads = {} + self.local_overflow = False + + def log_timers(self, timer_names): + if self.timers is None: + return + + self.timers.log(names=list(timer_names)) + + def start_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).start() + + def stop_timers(self, timer_names): + if self.timers is None: + return + + for name in timer_names: + self.timers(name).stop() + + def _pre_step(self): + self.micro_step_id = 0 + + print_rank_0(f"Inside Step function") + see_memory_usage(f"In step before checking overflow", force=False) + + print_rank_0("Finished Tracing at Beginning of Step") + self.param_coordinator.hierarchy = 0 + + print_rank_0("Finished Tracing at Beginning of Step") + + @instrument_w_nvtx + def _get_norm_groups(self): + norm_groups = [] + for i, group in enumerate(self.fp16_groups): + if self.offload_optimizer: + norm_groups.append( + self.complete_grad_norm_calculation_for_cpu_offload( + self.fp16_groups[i])) + else: + norm_groups.append( + self.get_grad_norm_direct(self.averaged_gradients[i], + self.fp16_groups[i])) + return norm_groups + + @instrument_w_nvtx + def _prepare_fp32_grad_for_sub_group(self, sub_group_id): + partition_id = dist.get_rank(group=self.dp_process_group) + + single_grad_partition = self.flatten(self.averaged_gradients[sub_group_id]).to( + self.fp32_partitioned_groups_flat[sub_group_id].dtype) + + assert single_grad_partition.numel() == self.fp32_partitioned_groups_flat[sub_group_id].numel(), \ + "averaged gradients have different number of elements that partition size {} {} {} {}".format( + single_grad_partition.numel(), self.fp32_partitioned_groups_flat[sub_group_id].numel(), sub_group_id, partition_id) + + self.fp32_partitioned_groups_flat[sub_group_id].grad = single_grad_partition + + # release all the gradient since we have already created a necessary copy in dp_grad_partition + self.zero_grad() + + for grad in filter(lambda g: g.is_cuda, self.averaged_gradients[sub_group_id]): + grad.record_stream(torch.cuda.current_stream()) + + self.averaged_gradients[sub_group_id] = None + + @instrument_w_nvtx + def _prepare_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before prepare optimizer sub group {sub_group_id}', + force=False) + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_in(sub_group_id, timer_names) + elif not self.offload_optimizer: + self._prepare_fp32_grad_for_sub_group(sub_group_id) + see_memory_usage(f'After prepare optimizer sub group {sub_group_id}', + force=False) + + def _optimizer_states_and_gradient_swap_in(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_IN_STATE = 'optimizer_swap_in_state' + see_memory_usage(f'pre-step Before swapping in optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_IN_STATE]) + + self.optimizer_swapper.swap_in_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_parameter=self.next_swappable_fp32_partitioned_groups[sub_group_id]) + + self.stop_timers([OPTIMIZER_SWAP_IN_STATE]) + timer_names.add(OPTIMIZER_SWAP_IN_STATE) + see_memory_usage(f'pre-step After swapping in optimizer tensors {sub_group_id}', + force=False) + + @instrument_w_nvtx + def _release_sub_group(self, sub_group_id, timer_names=set()): + see_memory_usage(f'Before release optimizer sub group {sub_group_id}', + force=False) + # get rid of the fp32 gradients. Not needed anymore + if not self.offload_optimizer: + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + if self._swappable_optimizer_subgroup(sub_group_id): + self._optimizer_states_and_gradient_swap_out(sub_group_id, timer_names) + see_memory_usage(f'After release optimizer sub group {sub_group_id}', + force=False) + + # create a flat tensor aligned at the alignment boundary + @instrument_w_nvtx + def flatten_dense_tensors_aligned(self, tensor_list, alignment): + num_elements = 0 + for tens in tensor_list: + num_elements = num_elements + tens.numel() + + remaining = num_elements % alignment + + if remaining: + elements_to_add = alignment - remaining + pad_tensor = torch.zeros(elements_to_add, + device=tensor_list[0].device, + dtype=tensor_list[0].dtype) + padded_tensor_list = tensor_list + [pad_tensor] + + num_elements = num_elements + elements_to_add + else: + padded_tensor_list = tensor_list + + return self.flatten(padded_tensor_list) + + def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set()): + param_length = self.fp16_partitioned_groups_flat_numel[sub_group_id] + fp32_param_id = id(self.fp32_partitioned_groups_flat[sub_group_id]) + assert self._swappable_optimizer_subgroup(sub_group_id), \ + f'Parameter {fp32_param_id} of numel={param_length} is not swappable' + + OPTIMIZER_SWAP_OUT_STATE = 'optimizer_swap_out_state' + see_memory_usage( + f'post-step Before swapping out optimizer tensors {sub_group_id}', + force=False) + self.start_timers([OPTIMIZER_SWAP_OUT_STATE]) + + self.optimizer_swapper.swap_out_optimizer_state( + parameter=self.fp32_partitioned_groups_flat[sub_group_id], + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is + not None) + + self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) + see_memory_usage( + f'post-step After swapping out optimizer tensors {sub_group_id}', + force=False) + timer_names.add(OPTIMIZER_SWAP_OUT_STATE) + + # get rid of the fp32 gradients. Not needed anymore + self.fp32_partitioned_groups_flat[sub_group_id].grad = None + + def _unflatten_partitioned_parameters(self, sub_group_id): + updated_params = self.unflatten(self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + def _overflow_clean_up(self, prev_scale): + see_memory_usage('After overflow before clearing gradients', force=False) + self.zero_grad() + + if self.offload_optimizer: + self.reset_cpu_buffers() + else: + self.averaged_gradients = {} + + see_memory_usage('After overflow after clearing gradients', force=False) + + if torch.distributed.get_rank() == 0: + logger.info( + "[deepscale] OVERFLOW! Rank {} Skipping step. Attempted loss scale: {}, " + "reducing to {}".format(dist.get_rank(), + prev_scale, + self.loss_scale)) + + @instrument_w_nvtx + def _overflow_check_and_loss_scale_update(self): + + # First compute norm for all group so we know if there is overflow + self.check_overflow() + + #loss scaling related computation + prev_scale = self.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + self._overflow_clean_up(prev_scale) + + return self.overflow + + @instrument_w_nvtx + def _post_step(self, timer_names=set()): + if self.offload_optimizer: + self.reset_cpu_buffers() + + #Gathering persisting parameters + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + + self.log_timers(timer_names) + + see_memory_usage('After zero_optimizer step', force=False) + print_rank_0(f"------------------Finishing Step-----------------------") + + @instrument_w_nvtx + def _reassign_or_swap_out_partitioned_parameters(self, sub_group_id): + if self.fp16_partitioned_groups_flat[sub_group_id] is not None: + self.fp16_partitioned_groups_flat[sub_group_id].data.copy_( + self.fp32_partitioned_groups_flat[sub_group_id].data) + + #unflatten fp16 parameter subgroup + self._unflatten_partitioned_parameters(sub_group_id) + else: + self._partitioned_params_swap_out(sub_group_id) + + @instrument_w_nvtx + def step(self, closure=None): + """ + Not supporting closure. + """ + self._pre_step() + self._partition_all_parameters() + + #checks for overflow, adjust the loss scale accordingly + if self._overflow_check_and_loss_scale_update(): + if self.swap_optimizer: + self.optimizer_swapper.log_timers() + return + + norm_groups = self._get_norm_groups() + self._global_grad_norm = get_global_norm(norm_list=norm_groups) + + timer_names = set() + + timer_names.add('optimizer_step') + self.start_timers(['optimizer_step']) + + #update parameters one sub group at a time + for sub_group_id, group in enumerate(self.fp16_groups): + + #prepare optimizer states, gradients and fp32 parameters for update + self._prepare_sub_group(sub_group_id, timer_names) + + #scale the fp32 gradients + self.unscale_and_clip_grads(sub_group_id, self._global_grad_norm) + + #apply the optimizer step on the sub group and copy fp32 parameters to fp16 + self._optimizer_step(sub_group_id) + + #put fp16 parameters in appropriate location + self._reassign_or_swap_out_partitioned_parameters(sub_group_id) + + #release memory or swap out optimizer states of fp32 parameters + self._release_sub_group(sub_group_id, timer_names) + + self.stop_timers(['optimizer_step']) + + self._post_step(timer_names) + + # warn user about caching allocator flushes + alloc_retries = torch.cuda.memory_stats()["num_alloc_retries"] if hasattr( + torch.cuda, + "memory_stats") else 0 + if alloc_retries > self.__n_caching_allocator_flushes: + if dist.get_rank() == 0: + logger.warning( + "%d pytorch allocator cache flushes since last step. this happens " + "when there is high memory pressure and is detrimental to " + "performance. if this is happening frequently consider adjusting " + "settings to reduce memory consumption. If you are unable to " + "make the cache flushes go away consider adding " + "torch.cuda.empty_cache() calls in your training loop to ensure " + "that all ranks flush their caches at the same time", + alloc_retries - self.__n_caching_allocator_flushes) + self.__n_caching_allocator_flushes = alloc_retries + + def dump_pre_step_gradients(self, debug_fp32_grads): + # Dump gradient norms for debugging + for i, _ in enumerate(self.fp16_groups): + print(f'Pre-Step Dump Norms for Group {i} FP16P, FP16G, FP32G, FP32GUC') + for fp16_param, fp32_grad in zip(self.fp16_groups[i], debug_fp32_grads[i]): + param_id = self.get_param_id(fp16_param) + fp16_grad_norm = self.debug_fp16_grads[i][param_id] + + fp32_grad_norm = [float(t.data.float().norm(2)) for t in fp32_grad] + norm_list = [fp16_grad_norm, fp32_grad_norm] + print(f'Pre-Step Norms {i} {param_id} = {norm_list}') + + def dump_post_step_gradients(self): + # Dump gradient norms for debugging + for i, group in enumerate(self.fp16_groups): + print( + f'Post-Step Dump Norms for Group {i} FP16P, FP16DS, FP16FLAT, FP32FLAT') + unflat_fp16 = self.unflatten(self.fp16_groups_flat[i], self.fp16_groups[i]) + unflat_fp32 = self.unflatten(self.fp32_partitioned_groups_flat[i], + self.fp16_groups[i]) + for j, p in enumerate(self.fp16_groups[i]): + param_id = self.get_param_id(p) + param_norm = float(p.data.float().norm(2)) + ds_norm = float(p.ds_tensor.data.float().norm(2)) + + unflat_norm = [ + float(t.data.float().norm(2)) + for t in [unflat_fp16[j], + unflat_fp32[j]] + ] + norm_list = [param_norm, ds_norm] + unflat_norm + print(f'Post-Step Norms {i} {param_id} = {norm_list}') + + @instrument_w_nvtx + def unscale_and_clip_grads(self, sub_group_id, total_norm): + grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] + + # compute combined scale factor for this group + combined_scale = self.loss_scale + if self.clip_grad > 0.: + # norm is in fact norm*scale + clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad + if clip > 1: + combined_scale = clip * self.loss_scale + # to maintain behavior of averaging over accumulation steps + combined_scale *= self.micro_step_id + 1 + + for grad in grad_groups_flat: + if isinstance(grad, list): + sub_partitions = grad + for g in sub_partitions: + g.data.mul_(1. / combined_scale) + else: + grad.data.mul_(1. / combined_scale) + + def _check_overflow(self, partition_gradients=True): + self.overflow = self.has_overflow(partition_gradients) + + # `params` is a list / generator of torch.Variable + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.fp16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False + + @instrument_w_nvtx + def has_overflow(self, partition_gradients=True): + if partition_gradients: + with torch.cuda.stream(self.__reduce_and_partition_stream): + self.local_overflow = bool(self.__inf_or_nan_tracker.item()) + self.__inf_or_nan_tracker.zero_() + + overflow = self.local_overflow + #overflow = self.has_overflow_partitioned_grads_serial() + overflow_gpu = torch.cuda.ByteTensor([overflow]) + torch.distributed.all_reduce(overflow_gpu, + op=torch.distributed.ReduceOp.MAX, + group=self.dp_process_group) + + else: + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + + overflow = self.has_overflow_serial(params, is_grad_list=partition_gradients) + overflow_gpu = torch.cuda.ByteTensor([overflow]) + + # Since each model parallel GPU carries only part of the model, + # make sure overflow flag is synced across all the model parallel GPUs + self._model_parallel_all_reduce(tensor=overflow_gpu, + op=torch.distributed.ReduceOp.MAX) + + overflow = overflow_gpu[0].item() + return bool(overflow) + + # `x` is a torch.Tensor + @staticmethod + def _has_inf_or_nan(x, j=None): + try: + # if x is half, the .float() incurs an additional deep copy, but it's necessary if + # Pytorch's .sum() creates a one-element tensor of the same type as x + # (which is true for some recent version of pytorch). + cpu_sum = float(x.float().sum()) + # More efficient version that can be used if .sum() returns a Python scalar + # cpu_sum = float(x.sum()) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + @instrument_w_nvtx + def backward(self, loss, retain_graph=False): + """ + :attr:`backward` performs the following steps: + + 1. fp32_loss = loss.float() + 2. scaled_loss = fp32_loss*loss_scale + 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves + """ + if self.swap_optimizer: + self.optimizer_swapper.pre_backward() + + see_memory_usage(f"Before backward", force=False) + + self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) + + self.param_coordinator.reset_step() + + if self.swap_optimizer: + self.optimizer_swapper.post_backward() + + def get_fp32_grad_partitions(self) -> Dict[int, Dict[int, Tensor]]: + """get fp32 gradient partition dictionary + accessed as grad_dict[parameter_group_index][parameter_index] + """ + self.__reduce_and_partition_stream.synchronize() + grad_dict = collections.defaultdict(dict) + if self.offload_optimizer: + for group in self.fp16_groups: + for param_idx, param in enumerate(group): + group_idx, dest_offset, num_elements = self.grad_position[self.get_param_id(param)] + fp32_grad = self.fp32_partitioned_groups_flat[group_idx].grad.narrow( + 0, + dest_offset, + num_elements) + grad_dict[group_idx][param_idx] = fp32_grad + else: + for group_idx, group in self.averaged_gradients.items(): + for param_idx, gradient in enumerate(group): + grad_dict[group_idx][param_idx] = gradient.float() + + return grad_dict + + @instrument_w_nvtx + def _partition_all_parameters(self): + """Partitioning Parameters that were not partitioned usually if parameters + of modules whose input parameters do not require grad computation do not + trigger post call and will therefore will remain unpartitioned""" + self.param_coordinator.release_and_reset_all() + for param in iter_params(self.module, recurse=True): + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"{param.ds_summary()} expected to be released") + + def check_overflow(self, partition_gradients=True): + self._check_overflow(partition_gradients) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" + def _get_state(self): + return self.optimizer.state + + def _set_state(self, value): + self.optimizer.state = value + + state = property(_get_state, _set_state) + + # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" + # (for example, to adjust the learning rate) + def _get_param_groups(self): + return self.optimizer.param_groups + + def _set_param_groups(self, value): + self.optimizer.param_groups = value + + param_groups = property(_get_param_groups, _set_param_groups) + + # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) + cur_scale = property(_get_loss_scale, _set_loss_scale) + + def _get_lean_tensors(self, padded_flattened_tensor, group_tensors, paddings): + # Remove paddings from flattened tensor + individual_tensors = self.unflatten(padded_flattened_tensor, group_tensors) + lean_lengths = [t.numel() - pad for t, pad in zip(group_tensors, paddings)] + lean_tensors = [t[:len] for t, len in zip(individual_tensors, lean_lengths)] + #logger.info(f'rank {dist.get_rank()}: lean_tensors = {[t.numel() for t in lean_tensors]}') + return lean_tensors + + #TODO REVISIT this for stage 3 + def get_lean_optimizer_state(self): + # Return optimizer states after removing paddings. + # This method assumes that each param group contains a single flattened tensor. + optimizer_groups_state = [] + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + lean_state = {} + for key, value in self.optimizer.state[p].items(): + if torch.is_tensor(value): + padded_lens = [t.numel() for t in self.fp16_partitioned_groups[i]] + lean_state[key] = self._get_lean_tensors( + value, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + lean_flat_len = sum([t.numel() for t in lean_state[key]]) + else: + lean_state[key] = value + + optimizer_groups_state.append(lean_state) + + return optimizer_groups_state + + def get_groups_without_padding(self, groups_with_padding): + # Return group tensor after removing paddings added for alignment to DP world size. + groups_without_padding = [] + for i, group in enumerate(groups_with_padding): + lean_group = self._get_lean_tensors(group, + self.fp16_partitioned_groups[i], + self.groups_padding[i]) + groups_without_padding.append(lean_group) + + return groups_without_padding + + def _set_fp32_optimizer_param_groups(self): + for sub_group_id, _ in enumerate(self.fp16_groups): + param_group_id = self.sub_group_to_group_id[sub_group_id] + self.optimizer.param_groups[param_group_id]['params'].append( + self.fp32_partitioned_groups_flat[sub_group_id]) + + def _clear_fp32_optimizer_param_groups(self): + for param_group in self.optimizer.param_groups: + param_group['params'] = [] + + def _rigid_state_dict(self): + state_dict = {} + state_dict['zero_stage'] = ZERO_OPTIMIZATION_WEIGHTS + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['partition_count'] = self.partition_count + + self._set_fp32_optimizer_param_groups() + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_flat_groups'] = self.fp32_partitioned_groups_flat + self._clear_fp32_optimizer_param_groups() + + return state_dict + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + return self._rigid_state_dict() + + +# Restore base optimizer fp32 weights from checkpoint by: +# 1) Merging fp32 weights from checkpoints of all partitions +# 2) Extracting fp32 weights for current partition from merged weights +# 3) Using extracted weights to update base optimizer weights directly. + + def _restore_from_fp32_weights(self, all_state_dict): + + flat_local_partition = [] + for i in range(len(self.fp32_partitioned_groups_flat)): + merged_partitions = [sd['fp32_groups'][i] for sd in all_state_dict] + flat_local_partition.append(self._get_flattened_partition(merged_partitions)) + + for current, saved in zip(self.fp32_partitioned_groups_flat, flat_local_partition): + current.data.copy_(saved.data) + + # Restore base optimizer fp32 weights from ZeRO fp16 weights + def _restore_from_fp16_weights(self): + for fp16_partitions, fp32_partition in zip(self.fp16_partitioned_groups_flat, self.fp32_partitioned_groups_flat): + fp32_partition.data.copy_(fp16_partitions.data) + + # Refresh the fp32 master params from the fp16 copies. + def refresh_fp32_params(self): + self._restore_from_fp16_weights() + + # Extract flattened partition for current rank from all partitions + def _get_flattened_partition(self, all_partition_states): + partition_id = dist.get_rank(group=self.dp_process_group) + alignment = dist.get_world_size(group=self.dp_process_group) + + param_partitions = [[] for _ in range(len(all_partition_states[0]))] + for i, partition in enumerate(all_partition_states): + for j, param in enumerate(partition): + param_partitions[j].append(param) + + local_state_partitions = [] + for param_index, param_slices in enumerate(param_partitions): + flattened_merged_tensor = self.flatten_dense_tensors_aligned( + param_slices, + alignment) + new_partitions = self.get_data_parallel_partitions(flattened_merged_tensor) + local_state_partitions.append(new_partitions[partition_id]) + + if torch.is_tensor(local_state_partitions[0]): + return self.flatten_dense_tensors_aligned(local_state_partitions, alignment) + + # Assume non-tensor states are not partitioned and equal across ranks, so return first one + return local_state_partitions[0] + + # Restore base optimizer state from checkpoint by + # 1) Merging optimizer state from checkpoints of all partitions + # 2) Extracting optimizer state for current partition from the merged state + # 3) Using the extracted value to directly update the base optimizer. + def _restore_base_optimizer_state(self, all_state_dict): + base_optimizer_group_states = [] + for i in range(len(self.optimizer.param_groups)): + partition_states = {} + all_partition_group_states = [ + sd['base_optimizer_state'][i] for sd in all_state_dict + ] + for key in all_partition_group_states[0].keys(): + all_partition_states = [ + all_states[key] for all_states in all_partition_group_states + ] + partition_states[key] = self._get_flattened_partition( + all_partition_states) + base_optimizer_group_states.append(partition_states) + + for i, group in enumerate(self.optimizer.param_groups): + p = group['params'][0] + for key, saved in base_optimizer_group_states[i].items(): + if torch.is_tensor(self.optimizer.state[p][key]): + self.optimizer.state[p][key].data.copy_(saved.data) + else: + self.optimizer.state[p][key] = saved + + def _rigid_load_state_dict(self, state_dict, load_optimizer_states=True): + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + + if load_optimizer_states: + self._set_fp32_optimizer_param_groups() + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + self._clear_fp32_optimizer_param_groups() + + # restore fp32 partitions + for curr_param, saved_param in zip(self.fp32_partitioned_groups_flat, state_dict['fp32_flat_groups']): + curr_param.data.copy_(saved_param.data) + + # restore fp16 partitions from fp32 + for sub_group_id in range(len(self.fp32_partitioned_groups_flat)): + fp32_param = self.fp32_partitioned_groups_flat[sub_group_id] + fp16_param = self.fp16_partitioned_groups_flat[sub_group_id] + fp16_param.data.copy_(fp32_param.data) + + # update fp16 unflattened params + for sub_group_id in range(len(self.fp16_partitioned_groups_flat)): + updated_params = self.unflatten( + self.fp16_partitioned_groups_flat[sub_group_id], + self.fp16_partitioned_groups[sub_group_id]) + + for partitioned_param, q in zip(self.fp16_partitioned_groups[sub_group_id], updated_params): + partitioned_param.data = q.data + + # TODO: Support different/changing load/save DP degree. + def load_state_dict(self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False): + r"""Loading a ZeRO checkpoint + Arguments: + state_dict_list: List of all saved ZeRO checkpoints, one for each saved partition. + Note that the number of saved partitions may differ from number of loading partitions to support + changing GPU count, specifically DP world size, between saving and loading checkpoints. + load_optimizer_states: Boolean indicating whether or not to load base optimizer states + load_from_fp32_weights: Boolean indicating whether to initialize fp32 master weights from fp32 + copies in checkpoints (no precision loss) or from model's fp16 copies (with precision loss). + """ + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + Example:: + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + + if self.elastic_checkpoint: + raise NotImplementedError( + "ZeRO-3 does not yet support elastic checkpointing, please disable for now." + ) + + if self.swap_optimizer or self.params_in_nvme_and_cpu: + raise NotImplementedError( + "ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now." + ) + + self._rigid_load_state_dict( + state_dict_list[dist.get_rank(group=self.dp_process_group)], + load_optimizer_states=load_optimizer_states) + + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].partition(self.persistent_parameters) + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + def save_checkpoint_prologue(self): + self._partition_all_parameters() + + def save_checkpoint_epilogue(self): + if len(self.persistent_parameters) > 0: + self.persistent_parameters[0].all_gather(self.persistent_parameters) + + +def _handle_overflow(cpu_sum, x, i): + import math + rank = torch.distributed.get_rank() + if rank == 0: + t_i = -1 + for v_i, v in enumerate(x.data.contiguous().view(-1)): + if not math.isfinite(float(v)): + t_i = v_i + break + logger.info( + f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}" + ) + + +def estimate_zero3_model_states_mem_needs(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + cpu_offload=True, + cpu_offload_params=True, + zero_init=True, + additional_buffer_factor=1.5): + + total_gpus = num_nodes * num_gpus_per_node + gpus_factor = 1 / num_nodes + largest_layer_memory = (4 * largest_layer_params) + + if cpu_offload: + if cpu_offload_params: + gpu_mem = largest_layer_memory + + if zero_init: + cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 18 * gpus_factor) * additional_buffer_factor + + else: + gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus) + + if zero_init: + cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor + else: + cpu_mem = total_params * max(4 * num_gpus_per_node, + 16 * gpus_factor) * additional_buffer_factor + else: + gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus) + if zero_init: + cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor + else: + cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor + + return int(cpu_mem), int(gpu_mem), largest_layer_memory + + +def model_to_params(model): + # shared params calculated only once + total_params = sum( + dict((p.data_ptr(), + p.numel()) for p in model.parameters()).values()) + + largest_layer_params = 0 + for m in model.modules(): + # assuming no shared params within a single layer + layer_params = sum(p.numel() for p in m.parameters(recurse=False)) + largest_layer_params = max(largest_layer_params, layer_params) + + return total_params, largest_layer_params + + +import math + + +def estimate_zero3_model_states_mem_needs_all_live(model, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If you have an actual model object, use this function and everything will be derived + automatically. + + If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + Args: + - ``model``: ``nn.Module`` object + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + + total_params, largest_layer_params = model_to_params(model) + + estimate_zero3_model_states_mem_needs_all_cold( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + additional_buffer_factor=additional_buffer_factor) + + +def estimate_zero3_model_states_mem_needs_all_cold(total_params, + largest_layer_params, + num_gpus_per_node=1, + num_nodes=1, + additional_buffer_factor=1.5): + """ + Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients + for a given ``model`` and hardware setup. + + If it's a hypothetical model, use this function where you have to pass + the ``total_params`` and ``largest_layer_params`` explicitly. + + If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything + will be derived automatically. + + Args: + - ``total_params``: total model params + - ``largest_layer_params``: largest layer's params + - ``num_gpus_per_node``: how many gpus per node (defaults to 1) + - ``num_nodes``: how many nodes (defaults to 1), + - ``additional_buffer_factor``: estimation factor (defaults to 1.5): + + """ + def format_options(cpu_offload, cpu_offload_params, zero_init): + enabled = [] + padded_cpu_str = f'{OFFLOAD_CPU_DEVICE:4}' + param_device = padded_cpu_str if cpu_offload_params else "none" + enabled.append(f"{OFFLOAD_PARAM}={param_device}") + optimizer_device = padded_cpu_str if cpu_offload else "none" + enabled.append(f"{OFFLOAD_OPTIMIZER}={optimizer_device}") + enabled.append(f"zero_init={1 if zero_init else 0}") + return ", ".join(enabled) + + nodes_str = "nodes" if num_nodes > 1 else "node" + gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU" + print( + "Estimated memory needed for params, optim states and gradients for a:\n" + f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n" + f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params." + ) + print(" per CPU | per GPU | Options") + for cpu_offload in [True, False]: + for cpu_offload_params in [True, False]: + if not cpu_offload and cpu_offload_params: + continue + for zero_init in [True, False]: + cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs( + total_params=total_params, + largest_layer_params=largest_layer_params, + num_gpus_per_node=num_gpus_per_node, + num_nodes=num_nodes, + cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init, + additional_buffer_factor=additional_buffer_factor + ) + + options_str = format_options(cpu_offload=cpu_offload, + cpu_offload_params=cpu_offload_params, + zero_init=zero_init) + print( + f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}") diff --git a/docs/README.md b/docs/README.md index 0ac7783f3860..4b80f6bd4a8b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,49 +1,49 @@ -# DeepSpeed Documentation - -This directory includes the source code for the website and documentation of DeepSpeed. The `code-docs/` directory is used to build [deepspeed.readthedocs.io](https://deepspeed.readthedocs.io/en/latest/). - -[deepspeed.ai](https://www.deepspeed.ai/) is the recommended way to read all DeepSpeed documentation. Directly viewing the Markdown files in this directory will not include images and other features. - -## Building the documentation locally -You can serve the DeepSpeed website locally. This is especially useful for development. - -### Prerequisites -The DeepSpeed website relies on [Jekyll](https://jekyllrb.com/). There are several [guides for installation](https://jekyllrb.com/docs/installation/). The instructions below assume you are in an Ubuntu environment and have been tested on WSL. - -First ensure that you have the necessary packages (e.g., `make` and `zlib`). -``` -sudo apt-get install build-essential zlib1g-dev ruby-full -``` - -Add these lines to your `.bashrc` or equivalent to ensure you have permissions to install Ruby packages without `sudo`. -``` -export GEM_HOME="$HOME/gems" -export PATH="$HOME/gems/bin:$PATH" -``` -Don't forget to `source ~/.bashrc` afterwards 😊. - - -Now we can install Jekyll and [Bundler](https://bundler.io/): -``` -gem install jekyll bundler -``` - -### Start a local webserver -We now need to install the required Ruby packages for the website. - -**NOTE**: you should change to this folder (i.e., docs) before running the installation command to avoid this [error](https://stackoverflow.com/questions/10012181/bundle-install-returns-could-not-locate-gemfile/35157872): - -> Could not locate Gemfile - -**NOTE**: this step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. - - -``` -bundle install -``` - -You can now start a local webserver via: -``` -bundle exec jekyll serve -``` -The website should now be accessible at [http://localhost:4000](http://localhost:4000) +# DeepSpeed Documentation + +This directory includes the source code for the website and documentation of DeepSpeed. The `code-docs/` directory is used to build [deepspeed.readthedocs.io](https://deepspeed.readthedocs.io/en/latest/). + +[deepspeed.ai](https://www.deepspeed.ai/) is the recommended way to read all DeepSpeed documentation. Directly viewing the Markdown files in this directory will not include images and other features. + +## Building the documentation locally +You can serve the DeepSpeed website locally. This is especially useful for development. + +### Prerequisites +The DeepSpeed website relies on [Jekyll](https://jekyllrb.com/). There are several [guides for installation](https://jekyllrb.com/docs/installation/). The instructions below assume you are in an Ubuntu environment and have been tested on WSL. + +First ensure that you have the necessary packages (e.g., `make` and `zlib`). +``` +sudo apt-get install build-essential zlib1g-dev ruby-full +``` + +Add these lines to your `.bashrc` or equivalent to ensure you have permissions to install Ruby packages without `sudo`. +``` +export GEM_HOME="$HOME/gems" +export PATH="$HOME/gems/bin:$PATH" +``` +Don't forget to `source ~/.bashrc` afterwards 😊. + + +Now we can install Jekyll and [Bundler](https://bundler.io/): +``` +gem install jekyll bundler +``` + +### Start a local webserver +We now need to install the required Ruby packages for the website. + +**NOTE**: you should change to this folder (i.e., docs) before running the installation command to avoid this [error](https://stackoverflow.com/questions/10012181/bundle-install-returns-could-not-locate-gemfile/35157872): + +> Could not locate Gemfile + +**NOTE**: this step frequently hangs when connected to a VPN (including MSVPN). Simply disconnect for the package installation. + + +``` +bundle install +``` + +You can now start a local webserver via: +``` +bundle exec jekyll serve +``` +The website should now be accessible at [http://localhost:4000](http://localhost:4000) diff --git a/docs/_posts/2021-03-08-zero3-offload.md b/docs/_posts/2021-03-08-zero3-offload.md index fa12ab5b25fb..3fba666ea095 100644 --- a/docs/_posts/2021-03-08-zero3-offload.md +++ b/docs/_posts/2021-03-08-zero3-offload.md @@ -1,100 +1,100 @@ ---- -layout: single -title: "DeepSpeed ZeRO-3 Offload" -excerpt: "" -categories: news -new_post: true -date: 2021-03-08 00:00:00 ---- -Today we are announcing the release of ZeRO-3 Offload, a highly efficient and easy to use implementation of ZeRO Stage 3 and ZeRO Offload combined, geared towards our continued goal of democratizing AI by making efficient large-scale DL training available to everyone. The key benefits of ZeRO-3 Offload are: - -* Unprecedented memory efficiency to run very large models on a limited number of GPU resources - e.g., fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs! -* Extremely Easy to use: - * Scale to over a trillion parameters without the need to combine multiple parallelism techniques in complicated ways. - * For existing DeepSpeed users, turn on ZeRO-3 Offload with just a few flags in DeepSpeed Config file. -* High-performance per-GPU throughput and super-linear scalability across GPUs for distributed training. - * With 1 Trillion parameters, ZeRO-3 Offload sustains 25 PetaFlops in compute performance on 512 NVIDIA V100 GPUs, achieving 49 TFlops/GPU. - * Up to 2x improvement in throughput compared to ZeRO- 2 Offload on single GPU - - -

Overview of ZeRO family of technology

- -The Zero Redundancy Optimizer (abbreviated ZeRO) is a family of memory optimization technologies for large-scale distributed deep learning. Unlike data parallelism (that is efficient but can only support a limited model size) or model parallelism (that can support larger model sizes but requires significant code refactoring while adding communication overhead that limits efficiency), ZeRO allows fitting larger models in memory without requiring code refactoring while remaining very efficient. ZeRO does so by eliminating the memory redundancy that is inherent in data parallelism while limiting the communication overhead to a minimum. -ZeRO removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency. -There are three stages in ZeRO corresponding to three model states, as shown in the Figure 1: the first stage (ZeRO-1) partitions only the optimizer states, the second stage (ZeRO-2) partitions both the optimizer states and the gradients and the final stage (ZeRO-3) partitions all three model states (for more details see the ZeRO [paper](https://arxiv.org/abs/1910.02054v3)). - - - - -Figure 1. Overview of ZeRO memory savings - -In addition to these three stages, ZeRO family of technology also consists of ZeRO-2 Offload. ZeRO-2 Offload is a heterogenous DL training technology that works in conjunction with ZeRO-2 to offload partitioned optimizer states and gradients to CPU memory. ZeRO-2 Offload offers the full memory advantage of ZeRO-2 even on a single GPU, while at the same time offering great scalability of ZeRO-2 on multi-GPU setup. DeepSpeed library has been offering ZeRO-2 Offload since Sept 2020. For details, please see below: - -* ZeRO: [Stage 1 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Stage 2 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Tutorial](/tutorials/zero) -* ZeRO-Offload: [Blog](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3), [Tutorials](/tutorials/zero-offload), [Paper link](https://arxiv.org/abs/2101.06840) - -

ZeRO-3 Offload

-With today’s release of ZeRO-3 Offload, we are adding support for partitioning and offloading parameters in addition to optimizer states and gradients partitioning already supported by ZeRO-2 Offload in DeepSpeed. With parameter partitioning ZeRO-3 Offload implements the full set of features in the three stages of ZeRO, that allows for a linear growth in model size with the number of GPUs. In addition, ZeRO-3 Offload can also optionally offload all these model states to CPU to further reduce GPU memory consumption, leveraging both CPU and GPU to maximize memory and compute efficiency of the entire system. - -We believe ZeRO-3 Offload offers a massive leap for large model training, in three regards: - -i) Unprecedented model scale, - -ii) Ease of supporting very-large models, and - -iii) Achieving excellent training efficiency. - - -

Unprecedented model scale

-Unlike ZeRO-2 and ZeRO-Offload where the parameters have to fit in the memory of a single GPU, ZeRO-3 Offload can partition the parameters across GPUs, and offload them to CPU, supporting model sizes that are much larger than the memory on a single GPU. Furthermore, ZeRO-3 Offload goes beyond the state-of-the-art hybrid 3D-parallelism (data, model and pipeline parallelism combined). While 3D Parallelism is limited by the aggregate GPU memory, ZeRO-3 Offload can exploit both GPU and CPU memory, the latter of which is much larger and cheaper compared to GPU memory. This allows ZeRO-3 Offload to train larger model sizes with the given GPU and CPU resources than any other currently available technology. - -Model Scale on Single GPU: ZeRO-3 Offload can train models with over 40B parameters efficiently on a single GPU (e.g., 32GB V100 GPU + 1.5TB CPU memory). This is 3x larger than what is possible with ZeRO-2 Offload, the current state-of-the art. - -Model Scale on Multi-GPUs: With ZeRO-3 Offload you can train a trillion and two trillion parameter models on NVIDIA 32GB V100 DGX-2 cluster with 256 GPUs and 512 GPUs, respectively. In contrast, the state-of-art 3D Parallelism requires 800 GPUs, and 1600 GPUs, respectively, to fit the same sized models. This represents a 3x reduction in GPUs required to fit models with over a trillion parameters. - -

Ease of supporting very large models

-From a system perspective, training models with hundreds of billions and trillions of parameters is extremely challenging. Data parallelism cannot scale the model size much further beyond a billion parameters, model parallelism (with tensor slicing) cannot be used to scale model size efficiently beyond a single node boundary due to massive communication overheads, and pipeline parallelism cannot scale beyond the number of layers available in a model, which limits both the model size and the number of GPUs that it can scale to. - -The only existing parallel technology available that can scale to over a trillion parameters on massively parallel GPU clusters is the [3D parallelism](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-0) that combines data, model and pipeline parallelism in complex ways. While such a system can be very efficient, it requires major model code refactoring from data scientists to split the model into load balanced pipeline stages. This also makes 3D parallelism inflexible in the type of models that it can support, since models with complex dependency graphs cannot be easily converted into a load balanced pipeline. - -ZeRO-3 Offload address these challenges in two ways: - -i) With ground-breaking memory efficiency, ZeRO-3 and ZeRO-3 Offload are the only DL parallel technology that can efficiently scale to over a trillion parameters by itself, without requiring a hybrid parallelism strategy, greatly simplifying the system stack for DL training. - -ii) ZeRO-3 Offload requires virtually no model refactoring from model scientists, liberating data scientists to scale up complex models to hundreds of billions to trillions of parameters. - -

Excellent training efficiency

-High-performance per-GPU throughput on multiple nodes: ZeRO-3 Offload offers excellent training efficiency for multi-billion and trillion parameter models on multiple nodes. It achieves a sustained throughput of up to 50 Tflops per GPU running on 32 DGX2 nodes comprising 512 NVIDIA V100 GPUs (see Figure 2). In comparison, the standard data parallel training with PyTorch can only achieve 30 TFlops per GPU for a 1.2B parameter model, the largest model that can be trained using data parallelism alone. - - - - -Figure 2. ZeRO-3 Offload: Multi-billion and trillion parameter model throughput on 512 V100 GPUs - -ZeRO-3 Offload obtains high efficiency despite the 50% communication overhead of ZeRO Stage 3 compared to standard data parallel training for a fixed batch size. This is made possible through a communication overlap centric design and implementation, which allows ZeRO-3 Offload to hide nearly all of the communication volume with computation, while taking advantage of a larger batch size for improved efficiency resulting from better GPU memory efficiency. - - -Efficient multi-billion parameter model training on a single GPU: ZeRO-3 Offload further democratizes AI by enabling efficient training of multi-billion parameter models on a single GPU. For single GPU training, ZeRO-3 Offload provides benefits over ZeRO-2 Offload along two dimensions. First, ZeRO-3 Offload increases the size of models trainable on a single V100 from 13B to 40B. Second, for ZeRO-3 Offload provides speedups (e.g., 2.3X for 13B) compared to ZeRO-2 Offload for model sizes trainable by both solutions. These results are summarized in Figure 3. - - - - -Figure 3. Multi-billion parameter model training on one V100 GPU - -Super-Linear scalability across GPUs: Additionally, ZeRO-3 Offload also preserves the super-linear scalability characteristics that we have demonstrated with all our previous ZeRO technologies (ZeRO Stage 1, ZeRO Stage 2 and ZeRO Offload). ZeRO-3 Offload can exploit the aggregate PCI-E bandwidth between GPU and CPU across all the GPUs in multi-GPU training configuration, and at the same time, it can also exploit the aggregate CPU compute across all the nodes. As a result, the CPU-GPU-CPU communication time as well as the optimizer update time decreases linearly with number of GPUs and nodes, respectively, allowing ZeRO-3 Offload to exhibit super-linear scaling (see Figure 4). - - - - -Figure 4. ZeRO-3 Offload Superlinear Scalability for a 200B parameter model. - -

How to use ZeRO-3 Offload

-As with many other existing DeepSpeed features, once the user model has been converted to use DeepSpeed, enabling ZeRO-3 Offload is as easy as turning on a couple of flags in DeepSpeed Config file. Supporting advanced features like weight sharing, or enabling extremely large models that requires to be partitioned across GPUs/nodes to fit in GPU/CPU memory, can be done with just a couple of additional lines of code change using the ZeRO-3 Offload API. - -If you are already a DeepSpeed user, you can find our detailed tutorial on ZeRO-3 Offload below. If you are new to DeepSpeed, we recommend that you start at the getting started page before trying out our ZeRO-3 Offload Tutorial. - -* DeepSpeed: [Getting Started Page](/getting-started/) - -* ZeRO-3 Offload [Documentation](https://deepspeed.readthedocs.io/en/latest/zero3.html), [Tutorial](/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload) - -The DeepSpeed Team is very excited to share ZeRO-3 Offload with the DL community. +--- +layout: single +title: "DeepSpeed ZeRO-3 Offload" +excerpt: "" +categories: news +new_post: true +date: 2021-03-08 00:00:00 +--- +Today we are announcing the release of ZeRO-3 Offload, a highly efficient and easy to use implementation of ZeRO Stage 3 and ZeRO Offload combined, geared towards our continued goal of democratizing AI by making efficient large-scale DL training available to everyone. The key benefits of ZeRO-3 Offload are: + +* Unprecedented memory efficiency to run very large models on a limited number of GPU resources - e.g., fine-tune models with over 40B parameters on a single GPU and over 2 Trillion parameters on 512 GPUs! +* Extremely Easy to use: + * Scale to over a trillion parameters without the need to combine multiple parallelism techniques in complicated ways. + * For existing DeepSpeed users, turn on ZeRO-3 Offload with just a few flags in DeepSpeed Config file. +* High-performance per-GPU throughput and super-linear scalability across GPUs for distributed training. + * With 1 Trillion parameters, ZeRO-3 Offload sustains 25 PetaFlops in compute performance on 512 NVIDIA V100 GPUs, achieving 49 TFlops/GPU. + * Up to 2x improvement in throughput compared to ZeRO- 2 Offload on single GPU + + +

Overview of ZeRO family of technology

+ +The Zero Redundancy Optimizer (abbreviated ZeRO) is a family of memory optimization technologies for large-scale distributed deep learning. Unlike data parallelism (that is efficient but can only support a limited model size) or model parallelism (that can support larger model sizes but requires significant code refactoring while adding communication overhead that limits efficiency), ZeRO allows fitting larger models in memory without requiring code refactoring while remaining very efficient. ZeRO does so by eliminating the memory redundancy that is inherent in data parallelism while limiting the communication overhead to a minimum. +ZeRO removes the memory redundancies across data-parallel processes by partitioning the three model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them. By doing this, it boosts memory efficiency compared to classic data-parallelism while retaining its computational granularity and communication efficiency. +There are three stages in ZeRO corresponding to three model states, as shown in the Figure 1: the first stage (ZeRO-1) partitions only the optimizer states, the second stage (ZeRO-2) partitions both the optimizer states and the gradients and the final stage (ZeRO-3) partitions all three model states (for more details see the ZeRO [paper](https://arxiv.org/abs/1910.02054v3)). + + + + +Figure 1. Overview of ZeRO memory savings + +In addition to these three stages, ZeRO family of technology also consists of ZeRO-2 Offload. ZeRO-2 Offload is a heterogenous DL training technology that works in conjunction with ZeRO-2 to offload partitioned optimizer states and gradients to CPU memory. ZeRO-2 Offload offers the full memory advantage of ZeRO-2 even on a single GPU, while at the same time offering great scalability of ZeRO-2 on multi-GPU setup. DeepSpeed library has been offering ZeRO-2 Offload since Sept 2020. For details, please see below: + +* ZeRO: [Stage 1 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Stage 2 blog](https://www.microsoft.com/en-us/research/blog/zero-2-deepspeed-shattering-barriers-of-deep-learning-speed-scale/), [Tutorial](/tutorials/zero) +* ZeRO-Offload: [Blog](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3), [Tutorials](/tutorials/zero-offload), [Paper link](https://arxiv.org/abs/2101.06840) + +

ZeRO-3 Offload

+With today’s release of ZeRO-3 Offload, we are adding support for partitioning and offloading parameters in addition to optimizer states and gradients partitioning already supported by ZeRO-2 Offload in DeepSpeed. With parameter partitioning ZeRO-3 Offload implements the full set of features in the three stages of ZeRO, that allows for a linear growth in model size with the number of GPUs. In addition, ZeRO-3 Offload can also optionally offload all these model states to CPU to further reduce GPU memory consumption, leveraging both CPU and GPU to maximize memory and compute efficiency of the entire system. + +We believe ZeRO-3 Offload offers a massive leap for large model training, in three regards: + +i) Unprecedented model scale, + +ii) Ease of supporting very-large models, and + +iii) Achieving excellent training efficiency. + + +

Unprecedented model scale

+Unlike ZeRO-2 and ZeRO-Offload where the parameters have to fit in the memory of a single GPU, ZeRO-3 Offload can partition the parameters across GPUs, and offload them to CPU, supporting model sizes that are much larger than the memory on a single GPU. Furthermore, ZeRO-3 Offload goes beyond the state-of-the-art hybrid 3D-parallelism (data, model and pipeline parallelism combined). While 3D Parallelism is limited by the aggregate GPU memory, ZeRO-3 Offload can exploit both GPU and CPU memory, the latter of which is much larger and cheaper compared to GPU memory. This allows ZeRO-3 Offload to train larger model sizes with the given GPU and CPU resources than any other currently available technology. + +Model Scale on Single GPU: ZeRO-3 Offload can train models with over 40B parameters efficiently on a single GPU (e.g., 32GB V100 GPU + 1.5TB CPU memory). This is 3x larger than what is possible with ZeRO-2 Offload, the current state-of-the art. + +Model Scale on Multi-GPUs: With ZeRO-3 Offload you can train a trillion and two trillion parameter models on NVIDIA 32GB V100 DGX-2 cluster with 256 GPUs and 512 GPUs, respectively. In contrast, the state-of-art 3D Parallelism requires 800 GPUs, and 1600 GPUs, respectively, to fit the same sized models. This represents a 3x reduction in GPUs required to fit models with over a trillion parameters. + +

Ease of supporting very large models

+From a system perspective, training models with hundreds of billions and trillions of parameters is extremely challenging. Data parallelism cannot scale the model size much further beyond a billion parameters, model parallelism (with tensor slicing) cannot be used to scale model size efficiently beyond a single node boundary due to massive communication overheads, and pipeline parallelism cannot scale beyond the number of layers available in a model, which limits both the model size and the number of GPUs that it can scale to. + +The only existing parallel technology available that can scale to over a trillion parameters on massively parallel GPU clusters is the [3D parallelism](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-0) that combines data, model and pipeline parallelism in complex ways. While such a system can be very efficient, it requires major model code refactoring from data scientists to split the model into load balanced pipeline stages. This also makes 3D parallelism inflexible in the type of models that it can support, since models with complex dependency graphs cannot be easily converted into a load balanced pipeline. + +ZeRO-3 Offload address these challenges in two ways: + +i) With ground-breaking memory efficiency, ZeRO-3 and ZeRO-3 Offload are the only DL parallel technology that can efficiently scale to over a trillion parameters by itself, without requiring a hybrid parallelism strategy, greatly simplifying the system stack for DL training. + +ii) ZeRO-3 Offload requires virtually no model refactoring from model scientists, liberating data scientists to scale up complex models to hundreds of billions to trillions of parameters. + +

Excellent training efficiency

+High-performance per-GPU throughput on multiple nodes: ZeRO-3 Offload offers excellent training efficiency for multi-billion and trillion parameter models on multiple nodes. It achieves a sustained throughput of up to 50 Tflops per GPU running on 32 DGX2 nodes comprising 512 NVIDIA V100 GPUs (see Figure 2). In comparison, the standard data parallel training with PyTorch can only achieve 30 TFlops per GPU for a 1.2B parameter model, the largest model that can be trained using data parallelism alone. + + + + +Figure 2. ZeRO-3 Offload: Multi-billion and trillion parameter model throughput on 512 V100 GPUs + +ZeRO-3 Offload obtains high efficiency despite the 50% communication overhead of ZeRO Stage 3 compared to standard data parallel training for a fixed batch size. This is made possible through a communication overlap centric design and implementation, which allows ZeRO-3 Offload to hide nearly all of the communication volume with computation, while taking advantage of a larger batch size for improved efficiency resulting from better GPU memory efficiency. + + +Efficient multi-billion parameter model training on a single GPU: ZeRO-3 Offload further democratizes AI by enabling efficient training of multi-billion parameter models on a single GPU. For single GPU training, ZeRO-3 Offload provides benefits over ZeRO-2 Offload along two dimensions. First, ZeRO-3 Offload increases the size of models trainable on a single V100 from 13B to 40B. Second, for ZeRO-3 Offload provides speedups (e.g., 2.3X for 13B) compared to ZeRO-2 Offload for model sizes trainable by both solutions. These results are summarized in Figure 3. + + + + +Figure 3. Multi-billion parameter model training on one V100 GPU + +Super-Linear scalability across GPUs: Additionally, ZeRO-3 Offload also preserves the super-linear scalability characteristics that we have demonstrated with all our previous ZeRO technologies (ZeRO Stage 1, ZeRO Stage 2 and ZeRO Offload). ZeRO-3 Offload can exploit the aggregate PCI-E bandwidth between GPU and CPU across all the GPUs in multi-GPU training configuration, and at the same time, it can also exploit the aggregate CPU compute across all the nodes. As a result, the CPU-GPU-CPU communication time as well as the optimizer update time decreases linearly with number of GPUs and nodes, respectively, allowing ZeRO-3 Offload to exhibit super-linear scaling (see Figure 4). + + + + +Figure 4. ZeRO-3 Offload Superlinear Scalability for a 200B parameter model. + +

How to use ZeRO-3 Offload

+As with many other existing DeepSpeed features, once the user model has been converted to use DeepSpeed, enabling ZeRO-3 Offload is as easy as turning on a couple of flags in DeepSpeed Config file. Supporting advanced features like weight sharing, or enabling extremely large models that requires to be partitioned across GPUs/nodes to fit in GPU/CPU memory, can be done with just a couple of additional lines of code change using the ZeRO-3 Offload API. + +If you are already a DeepSpeed user, you can find our detailed tutorial on ZeRO-3 Offload below. If you are new to DeepSpeed, we recommend that you start at the getting started page before trying out our ZeRO-3 Offload Tutorial. + +* DeepSpeed: [Getting Started Page](/getting-started/) + +* ZeRO-3 Offload [Documentation](https://deepspeed.readthedocs.io/en/latest/zero3.html), [Tutorial](/tutorials/zero/#training-trillion-scale-models-with-zero-3-offload) + +The DeepSpeed Team is very excited to share ZeRO-3 Offload with the DL community. diff --git a/docs/_posts/2021-05-05-inference-kernel-optimization.md b/docs/_posts/2021-05-05-inference-kernel-optimization.md index 18ab7c32186d..9b9e747a2766 100644 --- a/docs/_posts/2021-05-05-inference-kernel-optimization.md +++ b/docs/_posts/2021-05-05-inference-kernel-optimization.md @@ -1,73 +1,73 @@ ---- -layout: single -title: "DeepSpeed Inference: Multi-GPU inference with customized inference kernels and quantization support" -excerpt: "" -categories: news -new_post: false -date: 2021-03-16 00:00:00 ---- -While DeepSpeed supports training advanced large-scale models, using these trained models in the desired application scenarios is still challenging due to three major limitations in existing inference solutions: 1) lack of support for multi-GPU inference to fit large models and meet latency requirements, 2) limited GPU kernel performance when running inference with small batch sizes, and 3) difficulties in exploiting quantization, which includes both quantizing the model to reduce the model size and latency as well as supporting high-performance inference of quantized models without specialized hardware. - -To handle these challenges, we introduce DeepSpeed Inference, which seamlessly adds high-performance inference support to large models trained in DeepSpeed with three key features: inference-adapted parallelism for multi-GPU inference, inference-optimized kernels tuned for small batch sizes, and flexible support for quantize-aware training and inference kernels for quantized models. - -## Multi-GPU Inference with Adaptive Parallelism - -Parallelism is an effective approach to fit large models and reduce per-device memory consumption for both training and inference. However, simply applying training parallelism choices and degree to inference does not work well. The MP and PP configuration is normally set during the model training, apart from the data parallelism (DP), based on the memory footprint and computation style, and resource budget. On one hand, inference computation intrinsically requires less memory, so it can afford a larger partition per device. It helps reduce the degree of parallelism needed for model deployment. On the other hand, optimizing latency or meeting latency requirements is often a first-class citizen in inference while training optimizes throughput. - -To obtain desired latency, DeepSpeed Inference automatically adapts MP as an effective approach to reduce model latency, and its parallelism degree is often determined first. With MP, we can split the mode and parallelize computational operations across multiple devices (GPUs) to reduce latency, but it reduces computation granularity and increases communication that may hurt throughput. Once the latency target has been met, DeepSpeed can apply pipeline parallelism to maximize the throughput. Overall, DeepSpeed Inference supports flexible adaptation of both parallelism approach and degree choices from training to inference, minimizing latency while saving deployment costs. - - -## Customized Inference Kernels for Boosted Compute Efficiency of Transformer Blocks - -To achieve high compute efficiency, DeepSpeed-inference offers inference kernels tailored for Transformer blocks through operator fusion, taking model-parallelism for multi-GPU into account. The main difference between our kernel-fusion scheme and similar approaches is that we not only fuse element-wise operations (such as bias-add, residual, and activation function), but also merge the General matrix multiply (GeMM) operations with other operations. To do this, we design an efficient implementation for the vector-matrix or skinny matrix-matrix multiplication that allows us to fuse more operations at the reduction boundary of GeMM operations. - -# Kernel-Fusion - -We take two main policies for fusing operations: 1) keeping the access-pattern of inputs and outputs intact throughout the sequence of operations fused together; 2) fusing operations at each all-reduce boundary. The first policy ensures that different thread-blocks won’t encounter transferring data between Streaming-Multiprocessors (SMs). This is due to no straight-forward communication among SMs other than using the main memory which adds the block-synching overhead because of non-deterministic behavior of memory access. The reason behind the second policy is that we cannot continue the execution unless the partial results are reduced among the model-parallel GPUs. - -![Inference-Kernel-Fusion](/assets/images/inference-kernel-fusion.png){: .align-center} - -Figure 1: Transformer Layer with Megatron-style model-parallelism all-reduce components. The figure illustrates the parts of layer fused together with broken lines (width of line shows the fusion depth). - -Figure 1 shows the different components of a Transformer layer, and the groups of operations considered for fusion in our inference optimization. We also consider the NVIDIA Megatron-LM style of parallelism that partitions attention (Attn) and feed-forward (FF) blocks across multiple GPUs. Thus, we include the two all-reduce operations that reduce the results among parallel GPUs after Attn and FF blocks. As Figure 1 shows, we fuse the operations inside a Transformer layer at four main regions: -1. Input Layer-Norm plus Query, Key, and Value GeMMs and their bias adds. -2. Transform plus Attention. -3. Intermediate FF, Layer-Norm, Bias-add, Residual, and Gaussian Error Linear Unit (GELU). -4. Bias-add plus Residual. - -To fuse these operations, we exploit shared-memory as an intermediate cache for transferring data between reduction operations used in layer-norm and GeMM, and the element-wise operations. Moreover, we use the warp-level instructions to communicate data between threads when reducing partial computations. In addition, we use a new schedule for GeMM operations, which allows for fusing as many operations as needed for the third kernel-fusion. We also combine the GeMMs for the attention computation in the second kernel-fusion, by using an implicit matrix transformation in order to reduce the memory pressure. Compared to the unfused computation style using cuBLAS GeMM, we improve the performance by 1.5x, 2.9x. 3x, and 1.2x for all these kernel-fusions, respectively. - -## Seamless pipeline from training to inference with automatic kernel-injection - -To run the model in Inference mode, DeepSpeed simply requires the location of the model checkpoints and the desired parallelism configuration, i.e., MP/PP degree. DeepSpeed Inference kernels can also be enabled for many well-known model architectures such as HuggingFace (Bert and GPT-2) or Megatron GPT-based models using a pre-defined policy map that maps the original parameters to the parameters in the inference kernels. For other transformer-based models, user can specify their own policy map. Note that DS-Inference can run independent of the training pipeline as long as it receives all model checkpoints, and the DeepSpeed Transformer kernels for inference can be injected into any Transformer model if the right mapping policy is defined. For more information on how to enable Transformer inference kernel as well as specifying parallelism, please refer to out [inference tutorial](https://www.deepspeed.ai/tutorials/inference-tutorial/). - - -## Flexible quantization support - -To further reduce the inference cost for large-scale models, we created the DeepSpeed Quantization Toolkit, supporting flexible quantize-aware training and high-performance kernels for quantized inference. - -For training, we introduce a novel approach called Mixture of Quantization (MoQ), which is inspired by mixed-precision training while seamlessly applying quantization. With MoQ, we can control the precision of the model by simulating the impact of quantization when updating the parameters at each step of training. Moreover, it supports flexible quantization policies and schedules—we find that by dynamically adjusting the number of quantization bits during training, the final quantized model provides higher accuracy under the same compression ratio. To adapt to different tasks, MoQ can also leverage the second order information of models to detect their sensitivity to precision and adjust the quantization schedule and target accordingly. - -To maximize the performance gains from the quantization model, we provide inference kernels tailored for quantized models that reduce latency through optimizing data movement but do not require specialized hardware. Finally, our toolkit does not require any code changes on the client side, making it easy to use. - -## Performance results - -Boosting throughput and reducing inference cost. Figure 3 shows the inference throughput per GPU for the three model sizes corresponding to the three Transformer networks, GPT-2, Turing-NLG, and GPT-3. DeepSpeed Inference increases in per-GPU throughput by 2 to 4 times when using the same precision of FP16 as the baseline. By enabling quantization, we boost throughput further. We reach a throughput improvement of 3x for GPT-2, 5x for Turing-NLG, and 3x for a model that is similar in characteristics and size to GPT-3, which directly translates to 3–5x inference cost reduction on serving these large models. In addition, we achieve these throughput and cost improvements without compromising latency as shown in Figure 5. - -![Inference-Throughput](/assets/images/inference-throughput.png){: .align-center} - -Figure 3: Inference throughput for different model sizes. DeepSpeed Inference achieves 3x to 5x higher throughput than baseline. - -One source of inference cost reduction is through reducing the number of GPUs for hosting large models as shown in Figure 4. The optimized GPU resources comes from 1) using inference-adapted parallelism, allowing users to adjust the model and pipeline parallelism degree from the trained model checkpoints, and 2) shrinking model memory footprint by half with INT8 quantization. As shown in this figure, we use 2x less GPUs to run inference for the 17B model size by adapting the parallelism. Together with INT8 quantization through DeepSpeed MoQ, we use 4x and 2x fewer GPUs for 17B and 175B sizes respectively. - -![Inference-Throughput](/assets/images/gpu-numbers.png){: .align-center} - -Figure 4: Number of GPUs used for running inference on the different model sizes shown in Figure 4. - -Reducing inference latency. For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. - -For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. - -![Inference-Throughput](/assets/images/inference-latency.png){: .align-center} - -Figure 5. Inference latency for the 17B model using different parallelism configuration to optimize latency. +--- +layout: single +title: "DeepSpeed Inference: Multi-GPU inference with customized inference kernels and quantization support" +excerpt: "" +categories: news +new_post: false +date: 2021-03-16 00:00:00 +--- +While DeepSpeed supports training advanced large-scale models, using these trained models in the desired application scenarios is still challenging due to three major limitations in existing inference solutions: 1) lack of support for multi-GPU inference to fit large models and meet latency requirements, 2) limited GPU kernel performance when running inference with small batch sizes, and 3) difficulties in exploiting quantization, which includes both quantizing the model to reduce the model size and latency as well as supporting high-performance inference of quantized models without specialized hardware. + +To handle these challenges, we introduce DeepSpeed Inference, which seamlessly adds high-performance inference support to large models trained in DeepSpeed with three key features: inference-adapted parallelism for multi-GPU inference, inference-optimized kernels tuned for small batch sizes, and flexible support for quantize-aware training and inference kernels for quantized models. + +## Multi-GPU Inference with Adaptive Parallelism + +Parallelism is an effective approach to fit large models and reduce per-device memory consumption for both training and inference. However, simply applying training parallelism choices and degree to inference does not work well. The MP and PP configuration is normally set during the model training, apart from the data parallelism (DP), based on the memory footprint and computation style, and resource budget. On one hand, inference computation intrinsically requires less memory, so it can afford a larger partition per device. It helps reduce the degree of parallelism needed for model deployment. On the other hand, optimizing latency or meeting latency requirements is often a first-class citizen in inference while training optimizes throughput. + +To obtain desired latency, DeepSpeed Inference automatically adapts MP as an effective approach to reduce model latency, and its parallelism degree is often determined first. With MP, we can split the mode and parallelize computational operations across multiple devices (GPUs) to reduce latency, but it reduces computation granularity and increases communication that may hurt throughput. Once the latency target has been met, DeepSpeed can apply pipeline parallelism to maximize the throughput. Overall, DeepSpeed Inference supports flexible adaptation of both parallelism approach and degree choices from training to inference, minimizing latency while saving deployment costs. + + +## Customized Inference Kernels for Boosted Compute Efficiency of Transformer Blocks + +To achieve high compute efficiency, DeepSpeed-inference offers inference kernels tailored for Transformer blocks through operator fusion, taking model-parallelism for multi-GPU into account. The main difference between our kernel-fusion scheme and similar approaches is that we not only fuse element-wise operations (such as bias-add, residual, and activation function), but also merge the General matrix multiply (GeMM) operations with other operations. To do this, we design an efficient implementation for the vector-matrix or skinny matrix-matrix multiplication that allows us to fuse more operations at the reduction boundary of GeMM operations. + +# Kernel-Fusion + +We take two main policies for fusing operations: 1) keeping the access-pattern of inputs and outputs intact throughout the sequence of operations fused together; 2) fusing operations at each all-reduce boundary. The first policy ensures that different thread-blocks won’t encounter transferring data between Streaming-Multiprocessors (SMs). This is due to no straight-forward communication among SMs other than using the main memory which adds the block-synching overhead because of non-deterministic behavior of memory access. The reason behind the second policy is that we cannot continue the execution unless the partial results are reduced among the model-parallel GPUs. + +![Inference-Kernel-Fusion](/assets/images/inference-kernel-fusion.png){: .align-center} + +Figure 1: Transformer Layer with Megatron-style model-parallelism all-reduce components. The figure illustrates the parts of layer fused together with broken lines (width of line shows the fusion depth). + +Figure 1 shows the different components of a Transformer layer, and the groups of operations considered for fusion in our inference optimization. We also consider the NVIDIA Megatron-LM style of parallelism that partitions attention (Attn) and feed-forward (FF) blocks across multiple GPUs. Thus, we include the two all-reduce operations that reduce the results among parallel GPUs after Attn and FF blocks. As Figure 1 shows, we fuse the operations inside a Transformer layer at four main regions: +1. Input Layer-Norm plus Query, Key, and Value GeMMs and their bias adds. +2. Transform plus Attention. +3. Intermediate FF, Layer-Norm, Bias-add, Residual, and Gaussian Error Linear Unit (GELU). +4. Bias-add plus Residual. + +To fuse these operations, we exploit shared-memory as an intermediate cache for transferring data between reduction operations used in layer-norm and GeMM, and the element-wise operations. Moreover, we use the warp-level instructions to communicate data between threads when reducing partial computations. In addition, we use a new schedule for GeMM operations, which allows for fusing as many operations as needed for the third kernel-fusion. We also combine the GeMMs for the attention computation in the second kernel-fusion, by using an implicit matrix transformation in order to reduce the memory pressure. Compared to the unfused computation style using cuBLAS GeMM, we improve the performance by 1.5x, 2.9x. 3x, and 1.2x for all these kernel-fusions, respectively. + +## Seamless pipeline from training to inference with automatic kernel-injection + +To run the model in Inference mode, DeepSpeed simply requires the location of the model checkpoints and the desired parallelism configuration, i.e., MP/PP degree. DeepSpeed Inference kernels can also be enabled for many well-known model architectures such as HuggingFace (Bert and GPT-2) or Megatron GPT-based models using a pre-defined policy map that maps the original parameters to the parameters in the inference kernels. For other transformer-based models, user can specify their own policy map. Note that DS-Inference can run independent of the training pipeline as long as it receives all model checkpoints, and the DeepSpeed Transformer kernels for inference can be injected into any Transformer model if the right mapping policy is defined. For more information on how to enable Transformer inference kernel as well as specifying parallelism, please refer to out [inference tutorial](https://www.deepspeed.ai/tutorials/inference-tutorial/). + + +## Flexible quantization support + +To further reduce the inference cost for large-scale models, we created the DeepSpeed Quantization Toolkit, supporting flexible quantize-aware training and high-performance kernels for quantized inference. + +For training, we introduce a novel approach called Mixture of Quantization (MoQ), which is inspired by mixed-precision training while seamlessly applying quantization. With MoQ, we can control the precision of the model by simulating the impact of quantization when updating the parameters at each step of training. Moreover, it supports flexible quantization policies and schedules—we find that by dynamically adjusting the number of quantization bits during training, the final quantized model provides higher accuracy under the same compression ratio. To adapt to different tasks, MoQ can also leverage the second order information of models to detect their sensitivity to precision and adjust the quantization schedule and target accordingly. + +To maximize the performance gains from the quantization model, we provide inference kernels tailored for quantized models that reduce latency through optimizing data movement but do not require specialized hardware. Finally, our toolkit does not require any code changes on the client side, making it easy to use. + +## Performance results + +Boosting throughput and reducing inference cost. Figure 3 shows the inference throughput per GPU for the three model sizes corresponding to the three Transformer networks, GPT-2, Turing-NLG, and GPT-3. DeepSpeed Inference increases in per-GPU throughput by 2 to 4 times when using the same precision of FP16 as the baseline. By enabling quantization, we boost throughput further. We reach a throughput improvement of 3x for GPT-2, 5x for Turing-NLG, and 3x for a model that is similar in characteristics and size to GPT-3, which directly translates to 3–5x inference cost reduction on serving these large models. In addition, we achieve these throughput and cost improvements without compromising latency as shown in Figure 5. + +![Inference-Throughput](/assets/images/inference-throughput.png){: .align-center} + +Figure 3: Inference throughput for different model sizes. DeepSpeed Inference achieves 3x to 5x higher throughput than baseline. + +One source of inference cost reduction is through reducing the number of GPUs for hosting large models as shown in Figure 4. The optimized GPU resources comes from 1) using inference-adapted parallelism, allowing users to adjust the model and pipeline parallelism degree from the trained model checkpoints, and 2) shrinking model memory footprint by half with INT8 quantization. As shown in this figure, we use 2x less GPUs to run inference for the 17B model size by adapting the parallelism. Together with INT8 quantization through DeepSpeed MoQ, we use 4x and 2x fewer GPUs for 17B and 175B sizes respectively. + +![Inference-Throughput](/assets/images/gpu-numbers.png){: .align-center} + +Figure 4: Number of GPUs used for running inference on the different model sizes shown in Figure 4. + +Reducing inference latency. For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. + +For the application scenarios where inference latency is critical, we can increase model parallelism degree in DeepSpeed Inference to reduce inference latency further. As Figure 5 depicts, we can reduce the latency by 2.3x compared to PyTorch as we increase the model-parallelism size to 4. Furthermore, we can still have high latency improvement with a fewer number of GPUs by adapting the parallelism at inference and using MoQ to quantize the model. We obtain 1.3x and 1.9x speedups while using 4x and 2x lower resources than baseline, respectively. + +![Inference-Throughput](/assets/images/inference-latency.png){: .align-center} + +Figure 5. Inference latency for the 17B model using different parallelism configuration to optimize latency. diff --git a/docs/_tutorials/mixture-of-experts.md b/docs/_tutorials/mixture-of-experts.md index 39c85ebdbb4c..ef8ca1756b3e 100644 --- a/docs/_tutorials/mixture-of-experts.md +++ b/docs/_tutorials/mixture-of-experts.md @@ -1,197 +1,197 @@ ---- -title: "Mixture of Experts" ---- - -DeepSpeed v0.5 introduces new support for training Mixture of Experts (MoE) models. MoE models are an emerging class of sparsely activated models that have sublinear compute costs with respect to their parameters. For example, the [Switch Transformer](https://arxiv.org/abs/2101.03961) consists of over 1.6 trillion parameters, while the compute required to train it is approximately equal to that of a 10 billion-parameter dense model. This increase in model size offers tremendous accuracy gains for a constant compute budget. - -For more details on results and further discussion, please see our press release: [DeepSpeed powers 8x larger MoE model training with high performance]({{ site.press_release_v5 }}). - -## Getting started with a simple MoE example - -**Note:** DeepSpeed MoE requires Pytorch 1.8 or above. -{: .notice--info} - -As a simple starting point we will show how to apply DeepSpeed MoE to a cifar10 example. Please refer to -our [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) going forward. - -If you are adding MoE to an existing model you can use the snippet below to help guide you: - - -### Expert groups initialization - -DeepSpeed MoE supports five different forms of parallelism, and it exploits both GPU and CPU memory. Its flexible design enables users to mix different types of prevalent parallelism techniques, as shown in the table below. - -| Short Name | Flexible Parallelism Configurations | Benefit | -| ---------------- | ------------------------------------| --------------------------------------------------------------------------- | -| E | Expert | Scales the model size by increasing the number of experts | -| E + D | Expert + Data | Accelerates training throughput by scaling to multiple data parallel groups | -| E + Z | Expert + ZeRO-powered data | Partitions the nonexpert parameters to support larger base models | -| E + D + M | Expert + Data + Model | Supports massive hidden sizes and even larger base models than E+Z | -| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z | -| E + Z-Off + M | Expert + ZeRO-Offload + Model | Leverages both GPU and CPU memory for large MoE models on limited # of GPUs | - -To support different forms of parallelism, we create a notion of DeepSpeed process groups that resides in ```deepspeed.utils.groups.py``` - -For most cases, the model training code needs to initialize these groups by calling -```python -deepspeed.utils.groups.initialize(ep_size="desired expert-parallel world size") -``` - -The GPUs (or ranks) participating in an expert-parallel group will distribute the total number of experts specified by the model training code argument num_experts. - -### MoE layer API - -The hidden_size is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. - -Original model config - -```python - self.fc3 = nn.Linear(84, 10) -``` - -Updated with MoE Layers - -```python - self.fc3 = nn.Linear(84, 84) - self.fc3 = deepspeed.moe.layer.MoE(hidden_size=84, expert=self.fc3, num_experts=args.num_experts, ...) - self.fc4 = nn.Linear(84, 10) -``` - -### An Example Scenario - -Given a total number of GPUs in our world size and a subset of GPUs in our expert-parallel world as follows. - -```python -WORLD_SIZE = 4 -EP_WORLD_SIZE = 2 -EXPERTS = 8 -``` - -The user code needs to initialize the groups as follows. - -```python -groups.initialize (ep_size=EP_WORLD_SIZE) -``` - -After that, the model code needs to use the deepspeed.moe.layer.MoE API as follows. - -```python -self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=EXPERTS) -``` -With the above two commands, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. - -For more advanced use case of the groups API including the inter-operability with Megatron style mpu object, watch this space! - - -```python -import torch -import deepspeed -import deepspeed.utils.groups as groups -from deepspeed.moe.layer import MoE - -WORLD_SIZE = 4 -EP_WORLD_SIZE = 2 -EXPERTS = 8 - -groups.initialize(ep_size=EP_WORLD_SIZE) - -fc3 = torch.nn.Linear(84, 84) -fc3 = MoE(hidden_size=84, expert=self.fc3, num_experts=EXPERTS, k=1) -fc4 = torch.nn.Linear(84, 10) - -``` - -For a runnable end-to-end example, please look at [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) - -### Combining ZeRO-Offload and DeepSpeed MoE for very large models - -To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). - -The relevant function that creates these param groups is as follows. - -```python -def create_moe_param_groups(model): - from deepspeed.moe.utils import is_moe_param - - params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} - moe_params_with_weight_decay = { - 'params': [], - 'moe': True, - 'name': 'weight_decay_moe_params' - } - - for module_ in model.modules(): - moe_params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and is_moe_param(p) - ]) - params_with_weight_decay['params'].extend([ - p for n, p in list(module_._parameters.items()) - if p is not None and not is_moe_param(p) - ]) - - return params_with_weight_decay, moe_params_with_weight_decay -``` - -The above param groups can then be fed to the ZeRO stage-2 optimizer as follows. - -```python - -net = Net() - -parameters = create_moe_param_groups(net) - -model_engine, optimizer, trainloader, __ = deepspeed.initialize( - args=args, model=net, model_parameters=parameters, training_data=trainset) -``` - -We are working on automating this functionality in the DeepSpeed ZeRO optimizer so the model training code can be simplified further. - -To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) with ZeRO-Offload (stage 2) and MoE, please set the ds_config flags - -```json -"zero_optimization": { - "stage": 2, - "allgather_partitions": true, - "reduce_scatter": true, - "allgather_bucket_size": 50000000, - "reduce_bucket_size": 50000000, - "overlap_comm": true, - "contiguous_gradients": true, - "cpu_offload": true - } -``` - -An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in ds_config. - - ```json - "fp16": { - "enabled": true, - "fp16_master_weights_and_grads": true, - } - ``` - - - - - - -## Random Token Selection - -We have devised a new technique called “Random Token Selection” that greatly improves convergence. Random token selection addresses the limitation of biased selection problem in MoE model training. Our upcoming paper describes this technique and its results in detail. This feature is already part of the DeepSpeed runtime and is enabled by default so users can take advantage without any config flags or command-line arguments. - -## Advanced MoE usage - -Watch this space! We plan to add more interesting and detailed examples of using DeepSpeed MoE in the coming weeks. +--- +title: "Mixture of Experts" +--- + +DeepSpeed v0.5 introduces new support for training Mixture of Experts (MoE) models. MoE models are an emerging class of sparsely activated models that have sublinear compute costs with respect to their parameters. For example, the [Switch Transformer](https://arxiv.org/abs/2101.03961) consists of over 1.6 trillion parameters, while the compute required to train it is approximately equal to that of a 10 billion-parameter dense model. This increase in model size offers tremendous accuracy gains for a constant compute budget. + +For more details on results and further discussion, please see our press release: [DeepSpeed powers 8x larger MoE model training with high performance]({{ site.press_release_v5 }}). + +## Getting started with a simple MoE example + +**Note:** DeepSpeed MoE requires Pytorch 1.8 or above. +{: .notice--info} + +As a simple starting point we will show how to apply DeepSpeed MoE to a cifar10 example. Please refer to +our [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) going forward. + +If you are adding MoE to an existing model you can use the snippet below to help guide you: + + +### Expert groups initialization + +DeepSpeed MoE supports five different forms of parallelism, and it exploits both GPU and CPU memory. Its flexible design enables users to mix different types of prevalent parallelism techniques, as shown in the table below. + +| Short Name | Flexible Parallelism Configurations | Benefit | +| ---------------- | ------------------------------------| --------------------------------------------------------------------------- | +| E | Expert | Scales the model size by increasing the number of experts | +| E + D | Expert + Data | Accelerates training throughput by scaling to multiple data parallel groups | +| E + Z | Expert + ZeRO-powered data | Partitions the nonexpert parameters to support larger base models | +| E + D + M | Expert + Data + Model | Supports massive hidden sizes and even larger base models than E+Z | +| E + D + Z | Expert + Data + ZeRO-powered data | Supports massive hidden sizes and even larger base models than E+Z | +| E + Z-Off + M | Expert + ZeRO-Offload + Model | Leverages both GPU and CPU memory for large MoE models on limited # of GPUs | + +To support different forms of parallelism, we create a notion of DeepSpeed process groups that resides in ```deepspeed.utils.groups.py``` + +For most cases, the model training code needs to initialize these groups by calling +```python +deepspeed.utils.groups.initialize(ep_size="desired expert-parallel world size") +``` + +The GPUs (or ranks) participating in an expert-parallel group will distribute the total number of experts specified by the model training code argument num_experts. + +### MoE layer API + +The hidden_size is the input dimension of a particular layer and the output dimension is the same as that. This could lead to some changes to your model definition, especially for vision/convolutional models because the input/output dimensions don't match in certain cases. E.g. in the CIFAR-10 example, we modify the third fully connected layer to add the MoE layer. To cater for this, we need to add an additional fully-connected layer, whose input dimension is equal to the output dimension of the MoE layer. + +Original model config + +```python + self.fc3 = nn.Linear(84, 10) +``` + +Updated with MoE Layers + +```python + self.fc3 = nn.Linear(84, 84) + self.fc3 = deepspeed.moe.layer.MoE(hidden_size=84, expert=self.fc3, num_experts=args.num_experts, ...) + self.fc4 = nn.Linear(84, 10) +``` + +### An Example Scenario + +Given a total number of GPUs in our world size and a subset of GPUs in our expert-parallel world as follows. + +```python +WORLD_SIZE = 4 +EP_WORLD_SIZE = 2 +EXPERTS = 8 +``` + +The user code needs to initialize the groups as follows. + +```python +groups.initialize (ep_size=EP_WORLD_SIZE) +``` + +After that, the model code needs to use the deepspeed.moe.layer.MoE API as follows. + +```python +self.experts = deepspeed.moe.layer.MoE(hidden_size=input_dim, expert=ExpertModule(), num_experts=EXPERTS) +``` +With the above two commands, the DeepSpeed runtime will be set to train an MoE model with a total of 8 experts on 4 GPUs in 4 experts/GPU mode. We call this the E + D mode as described earlier in the table. + +For more advanced use case of the groups API including the inter-operability with Megatron style mpu object, watch this space! + + +```python +import torch +import deepspeed +import deepspeed.utils.groups as groups +from deepspeed.moe.layer import MoE + +WORLD_SIZE = 4 +EP_WORLD_SIZE = 2 +EXPERTS = 8 + +groups.initialize(ep_size=EP_WORLD_SIZE) + +fc3 = torch.nn.Linear(84, 84) +fc3 = MoE(hidden_size=84, expert=self.fc3, num_experts=EXPERTS, k=1) +fc4 = torch.nn.Linear(84, 10) + +``` + +For a runnable end-to-end example, please look at [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) + +### Combining ZeRO-Offload and DeepSpeed MoE for very large models + +To use MoE Layers in DeepSpeed, we rely on two parameter groups that are passed to an optimizer. A concrete example to create such groups is available from the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar). + +The relevant function that creates these param groups is as follows. + +```python +def create_moe_param_groups(model): + from deepspeed.moe.utils import is_moe_param + + params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'} + moe_params_with_weight_decay = { + 'params': [], + 'moe': True, + 'name': 'weight_decay_moe_params' + } + + for module_ in model.modules(): + moe_params_with_weight_decay['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and is_moe_param(p) + ]) + params_with_weight_decay['params'].extend([ + p for n, p in list(module_._parameters.items()) + if p is not None and not is_moe_param(p) + ]) + + return params_with_weight_decay, moe_params_with_weight_decay +``` + +The above param groups can then be fed to the ZeRO stage-2 optimizer as follows. + +```python + +net = Net() + +parameters = create_moe_param_groups(net) + +model_engine, optimizer, trainloader, __ = deepspeed.initialize( + args=args, model=net, model_parameters=parameters, training_data=trainset) +``` + +We are working on automating this functionality in the DeepSpeed ZeRO optimizer so the model training code can be simplified further. + +To run the [cifar10 example](https://github.com/microsoft/DeepSpeedExamples/tree/master/cifar) with ZeRO-Offload (stage 2) and MoE, please set the ds_config flags + +```json +"zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "reduce_scatter": true, + "allgather_bucket_size": 50000000, + "reduce_bucket_size": 50000000, + "overlap_comm": true, + "contiguous_gradients": true, + "cpu_offload": true + } +``` + +An additional optimization to save memory for extremely large model training on limited number of GPUs has also been introduced. Please enable that using the following config flag to the fp16 optimizer in ds_config. + + ```json + "fp16": { + "enabled": true, + "fp16_master_weights_and_grads": true, + } + ``` + + + + + + +## Random Token Selection + +We have devised a new technique called “Random Token Selection” that greatly improves convergence. Random token selection addresses the limitation of biased selection problem in MoE model training. Our upcoming paper describes this technique and its results in detail. This feature is already part of the DeepSpeed runtime and is enabled by default so users can take advantage without any config flags or command-line arguments. + +## Advanced MoE usage + +Watch this space! We plan to add more interesting and detailed examples of using DeepSpeed MoE in the coming weeks. diff --git a/docs/_tutorials/progressive_layer_dropping.md b/docs/_tutorials/progressive_layer_dropping.md index 8a447e97c945..8c184dfc6d21 100755 --- a/docs/_tutorials/progressive_layer_dropping.md +++ b/docs/_tutorials/progressive_layer_dropping.md @@ -1,155 +1,155 @@ ---- -title: "Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping" - ---- - -In this tutorial, we are going to introduce the progressive layer dropping (PLD) in DeepSpeed and provide examples on how to use PLD. PLD allows to train Transformer networks such as BERT 24% faster under the same number of samples and 2.5 times faster to get similar accuracy on downstream tasks. Detailed description of PLD and the experimental results are available in our [technical report](https://arxiv.org/pdf/2010.13369.pdf). - -To illustrate how to use PLD in DeepSpeed, we show how to enable PLD to pre-train a BERT model and fine-tune the pre-trained model on the GLUE datasets. - -## Running Pre-training with DeepSpeed and PLD - -To perform pre-training, one needs to first prepare the datasets. For this part, please refer our [BERT Pre-training](/tutorials/bert-pretraining/) post, which contains detailed information on how to do data downloading and pre-processing. For the below experiment, we use Wikipedia text and Bookcorpus, similar as [Devlin et. al.](https://arxiv.org/abs/1810.04805). - -The main part of pre-training is done in `deepspeed_train.py`, which has -already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh` is the shell script that launches the pre-training with DeepSpeed and PLD. - -```shell -bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh -``` - -Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. - -```bash ---progressive_layer_drop -``` - -To enable PLD in DeepSpeed, one needs to update the json configuration file with an appropriate PLD configuration dictionary like below: - -```json -{ - ... - "progressive_layer_drop": { - "enabled": true, - "theta": 0.5, - "gamma": 0.001 - } -} -``` - -we recommend a PLD theta value of 0.5 and gamma of 0.001 because these have worked well in our experiments. - -With these configuration changes, the DeepSpeed engine should print a runtime message as below: - - [INFO] [logging.py:60:log_dist] [Rank 0] Enabled progressive layer dropping (theta = 0.5) - -The `deepspeed_bsz4k_progressive_layer_drop_config_seq128.json` file allows users to specify DeepSpeed options in terms of batch size, micro batch size, optimizer, learning rate, sequence length, and other parameters. Below is the DeepSpeed configuration file we use for running BERT and PLD. - -```json -{ - "train_batch_size": 4096, - "train_micro_batch_size_per_gpu": 16, - "steps_per_print": 1000, - "prescale_gradients": true, - "gradient_predivide_factor": 8, - "optimizer": { - "type": "Adam", - "params": { - "lr": 1e-3, - "weight_decay": 0.01, - "bias_correction": false - } - }, - "gradient_clipping": 1.0, - - "wall_clock_breakdown": false, - - "fp16": { - "enabled": true, - "loss_scale": 0 - }, - - "progressive_layer_drop": { - "enabled": true, - "theta": 0.5, - "gamma": 0.001 - } -} -``` - -Note that the above configuration assumes training on 64 X 32GB V100 GPUs. Each GPU uses a micro batch size of 16 and accumulates gradients until the effective batch size reaches 4096. If you have GPUs with less memory, you may need to reduce "train_micro_batch_size_per_gpu". Alternatively, if you have more GPUs, you can increase the "train_batch_size" to increase training speed. We use the following hyperparameters for pre-training BERT with PLD enabled. - -| Parameters | Value | -| ------------------------------ | ----------------------- | -| Effective batch size | 4K | -| Train micro batch size per GPU | 16 | -| Optimizer | Adam | -| Peak learning rate | 1e-3 | -| Sequence-length | 128 | -| Learning rate scheduler | Warmup linear decay exp | -| Warmup ratio | 0.02 | -| Decay rate | 0.99 | -| Decay step | 1000 | -| Weight decay | 0.01 | -| Gradient clipping | 1.0 | - -Table 1. Pre-training hyperparameters - -**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stabilize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". - -## Fine-tuning with DeepSpeed on GLUE Tasks - -We use GLUE for fine-tuning tasks. GLUE (General Language Understanding Evaluation benchmark) (https://gluebenchmark.com/) is a collection of sentence or sentence-pair natural language understanding tasks including question answering, sentiment analysis, and textual entailment. It is designed to favor sample-efficient learning and knowledge-transfer across a range of different linguistic tasks in different domains. - -One can download all GLUE data using the provided helper [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). Once the data has been downloaded, one can set up the data and move the data to "/data/GlueData", which is the default location for hosting GLUE data. We then can use the PLD pre-trained BERT model checkpoint to run the fine-tuning. - -The main part of fine-tuning is done in `run_glue_classifier_bert_base.py`, which has -already been modified to use DeepSpeed. Before the fine-tuning, one needs to specify the BERT model configuration through the following config in `run_glue_classifier_bert_base.py`. In this case, it has already been modified to be the same as the configuration of the pre-trained model. - -```json - bert_model_config = { - "vocab_size_or_config_json_file": 119547, - "hidden_size": 768, - "num_hidden_layers": 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - "hidden_act": "gelu", - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "type_vocab_size": 2, - "initializer_range": 0.02 - } -``` - -Next, one can load a DeepSpeed style checkpoint with the following command, which has also already been added in the script. - -```shell -model.load_state_dict(checkpoint_state_dict['module'], strict=False) -``` - -Finally, the `run_glue_classifier_bert_base.sh` script invokes pre-training and setups several hyperparameters relevant to fine-tuning. - -```shell -bash run_glue_bert_base_finetune.sh [task] [batch size] [learning rate] [number of epochs] [job name] [checkpoint path] -``` - -An example would be: - -```shell -bash run_glue_bert_base_finetune.sh MNLI 32 3e-5 5 "fine_tune_MNLI" deepspeed_checkpoint.pt -``` - - - -### Expected Results - -The fine-tuning results can be found under the "logs" directory, and below are expected results for PLD on GLUE tasks. The "Lr" row indicates the learning rate we use for getting the corresponding accuracy result for each task. - -| | RTE | MRPC | STS-B | CoLA | SST-2 | QNLI | QQP | MNLI-m/mm | GLUE | -| ---------------------- | :--: | --------- | --------- | ---- | ----- | ---- | --------- | --------- | ---- | -| Metrics | Acc. | F1/Acc. | PCC/SCC | Acc. | Acc. | Acc. | F1/Acc. | Acc. | | -| Bert_{base} (original) | 66.4 | 88.9/84.8 | 87.1/89.2 | 52.1 | 93.5 | 90.5 | 71.2/89.2 | 84.6/83.4 | 80.7 | -| Bert_{base} (Our impl) | 67.8 | 88.0/86.0 | 89.5/89.2 | 52.5 | 91.2 | 87.1 | 89.0/90.6 | 82.5/83.4 | 82.1 | -| PLD | 69.3 | 86.6/84.3 | 90.0/89.6 | 55.8 | 91.6 | 90.7 | 89.6/91.2 | 84.1/83.8 | 82.9 | -| Lr | 7e-5 | 9e-5 | 7e-5 | 5e-5 | 7e-5 | 9e-5 | 2e-4 | 3e-5 | | +--- +title: "Accelerating Training of Transformer-Based Language Models with Progressive Layer Dropping" + +--- + +In this tutorial, we are going to introduce the progressive layer dropping (PLD) in DeepSpeed and provide examples on how to use PLD. PLD allows to train Transformer networks such as BERT 24% faster under the same number of samples and 2.5 times faster to get similar accuracy on downstream tasks. Detailed description of PLD and the experimental results are available in our [technical report](https://arxiv.org/pdf/2010.13369.pdf). + +To illustrate how to use PLD in DeepSpeed, we show how to enable PLD to pre-train a BERT model and fine-tune the pre-trained model on the GLUE datasets. + +## Running Pre-training with DeepSpeed and PLD + +To perform pre-training, one needs to first prepare the datasets. For this part, please refer our [BERT Pre-training](/tutorials/bert-pretraining/) post, which contains detailed information on how to do data downloading and pre-processing. For the below experiment, we use Wikipedia text and Bookcorpus, similar as [Devlin et. al.](https://arxiv.org/abs/1810.04805). + +The main part of pre-training is done in `deepspeed_train.py`, which has +already been modified to use DeepSpeed. The `ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh` is the shell script that launches the pre-training with DeepSpeed and PLD. + +```shell +bash ds_train_bert_progressive_layer_drop_bsz4k_seq128.sh +``` + +Most of the flags in the above script should be familiar if you have stepped through the BERT pre-training [tutorial](/tutorials/bert-pretraining/). To enable training with PLD, one needs to enable PLD in both the client script and in the DeepSpeed engine. To enable PLD in the client script, one needs to add the following command line flag to enable progressive layer dropping on Transformer blocks. + +```bash +--progressive_layer_drop +``` + +To enable PLD in DeepSpeed, one needs to update the json configuration file with an appropriate PLD configuration dictionary like below: + +```json +{ + ... + "progressive_layer_drop": { + "enabled": true, + "theta": 0.5, + "gamma": 0.001 + } +} +``` + +we recommend a PLD theta value of 0.5 and gamma of 0.001 because these have worked well in our experiments. + +With these configuration changes, the DeepSpeed engine should print a runtime message as below: + + [INFO] [logging.py:60:log_dist] [Rank 0] Enabled progressive layer dropping (theta = 0.5) + +The `deepspeed_bsz4k_progressive_layer_drop_config_seq128.json` file allows users to specify DeepSpeed options in terms of batch size, micro batch size, optimizer, learning rate, sequence length, and other parameters. Below is the DeepSpeed configuration file we use for running BERT and PLD. + +```json +{ + "train_batch_size": 4096, + "train_micro_batch_size_per_gpu": 16, + "steps_per_print": 1000, + "prescale_gradients": true, + "gradient_predivide_factor": 8, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3, + "weight_decay": 0.01, + "bias_correction": false + } + }, + "gradient_clipping": 1.0, + + "wall_clock_breakdown": false, + + "fp16": { + "enabled": true, + "loss_scale": 0 + }, + + "progressive_layer_drop": { + "enabled": true, + "theta": 0.5, + "gamma": 0.001 + } +} +``` + +Note that the above configuration assumes training on 64 X 32GB V100 GPUs. Each GPU uses a micro batch size of 16 and accumulates gradients until the effective batch size reaches 4096. If you have GPUs with less memory, you may need to reduce "train_micro_batch_size_per_gpu". Alternatively, if you have more GPUs, you can increase the "train_batch_size" to increase training speed. We use the following hyperparameters for pre-training BERT with PLD enabled. + +| Parameters | Value | +| ------------------------------ | ----------------------- | +| Effective batch size | 4K | +| Train micro batch size per GPU | 16 | +| Optimizer | Adam | +| Peak learning rate | 1e-3 | +| Sequence-length | 128 | +| Learning rate scheduler | Warmup linear decay exp | +| Warmup ratio | 0.02 | +| Decay rate | 0.99 | +| Decay step | 1000 | +| Weight decay | 0.01 | +| Gradient clipping | 1.0 | + +Table 1. Pre-training hyperparameters + +**Note:** DeepSpeed now supports PreLayerNorm as the default way for training BERT, because of its ability to avoid vanishing gradient, stabilize optimization, and performance gains, as described in our fastest BERT training [blog post](https://www.deepspeed.ai/news/2020/05/27/fastest-bert-training.html). We therefore support the switchable Transformer block directly on the the BERT with PreLayerNorm. The implementation can be found at "example\bing_bert\nvidia\modelingpreln_layerdrop.py". + +## Fine-tuning with DeepSpeed on GLUE Tasks + +We use GLUE for fine-tuning tasks. GLUE (General Language Understanding Evaluation benchmark) (https://gluebenchmark.com/) is a collection of sentence or sentence-pair natural language understanding tasks including question answering, sentiment analysis, and textual entailment. It is designed to favor sample-efficient learning and knowledge-transfer across a range of different linguistic tasks in different domains. + +One can download all GLUE data using the provided helper [script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e). Once the data has been downloaded, one can set up the data and move the data to "/data/GlueData", which is the default location for hosting GLUE data. We then can use the PLD pre-trained BERT model checkpoint to run the fine-tuning. + +The main part of fine-tuning is done in `run_glue_classifier_bert_base.py`, which has +already been modified to use DeepSpeed. Before the fine-tuning, one needs to specify the BERT model configuration through the following config in `run_glue_classifier_bert_base.py`. In this case, it has already been modified to be the same as the configuration of the pre-trained model. + +```json + bert_model_config = { + "vocab_size_or_config_json_file": 119547, + "hidden_size": 768, + "num_hidden_layers": 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 2, + "initializer_range": 0.02 + } +``` + +Next, one can load a DeepSpeed style checkpoint with the following command, which has also already been added in the script. + +```shell +model.load_state_dict(checkpoint_state_dict['module'], strict=False) +``` + +Finally, the `run_glue_classifier_bert_base.sh` script invokes pre-training and setups several hyperparameters relevant to fine-tuning. + +```shell +bash run_glue_bert_base_finetune.sh [task] [batch size] [learning rate] [number of epochs] [job name] [checkpoint path] +``` + +An example would be: + +```shell +bash run_glue_bert_base_finetune.sh MNLI 32 3e-5 5 "fine_tune_MNLI" deepspeed_checkpoint.pt +``` + + + +### Expected Results + +The fine-tuning results can be found under the "logs" directory, and below are expected results for PLD on GLUE tasks. The "Lr" row indicates the learning rate we use for getting the corresponding accuracy result for each task. + +| | RTE | MRPC | STS-B | CoLA | SST-2 | QNLI | QQP | MNLI-m/mm | GLUE | +| ---------------------- | :--: | --------- | --------- | ---- | ----- | ---- | --------- | --------- | ---- | +| Metrics | Acc. | F1/Acc. | PCC/SCC | Acc. | Acc. | Acc. | F1/Acc. | Acc. | | +| Bert_{base} (original) | 66.4 | 88.9/84.8 | 87.1/89.2 | 52.1 | 93.5 | 90.5 | 71.2/89.2 | 84.6/83.4 | 80.7 | +| Bert_{base} (Our impl) | 67.8 | 88.0/86.0 | 89.5/89.2 | 52.5 | 91.2 | 87.1 | 89.0/90.6 | 82.5/83.4 | 82.1 | +| PLD | 69.3 | 86.6/84.3 | 90.0/89.6 | 55.8 | 91.6 | 90.7 | 89.6/91.2 | 84.1/83.8 | 82.9 | +| Lr | 7e-5 | 9e-5 | 7e-5 | 5e-5 | 7e-5 | 9e-5 | 2e-4 | 3e-5 | | diff --git a/docs/_tutorials/zero-offload.md b/docs/_tutorials/zero-offload.md index 8b0f56ec510f..afc916d8fc33 100644 --- a/docs/_tutorials/zero-offload.md +++ b/docs/_tutorials/zero-offload.md @@ -1,75 +1,75 @@ ---- -title: "ZeRO-Offload" ---- -ZeRO-3 Offload consists of a subset of features in our newly released ZeRO-Infinity. Read our [ZeRO-Infinity blog](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) to learn more! - -We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. - -ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. - -## ZeRO-Offload Overview -For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3). - -## Training Environment -For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/tutorials/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise. - -## Training a 10B parameter GPT-2 on 1 V100 GPU -We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json. - -### Megatron-LM GPT-2 launch script changes -We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model with activation checkpointing enabled, which can be achieved by the following set of changes: - -```bash - --model-parallel-size 1 \ - --num-layers 50 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --batch-size 10 \ - --deepspeed_config ds_zero_offload.config \ - --checkpoint-activations -``` - -Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/tutorials/megatron/). - -Second, we need to apply the following changes to ensure that only one GPU is used for training. -```bash - deepspeed --num_nodes 1 --num_gpus 1 ... -``` - -### DeepSpeed Configuration Changes -ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below: - -```json -{ - "zero_optimization": { - "stage": 2, - "cpu_offload": true, - "contiguous_gradients": true, - "overlap_comm": true - } -} -``` - -As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** to enable ZeRO-Offload optimizations. In addition, we can set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance. With these changes we can now run the model. We share some screenshots of the training below. - -Here is a screenshot of the training log: - - - - - - -Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training: - - - - - -Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation: - - - - - -Congratulations! You have completed the ZeRO-Offload tutorial. - +--- +title: "ZeRO-Offload" +--- +ZeRO-3 Offload consists of a subset of features in our newly released ZeRO-Infinity. Read our [ZeRO-Infinity blog](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/) to learn more! + +We recommend that you read the tutorials on [Getting Started](/getting-started/) and [ZeRO](/tutorials/zero/) before stepping through this tutorial. + +ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed. + +## ZeRO-Offload Overview +For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3). + +## Training Environment +For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/tutorials/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise. + +## Training a 10B parameter GPT-2 on 1 V100 GPU +We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json. + +### Megatron-LM GPT-2 launch script changes +We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model with activation checkpointing enabled, which can be achieved by the following set of changes: + +```bash + --model-parallel-size 1 \ + --num-layers 50 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --batch-size 10 \ + --deepspeed_config ds_zero_offload.config \ + --checkpoint-activations +``` + +Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/tutorials/megatron/). + +Second, we need to apply the following changes to ensure that only one GPU is used for training. +```bash + deepspeed --num_nodes 1 --num_gpus 1 ... +``` + +### DeepSpeed Configuration Changes +ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below: + +```json +{ + "zero_optimization": { + "stage": 2, + "cpu_offload": true, + "contiguous_gradients": true, + "overlap_comm": true + } +} +``` + +As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** to enable ZeRO-Offload optimizations. In addition, we can set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance. With these changes we can now run the model. We share some screenshots of the training below. + +Here is a screenshot of the training log: + + + + + + +Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training: + + + + + +Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation: + + + + + +Congratulations! You have completed the ZeRO-Offload tutorial. + diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 01595c11394b..adc0bbb40963 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -1,301 +1,301 @@ ---- -title: "Zero Redundancy Optimizer (ZeRO)" ---- -If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. - -In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. - -## ZeRO Overview -ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). - -* **Stage 1**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. - -* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. - -* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes. - -In addition, ZeRO-3 includes the *infinity offload engine* to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload to both CPU and NVMe memory for huge memory savings. - -## Training environment -We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. - -## Enabling ZeRO Optimization -To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). - -### Training a 1.5B Parameter GPT-2 model -We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: - -```bash - --model-parallel-size 1 \ - --num-layers 48 \ - --hidden-size 1600 \ - --num-attention-heads 16 \ - --batch-size 1 \ - --deepspeed_config ds_zero_stage_1.config \ -``` - -Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: - - - - - -A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below: - -```json -{ - "zero_optimization": { - "stage":1, - "reduce_bucket_size": 5e8 - } -} -``` -As seen above, we set two fields in the **zero_optimization** key. Specifically we set the _stage_ field to 1, and the optional _reduce_bucket_size_ for gradient reduction to 500M. With ZeRO stage 1 enabled, the model can now train smoothly on 8 GPUs without running out of memory. Below we provide some screenshots of the model training: - - - - - - - - - - - -From the nvidia-smi screenshot above we can see that only GPUs 6-7 are being used for training the model. With ZeRO stage 1 we can further reduce the per-device memory consumption by increasing the data parallelism degree. These memory savings can be leveraged to either increase model size and/or batch size. In contrast, such benefits are not possible with data parallelism alone. - -### Training a 10B Parameter GPT-2 model -ZeRO stage 2 optimizations further increases the size of models that can be trained using data parallelism. We show this by training a model with 10B parameters using 32 V100 GPUs. - -First, we need to configure a 10B parameter model with activation checkpointing enabled. This can be done by applying the following GPT-2 model configuration changes to the DeepSpeed launch script. - -```bash - --model-parallel-size 1 \ - --num-layers 50 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --batch-size 1 \ - --deepspeed_config ds_zero_stage_2.config \ - --checkpoint-activations -``` - -Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations: - -```json -{ - "zero_optimization": { - "stage":2, - "contiguous_gradients": true, - "overlap_comm": true, - "reduce_scatter": true, - "reduce_bucket_size": 5e8, - "allgather_bucket_size": 5e8 - } -} -``` - -In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmentation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. - -Here is a screenshot of the training log: - - - - - -Here is a screenshot of nvidia-smi showing GPU activity during training: - - - - - -### Training trillion-scale models with ZeRO-Infinity - -ZeRO-3, the third stage of ZeRO, partitions the full model state (i.e., -weights, gradients, and optimizer states) to scale memory savings linearly -with the degree of data parallelism. ZeRO-3 can be enabled in the JSON -configuration. A full description of these configurations is available -[here](/docs/config-json/#zero-optimizations-for-fp16-training). - - -#### Offloading to CPU and NVMe with ZeRO-Infinity - -ZeRO-Infinity uses DeepSpeed's infinity offload engine to offload the full -model state to CPU or NVMe memory, allowing for even larger model sizes. Offloading -can be enabled inside the DeepSpeed configuration: - -```diff -@@ -6,5 +6,11 @@ - "zero_optimization": { - "stage": 3, - "contiguous_gradients": true, - "stage3_max_live_parameters": 1e9, - "stage3_max_reuse_distance": 1e9, - "stage3_prefetch_bucket_size": 1e7, - "stage3_param_persistence_threshold": 1e5, - "reduce_bucket_size": 1e7, -- "sub_group_size": 1e9 -+ "sub_group_size": 1e9, -+ "offload_optimizer": { -+ "device": "cpu" -+ }, -+ "offload_param": { -+ "device": "cpu" -+ } - } -``` - -**ZeRO-Infinity vs ZeRO-Offload:** -DeepSpeed first included offloading capabilities with ZeRO-Offload, -a system for offloading optimizer and gradient states to CPU memory -within ZeRO-2. ZeRO-Infinity is the next generation of offloading -capabilities accessible to ZeRO-3. ZeRO-Infinity is able to offload -more data than ZeRO-Offload and has more effective bandwidth utilization -and overlapping of computation and communication. -{: .notice--info} - - - - -#### Allocating Massive Megatron-LM Models - -We make two further changes to model initialization in order to support models -that exceed *local* system memory, but not *total* system memory. - -1. Allocate the model in a memory-scalable fashion. The model parameters will -be allocated and immediately partitioned across the data parallel group. If -`remote_device` is `"cpu"` or `"nvme"`, the model will also be allocated in CPU/NVMe memory -instead of GPU memory. Please see the full -[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init) -for more details. - - ```python - with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), - remote_device=get_args().remote_device, - enabled=get_args().zero_stage==3): - model = GPT2Model(num_tokentypes=0, parallel_output=True) - ``` - -2. Gather the embeddings weight for initialization. DeepSpeed will automatically -gather a module's parameters during its constructor and for its forward and backward pass. -However, additional accesses must coordinate with DeepSpeed to ensure that parameter data -is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank` -argument should also be used to ensure all ranks have a consistent view of -the data. Please see the full -[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters) -for more details. - - ```python - self.position_embeddings = torch.nn.Embedding(...) - with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, - modifier_rank=0): - # Initialize the position embeddings. - self.init_method(self.position_embeddings.weight) - - ... - - self.tokentype_embeddings = torch.nn.Embedding(...) - with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight, - modifier_rank=0): - # Initialize the token-type embeddings. - self.init_method(self.tokentype_embeddings.weight) - ``` - -#### Memory-centric tiling -ZeRO-Infinity includes a replacement for `Linear` layers that further reduces memory. -We optionally tile the model parallel linear layers found in each Transformer layer. Note -that model parallelism and tiling can be combined by specifying the corresponding -base class when building the layer. -The `deepspeed.zero.TiledLinear` module exploits the data fetch and release -pattern of ZeRO-3 to reduce the working memory requirements by breaking down -a large operator into smaller tiles that can be executed sequentially. - -We include the changes for one example from Megatron-LM's [ParallelMLP](https://github.com/microsoft/DeepSpeedExamples/blob/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py#L82). Three more -model-parallel layers in `transformer.py` proceed similarly. - -The model parallel layers of Megatron-LM have a special form in which the -additive `bias` of the layer is delayed and instead returned from `forward()` -to be fused with a later operator. DeepSpeed's -`deepspeed.zero.TiledLinearReturnBias` subclass of `TiledLinear` simply also -forwards the returned `bias` parameter without accumulating. - -```diff -@@ -1,6 +1,9 @@ --self.dense_h_to_4h = mpu.ColumnParallelLinear( -+self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( - args.hidden_size, - 4 * args.hidden_size, -+ in_splits=args.tile_factor, -+ out_splits=4*args.tile_factor, -+ linear_cls=mpu.ColumnParallelLinear, - gather_output=False, - init_method=init_method, - skip_bias_add=True) -``` - -Note that we scale `in_splits` and `out_splits` proportionally with `input_size` and `output_size`. This -results in tiles of fixed size `[hidden/tile_factor, hidden/tile_factor]`. - -#### Registering external parameters - -**Deprecated:** -DeepSpeed version `0.3.15` introduced automatic external parameter -registration and this step is no longer needed. -{: .notice--info} - - -## Extracting weights - -If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights: - -- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`. -- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable: - -```json - "zero_optimization": { - "stage3_gather_fp16_weights_on_model_save": true - }, -``` -And then save the model using: - -```python - if self.deepspeed: - self.deepspeed.save_fp16_model(output_dir, output_file) -``` - -Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. - -Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them. -You can use this method to save ZeRO-2 weights as well. - -If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: - -``` bash -$ cd /path/to/checkpoint_dir -$ ./zero_to_fp32.py . pytorch_model.bin -Processing zero checkpoint at global_step1 -Detected checkpoint of type zero stage 3, world_size: 2 -Saving fp32 state dict to pytorch_model.bin (total_numel=60506624) -``` - -The `zero_to_fp32.py` gets created automatically when you save a checkpoint. - -Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint. - -Alternatively, if you have plenty of spare CPU memory and instead of getting the file you want your model to be updated to its fp32 weights, you can do the following at the end of the training: - -``` python - from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - fp32_model = load_state_dict_from_zero_checkpoint(deepspeed.module, checkpoint_dir) -``` - -Beware, that the model will be good for saving, but no longer good for continuing the training and will require a `deepspeed.initialize()` anew. - -If you just want the `state_dict`, you can do: - -``` python - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) -``` - - -Congratulations! You have completed the ZeRO tutorial. +--- +title: "Zero Redundancy Optimizer (ZeRO)" +--- +If you have not done so already, we advise that you read the DeepSpeed tutorials on [Getting Started](/getting-started/) and [Megatron-LM GPT-2](/tutorials/megatron/) before stepping through this tutorial. + +In this tutorial, we will apply the ZeRO optimizer to the [Megatron-LM GPT-2](https://github.com/NVIDIA/Megatron-LM) model. ZeRO is a powerful set of memory optimization techniques that enable effective FP16 training of large models with trillions of parameters, such as [GPT-2](https://openai.com/blog/better-language-models/) and [Turing-NLG 17B](https://www.microsoft.com/en-us/research/blog/turing-nlg-a-17-billion-parameter-language-model-by-microsoft/). Compared to the alternative model parallelism approaches for training large models, a key appeal of ZeRO is that no model code modifications are required. As this tutorial will demonstrate, *using ZeRO in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration JSON*. No code changes are needed. + +## ZeRO Overview +ZeRO leverages the aggregate computation and memory resources of data parallelism to reduce the memory and compute requirements of each device (GPU) used for model training. ZeRO reduces the memory consumption of each GPU by partitioning the various model training states (weights, gradients, and optimizer states) across the available devices (GPUs and CPUs) in the distributed training hardware. Concretely, ZeRO is being implemented as incremental stages of optimizations, where optimizations in earlier stages are available in the later stages. To deep dive into ZeRO, please see our [paper](https://arxiv.org/abs/1910.02054v3). + +* **Stage 1**: The optimizer states (e.g., for [Adam optimizer](https://arxiv.org/abs/1412.6980), 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. + +* **Stage 2**: The reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. + +* **Stage 3**: The 16-bit model parameters are partitioned across the processes. ZeRO-3 will automatically collect and partition them during the forward and backward passes. + +In addition, ZeRO-3 includes the *infinity offload engine* to form ZeRO-Infinity ([paper](https://arxiv.org/abs/2104.07857)), which can offload to both CPU and NVMe memory for huge memory savings. + +## Training environment +We use the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM-v1.1.5-ZeRO3) GPT-2 code for this exercise. You can step through the Megatron-LM [tutorial](/tutorials/megatron/) to familiarize yourself with the code. We will train the models in this tutorial on [NVIDIA Tesla V100-SXM3 Tensor Core GPUs](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM. + +## Enabling ZeRO Optimization +To enable ZeRO optimizations for a DeepSpeed model, we simply add the **_zero_optimization_** key to the DeepSpeed JSON configuration. A full description of configuration knobs of the **zero_optimization** key is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). + +### Training a 1.5B Parameter GPT-2 model +We demonstrate the benefits of ZeRO stage 1 by showing that it enables data parallel training of a 1.5 billion parameter GPT-2 model on eight V100 GPUs. We configure training to use a batch size of 1 per device to ensure that the memory consumption is primarily due to model parameters and optimizer states. We create this training scenario by applying the following modifications to the deepspeed launch script: + +```bash + --model-parallel-size 1 \ + --num-layers 48 \ + --hidden-size 1600 \ + --num-attention-heads 16 \ + --batch-size 1 \ + --deepspeed_config ds_zero_stage_1.config \ +``` + +Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: + + + + + +A key reason why this model does not fit in GPU memory is that the Adam optimizer states for the model consume 18GB; a significant portion of the 32GB RAM. By using ZeRO stage 1 to partition the optimizer state among eight data parallel ranks, the per-device memory consumption can be reduced to 2.25GB, thus making the model trainable. To enable ZeRO stage 1, we simply update the DeepSpeed JSON config file as below: + +```json +{ + "zero_optimization": { + "stage":1, + "reduce_bucket_size": 5e8 + } +} +``` +As seen above, we set two fields in the **zero_optimization** key. Specifically we set the _stage_ field to 1, and the optional _reduce_bucket_size_ for gradient reduction to 500M. With ZeRO stage 1 enabled, the model can now train smoothly on 8 GPUs without running out of memory. Below we provide some screenshots of the model training: + + + + + + + + + + + +From the nvidia-smi screenshot above we can see that only GPUs 6-7 are being used for training the model. With ZeRO stage 1 we can further reduce the per-device memory consumption by increasing the data parallelism degree. These memory savings can be leveraged to either increase model size and/or batch size. In contrast, such benefits are not possible with data parallelism alone. + +### Training a 10B Parameter GPT-2 model +ZeRO stage 2 optimizations further increases the size of models that can be trained using data parallelism. We show this by training a model with 10B parameters using 32 V100 GPUs. + +First, we need to configure a 10B parameter model with activation checkpointing enabled. This can be done by applying the following GPT-2 model configuration changes to the DeepSpeed launch script. + +```bash + --model-parallel-size 1 \ + --num-layers 50 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --batch-size 1 \ + --deepspeed_config ds_zero_stage_2.config \ + --checkpoint-activations +``` + +Next, we need to update the DeepSpeed JSON configuration, as shown below, to enable ZeRO stage 2 optimizations: + +```json +{ + "zero_optimization": { + "stage":2, + "contiguous_gradients": true, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "allgather_bucket_size": 5e8 + } +} +``` + +In the above changes, we have set the _stage_ field to 2, and configured other optimization knobs that are available in ZeRO stage 2. For example, we have enabled _contiguous_gradients_ to reduce memory fragmentation during backward pass. A full description of these optimization knobs is available [here](/docs/config-json/#zero-optimizations-for-fp16-training). With these changes, we can now launch the training run. + +Here is a screenshot of the training log: + + + + + +Here is a screenshot of nvidia-smi showing GPU activity during training: + + + + + +### Training trillion-scale models with ZeRO-Infinity + +ZeRO-3, the third stage of ZeRO, partitions the full model state (i.e., +weights, gradients, and optimizer states) to scale memory savings linearly +with the degree of data parallelism. ZeRO-3 can be enabled in the JSON +configuration. A full description of these configurations is available +[here](/docs/config-json/#zero-optimizations-for-fp16-training). + + +#### Offloading to CPU and NVMe with ZeRO-Infinity + +ZeRO-Infinity uses DeepSpeed's infinity offload engine to offload the full +model state to CPU or NVMe memory, allowing for even larger model sizes. Offloading +can be enabled inside the DeepSpeed configuration: + +```diff +@@ -6,5 +6,11 @@ + "zero_optimization": { + "stage": 3, + "contiguous_gradients": true, + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_prefetch_bucket_size": 1e7, + "stage3_param_persistence_threshold": 1e5, + "reduce_bucket_size": 1e7, +- "sub_group_size": 1e9 ++ "sub_group_size": 1e9, ++ "offload_optimizer": { ++ "device": "cpu" ++ }, ++ "offload_param": { ++ "device": "cpu" ++ } + } +``` + +**ZeRO-Infinity vs ZeRO-Offload:** +DeepSpeed first included offloading capabilities with ZeRO-Offload, +a system for offloading optimizer and gradient states to CPU memory +within ZeRO-2. ZeRO-Infinity is the next generation of offloading +capabilities accessible to ZeRO-3. ZeRO-Infinity is able to offload +more data than ZeRO-Offload and has more effective bandwidth utilization +and overlapping of computation and communication. +{: .notice--info} + + + + +#### Allocating Massive Megatron-LM Models + +We make two further changes to model initialization in order to support models +that exceed *local* system memory, but not *total* system memory. + +1. Allocate the model in a memory-scalable fashion. The model parameters will +be allocated and immediately partitioned across the data parallel group. If +`remote_device` is `"cpu"` or `"nvme"`, the model will also be allocated in CPU/NVMe memory +instead of GPU memory. Please see the full +[ZeRO-3 Init docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.Init) +for more details. + + ```python + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=get_args().remote_device, + enabled=get_args().zero_stage==3): + model = GPT2Model(num_tokentypes=0, parallel_output=True) + ``` + +2. Gather the embeddings weight for initialization. DeepSpeed will automatically +gather a module's parameters during its constructor and for its forward and backward pass. +However, additional accesses must coordinate with DeepSpeed to ensure that parameter data +is gathered and subsequently partitioned. If the tensor is modified, the `modifier_rank` +argument should also be used to ensure all ranks have a consistent view of +the data. Please see the full +[GatheredParameters docs](https://deepspeed.readthedocs.io/en/latest/zero3.html#deepspeed.zero.GatheredParameters) +for more details. + + ```python + self.position_embeddings = torch.nn.Embedding(...) + with deepspeed.zero.GatheredParameters(self.position_embeddings.weight, + modifier_rank=0): + # Initialize the position embeddings. + self.init_method(self.position_embeddings.weight) + + ... + + self.tokentype_embeddings = torch.nn.Embedding(...) + with deepspeed.zero.GatheredParameters(self.tokentype_embeddings.weight, + modifier_rank=0): + # Initialize the token-type embeddings. + self.init_method(self.tokentype_embeddings.weight) + ``` + +#### Memory-centric tiling +ZeRO-Infinity includes a replacement for `Linear` layers that further reduces memory. +We optionally tile the model parallel linear layers found in each Transformer layer. Note +that model parallelism and tiling can be combined by specifying the corresponding +base class when building the layer. +The `deepspeed.zero.TiledLinear` module exploits the data fetch and release +pattern of ZeRO-3 to reduce the working memory requirements by breaking down +a large operator into smaller tiles that can be executed sequentially. + +We include the changes for one example from Megatron-LM's [ParallelMLP](https://github.com/microsoft/DeepSpeedExamples/blob/bdf8e59aede8c8e0577e8d4d557298ca8515268f/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py#L82). Three more +model-parallel layers in `transformer.py` proceed similarly. + +The model parallel layers of Megatron-LM have a special form in which the +additive `bias` of the layer is delayed and instead returned from `forward()` +to be fused with a later operator. DeepSpeed's +`deepspeed.zero.TiledLinearReturnBias` subclass of `TiledLinear` simply also +forwards the returned `bias` parameter without accumulating. + +```diff +@@ -1,6 +1,9 @@ +-self.dense_h_to_4h = mpu.ColumnParallelLinear( ++self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias( + args.hidden_size, + 4 * args.hidden_size, ++ in_splits=args.tile_factor, ++ out_splits=4*args.tile_factor, ++ linear_cls=mpu.ColumnParallelLinear, + gather_output=False, + init_method=init_method, + skip_bias_add=True) +``` + +Note that we scale `in_splits` and `out_splits` proportionally with `input_size` and `output_size`. This +results in tiles of fixed size `[hidden/tile_factor, hidden/tile_factor]`. + +#### Registering external parameters + +**Deprecated:** +DeepSpeed version `0.3.15` introduced automatic external parameter +registration and this step is no longer needed. +{: .notice--info} + + +## Extracting weights + +If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights: + +- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`. +- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable: + +```json + "zero_optimization": { + "stage3_gather_fp16_weights_on_model_save": true + }, +``` +And then save the model using: + +```python + if self.deepspeed: + self.deepspeed.save_fp16_model(output_dir, output_file) +``` + +Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. + +Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them. +You can use this method to save ZeRO-2 weights as well. + +If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: + +``` bash +$ cd /path/to/checkpoint_dir +$ ./zero_to_fp32.py . pytorch_model.bin +Processing zero checkpoint at global_step1 +Detected checkpoint of type zero stage 3, world_size: 2 +Saving fp32 state dict to pytorch_model.bin (total_numel=60506624) +``` + +The `zero_to_fp32.py` gets created automatically when you save a checkpoint. + +Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint. + +Alternatively, if you have plenty of spare CPU memory and instead of getting the file you want your model to be updated to its fp32 weights, you can do the following at the end of the training: + +``` python + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + fp32_model = load_state_dict_from_zero_checkpoint(deepspeed.module, checkpoint_dir) +``` + +Beware, that the model will be good for saving, but no longer good for continuing the training and will require a `deepspeed.initialize()` anew. + +If you just want the `state_dict`, you can do: + +``` python + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) +``` + + +Congratulations! You have completed the ZeRO tutorial. diff --git a/docs/code-docs/source/schedulers.rst b/docs/code-docs/source/schedulers.rst index c7b67cbb20d8..5bc23ffb0acf 100755 --- a/docs/code-docs/source/schedulers.rst +++ b/docs/code-docs/source/schedulers.rst @@ -1,25 +1,25 @@ -Learning Rate Schedulers -=================== - -DeepSpeed offers implementations of ``LRRangeTest``, ``OneCycle``, ``WarmupLR``, ``WarmupDecayLR`` learning rate schedulers. When using a DeepSpeed's learning rate scheduler (specified in the `ds_config.json` file), DeepSpeed calls the `step()` method of the scheduler at every training step (when `model_engine.step()` is executed). When not using a DeepSpeed's learning rate scheduler: - * if the schedule is supposed to execute at every training step, then the user can pass the scheduler to `deepspeed.initialize` when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore. - * if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly. - -LRRangeTest ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.LRRangeTest - - -OneCycle ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.OneCycle - - -WarmupLR ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.WarmupLR - - -WarmupDecayLR ---------------------------- -.. autoclass:: deepspeed.runtime.lr_schedules.WarmupDecayLR +Learning Rate Schedulers +=================== + +DeepSpeed offers implementations of ``LRRangeTest``, ``OneCycle``, ``WarmupLR``, ``WarmupDecayLR`` learning rate schedulers. When using a DeepSpeed's learning rate scheduler (specified in the `ds_config.json` file), DeepSpeed calls the `step()` method of the scheduler at every training step (when `model_engine.step()` is executed). When not using a DeepSpeed's learning rate scheduler: + * if the schedule is supposed to execute at every training step, then the user can pass the scheduler to `deepspeed.initialize` when initializing the DeepSpeed engine and let DeepSpeed manage it for update or save/restore. + * if the schedule is supposed to execute at any other interval (e.g., training epochs), then the user should NOT pass the scheduler to DeepSpeed during initialization and must manage it explicitly. + +LRRangeTest +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.LRRangeTest + + +OneCycle +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.OneCycle + + +WarmupLR +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.WarmupLR + + +WarmupDecayLR +--------------------------- +.. autoclass:: deepspeed.runtime.lr_schedules.WarmupDecayLR diff --git a/tests/perf/adam_test.py b/tests/perf/adam_test.py index 0f29cab4662e..1ddcd44bbdd4 100755 --- a/tests/perf/adam_test.py +++ b/tests/perf/adam_test.py @@ -1,24 +1,24 @@ -import torch -from deepspeed.ops.adam import DeepSpeedCPUAdam -import time - -device = 'cpu' -model_size = 1 * 1024**3 -group_size = [model_size, 274432] - -param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size] -optimizer = DeepSpeedCPUAdam(param) -#torch.set_num_threads(128) -for i, p in enumerate(param): - p.grad = torch.ones(group_size[i], device=device) -#param.grad = torch.ones(model_size, device=device) -avg = 0 -for i in range(100): - start = time.time() - optimizer.step() - stop = time.time() - avg += (stop - start) - for i, p in enumerate(param): - p.grad = torch.ones(group_size[i], device=device) * 2 - #param.grad = torch.ones(model_size, device=device) * 2 -print("Elapsed Time is ", avg / 100) +import torch +from deepspeed.ops.adam import DeepSpeedCPUAdam +import time + +device = 'cpu' +model_size = 1 * 1024**3 +group_size = [model_size, 274432] + +param = [torch.nn.Parameter(torch.ones(size, device=device)) for size in group_size] +optimizer = DeepSpeedCPUAdam(param) +#torch.set_num_threads(128) +for i, p in enumerate(param): + p.grad = torch.ones(group_size[i], device=device) +#param.grad = torch.ones(model_size, device=device) +avg = 0 +for i in range(100): + start = time.time() + optimizer.step() + stop = time.time() + avg += (stop - start) + for i, p in enumerate(param): + p.grad = torch.ones(group_size[i], device=device) * 2 + #param.grad = torch.ones(model_size, device=device) * 2 +print("Elapsed Time is ", avg / 100) diff --git a/tests/perf/adam_test1.py b/tests/perf/adam_test1.py index b0aba0fcd6b9..88f1a1c5961d 100755 --- a/tests/perf/adam_test1.py +++ b/tests/perf/adam_test1.py @@ -1,22 +1,22 @@ -import torch -from deepspeed.ops.adam import DeepSpeedCPUAdam -import time - -device = 'cpu' -model_size = 1 * 1024**3 -param = torch.nn.Parameter(torch.ones(model_size, device=device)) -param_fp16 = torch.nn.Parameter(torch.ones(model_size, - dtype=torch.half, - device='cuda:0')) - -optimizer = DeepSpeedCPUAdam([param]) -#torch.set_num_threads(128) -param.grad = torch.ones(model_size, device=device) -avg = 0 -for i in range(100): - start = time.time() - optimizer.step(fp16_param_groups=[param_fp16]) - stop = time.time() - avg += (stop - start) - param.grad = torch.ones(model_size, device=device) * 2 -print("Elapsed Time is ", avg / 100) +import torch +from deepspeed.ops.adam import DeepSpeedCPUAdam +import time + +device = 'cpu' +model_size = 1 * 1024**3 +param = torch.nn.Parameter(torch.ones(model_size, device=device)) +param_fp16 = torch.nn.Parameter(torch.ones(model_size, + dtype=torch.half, + device='cuda:0')) + +optimizer = DeepSpeedCPUAdam([param]) +#torch.set_num_threads(128) +param.grad = torch.ones(model_size, device=device) +avg = 0 +for i in range(100): + start = time.time() + optimizer.step(fp16_param_groups=[param_fp16]) + stop = time.time() + avg += (stop - start) + param.grad = torch.ones(model_size, device=device) * 2 +print("Elapsed Time is ", avg / 100) diff --git a/tests/unit/ds_batch_config.json b/tests/unit/ds_batch_config.json index 2558a5b9d31b..2e86c1929cae 100755 --- a/tests/unit/ds_batch_config.json +++ b/tests/unit/ds_batch_config.json @@ -1,15 +1,15 @@ -{ - "train_batch_size": 2, - "gradient_accumulation_steps": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": true, - "loss_scale": 0 - } - } +{ + "train_batch_size": 2, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": true, + "loss_scale": 0 + } + } diff --git a/tests/unit/modelingpreln.py b/tests/unit/modelingpreln.py index 43f210ec9944..7661303a4145 100755 --- a/tests/unit/modelingpreln.py +++ b/tests/unit/modelingpreln.py @@ -1,1692 +1,1692 @@ -# DeepSpeed note, code taken from commit 3d59216cec89a363649b4fe3d15295ba936ced0f -# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py - -# coding=utf-8 -# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch BERT model.""" - -from __future__ import absolute_import, division, print_function, unicode_literals - -import copy -import json -import logging -import math -import os -import shutil -import tarfile -import tempfile -import sys -from io import open - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss -from torch.utils import checkpoint -import torch.distributed as dist - -from torch.nn import Module -from torch.nn.parameter import Parameter -import torch.nn.functional as F -import torch.nn.init as init -import time - -#from numba import cuda - -#from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax - -logger = logging.getLogger(__name__) - -PRETRAINED_MODEL_ARCHIVE_MAP = { - 'bert-base-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': - "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", -} -CONFIG_NAME = 'bert_config.json' -WEIGHTS_NAME = 'pytorch_model.bin' -TF_WEIGHTS_NAME = 'model.ckpt' - - -def load_tf_weights_in_bert(model, tf_checkpoint_path): - """ Load tf checkpoints in a pytorch model - """ - try: - import re - import numpy as np - import tensorflow as tf - except ImportError: - print( - "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " - "https://www.tensorflow.org/install/ for installation instructions.") - raise - tf_path = os.path.abspath(tf_checkpoint_path) - print("Converting TensorFlow checkpoint from {}".format(tf_path)) - # Load weights from TF model - init_vars = tf.train.list_variables(tf_path) - names = [] - arrays = [] - for name, shape in init_vars: - print("Loading TF weight {} with shape {}".format(name, shape)) - array = tf.train.load_variable(tf_path, name) - names.append(name) - arrays.append(array) - - for name, array in zip(names, arrays): - name = name.split('/') - # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v - # which are not required for using pretrained model - if any(n in ["adam_v", "adam_m"] for n in name): - print("Skipping {}".format("/".join(name))) - continue - pointer = model - for m_name in name: - if re.fullmatch(r'[A-Za-z]+_\d+', m_name): - l = re.split(r'_(\d+)', m_name) - else: - l = [m_name] - if l[0] == 'kernel' or l[0] == 'gamma': - pointer = getattr(pointer, 'weight') - elif l[0] == 'output_bias' or l[0] == 'beta': - pointer = getattr(pointer, 'bias') - elif l[0] == 'output_weights': - pointer = getattr(pointer, 'weight') - else: - pointer = getattr(pointer, l[0]) - if len(l) >= 2: - num = int(l[1]) - pointer = pointer[num] - if m_name[-11:] == '_embeddings': - pointer = getattr(pointer, 'weight') - elif m_name == 'kernel': - array = np.transpose(array) - try: - assert pointer.shape == array.shape - except AssertionError as e: - e.args += (pointer.shape, array.shape) - raise - print("Initialize PyTorch weight {}".format(name)) - pointer.data = torch.from_numpy(array) - return model - - -""" -@torch.jit.script -def f_gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) -@torch.jit.script -def bias_gelu(bias, y): - x = bias + y - return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) -@torch.jit.script -def bias_tanh(bias, y): - x = bias + y - return torch.tanh(x) - """ - - -def f_gelu(x): - x_type = x.dtype - x = x.float() - x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) - return x.to(x_type) - - -def bias_gelu(bias, y): - y_type = y.dtype - x = bias.float() + y.float() - x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) - return x.to(y_type) - - -def bias_tanh(bias, y): - y_type = y.dtype - x = bias.float() + y.float() - x = torch.tanh(x) - return x.to(y_type) - - -def gelu(x): - """Implementation of the gelu activation function. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - Also see https://arxiv.org/abs/1606.08415 - """ - return f_gelu(x) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class GPUTimer: - def __init__(self): - super().__init__() - self.start = cuda.event() - self.stop = cuda.event() - - def record(self): - self.start.record() - - def elapsed(self): - self.stop.record() - self.stop.synchronize() - return self.start.elapsed_time(self.stop) / 1000.0 - - -class LinearActivation(Module): - r"""Fused Linear and activation Module. - """ - __constants__ = ['bias'] - - def __init__(self, - in_features, - out_features, - weights, - biases, - act='gelu', - bias=True): - super(LinearActivation, self).__init__() - self.in_features = in_features - self.out_features = out_features - self.fused_gelu = False - self.fused_tanh = False - if isinstance(act, - str) or (sys.version_info[0] == 2 and isinstance(act, - unicode)): - if bias and act == 'gelu': - self.fused_gelu = True - elif bias and act == 'tanh': - self.fused_tanh = True - else: - self.act_fn = ACT2FN[act] - else: - self.act_fn = act - #self.weight = Parameter(torch.Tensor(out_features, in_features)) - self.weight = weights[5] - self.bias = biases[5] - #if bias: - # self.bias = Parameter(torch.Tensor(out_features)) - #else: - # self.register_parameter('bias', None) - #self.reset_parameters() - - def reset_parameters(self): - init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - if self.bias is not None: - fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) - bound = 1 / math.sqrt(fan_in) - init.uniform_(self.bias, -bound, bound) - - def forward(self, input): - if self.fused_gelu: - #timing = [] - #t1 = GPUTimer() - #t1.record() - y = F.linear(input, self.weight, None) - #timing.append(t1.elapsed()) - #t1.record() - bg = bias_gelu(self.bias, y) - #timing.append(t1.elapsed()) - return bg - elif self.fused_tanh: - return bias_tanh(self.bias, F.linear(input, self.weight, None)) - else: - return self.act_fn(F.linear(input, self.weight, self.bias)) - - def extra_repr(self): - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, - self.out_features, - self.bias is not None) - - -class BertConfig(object): - """Configuration class to store the configuration of a `BertModel`. - """ - def __init__(self, - vocab_size_or_config_json_file, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - batch_size=8, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - fp16=False): - """Constructs BertConfig. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probability for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - if isinstance(vocab_size_or_config_json_file, - str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, - unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.batch_size = batch_size - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.fp16 = fp16 - else: - raise ValueError("First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)") - - @classmethod - def from_dict(cls, json_object): - """Constructs a `BertConfig` from a Python dictionary of parameters.""" - config = BertConfig(vocab_size_or_config_json_file=-1) - for key, value in json_object.items(): - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r", encoding='utf-8') as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - -try: - import apex - #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm') - import apex.normalization - #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward') - BertLayerNorm = apex.normalization.FusedLayerNorm -except ImportError: - print( - "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex." - ) - - class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - """Construct a layernorm module in the TF style (epsilon inside the square root). - """ - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - pdtype = x.dtype - x = x.float() - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x.to(pdtype) + self.bias - - #def forward(self, x): - # u = x.mean(-1, keepdim=True) - # s = (x - u).pow(2).mean(-1, keepdim=True) - # x = (x - u) / torch.sqrt(s + self.variance_epsilon) - # return self.weight * x + self.bias - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ - def __init__(self, config): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.position_embeddings = nn.Embedding(config.max_position_embeddings, - config.hidden_size) - self.token_type_embeddings = nn.Embedding(config.type_vocab_size, - config.hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, - dtype=torch.long, - device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertSelfAttention, self).__init__() - if config.hidden_size % config.num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, - config.num_attention_heads)) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.query.weight = weights[0] - self.query.bias = biases[0] - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.key.weight = weights[1] - self.key.bias = biases[1] - self.value = nn.Linear(config.hidden_size, self.all_head_size) - self.value.weight = weights[2] - self.value.bias = biases[2] - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.softmax = nn.Softmax(dim=-1) - #self.softmax_config = DeepSpeedSoftmaxConfig() - #self.softmax_config.batch_size = config.batch_size - #self.softmax_config.max_seq_length = config.max_position_embeddings - #self.softmax_config.hidden_size = config.hidden_size - #self.softmax_config.heads = config.num_attention_heads - #self.softmax_config.softmax_id = i - #self.softmax_config.fp16 = config.fp16 - #self.softmax_config.prob_drop_out = 0.0 - #self.softmax = DeepSpeedSoftmax(i, self.softmax_config) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def transpose_key_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, - self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 3, 1) - - def forward(self, hidden_states, attention_mask, grads=None): - #timing = [] - #t1 = GPUTimer() - #t1.record() - mixed_query_layer = self.query(hidden_states) - - #timing.append(t1.elapsed()) - #print("Query elapsed: %s" % (time.clock() - start)) - #t1.record() - mixed_key_layer = self.key(hidden_states) - - #timing.append(t1.elapsed()) - #print("Key elapsed: %s" % (time.clock() - start)) - #t1.record() - mixed_value_layer = self.value(hidden_states) - #timing.append(t1.elapsed()) - #print("Value elapsed: %s" % (time.clock() - start)) - - #t1.record() - query_layer = self.transpose_for_scores(mixed_query_layer) - # print(query_layer) - #timing.append(t1.elapsed()) - #print("Query-Transform elapsed: %s" % (time.clock() - start)) - #t1.record() - key_layer = self.transpose_key_for_scores(mixed_key_layer) - # print(key_layer) - #timing.append(t1.elapsed()) - #print("Key-Transform elapsed: %s" % (time.clock() - start)) - #t1.record() - value_layer = self.transpose_for_scores(mixed_value_layer) - #print(value_layer) - #timing.append(t1.elapsed()) - #print("Value-Transform elapsed: %s" % (time.clock() - start)) - - # Take the dot product between "query" and "key" to get the raw attention scores. - #t1.record() - #print(query_layer.shape) - #print(key_layer.shape) - attention_scores = torch.matmul(query_layer, key_layer) - #print(attention_scores.shape) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - #print("Pytorch: ", attention_scores) - #timing.append(t1.elapsed()) - #print("Attention-Score elapsed: %s" % (time.clock() - start)) - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - #t1.record() - - # context_layer = self.softmax(query_layer, key_layer, value_layer, attention_mask) - #print("context shape is :", context_layer.shape) - #print("Cuda-ext:, ", attention_scores1) - # Normalize the attention scores to probabilities. - ####attention_probs = self.softmax(attention_scores) - #timing.append(t1.elapsed()) - #print("Softmax elapsed: %s" % (time.clock() - start)) - #t1 = GPUTimer() - #t1.record() - attention_scores = attention_scores + attention_mask - attention_probs = self.softmax(attention_scores) - #attention_scores = self.softmax(attention_scores, attention_mask) - #print("Softmax elapse {0:8.2f} ms", t1.elapsed() * 1000) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - #t1.record() - context_layer = torch.matmul(attention_probs, value_layer) - #timing.append(t1.elapsed()) - #print("Context elapsed: %s" % (time.clock() - start)) - #t1.record() - #context_layer1 = context_layer.permute( - # 0, 1, 3, 2, 4).contiguous() - #if grads is not None: - # context_layer.register_hook(lambda x, self = self : grads.append([x, "Context"])) - context_layer1 = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer1.size()[:-2] + (self.all_head_size, ) - context_layer1 = context_layer1.view(*new_context_layer_shape) - #timing.append(t1.elapsed()) - #print("Context-Transform elapsed: %s" % (time.clock() - start)) - - if grads is not None: - query_layer.register_hook(lambda x, self=self: grads.append([x, "Query"])) - key_layer.register_hook(lambda x, self=self: grads.append([x, "Key"])) - value_layer.register_hook(lambda x, self=self: grads.append([x, "Value"])) - - return context_layer1 - - -class BertSelfOutput(nn.Module): - def __init__(self, config, weights, biases): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dense.weight = weights[3] - self.dense.bias = biases[3] - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - #timing = [] - #t1 = GPUTimer() - #t1.record() - hidden_states = self.dense(hidden_states) - #timing.append(t1.elapsed()) - #print("Attention Output elapsed: %s" % (time.clock() - start)) - hidden_states = self.dropout(hidden_states) - #t1.record() - #hidden_states = self.LayerNorm(hidden_states + input_tensor) - #timing.append(t1.elapsed()) - #print("LayerNorm elapsed: %s" % (time.clock() - start)) - return hidden_states - - def get_w(self): - return self.dense.weight - - -class BertAttention(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertAttention, self).__init__() - self.self = BertSelfAttention(i, config, weights, biases) - self.output = BertSelfOutput(config, weights, biases) - - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) - attention_output = self.output(self_output, input_tensor) - return attention_output - - def get_w(self): - return self.output.get_w() - - -class BertIntermediate(nn.Module): - def __init__(self, config, weights, biases): - super(BertIntermediate, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.intermediate_size, - weights, - biases, - act=config.hidden_act) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config, weights, biases): - super(BertOutput, self).__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dense.weight = weights[6] - self.dense.bias = biases[6] - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - #timing = [] - #t1 = GPUTimer() - #t1.record() - #print (hidden_states) - #print (self.dense.weight) - hidden_states = self.dense(hidden_states) - #timing.append(t1.elapsed()) - #print("FF2 elapsed: %s" % (time.clock() - start)) - hidden_states = self.dropout(hidden_states) - #t1.record() - #hidden_states = self.LayerNorm(hidden_states + input_tensor) - #timing.append(t1.elapsed()) - #print("LayerNorm elapsed: %s" % (time.clock() - start)) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, i, config, weights, biases): - super(BertLayer, self).__init__() - self.attention = BertAttention(i, config, weights, biases) - self.PreAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.PostAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - self.intermediate = BertIntermediate(config, weights, biases) - self.output = BertOutput(config, weights, biases) - self.weight = weights - self.biases = biases - - def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False): - input_layer_norm = self.PreAttentionLayerNorm(hidden_states) - attention_output = self.attention(input_layer_norm, attention_mask) - #print ("hidden shape is :", hidden_states.shape) - intermediate_input = hidden_states + attention_output - - intermediate_layer_norm = self.PostAttentionLayerNorm(intermediate_input) - intermediate_output = self.intermediate(intermediate_layer_norm) - layer_output = self.output(intermediate_output, attention_output) - - #attention_output = self.attention(hidden_states, attention_mask) - #intermediate_output = self.intermediate(attention_output) - #layer_output = self.output(intermediate_output, attention_output) - - if collect_all_grads: - # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"])) - # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"])) - # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"])) - # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"])) - self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"])) - self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"])) - self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"])) - self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"])) - self.PostAttentionLayerNorm.weight.register_hook( - lambda x, - self=self: grads.append([x, - "N2_W"])) - self.PostAttentionLayerNorm.bias.register_hook( - lambda x, - self=self: grads.append([x, - "N2_B"])) - self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"])) - self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"])) - self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"])) - self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"])) - self.PreAttentionLayerNorm.weight.register_hook( - lambda x, - self=self: grads.append([x, - "norm_W"])) - self.PreAttentionLayerNorm.bias.register_hook( - lambda x, - self=self: grads.append([x, - "norm_B"])) - - return layer_output + intermediate_input - - def get_w(self): - return self.attention.get_w() - - -class BertEncoder(nn.Module): - def __init__(self, config, weights, biases): - super(BertEncoder, self).__init__() - #layer = BertLayer(config, weights, biases) - self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - self.layer = nn.ModuleList([ - copy.deepcopy(BertLayer(i, - config, - weights, - biases)) for i in range(config.num_hidden_layers) - ]) - self.grads = [] - self.graph = [] - - def get_grads(self): - return self.grads - - # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - # all_encoder_layers = [] - # for layer_module in self.layer: - # hidden_states = layer_module(hidden_states, attention_mask) - # if output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # if not output_all_encoded_layers: - # all_encoder_layers.append(hidden_states) - # return all_encoder_layers - - def get_modules(self, big_node, input): - for mdl in big_node.named_children(): - graph.append(mdl) - get_modules(self, mdl, input) - - def forward(self, - hidden_states, - attention_mask, - output_all_encoded_layers=True, - checkpoint_activations=False): - all_encoder_layers = [] - - def custom(start, end): - def custom_forward(*inputs): - layers = self.layer[start:end] - x_ = inputs[0] - for layer in layers: - x_ = layer(x_, inputs[1]) - return x_ - - return custom_forward - - if checkpoint_activations: - l = 0 - num_layers = len(self.layer) - chunk_length = math.ceil(math.sqrt(num_layers)) - while l < num_layers: - hidden_states = checkpoint.checkpoint(custom(l, - l + chunk_length), - hidden_states, - attention_mask * 1) - l += chunk_length - # decoder layers - else: - for i, layer_module in enumerate(self.layer): - hidden_states = layer_module(hidden_states, - attention_mask, - self.grads, - collect_all_grads=True) - hidden_states.register_hook( - lambda x, - i=i, - self=self: self.grads.append([x, - "hidden_state"])) - #print("pytorch weight is: ", layer_module.get_w()) - - if output_all_encoded_layers: - all_encoder_layers.append((hidden_states)) - - if not output_all_encoded_layers or checkpoint_activations: - hidden_states = self.FinalLayerNorm(hidden_states) - all_encoder_layers.append((hidden_states)) - return all_encoder_layers - - -#class BertEncoder(nn.Module): -# def __init__(self, config): -# super(BertEncoder, self).__init__() -# layer = BertLayer(config) -# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) -# -# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): -# all_encoder_layers = [] -# for layer_module in self.layer: -# hidden_states = layer_module(hidden_states, attention_mask) -# if output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# if not output_all_encoded_layers: -# all_encoder_layers.append(hidden_states) -# return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, config): - super(BertPooler, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.hidden_size, - act="tanh") - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense_act(first_token_tensor) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super(BertPredictionHeadTransform, self).__init__() - self.dense_act = LinearActivation(config.hidden_size, - config.hidden_size, - act=config.hidden_act) - self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) - - def forward(self, hidden_states): - hidden_states = self.dense_act(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertLMPredictionHead, self).__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear(bert_model_embedding_weights.size(1), - bert_model_embedding_weights.size(0), - bias=False) - self.decoder.weight = bert_model_embedding_weights - self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - torch.cuda.nvtx.range_push( - "decoder input.size() = {}, weight.size() = {}".format( - hidden_states.size(), - self.decoder.weight.size())) - hidden_states = self.decoder(hidden_states) + self.bias - torch.cuda.nvtx.range_pop() - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertOnlyMLMHead, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertOnlyNSPHead(nn.Module): - def __init__(self, config): - super(BertOnlyNSPHead, self).__init__() - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, pooled_output): - seq_relationship_score = self.seq_relationship(pooled_output) - return seq_relationship_score - - -class BertPreTrainingHeads(nn.Module): - def __init__(self, config, bert_model_embedding_weights): - super(BertPreTrainingHeads, self).__init__() - self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) - self.seq_relationship = nn.Linear(config.hidden_size, 2) - - def forward(self, sequence_output, pooled_output): - prediction_scores = self.predictions(sequence_output) - seq_relationship_score = self.seq_relationship(pooled_output) - return prediction_scores, seq_relationship_score - - -class BertPreTrainedModel(nn.Module): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. - """ - def __init__(self, config, *inputs, **kwargs): - super(BertPreTrainedModel, self).__init__() - if not isinstance(config, BertConfig): - raise ValueError( - "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " - "To create a model from a Google pretrained model use " - "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( - self.__class__.__name__, - self.__class__.__name__)) - self.config = config - - def init_bert_weights(self, module): - """ Initialize the weights. - """ - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - @classmethod - def from_pretrained(cls, - pretrained_model_name_or_path, - state_dict=None, - cache_dir=None, - from_tf=False, - *inputs, - **kwargs): - """ - Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. - Download and cache the pre-trained model file if needed. - - Params: - pretrained_model_name_or_path: either: - - a str with the name of a pre-trained model to load selected in the list of: - . `bert-base-uncased` - . `bert-large-uncased` - . `bert-base-cased` - . `bert-large-cased` - . `bert-base-multilingual-uncased` - . `bert-base-multilingual-cased` - . `bert-base-chinese` - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance - - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model - . `model.chkpt` a TensorFlow checkpoint - from_tf: should we load the weights from a locally saved TensorFlow checkpoint - cache_dir: an optional path to a folder in which the pre-trained models will be cached. - state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) - """ - if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: - archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] - else: - archive_file = pretrained_model_name_or_path - if resolved_archive_file == archive_file: - logger.info("loading archive file {}".format(archive_file)) - else: - logger.info("loading archive file {} from cache at {}".format( - archive_file, - resolved_archive_file)) - tempdir = None - if os.path.isdir(resolved_archive_file) or from_tf: - serialization_dir = resolved_archive_file - else: - # Extract archive to temp dir - tempdir = tempfile.mkdtemp() - logger.info("extracting archive file {} to temp dir {}".format( - resolved_archive_file, - tempdir)) - with tarfile.open(resolved_archive_file, 'r:gz') as archive: - archive.extractall(tempdir) - serialization_dir = tempdir - # Load config - config_file = os.path.join(serialization_dir, CONFIG_NAME) - config = BertConfig.from_json_file(config_file) - logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) - if state_dict is None and not from_tf: - weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) - state_dict = torch.load( - weights_path, - map_location='cpu' if not torch.cuda.is_available() else None) - if tempdir: - # Clean up temp dir - shutil.rmtree(tempdir) - if from_tf: - # Directly load from a TensorFlow checkpoint - weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) - return load_tf_weights_in_bert(model, weights_path) - # Load from a PyTorch state_dict - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict(state_dict, - prefix, - local_metadata, - True, - missing_keys, - unexpected_keys, - error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - start_prefix = '' - if not hasattr(model, - 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): - start_prefix = 'bert.' - load(model, prefix=start_prefix) - if len(missing_keys) > 0: - logger.info("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, - missing_keys)) - if len(unexpected_keys) > 0: - logger.info("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, - unexpected_keys)) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - model.__class__.__name__, - "\n\t".join(error_msgs))) - return model - - -class BertModel(BertPreTrainedModel): - """BERT model ("Bidirectional Embedding Representations from a Transformer"). - - Params: - config: a BertConfig class instance with the configuration to build a new model - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. - - Outputs: Tuple of (encoded_layers, pooled_output) - `encoded_layers`: controlled by `output_all_encoded_layers` argument: - - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end - of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each - encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], - - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding - to the last attention block of shape [batch_size, sequence_length, hidden_size], - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a - classifier pretrained on top of the hidden state associated to the first character of the - input (`CLS`) to train on the Next-Sentence task (see BERT's paper). - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = modeling.BertModel(config=config) - all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertModel, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - output_all_encoded_layers=True, - checkpoint_activations=False): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next( - self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder( - embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers, - checkpoint_activations=checkpoint_activations) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - -class BertForPreTraining(BertPreTrainedModel): - """BERT model with pre-training heads. - This module comprises the BERT model followed by the two pre-training heads: - - the masked language modeling head, and - - the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `masked_lm_labels` and `next_sentence_label` are not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `masked_lm_labels` or `next_sentence_label` is `None`: - Outputs a tuple comprising - - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - - the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForPreTraining(config) - masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, args): - super(BertForPreTraining, self).__init__(config) - self.summary_writer = None - if dist.get_rank() == 0: - self.summary_writer = args.summary_writer - self.samples_per_step = dist.get_world_size() * args.train_batch_size - self.sample_count = self.samples_per_step - self.bert = BertModel(config) - self.cls = BertPreTrainingHeads(config, - self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def log_summary_writer(self, logs: dict, base='Train'): - if dist.get_rank() == 0: - module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) - for key, log in logs.items(): - self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', - log, - self.sample_count) - self.sample_count += self.samples_per_step - - def forward(self, batch, log=True): - #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): - input_ids = batch[1] - token_type_ids = batch[3] - attention_mask = batch[2] - masked_lm_labels = batch[5] - next_sentence_label = batch[4] - checkpoint_activations = False - - sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) - prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) - - if masked_lm_labels is not None and next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, - self.config.vocab_size), - masked_lm_labels.view(-1)) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, - 2), - next_sentence_label.view(-1)) - #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) - total_loss = masked_lm_loss + next_sentence_loss - # if log: - # self.log_summary_writer(logs={'train_loss': total_loss.item()}) - return total_loss - else: - return prediction_scores, seq_relationship_score - - -class BertForMaskedLM(BertPreTrainedModel): - """BERT model with the masked language modeling head. - This module comprises the BERT model followed by the masked language modeling head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss - is only computed for the labels set in [0, ..., vocab_size] - - Outputs: - if `masked_lm_labels` is not `None`: - Outputs the masked language modeling loss. - if `masked_lm_labels` is `None`: - Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForMaskedLM(config) - masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - masked_lm_labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) - prediction_scores = self.cls(sequence_output) - - if masked_lm_labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - masked_lm_loss = loss_fct(prediction_scores.view(-1, - self.config.vocab_size), - masked_lm_labels.view(-1)) - return masked_lm_loss - else: - return prediction_scores - - -class BertForNextSentencePrediction(BertPreTrainedModel): - """BERT model with next sentence prediction head. - This module comprises the BERT model followed by the next sentence classification head. - - Params: - config: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] - with indices selected in [0, 1]. - 0 => next sentence is the continuation, 1 => next sentence is a random sentence. - - Outputs: - if `next_sentence_label` is not `None`: - Outputs the total_loss which is the sum of the masked language modeling loss and the next - sentence classification loss. - if `next_sentence_label` is `None`: - Outputs the next sentence classification logits of shape [batch_size, 2]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForNextSentencePrediction(config) - seq_relationship_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) - self.cls = BertOnlyNSPHead(config) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - next_sentence_label=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, - output_all_encoded_layers=False) - seq_relationship_score = self.cls(pooled_output) - - if next_sentence_label is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - next_sentence_loss = loss_fct(seq_relationship_score.view(-1, - 2), - next_sentence_label.view(-1)) - return next_sentence_loss - else: - return seq_relationship_score - - -class BertForSequenceClassification(BertPreTrainedModel): - """BERT model for classification. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForSequenceClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_labels): - super(BertForSequenceClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForMultipleChoice(BertPreTrainedModel): - """BERT model for multiple choice tasks. - This module is composed of the BERT model with a linear layer on top of - the pooled output. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_choices`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] - with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` - and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] - with indices selected in [0, ..., num_choices]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) - input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) - token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_choices = 2 - - model = BertForMultipleChoice(config, num_choices) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_choices): - super(BertForMultipleChoice, self).__init__(config) - self.num_choices = num_choices - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, 1) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - flat_input_ids = input_ids.view(-1, input_ids.size(-1)) - flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) - flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) - _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) - pooled_output = self.dropout(pooled_output) - logits = self.classifier(pooled_output) - reshaped_logits = logits.view(-1, self.num_choices) - - if labels is not None: - loss_fct = CrossEntropyLoss() - loss = loss_fct(reshaped_logits, labels) - return loss - else: - return reshaped_logits - - -class BertForTokenClassification(BertPreTrainedModel): - """BERT model for token-level classification. - This module is composed of the BERT model with a linear layer on top of - the full hidden state of the last layer. - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - `num_labels`: the number of classes for the classifier. Default = 2. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] - with indices selected in [0, ..., num_labels]. - - Outputs: - if `labels` is not `None`: - Outputs the CrossEntropy classification loss of the output with the labels. - if `labels` is `None`: - Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - num_labels = 2 - - model = BertForTokenClassification(config, num_labels) - logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config, num_labels): - super(BertForTokenClassification, self).__init__(config) - self.num_labels = num_labels - self.bert = BertModel(config) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.classifier = nn.Linear(config.hidden_size, num_labels) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - labels=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - sequence_output = self.dropout(sequence_output) - logits = self.classifier(sequence_output) - - if labels is not None: - loss_fct = CrossEntropyLoss() - # Only keep active parts of the loss - if attention_mask is not None: - active_loss = attention_mask.view(-1) == 1 - active_logits = logits.view(-1, self.num_labels)[active_loss] - active_labels = labels.view(-1)[active_loss] - loss = loss_fct(active_logits, active_labels) - else: - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - return loss - else: - return logits - - -class BertForQuestionAnswering(BertPreTrainedModel): - """BERT model for Question Answering (span extraction). - This module is composed of the BERT model with a linear layer on top of - the sequence output that computes start_logits and end_logits - - Params: - `config`: a BertConfig class instance with the configuration to build a new model. - - Inputs: - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] - with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts - `extract_features.py`, `run_classifier.py` and `run_squad.py`) - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token - types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to - a `sentence B` token (see BERT paper for more details). - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices - selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max - input sequence length in the current batch. It's the mask that we typically use for attention when - a batch has varying length sentences. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. - Positions are clamped to the length of the sequence and position outside of the sequence are not taken - into account for computing the loss. - - Outputs: - if `start_positions` and `end_positions` are not `None`: - Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. - if `start_positions` or `end_positions` is `None`: - Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end - position tokens of shape [batch_size, sequence_length]. - - Example usage: - ```python - # Already been converted into WordPiece token ids - input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) - input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) - token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) - - config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, - num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) - - model = BertForQuestionAnswering(config) - start_logits, end_logits = model(input_ids, token_type_ids, input_mask) - ``` - """ - def __init__(self, config): - super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.qa_outputs = nn.Linear(config.hidden_size, 2) - self.apply(self.init_bert_weights) - - def forward(self, - input_ids, - token_type_ids=None, - attention_mask=None, - start_positions=None, - end_positions=None, - checkpoint_activations=False): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - logits = self.qa_outputs(sequence_output) - start_logits, end_logits = logits.split(1, dim=-1) - start_logits = start_logits.squeeze(-1) - end_logits = end_logits.squeeze(-1) - - if start_positions is not None and end_positions is not None: - # If we are on multi-GPU, split add a dimension - if len(start_positions.size()) > 1: - start_positions = start_positions.squeeze(-1) - if len(end_positions.size()) > 1: - end_positions = end_positions.squeeze(-1) - # sometimes the start/end positions are outside our model inputs, we ignore these terms - ignored_index = start_logits.size(1) - start_positions.clamp_(0, ignored_index) - end_positions.clamp_(0, ignored_index) - - loss_fct = CrossEntropyLoss(ignore_index=ignored_index) - start_loss = loss_fct(start_logits, start_positions) - end_loss = loss_fct(end_logits, end_positions) - total_loss = (start_loss + end_loss) / 2 - return total_loss - else: - return start_logits, end_logits +# DeepSpeed note, code taken from commit 3d59216cec89a363649b4fe3d15295ba936ced0f +# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/BERT/modeling.py + +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import copy +import json +import logging +import math +import os +import shutil +import tarfile +import tempfile +import sys +from io import open + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils import checkpoint +import torch.distributed as dist + +from torch.nn import Module +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import torch.nn.init as init +import time + +#from numba import cuda + +#from deepspeed_cuda import DeepSpeedSoftmaxConfig, DeepSpeedSoftmax + +logger = logging.getLogger(__name__) + +PRETRAINED_MODEL_ARCHIVE_MAP = { + 'bert-base-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': + "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", +} +CONFIG_NAME = 'bert_config.json' +WEIGHTS_NAME = 'pytorch_model.bin' +TF_WEIGHTS_NAME = 'model.ckpt' + + +def load_tf_weights_in_bert(model, tf_checkpoint_path): + """ Load tf checkpoints in a pytorch model + """ + try: + import re + import numpy as np + import tensorflow as tf + except ImportError: + print( + "Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions.") + raise + tf_path = os.path.abspath(tf_checkpoint_path) + print("Converting TensorFlow checkpoint from {}".format(tf_path)) + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + print("Loading TF weight {} with shape {}".format(name, shape)) + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array) + + for name, array in zip(names, arrays): + name = name.split('/') + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any(n in ["adam_v", "adam_m"] for n in name): + print("Skipping {}".format("/".join(name))) + continue + pointer = model + for m_name in name: + if re.fullmatch(r'[A-Za-z]+_\d+', m_name): + l = re.split(r'_(\d+)', m_name) + else: + l = [m_name] + if l[0] == 'kernel' or l[0] == 'gamma': + pointer = getattr(pointer, 'weight') + elif l[0] == 'output_bias' or l[0] == 'beta': + pointer = getattr(pointer, 'bias') + elif l[0] == 'output_weights': + pointer = getattr(pointer, 'weight') + else: + pointer = getattr(pointer, l[0]) + if len(l) >= 2: + num = int(l[1]) + pointer = pointer[num] + if m_name[-11:] == '_embeddings': + pointer = getattr(pointer, 'weight') + elif m_name == 'kernel': + array = np.transpose(array) + try: + assert pointer.shape == array.shape + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + print("Initialize PyTorch weight {}".format(name)) + pointer.data = torch.from_numpy(array) + return model + + +""" +@torch.jit.script +def f_gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) +@torch.jit.script +def bias_gelu(bias, y): + x = bias + y + return x * 0.5 * (1.0 + torch.erf(x / 1.41421)) +@torch.jit.script +def bias_tanh(bias, y): + x = bias + y + return torch.tanh(x) + """ + + +def f_gelu(x): + x_type = x.dtype + x = x.float() + x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) + return x.to(x_type) + + +def bias_gelu(bias, y): + y_type = y.dtype + x = bias.float() + y.float() + x = x * 0.5 * (1.0 + torch.erf(x / 1.41421)) + return x.to(y_type) + + +def bias_tanh(bias, y): + y_type = y.dtype + x = bias.float() + y.float() + x = torch.tanh(x) + return x.to(y_type) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return f_gelu(x) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class GPUTimer: + def __init__(self): + super().__init__() + self.start = cuda.event() + self.stop = cuda.event() + + def record(self): + self.start.record() + + def elapsed(self): + self.stop.record() + self.stop.synchronize() + return self.start.elapsed_time(self.stop) / 1000.0 + + +class LinearActivation(Module): + r"""Fused Linear and activation Module. + """ + __constants__ = ['bias'] + + def __init__(self, + in_features, + out_features, + weights, + biases, + act='gelu', + bias=True): + super(LinearActivation, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.fused_gelu = False + self.fused_tanh = False + if isinstance(act, + str) or (sys.version_info[0] == 2 and isinstance(act, + unicode)): + if bias and act == 'gelu': + self.fused_gelu = True + elif bias and act == 'tanh': + self.fused_tanh = True + else: + self.act_fn = ACT2FN[act] + else: + self.act_fn = act + #self.weight = Parameter(torch.Tensor(out_features, in_features)) + self.weight = weights[5] + self.bias = biases[5] + #if bias: + # self.bias = Parameter(torch.Tensor(out_features)) + #else: + # self.register_parameter('bias', None) + #self.reset_parameters() + + def reset_parameters(self): + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + if self.fused_gelu: + #timing = [] + #t1 = GPUTimer() + #t1.record() + y = F.linear(input, self.weight, None) + #timing.append(t1.elapsed()) + #t1.record() + bg = bias_gelu(self.bias, y) + #timing.append(t1.elapsed()) + return bg + elif self.fused_tanh: + return bias_tanh(self.bias, F.linear(input, self.weight, None)) + else: + return self.act_fn(F.linear(input, self.weight, self.bias)) + + def extra_repr(self): + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, + self.out_features, + self.bias is not None) + + +class BertConfig(object): + """Configuration class to store the configuration of a `BertModel`. + """ + def __init__(self, + vocab_size_or_config_json_file, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + batch_size=8, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + fp16=False): + """Constructs BertConfig. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probability for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + if isinstance(vocab_size_or_config_json_file, + str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, + unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.batch_size = batch_size + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.fp16 = fp16 + else: + raise ValueError("First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)") + + @classmethod + def from_dict(cls, json_object): + """Constructs a `BertConfig` from a Python dictionary of parameters.""" + config = BertConfig(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + +try: + import apex + #apex.amp.register_half_function(apex.normalization.fused_layer_norm, 'FusedLayerNorm') + import apex.normalization + #apex.amp.register_float_function(apex.normalization.FusedLayerNorm, 'forward') + BertLayerNorm = apex.normalization.FusedLayerNorm +except ImportError: + print( + "Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex." + ) + + class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + pdtype = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x.to(pdtype) + self.bias + + #def forward(self, x): + # u = x.mean(-1, keepdim=True) + # s = (x - u).pow(2).mean(-1, keepdim=True) + # x = (x - u) / torch.sqrt(s + self.variance_epsilon) + # return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, config): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, + config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, + config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, + dtype=torch.long, + device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, + config.num_attention_heads)) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.query.weight = weights[0] + self.query.bias = biases[0] + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.key.weight = weights[1] + self.key.bias = biases[1] + self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.value.weight = weights[2] + self.value.bias = biases[2] + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.softmax = nn.Softmax(dim=-1) + #self.softmax_config = DeepSpeedSoftmaxConfig() + #self.softmax_config.batch_size = config.batch_size + #self.softmax_config.max_seq_length = config.max_position_embeddings + #self.softmax_config.hidden_size = config.hidden_size + #self.softmax_config.heads = config.num_attention_heads + #self.softmax_config.softmax_id = i + #self.softmax_config.fp16 = config.fp16 + #self.softmax_config.prob_drop_out = 0.0 + #self.softmax = DeepSpeedSoftmax(i, self.softmax_config) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def transpose_key_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, + self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 3, 1) + + def forward(self, hidden_states, attention_mask, grads=None): + #timing = [] + #t1 = GPUTimer() + #t1.record() + mixed_query_layer = self.query(hidden_states) + + #timing.append(t1.elapsed()) + #print("Query elapsed: %s" % (time.clock() - start)) + #t1.record() + mixed_key_layer = self.key(hidden_states) + + #timing.append(t1.elapsed()) + #print("Key elapsed: %s" % (time.clock() - start)) + #t1.record() + mixed_value_layer = self.value(hidden_states) + #timing.append(t1.elapsed()) + #print("Value elapsed: %s" % (time.clock() - start)) + + #t1.record() + query_layer = self.transpose_for_scores(mixed_query_layer) + # print(query_layer) + #timing.append(t1.elapsed()) + #print("Query-Transform elapsed: %s" % (time.clock() - start)) + #t1.record() + key_layer = self.transpose_key_for_scores(mixed_key_layer) + # print(key_layer) + #timing.append(t1.elapsed()) + #print("Key-Transform elapsed: %s" % (time.clock() - start)) + #t1.record() + value_layer = self.transpose_for_scores(mixed_value_layer) + #print(value_layer) + #timing.append(t1.elapsed()) + #print("Value-Transform elapsed: %s" % (time.clock() - start)) + + # Take the dot product between "query" and "key" to get the raw attention scores. + #t1.record() + #print(query_layer.shape) + #print(key_layer.shape) + attention_scores = torch.matmul(query_layer, key_layer) + #print(attention_scores.shape) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + #print("Pytorch: ", attention_scores) + #timing.append(t1.elapsed()) + #print("Attention-Score elapsed: %s" % (time.clock() - start)) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + #t1.record() + + # context_layer = self.softmax(query_layer, key_layer, value_layer, attention_mask) + #print("context shape is :", context_layer.shape) + #print("Cuda-ext:, ", attention_scores1) + # Normalize the attention scores to probabilities. + ####attention_probs = self.softmax(attention_scores) + #timing.append(t1.elapsed()) + #print("Softmax elapsed: %s" % (time.clock() - start)) + #t1 = GPUTimer() + #t1.record() + attention_scores = attention_scores + attention_mask + attention_probs = self.softmax(attention_scores) + #attention_scores = self.softmax(attention_scores, attention_mask) + #print("Softmax elapse {0:8.2f} ms", t1.elapsed() * 1000) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + #t1.record() + context_layer = torch.matmul(attention_probs, value_layer) + #timing.append(t1.elapsed()) + #print("Context elapsed: %s" % (time.clock() - start)) + #t1.record() + #context_layer1 = context_layer.permute( + # 0, 1, 3, 2, 4).contiguous() + #if grads is not None: + # context_layer.register_hook(lambda x, self = self : grads.append([x, "Context"])) + context_layer1 = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer1.size()[:-2] + (self.all_head_size, ) + context_layer1 = context_layer1.view(*new_context_layer_shape) + #timing.append(t1.elapsed()) + #print("Context-Transform elapsed: %s" % (time.clock() - start)) + + if grads is not None: + query_layer.register_hook(lambda x, self=self: grads.append([x, "Query"])) + key_layer.register_hook(lambda x, self=self: grads.append([x, "Key"])) + value_layer.register_hook(lambda x, self=self: grads.append([x, "Value"])) + + return context_layer1 + + +class BertSelfOutput(nn.Module): + def __init__(self, config, weights, biases): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dense.weight = weights[3] + self.dense.bias = biases[3] + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + #timing = [] + #t1 = GPUTimer() + #t1.record() + hidden_states = self.dense(hidden_states) + #timing.append(t1.elapsed()) + #print("Attention Output elapsed: %s" % (time.clock() - start)) + hidden_states = self.dropout(hidden_states) + #t1.record() + #hidden_states = self.LayerNorm(hidden_states + input_tensor) + #timing.append(t1.elapsed()) + #print("LayerNorm elapsed: %s" % (time.clock() - start)) + return hidden_states + + def get_w(self): + return self.dense.weight + + +class BertAttention(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(i, config, weights, biases) + self.output = BertSelfOutput(config, weights, biases) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + def get_w(self): + return self.output.get_w() + + +class BertIntermediate(nn.Module): + def __init__(self, config, weights, biases): + super(BertIntermediate, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.intermediate_size, + weights, + biases, + act=config.hidden_act) + + def forward(self, hidden_states): + hidden_states = self.dense_act(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config, weights, biases): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.dense.weight = weights[6] + self.dense.bias = biases[6] + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + #timing = [] + #t1 = GPUTimer() + #t1.record() + #print (hidden_states) + #print (self.dense.weight) + hidden_states = self.dense(hidden_states) + #timing.append(t1.elapsed()) + #print("FF2 elapsed: %s" % (time.clock() - start)) + hidden_states = self.dropout(hidden_states) + #t1.record() + #hidden_states = self.LayerNorm(hidden_states + input_tensor) + #timing.append(t1.elapsed()) + #print("LayerNorm elapsed: %s" % (time.clock() - start)) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, i, config, weights, biases): + super(BertLayer, self).__init__() + self.attention = BertAttention(i, config, weights, biases) + self.PreAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.PostAttentionLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + self.intermediate = BertIntermediate(config, weights, biases) + self.output = BertOutput(config, weights, biases) + self.weight = weights + self.biases = biases + + def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False): + input_layer_norm = self.PreAttentionLayerNorm(hidden_states) + attention_output = self.attention(input_layer_norm, attention_mask) + #print ("hidden shape is :", hidden_states.shape) + intermediate_input = hidden_states + attention_output + + intermediate_layer_norm = self.PostAttentionLayerNorm(intermediate_input) + intermediate_output = self.intermediate(intermediate_layer_norm) + layer_output = self.output(intermediate_output, attention_output) + + #attention_output = self.attention(hidden_states, attention_mask) + #intermediate_output = self.intermediate(attention_output) + #layer_output = self.output(intermediate_output, attention_output) + + if collect_all_grads: + # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"])) + # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"])) + # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"])) + # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"])) + self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"])) + self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"])) + self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"])) + self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"])) + self.PostAttentionLayerNorm.weight.register_hook( + lambda x, + self=self: grads.append([x, + "N2_W"])) + self.PostAttentionLayerNorm.bias.register_hook( + lambda x, + self=self: grads.append([x, + "N2_B"])) + self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"])) + self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"])) + self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"])) + self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"])) + self.PreAttentionLayerNorm.weight.register_hook( + lambda x, + self=self: grads.append([x, + "norm_W"])) + self.PreAttentionLayerNorm.bias.register_hook( + lambda x, + self=self: grads.append([x, + "norm_B"])) + + return layer_output + intermediate_input + + def get_w(self): + return self.attention.get_w() + + +class BertEncoder(nn.Module): + def __init__(self, config, weights, biases): + super(BertEncoder, self).__init__() + #layer = BertLayer(config, weights, biases) + self.FinalLayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + self.layer = nn.ModuleList([ + copy.deepcopy(BertLayer(i, + config, + weights, + biases)) for i in range(config.num_hidden_layers) + ]) + self.grads = [] + self.graph = [] + + def get_grads(self): + return self.grads + + # def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + # all_encoder_layers = [] + # for layer_module in self.layer: + # hidden_states = layer_module(hidden_states, attention_mask) + # if output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # if not output_all_encoded_layers: + # all_encoder_layers.append(hidden_states) + # return all_encoder_layers + + def get_modules(self, big_node, input): + for mdl in big_node.named_children(): + graph.append(mdl) + get_modules(self, mdl, input) + + def forward(self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + checkpoint_activations=False): + all_encoder_layers = [] + + def custom(start, end): + def custom_forward(*inputs): + layers = self.layer[start:end] + x_ = inputs[0] + for layer in layers: + x_ = layer(x_, inputs[1]) + return x_ + + return custom_forward + + if checkpoint_activations: + l = 0 + num_layers = len(self.layer) + chunk_length = math.ceil(math.sqrt(num_layers)) + while l < num_layers: + hidden_states = checkpoint.checkpoint(custom(l, + l + chunk_length), + hidden_states, + attention_mask * 1) + l += chunk_length + # decoder layers + else: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, + attention_mask, + self.grads, + collect_all_grads=True) + hidden_states.register_hook( + lambda x, + i=i, + self=self: self.grads.append([x, + "hidden_state"])) + #print("pytorch weight is: ", layer_module.get_w()) + + if output_all_encoded_layers: + all_encoder_layers.append((hidden_states)) + + if not output_all_encoded_layers or checkpoint_activations: + hidden_states = self.FinalLayerNorm(hidden_states) + all_encoder_layers.append((hidden_states)) + return all_encoder_layers + + +#class BertEncoder(nn.Module): +# def __init__(self, config): +# super(BertEncoder, self).__init__() +# layer = BertLayer(config) +# self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) +# +# def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): +# all_encoder_layers = [] +# for layer_module in self.layer: +# hidden_states = layer_module(hidden_states, attention_mask) +# if output_all_encoded_layers: +# all_encoder_layers.append(hidden_states) +# if not output_all_encoded_layers: +# all_encoder_layers.append(hidden_states) +# return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, config): + super(BertPooler, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.hidden_size, + act="tanh") + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense_act(first_token_tensor) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super(BertPredictionHeadTransform, self).__init__() + self.dense_act = LinearActivation(config.hidden_size, + config.hidden_size, + act=config.hidden_act) + self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) + + def forward(self, hidden_states): + hidden_states = self.dense_act(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertLMPredictionHead, self).__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(bert_model_embedding_weights.size(1), + bert_model_embedding_weights.size(0), + bias=False) + self.decoder.weight = bert_model_embedding_weights + self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0))) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + torch.cuda.nvtx.range_push( + "decoder input.size() = {}, weight.size() = {}".format( + hidden_states.size(), + self.decoder.weight.size())) + hidden_states = self.decoder(hidden_states) + self.bias + torch.cuda.nvtx.range_pop() + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertOnlyMLMHead, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertOnlyNSPHead(nn.Module): + def __init__(self, config): + super(BertOnlyNSPHead, self).__init__() + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, pooled_output): + seq_relationship_score = self.seq_relationship(pooled_output) + return seq_relationship_score + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config, bert_model_embedding_weights): + super(BertPreTrainingHeads, self).__init__() + self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super(BertPreTrainedModel, self).__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, + self.__class__.__name__)) + self.config = config + + def init_bert_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + @classmethod + def from_pretrained(cls, + pretrained_model_name_or_path, + state_dict=None, + cache_dir=None, + from_tf=False, + *inputs, + **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `bert-base-uncased` + . `bert-large-uncased` + . `bert-base-cased` + . `bert-large-cased` + . `bert-base-multilingual-uncased` + . `bert-base-multilingual-cased` + . `bert-base-chinese` + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of Google pre-trained models + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: + archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] + else: + archive_file = pretrained_model_name_or_path + if resolved_archive_file == archive_file: + logger.info("loading archive file {}".format(archive_file)) + else: + logger.info("loading archive file {} from cache at {}".format( + archive_file, + resolved_archive_file)) + tempdir = None + if os.path.isdir(resolved_archive_file) or from_tf: + serialization_dir = resolved_archive_file + else: + # Extract archive to temp dir + tempdir = tempfile.mkdtemp() + logger.info("extracting archive file {} to temp dir {}".format( + resolved_archive_file, + tempdir)) + with tarfile.open(resolved_archive_file, 'r:gz') as archive: + archive.extractall(tempdir) + serialization_dir = tempdir + # Load config + config_file = os.path.join(serialization_dir, CONFIG_NAME) + config = BertConfig.from_json_file(config_file) + logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(config, *inputs, **kwargs) + if state_dict is None and not from_tf: + weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) + state_dict = torch.load( + weights_path, + map_location='cpu' if not torch.cuda.is_available() else None) + if tempdir: + # Clean up temp dir + shutil.rmtree(tempdir) + if from_tf: + # Directly load from a TensorFlow checkpoint + weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME) + return load_tf_weights_in_bert(model, weights_path) + # Load from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict(state_dict, + prefix, + local_metadata, + True, + missing_keys, + unexpected_keys, + error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + start_prefix = '' + if not hasattr(model, + 'bert') and any(s.startswith('bert.') for s in state_dict.keys()): + start_prefix = 'bert.' + load(model, prefix=start_prefix) + if len(missing_keys) > 0: + logger.info("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, + missing_keys)) + if len(unexpected_keys) > 0: + logger.info("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, + unexpected_keys)) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + model.__class__.__name__, + "\n\t".join(error_msgs))) + return model + + +class BertModel(BertPreTrainedModel): + """BERT model ("Bidirectional Embedding Representations from a Transformer"). + + Params: + config: a BertConfig class instance with the configuration to build a new model + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`. + + Outputs: Tuple of (encoded_layers, pooled_output) + `encoded_layers`: controlled by `output_all_encoded_layers` argument: + - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end + of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each + encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size], + - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding + to the last attention block of shape [batch_size, sequence_length, hidden_size], + `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a + classifier pretrained on top of the hidden state associated to the first character of the + input (`CLS`) to train on the Next-Sentence task (see BERT's paper). + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = modeling.BertModel(config=config) + all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + output_all_encoded_layers=True, + checkpoint_activations=False): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next( + self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder( + embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers, + checkpoint_activations=checkpoint_activations) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + +class BertForPreTraining(BertPreTrainedModel): + """BERT model with pre-training heads. + This module comprises the BERT model followed by the two pre-training heads: + - the masked language modeling head, and + - the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `masked_lm_labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `masked_lm_labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForPreTraining(config) + masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, args): + super(BertForPreTraining, self).__init__(config) + self.summary_writer = None + if dist.get_rank() == 0: + self.summary_writer = args.summary_writer + self.samples_per_step = dist.get_world_size() * args.train_batch_size + self.sample_count = self.samples_per_step + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config, + self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def log_summary_writer(self, logs: dict, base='Train'): + if dist.get_rank() == 0: + module_name = "Samples" #self._batch_module_name.get(batch_type, self._get_batch_type_error(batch_type)) + for key, log in logs.items(): + self.summary_writer.add_scalar(f'{base}/{module_name}/{key}', + log, + self.sample_count) + self.sample_count += self.samples_per_step + + def forward(self, batch, log=True): + #input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None, checkpoint_activations=False): + input_ids = batch[1] + token_type_ids = batch[3] + attention_mask = batch[2] + masked_lm_labels = batch[5] + next_sentence_label = batch[4] + checkpoint_activations = False + + sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False, checkpoint_activations=checkpoint_activations) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + if masked_lm_labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, + self.config.vocab_size), + masked_lm_labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, + 2), + next_sentence_label.view(-1)) + #print("loss is {} {}".format(masked_lm_loss, next_sentence_loss)) + total_loss = masked_lm_loss + next_sentence_loss + # if log: + # self.log_summary_writer(logs={'train_loss': total_loss.item()}) + return total_loss + else: + return prediction_scores, seq_relationship_score + + +class BertForMaskedLM(BertPreTrainedModel): + """BERT model with the masked language modeling head. + This module comprises the BERT model followed by the masked language modeling head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss + is only computed for the labels set in [0, ..., vocab_size] + + Outputs: + if `masked_lm_labels` is not `None`: + Outputs the masked language modeling loss. + if `masked_lm_labels` is `None`: + Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForMaskedLM(config) + masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForMaskedLM, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + masked_lm_labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + prediction_scores = self.cls(sequence_output) + + if masked_lm_labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + masked_lm_loss = loss_fct(prediction_scores.view(-1, + self.config.vocab_size), + masked_lm_labels.view(-1)) + return masked_lm_loss + else: + return prediction_scores + + +class BertForNextSentencePrediction(BertPreTrainedModel): + """BERT model with next sentence prediction head. + This module comprises the BERT model followed by the next sentence classification head. + + Params: + config: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] + with indices selected in [0, 1]. + 0 => next sentence is the continuation, 1 => next sentence is a random sentence. + + Outputs: + if `next_sentence_label` is not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `next_sentence_label` is `None`: + Outputs the next sentence classification logits of shape [batch_size, 2]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForNextSentencePrediction(config) + seq_relationship_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForNextSentencePrediction, self).__init__(config) + self.bert = BertModel(config) + self.cls = BertOnlyNSPHead(config) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + next_sentence_label=None, + checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + output_all_encoded_layers=False) + seq_relationship_score = self.cls(pooled_output) + + if next_sentence_label is not None: + loss_fct = CrossEntropyLoss(ignore_index=-1) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, + 2), + next_sentence_label.view(-1)) + return next_sentence_loss + else: + return seq_relationship_score + + +class BertForSequenceClassification(BertPreTrainedModel): + """BERT model for classification. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForSequenceClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels): + super(BertForSequenceClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForMultipleChoice(BertPreTrainedModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_choices): + super(BertForMultipleChoice, self).__init__(config) + self.num_choices = num_choices + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, self.num_choices) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return loss + else: + return reshaped_logits + + +class BertForTokenClassification(BertPreTrainedModel): + """BERT model for token-level classification. + This module is composed of the BERT model with a linear layer on top of + the full hidden state of the last layer. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_labels`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size, sequence_length] + with indices selected in [0, ..., num_labels]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, sequence_length, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_labels = 2 + + model = BertForTokenClassification(config, num_labels) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_labels): + super(BertForTokenClassification, self).__init__(config) + self.num_labels = num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, num_labels) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + labels=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels)[active_loss] + active_labels = labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + return loss + else: + return logits + + +class BertForQuestionAnswering(BertPreTrainedModel): + """BERT model for Question Answering (span extraction). + This module is composed of the BERT model with a linear layer on top of + the sequence output that computes start_logits and end_logits + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token + types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to + a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + `end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size]. + Positions are clamped to the length of the sequence and position outside of the sequence are not taken + into account for computing the loss. + + Outputs: + if `start_positions` and `end_positions` are not `None`: + Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions. + if `start_positions` or `end_positions` is `None`: + Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end + position tokens of shape [batch_size, sequence_length]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + model = BertForQuestionAnswering(config) + start_logits, end_logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config): + super(BertForQuestionAnswering, self).__init__(config) + self.bert = BertModel(config) + # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version + # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + self.apply(self.init_bert_weights) + + def forward(self, + input_ids, + token_type_ids=None, + attention_mask=None, + start_positions=None, + end_positions=None, + checkpoint_activations=False): + sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + return total_loss + else: + return start_logits, end_logits diff --git a/tests/unit/test_aio.py b/tests/unit/test_aio.py index ac0391176f9d..fdec95a35ae7 100755 --- a/tests/unit/test_aio.py +++ b/tests/unit/test_aio.py @@ -1,335 +1,335 @@ -import pytest -import os -import filecmp -import torch -import deepspeed -import torch.distributed as dist -from deepspeed.ops.aio import AsyncIOBuilder -from .common import distributed_test - -MEGA_BYTE = 1024**2 -BLOCK_SIZE = MEGA_BYTE -QUEUE_DEPTH = 2 -IO_SIZE = 16 * MEGA_BYTE -IO_PARALLEL = 2 - - -def _skip_if_no_aio(): - if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: - pytest.skip('Skip tests since async-io is not compatible') - - -def _do_ref_write(tmpdir, index=0): - file_suffix = f'{dist.get_rank()}_{index}' - ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') - ref_buffer = os.urandom(IO_SIZE) - with open(ref_file, 'wb') as f: - f.write(ref_buffer) - - return ref_file, ref_buffer - - -def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0): - file_suffix = f'{dist.get_rank()}_{index}' - test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') - if cuda_device: - test_buffer = torch.cuda.ByteTensor(list(ref_buffer)) - else: - test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory() - - return test_file, test_buffer - - -def _validate_handle_state(handle, single_submit, overlap_events): - assert handle.get_single_submit() == single_submit - assert handle.get_overlap_events() == overlap_events - assert handle.get_thread_count() == IO_PARALLEL - assert handle.get_block_size() == BLOCK_SIZE - assert handle.get_queue_depth() == QUEUE_DEPTH - - -@pytest.mark.parametrize('single_submit, overlap_events', - [(False, - False), - (False, - True), - (True, - False), - (True, - True)]) -def test_parallel_read(tmpdir, single_submit, overlap_events): - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_parallel_read(single_submit, overlap_events): - ref_file, _ = _do_ref_write(tmpdir) - - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - read_status = h.sync_pread(aio_buffer, ref_file) - assert read_status == 1 - - with open(ref_file, 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffer.tolist() - - _test_parallel_read(single_submit, overlap_events) - - -@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', - [(False, - False, - False), - (False, - True, - False), - (True, - False, - False), - (True, - True, - False), - (False, - False, - True), - (True, - True, - True)]) -def test_async_read(tmpdir, single_submit, overlap_events, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_read(single_submit, overlap_events, cuda_device): - ref_file, _ = _do_ref_write(tmpdir) - - if cuda_device: - aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') - else: - aio_buffer = torch.empty(IO_SIZE, - dtype=torch.uint8, - device='cpu').pin_memory() - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - read_status = h.async_pread(aio_buffer, ref_file) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == 1 - - with open(ref_file, 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffer.tolist() - - _test_async_read(single_submit, overlap_events, cuda_device) - - -@pytest.mark.parametrize('single_submit, overlap_events', - [(False, - False), - (False, - True), - (True, - False), - (True, - True)]) -def test_parallel_write(tmpdir, single_submit, overlap_events): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_parallel_write(single_submit, overlap_events): - ref_file, ref_buffer = _do_ref_write(tmpdir) - - aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False) - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - write_status = h.sync_pwrite(aio_buffer, aio_file) - assert write_status == 1 - - assert os.path.isfile(aio_file) - - filecmp.clear_cache() - assert filecmp.cmp(ref_file, aio_file, shallow=False) - - _test_parallel_write(single_submit, overlap_events) - - -@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', - [(False, - False, - False), - (False, - True, - False), - (True, - False, - False), - (True, - True, - False), - (False, - False, - True), - (True, - True, - True)]) -def test_async_write(tmpdir, single_submit, overlap_events, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_write(single_submit, overlap_events, cuda_device): - ref_file, ref_buffer = _do_ref_write(tmpdir) - - aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device) - - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - write_status = h.async_pwrite(aio_buffer, aio_file) - assert write_status == 0 - - wait_status = h.wait() - assert wait_status == 1 - - assert os.path.isfile(aio_file) - - filecmp.clear_cache() - assert filecmp.cmp(ref_file, aio_file, shallow=False) - - _test_async_write(single_submit, overlap_events, cuda_device) - - -@pytest.mark.parametrize('async_queue, cuda_device', - [(2, - False), - (4, - False), - (2, - True), - (4, - True)]) -def test_async_queue_read(tmpdir, async_queue, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_queue_read(async_queue, cuda_device): - ref_files = [] - for i in range(async_queue): - f, _ = _do_ref_write(tmpdir, i) - ref_files.append(f) - - aio_buffers = [] - for i in range(async_queue): - if cuda_device: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') - else: - buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() - aio_buffers.append(buf) - - single_submit = True - overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - for i in range(async_queue): - read_status = h.async_pread(aio_buffers[i], ref_files[i]) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == async_queue - - for i in range(async_queue): - with open(ref_files[i], 'rb') as f: - ref_buffer = list(f.read()) - assert ref_buffer == aio_buffers[i].tolist() - - _test_async_queue_read(async_queue, cuda_device) - - -@pytest.mark.parametrize('async_queue, cuda_device', - [(2, - False), - (7, - False), - (2, - True), - (7, - True)]) -def test_async_queue_write(tmpdir, async_queue, cuda_device): - - _skip_if_no_aio() - - @distributed_test(world_size=[2]) - def _test_async_queue_write(async_queue, cuda_device): - ref_files = [] - ref_buffers = [] - for i in range(async_queue): - f, buf = _do_ref_write(tmpdir, i) - ref_files.append(f) - ref_buffers.append(buf) - - aio_files = [] - aio_buffers = [] - for i in range(async_queue): - f, buf = _get_test_file_and_buffer(tmpdir, ref_buffers[i], cuda_device, i) - aio_files.append(f) - aio_buffers.append(buf) - - single_submit = True - overlap_events = True - h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, - QUEUE_DEPTH, - single_submit, - overlap_events, - IO_PARALLEL) - - _validate_handle_state(h, single_submit, overlap_events) - - for i in range(async_queue): - read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) - assert read_status == 0 - - wait_status = h.wait() - assert wait_status == async_queue - - for i in range(async_queue): - assert os.path.isfile(aio_files[i]) - - filecmp.clear_cache() - assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) - - _test_async_queue_write(async_queue, cuda_device) +import pytest +import os +import filecmp +import torch +import deepspeed +import torch.distributed as dist +from deepspeed.ops.aio import AsyncIOBuilder +from .common import distributed_test + +MEGA_BYTE = 1024**2 +BLOCK_SIZE = MEGA_BYTE +QUEUE_DEPTH = 2 +IO_SIZE = 16 * MEGA_BYTE +IO_PARALLEL = 2 + + +def _skip_if_no_aio(): + if not deepspeed.ops.__compatible_ops__[AsyncIOBuilder.NAME]: + pytest.skip('Skip tests since async-io is not compatible') + + +def _do_ref_write(tmpdir, index=0): + file_suffix = f'{dist.get_rank()}_{index}' + ref_file = os.path.join(tmpdir, f'_py_random_{file_suffix}.pt') + ref_buffer = os.urandom(IO_SIZE) + with open(ref_file, 'wb') as f: + f.write(ref_buffer) + + return ref_file, ref_buffer + + +def _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device, index=0): + file_suffix = f'{dist.get_rank()}_{index}' + test_file = os.path.join(tmpdir, f'_aio_write_random_{file_suffix}.pt') + if cuda_device: + test_buffer = torch.cuda.ByteTensor(list(ref_buffer)) + else: + test_buffer = torch.ByteTensor(list(ref_buffer)).pin_memory() + + return test_file, test_buffer + + +def _validate_handle_state(handle, single_submit, overlap_events): + assert handle.get_single_submit() == single_submit + assert handle.get_overlap_events() == overlap_events + assert handle.get_thread_count() == IO_PARALLEL + assert handle.get_block_size() == BLOCK_SIZE + assert handle.get_queue_depth() == QUEUE_DEPTH + + +@pytest.mark.parametrize('single_submit, overlap_events', + [(False, + False), + (False, + True), + (True, + False), + (True, + True)]) +def test_parallel_read(tmpdir, single_submit, overlap_events): + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_parallel_read(single_submit, overlap_events): + ref_file, _ = _do_ref_write(tmpdir) + + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + read_status = h.sync_pread(aio_buffer, ref_file) + assert read_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffer.tolist() + + _test_parallel_read(single_submit, overlap_events) + + +@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', + [(False, + False, + False), + (False, + True, + False), + (True, + False, + False), + (True, + True, + False), + (False, + False, + True), + (True, + True, + True)]) +def test_async_read(tmpdir, single_submit, overlap_events, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_read(single_submit, overlap_events, cuda_device): + ref_file, _ = _do_ref_write(tmpdir) + + if cuda_device: + aio_buffer = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') + else: + aio_buffer = torch.empty(IO_SIZE, + dtype=torch.uint8, + device='cpu').pin_memory() + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + read_status = h.async_pread(aio_buffer, ref_file) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + with open(ref_file, 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffer.tolist() + + _test_async_read(single_submit, overlap_events, cuda_device) + + +@pytest.mark.parametrize('single_submit, overlap_events', + [(False, + False), + (False, + True), + (True, + False), + (True, + True)]) +def test_parallel_write(tmpdir, single_submit, overlap_events): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_parallel_write(single_submit, overlap_events): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, False) + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.sync_pwrite(aio_buffer, aio_file) + assert write_status == 1 + + assert os.path.isfile(aio_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + _test_parallel_write(single_submit, overlap_events) + + +@pytest.mark.parametrize('single_submit, overlap_events, cuda_device', + [(False, + False, + False), + (False, + True, + False), + (True, + False, + False), + (True, + True, + False), + (False, + False, + True), + (True, + True, + True)]) +def test_async_write(tmpdir, single_submit, overlap_events, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_write(single_submit, overlap_events, cuda_device): + ref_file, ref_buffer = _do_ref_write(tmpdir) + + aio_file, aio_buffer = _get_test_file_and_buffer(tmpdir, ref_buffer, cuda_device) + + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + write_status = h.async_pwrite(aio_buffer, aio_file) + assert write_status == 0 + + wait_status = h.wait() + assert wait_status == 1 + + assert os.path.isfile(aio_file) + + filecmp.clear_cache() + assert filecmp.cmp(ref_file, aio_file, shallow=False) + + _test_async_write(single_submit, overlap_events, cuda_device) + + +@pytest.mark.parametrize('async_queue, cuda_device', + [(2, + False), + (4, + False), + (2, + True), + (4, + True)]) +def test_async_queue_read(tmpdir, async_queue, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_queue_read(async_queue, cuda_device): + ref_files = [] + for i in range(async_queue): + f, _ = _do_ref_write(tmpdir, i) + ref_files.append(f) + + aio_buffers = [] + for i in range(async_queue): + if cuda_device: + buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cuda') + else: + buf = torch.empty(IO_SIZE, dtype=torch.uint8, device='cpu').pin_memory() + aio_buffers.append(buf) + + single_submit = True + overlap_events = True + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pread(aio_buffers[i], ref_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + with open(ref_files[i], 'rb') as f: + ref_buffer = list(f.read()) + assert ref_buffer == aio_buffers[i].tolist() + + _test_async_queue_read(async_queue, cuda_device) + + +@pytest.mark.parametrize('async_queue, cuda_device', + [(2, + False), + (7, + False), + (2, + True), + (7, + True)]) +def test_async_queue_write(tmpdir, async_queue, cuda_device): + + _skip_if_no_aio() + + @distributed_test(world_size=[2]) + def _test_async_queue_write(async_queue, cuda_device): + ref_files = [] + ref_buffers = [] + for i in range(async_queue): + f, buf = _do_ref_write(tmpdir, i) + ref_files.append(f) + ref_buffers.append(buf) + + aio_files = [] + aio_buffers = [] + for i in range(async_queue): + f, buf = _get_test_file_and_buffer(tmpdir, ref_buffers[i], cuda_device, i) + aio_files.append(f) + aio_buffers.append(buf) + + single_submit = True + overlap_events = True + h = AsyncIOBuilder().load().aio_handle(BLOCK_SIZE, + QUEUE_DEPTH, + single_submit, + overlap_events, + IO_PARALLEL) + + _validate_handle_state(h, single_submit, overlap_events) + + for i in range(async_queue): + read_status = h.async_pwrite(aio_buffers[i], aio_files[i]) + assert read_status == 0 + + wait_status = h.wait() + assert wait_status == async_queue + + for i in range(async_queue): + assert os.path.isfile(aio_files[i]) + + filecmp.clear_cache() + assert filecmp.cmp(ref_files[i], aio_files[i], shallow=False) + + _test_async_queue_write(async_queue, cuda_device) diff --git a/tests/unit/test_cpu_adagrad.py b/tests/unit/test_cpu_adagrad.py index b8a025fe02a8..f2ba26255847 100755 --- a/tests/unit/test_cpu_adagrad.py +++ b/tests/unit/test_cpu_adagrad.py @@ -1,125 +1,125 @@ -import torch -import numpy as np -import pytest - -import deepspeed -from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad -from deepspeed.ops.op_builder import CPUAdagradBuilder - -if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]: - pytest.skip("cpu-adagrad is not compatible") - - -def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() - if verbose: - print("x = {}".format(x.flatten())) - print("y = {}".format(y.flatten())) - print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) - - -@pytest.mark.parametrize('model_size', - [ - (64), - (22), - (55), - (127), - (1024), - (1048576), - (30000000), - ]) # yapf: disable -def test_cpu_adagrad_opt(model_size): - device = 'cpu' - rng_state = torch.get_rng_state() - param = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - - optimizer = DeepSpeedCPUAdagrad([param]) - optimizer1 = torch.optim.Adagrad([param1]) - - for i in range(10): - rng_state = torch.get_rng_state() - param.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param1.grad = torch.randn(model_size, device=device) - optimizer.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) - - -@pytest.mark.parametrize('model_size,vocabulary_size,dim', - [ - (16 * 2, 16 * 4, 16), - (16 * 32, 16 * 256, 16), - (16 * 256, 16 * 16384, 16), - ]) # yapf: disable -def test_cpu_adagrad_opt_sparse_embedding(model_size, vocabulary_size, dim): - device = 'cpu' - rng_state = torch.get_rng_state() - - def gen_sparse_grad(vocabulary_size, dim, num_indices, dtype, device): - i = torch.randint(vocabulary_size, - size=(1, - num_indices), - dtype=torch.int64, - device=device) - v = torch.randn(num_indices, dim, dtype=dtype, device=device) - t = torch.sparse_coo_tensor(i, v, (vocabulary_size, dim), device=device) - t = t.coalesce() - new_i = (t.indices().view(-1, - 1).repeat(1, - dim) * dim + - torch.tensor(range(dim))).flatten().unsqueeze(0) - new_v = t.values().flatten() - new_t = torch.sparse_coo_tensor(new_i, - new_v, - (vocabulary_size * dim, - ), - device=device) - new_t = new_t.coalesce() - new_t.requires_grad = False - return new_t - - voc_size = vocabulary_size - dim = dim - num_indices = int(model_size // dim) - dtype = torch.float32 - - param = torch.nn.Parameter(torch.randn((voc_size * dim, - ), - dtype=dtype, - device=device), - requires_grad=True) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn((voc_size * dim, - ), - dtype=dtype, - device=device), - requires_grad=True) - torch.set_rng_state(rng_state) - - optimizer = DeepSpeedCPUAdagrad([param]) - optimizer1 = torch.optim.Adagrad([param1]) - - for i in range(10): - torch.set_rng_state(rng_state) - param.grad = gen_sparse_grad(voc_size, - dim, - num_indices, - dtype=dtype, - device=device) - torch.set_rng_state(rng_state) - param1.grad = gen_sparse_grad(voc_size, - dim, - num_indices, - dtype=dtype, - device=device) - optimizer.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) +import torch +import numpy as np +import pytest + +import deepspeed +from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad +from deepspeed.ops.op_builder import CPUAdagradBuilder + +if not deepspeed.ops.__compatible_ops__[CPUAdagradBuilder.NAME]: + pytest.skip("cpu-adagrad is not compatible") + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + + +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + (55), + (127), + (1024), + (1048576), + (30000000), + ]) # yapf: disable +def test_cpu_adagrad_opt(model_size): + device = 'cpu' + rng_state = torch.get_rng_state() + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + + optimizer = DeepSpeedCPUAdagrad([param]) + optimizer1 = torch.optim.Adagrad([param1]) + + for i in range(10): + rng_state = torch.get_rng_state() + param.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param1.grad = torch.randn(model_size, device=device) + optimizer.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) + + +@pytest.mark.parametrize('model_size,vocabulary_size,dim', + [ + (16 * 2, 16 * 4, 16), + (16 * 32, 16 * 256, 16), + (16 * 256, 16 * 16384, 16), + ]) # yapf: disable +def test_cpu_adagrad_opt_sparse_embedding(model_size, vocabulary_size, dim): + device = 'cpu' + rng_state = torch.get_rng_state() + + def gen_sparse_grad(vocabulary_size, dim, num_indices, dtype, device): + i = torch.randint(vocabulary_size, + size=(1, + num_indices), + dtype=torch.int64, + device=device) + v = torch.randn(num_indices, dim, dtype=dtype, device=device) + t = torch.sparse_coo_tensor(i, v, (vocabulary_size, dim), device=device) + t = t.coalesce() + new_i = (t.indices().view(-1, + 1).repeat(1, + dim) * dim + + torch.tensor(range(dim))).flatten().unsqueeze(0) + new_v = t.values().flatten() + new_t = torch.sparse_coo_tensor(new_i, + new_v, + (vocabulary_size * dim, + ), + device=device) + new_t = new_t.coalesce() + new_t.requires_grad = False + return new_t + + voc_size = vocabulary_size + dim = dim + num_indices = int(model_size // dim) + dtype = torch.float32 + + param = torch.nn.Parameter(torch.randn((voc_size * dim, + ), + dtype=dtype, + device=device), + requires_grad=True) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn((voc_size * dim, + ), + dtype=dtype, + device=device), + requires_grad=True) + torch.set_rng_state(rng_state) + + optimizer = DeepSpeedCPUAdagrad([param]) + optimizer1 = torch.optim.Adagrad([param1]) + + for i in range(10): + torch.set_rng_state(rng_state) + param.grad = gen_sparse_grad(voc_size, + dim, + num_indices, + dtype=dtype, + device=device) + torch.set_rng_state(rng_state) + param1.grad = gen_sparse_grad(voc_size, + dim, + num_indices, + dtype=dtype, + device=device) + optimizer.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) diff --git a/tests/unit/test_cpu_adam.py b/tests/unit/test_cpu_adam.py index dd5527b01371..94453c7e8265 100755 --- a/tests/unit/test_cpu_adam.py +++ b/tests/unit/test_cpu_adam.py @@ -1,62 +1,62 @@ -import argparse -import torch -import time -import numpy as np -import pytest -import copy - -import deepspeed -from deepspeed.ops.adam import FusedAdam -from deepspeed.ops.op_builder import CPUAdamBuilder - -if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: - pytest.skip("cpu-adam is not compatible") - - -def check_equal(first, second, atol=1e-2, verbose=False): - x = first.detach().numpy() - y = second.detach().numpy() - if verbose: - print("x = {}".format(x.flatten())) - print("y = {}".format(y.flatten())) - print('-' * 80) - np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) - -@pytest.mark.parametrize('model_size', - [ - (64), - (22), - (55), - (127), - (1024), - (1048576), - ]) # yapf: disable -def test_cpu_adam_opt(model_size): - from deepspeed.ops.adam import DeepSpeedCPUAdam - device = 'cpu' - rng_state = torch.get_rng_state() - param = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) - torch.set_rng_state(rng_state) - param2_data = torch.randn(model_size, device=device).cuda() - param2 = torch.nn.Parameter(param2_data) - - optimizer1 = torch.optim.AdamW([param1]) - optimizer2 = FusedAdam([param2]) - optimizer = DeepSpeedCPUAdam([param]) - - for i in range(10): - rng_state = torch.get_rng_state() - param.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param1.grad = torch.randn(model_size, device=device) - torch.set_rng_state(rng_state) - param2.grad = torch.randn(model_size, device=device).cuda() - - optimizer.step() - optimizer2.step() - optimizer1.step() - - check_equal(param, param1, atol=1e-2, verbose=True) - check_equal(param, param2.cpu(), atol=1e-2, verbose=True) +import argparse +import torch +import time +import numpy as np +import pytest +import copy + +import deepspeed +from deepspeed.ops.adam import FusedAdam +from deepspeed.ops.op_builder import CPUAdamBuilder + +if not deepspeed.ops.__compatible_ops__[CPUAdamBuilder.NAME]: + pytest.skip("cpu-adam is not compatible") + + +def check_equal(first, second, atol=1e-2, verbose=False): + x = first.detach().numpy() + y = second.detach().numpy() + if verbose: + print("x = {}".format(x.flatten())) + print("y = {}".format(y.flatten())) + print('-' * 80) + np.testing.assert_allclose(x, y, err_msg="param-update mismatch!", atol=atol) + +@pytest.mark.parametrize('model_size', + [ + (64), + (22), + (55), + (127), + (1024), + (1048576), + ]) # yapf: disable +def test_cpu_adam_opt(model_size): + from deepspeed.ops.adam import DeepSpeedCPUAdam + device = 'cpu' + rng_state = torch.get_rng_state() + param = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param1 = torch.nn.Parameter(torch.randn(model_size, device=device)) + torch.set_rng_state(rng_state) + param2_data = torch.randn(model_size, device=device).cuda() + param2 = torch.nn.Parameter(param2_data) + + optimizer1 = torch.optim.AdamW([param1]) + optimizer2 = FusedAdam([param2]) + optimizer = DeepSpeedCPUAdam([param]) + + for i in range(10): + rng_state = torch.get_rng_state() + param.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param1.grad = torch.randn(model_size, device=device) + torch.set_rng_state(rng_state) + param2.grad = torch.randn(model_size, device=device).cuda() + + optimizer.step() + optimizer2.step() + optimizer1.step() + + check_equal(param, param1, atol=1e-2, verbose=True) + check_equal(param, param2.cpu(), atol=1e-2, verbose=True) diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index c07f79ba7b2d..d428826d3a7a 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -1,920 +1,920 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.distributed as dist -import deepspeed -import argparse -import pytest -import copy -import json -import os -import numpy as np -import time - -from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology -PipeTopo = PipeDataParallelTopology -from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec -from .common import distributed_test -from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args -from .test_pipe import AlexNetPipe, train_cifar - -TORCH_MAJOR = int(torch.__version__.split('.')[0]) -TORCH_MINOR = int(torch.__version__.split('.')[1]) -if TORCH_MAJOR < 1 or TORCH_MINOR < 8: - pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", - allow_module_level=True) - - -def test_onebitadam_fp16_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitadam_fp16_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_fp32_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitadam_fp32_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.float) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_exp_avg_mask(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask1 = torch.flatten(mask1) - optimizer_grouped_parameters = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): - model, optimizer, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - # Test whether the momentum mask works - for v in optimizer.state.values(): - if v['exp_avg'].size() == mask1.size(): - assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" - - _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitadam_checkpointing(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - mask2 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask2[1][col] += 1 - mask1 = torch.flatten(mask1) - mask2 = torch.flatten(mask2) - - optimizer_grouped_parameters_1 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_2 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask2 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_3 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): - model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_1) - data_loader = random_dataloader(model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device) - for n, batch in enumerate(data_loader): - loss = model_1(batch[0], batch[1]) - model_1.backward(loss) - model_1.step() - # Test whether momentum mask still exist after saving checkpoint - assert optimizer_1.optimizer.adam_freeze_key is True - mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - model_1.save_checkpoint(save_folder, tag=None) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" - - - model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_2) - # Test whether momentum mask stays the same after loading checkpoint - mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" - model_2.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - for v in optimizer_2.state.values(): - assert 'worker_error' not in v, f"Incorrect worker error" - assert 'server_error' not in v, f"Incorrect server error" - assert optimizer_2.optimizer.adam_freeze_key is True - - model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_3) - optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader(model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device) - for n, batch in enumerate(data_loader): - loss = model_3(batch[0], batch[1]) - model_3.backward(loss) - model_3.step() - assert optimizer_3.optimizer.adam_freeze_key is True - # Test whether momentum mask stays the same after loading checkpoint - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" - model_3.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - for v in optimizer_3.state.values(): - assert 'worker_error' not in v, f"Incorrect worker error" - assert 'server_error' not in v, f"Incorrect server error" - assert optimizer_3.optimizer.adam_freeze_key is False - - _test_onebitadam_checkpointing(mask1, - mask2, - args=args, - model=model, - hidden_dim=hidden_dim) - - -def test_onebitadam_checkpointing_overflow(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[2]) - def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=100, - hidden_dim=hidden_dim, - device=model.device) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - if dist.get_rank() == 0 and n >= 10: - loss = loss * 1000000.0 - model.backward(loss) - dist.barrier() - model.step() - dist.barrier() - model.save_checkpoint(save_folder, tag=None) - - _test_onebitadam_checkpointing_overflow(args=args, - model=model, - hidden_dim=hidden_dim) - - -@pytest.mark.parametrize('topo', - [ - PipeTopo(num_pp=1, - num_dp=4), - PipeTopo(num_pp=2, - num_dp=2), - PipeTopo(num_pp=4, - num_dp=1), - ]) -def test_onebitadam_fp16_pipeline(topo, tmpdir): - config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, - "steps_per_print": 20, - "optimizer": { - "type": "OneBitAdam", - "params": { - "lr": 0.00001, - "betas": [0.9, - 0.999], - "eps": 1e-8, - "weight_decay": 3e-7, - "freeze_step": 200, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - args = args_from_dict(tmpdir, config_dict) - - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - @distributed_test(world_size=4) - def _helper(topo, tmpdir, steps=500): - assert steps >= 100 - - test_net = copy.deepcopy(init_net) - test_model = PipelineModule(layers=test_net.to_layers(), - topology=topo, - loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar(test_model, - args, - num_steps=steps, - fp16=config_dict['fp16']['enabled']) - - _helper(topo, tmpdir) - - -def test_onebitlamb_fp16_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitlamb_fp16_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_fp32_basic(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[1, 2]) - def _test_onebitlamb_fp32_basic(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device, - dtype=torch.float) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_exp_avg_mask(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - optimizer_grouped_parameters = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): - model, optimizer, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters) - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - # Test whether the momentum mask works - for v in optimizer.state.values(): - if v['exp_avg'].size() == mask1.size(): - assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" - - _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) - - -def test_onebitlamb_checkpointing(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - param_optimizer = list(model.named_parameters()) - mask1 = torch.zeros_like(param_optimizer[0][1].data) - mask2 = torch.zeros_like(param_optimizer[0][1].data) - for col in range(mask1.size()[1]): - mask1[0][col] += 1 - mask2[1][col] += 1 - - optimizer_grouped_parameters_1 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask1 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_2 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01, - 'exp_avg_mask': mask2 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - optimizer_grouped_parameters_3 = [{ - 'params': [param_optimizer[0][1]], - 'weight_decay': 0.01 - }, - { - 'params': [param_optimizer[1][1]], - 'weight_decay': 0.01 - }] - - @distributed_test(world_size=[2]) - def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): - model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_1) - data_loader = random_dataloader(model=model_1, - total_samples=10, - hidden_dim=hidden_dim, - device=model_1.device) - for n, batch in enumerate(data_loader): - loss = model_1(batch[0], batch[1]) - model_1.backward(loss) - model_1.step() - # Test whether momentum mask still exist after saving checkpoint - assert optimizer_1.optimizer.lamb_freeze_key is True - mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" - scaling_coeff_1 = [] - for v in optimizer_1.state.values(): - assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" - scaling_coeff_1.append(v['scaling_coeff']) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - model_1.save_checkpoint(save_folder, tag=None) - assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" - - - model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_2) - # Test whether momentum mask stays the same after loading checkpoint - mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" - model_2.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" - # Test whether scaling_coeffs is loaded correctly - scaling_coeff_2 = [] - for v in optimizer_2.state.values(): - assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" - scaling_coeff_2.append(v['scaling_coeff']) - assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" - assert optimizer_2.optimizer.lamb_freeze_key is True - - model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=optimizer_grouped_parameters_3) - optimizer_3.optimizer.freeze_step = 20 - data_loader = random_dataloader(model=model_3, - total_samples=50, - hidden_dim=hidden_dim, - device=model_3.device) - for n, batch in enumerate(data_loader): - loss = model_3(batch[0], batch[1]) - model_3.backward(loss) - model_3.step() - assert optimizer_3.optimizer.lamb_freeze_key is True - # Test whether momentum mask stays the same after loading checkpoint - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" - model_3.load_checkpoint(save_folder, - tag=None, - load_optimizer_states=True, - load_lr_scheduler_states=True) - assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" - # Test whether worker&server error is resetted - assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" - assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" - # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted - for v in optimizer_3.state.values(): - assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" - assert v['last_factor'] == 1.0, f"Incorrect last_factor" - assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" - assert optimizer_3.optimizer.lamb_freeze_key is False - - _test_onebitlamb_checkpointing(mask1, - mask2, - args=args, - model=model, - hidden_dim=hidden_dim) - - -def test_onebitlamb_checkpointing_overflow(tmpdir): - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00015, - "weight_decay": 0.01, - "max_coeff": 0.3, - "min_coeff": 0.01, - "freeze_step": 2, - "cuda_aware": False, - "comm_backend_name": "nccl", - "coeff_beta": 0.9, - "factor_max": 1.0, - "factor_min": 0.5, - "factor_threshold": 0.1 - } - }, - "gradient_clipping": 1.0, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - } - } - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - - @distributed_test(world_size=[2]) - def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, - total_samples=100, - hidden_dim=hidden_dim, - device=model.device) - save_folder = os.path.join(tmpdir, 'saved_checkpoint') - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - if dist.get_rank() == 0 and n >= 10: - loss = loss * 1000000.0 - model.backward(loss) - dist.barrier() - model.step() - dist.barrier() - model.save_checkpoint(save_folder, tag=None) - - _test_onebitlamb_checkpointing_overflow(args=args, - model=model, - hidden_dim=hidden_dim) - - -@pytest.mark.parametrize('topo', - [ - PipeTopo(num_pp=1, - num_dp=4), - PipeTopo(num_pp=2, - num_dp=2), - PipeTopo(num_pp=4, - num_dp=1), - ]) -def test_onebitlamb_fp16_pipeline(topo, tmpdir): - config_dict = { - "train_batch_size": 16, - "train_micro_batch_size_per_gpu": 4, - "steps_per_print": 20, - "optimizer": { - "type": "OneBitLamb", - "params": { - "lr": 0.00001, - "betas": [0.9, - 0.999], - "eps": 1e-8, - "weight_decay": 3e-7, - "freeze_step": 200, - "cuda_aware": False, - "comm_backend_name": "nccl" - } - }, - "gradient_clipping": 1.0, - "zero_optimization": { - "stage": 0 - }, - "fp16": { - "enabled": True, - "loss_scale": 0, - "initial_scale_power": 16 - }, - "pipeline": { - "seed_layers": True, - "activation_checkpoint_interval": 1 - } - } - args = args_from_dict(tmpdir, config_dict) - - # Allocate model for consistent initial weights. - init_net = AlexNetPipe() - - @distributed_test(world_size=4) - def _helper(topo, tmpdir, steps=500): - assert steps >= 100 - - test_net = copy.deepcopy(init_net) - test_model = PipelineModule(layers=test_net.to_layers(), - topology=topo, - loss_fn=nn.CrossEntropyLoss()) - - test_losses = train_cifar(test_model, - args, - num_steps=steps, - fp16=config_dict['fp16']['enabled']) - - _helper(topo, tmpdir) - - -def test_compressed_allreduce_basic(tmpdir): - @distributed_test(world_size=[1, 2]) - def _test_compressed_allreduce_basic(): - from deepspeed.runtime.comm.nccl import NcclBackend - size = dist.get_world_size() - rank = dist.get_rank() - backend = NcclBackend() - local_rank = dist.get_rank() - device = torch.device("cuda", dist.get_rank()) - - # A simulated compression function using torch.distributed - def torch_sim(a): - a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) - scale = a.norm() / np.sqrt(a.numel()) - a_compressed = scale * a_sign - a_sign = None - worker_error = a - a_compressed - dist.all_reduce(a_compressed) - a_compressed.mul_(1 / dist.get_world_size()) - a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( - 2.0) - a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) - server_scale = [ - chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list - ] - a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) - a_server_compressed = torch.cat( - [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) - rank = dist.get_rank() - server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] - torch.cuda.synchronize() - torch.distributed.barrier() - return a_server_compressed, worker_error, server_error - - tensor_size = 300 * 2**20 - server_size = int(tensor_size / size) - if tensor_size % (8 * size) != 0: - right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) - else: - right_tensor_size = tensor_size - right_server_size = right_tensor_size // size - - # Adding bias to the initialization of the gradient we are communicating - # In order to get rid of the case where some elements in the gradient are too small - a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank - - worker_error = torch.zeros(right_tensor_size, device=device) - server_error = torch.zeros(right_server_size, device=device) - - a_torch, worker_error_torch, server_error_torch = torch_sim(a) - torch.cuda.empty_cache() - - a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) - - threshold = 1e-6 - magnitude_threshold = 1e-6 - diff_mask = (a_after - a_torch) > threshold - diff_server_mask = torch.chunk(diff_mask, size)[rank] - mpi_server = torch.chunk(a_after, size)[rank] + server_error - torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch - - # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic - # The test would skip those numbers that are too small in compensated_server_m - check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold - if torch.sum(check_mag_mask) != 0: - print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) - assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 - - _test_compressed_allreduce_basic() +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import deepspeed +import argparse +import pytest +import copy +import json +import os +import numpy as np +import time + +from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology +PipeTopo = PipeDataParallelTopology +from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec +from .common import distributed_test +from .simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict, create_deepspeed_args +from .test_pipe import AlexNetPipe, train_cifar + +TORCH_MAJOR = int(torch.__version__.split('.')[0]) +TORCH_MINOR = int(torch.__version__.split('.')[1]) +if TORCH_MAJOR < 1 or TORCH_MINOR < 8: + pytest.skip("NCCL-based 1-bit compression requires torch 1.8 or higher", + allow_module_level=True) + + +def test_onebitadam_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitadam_fp16_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitadam_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_fp32_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitadam_fp32_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitadam_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask1 = torch.flatten(mask1) + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitadam_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitadam_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + mask1 = torch.flatten(mask1) + mask2 = torch.flatten(mask2) + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.adam_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + model_1.save_checkpoint(save_folder, tag=None) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_2.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_2.optimizer.adam_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.adam_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + for v in optimizer_3.state.values(): + assert 'worker_error' not in v, f"Incorrect worker error" + assert 'server_error' not in v, f"Incorrect server error" + assert optimizer_3.optimizer.adam_freeze_key is False + + _test_onebitadam_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_onebitadam_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitadam_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitadam_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitadam_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitAdam", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + +def test_onebitlamb_fp16_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp16_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp16_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_fp32_basic(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[1, 2]) + def _test_onebitlamb_fp32_basic(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + _test_onebitlamb_fp32_basic(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_exp_avg_mask(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + optimizer_grouped_parameters = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_exp_avg_mask(args, model, hidden_dim): + model, optimizer, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters) + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + # Test whether the momentum mask works + for v in optimizer.state.values(): + if v['exp_avg'].size() == mask1.size(): + assert torch.allclose(v['exp_avg'], v['exp_avg'].mul_(mask1.to(device=v['exp_avg'].device)), atol=1e-07), f"Momentum mask is not working properly" + + _test_onebitlamb_exp_avg_mask(args=args, model=model, hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + param_optimizer = list(model.named_parameters()) + mask1 = torch.zeros_like(param_optimizer[0][1].data) + mask2 = torch.zeros_like(param_optimizer[0][1].data) + for col in range(mask1.size()[1]): + mask1[0][col] += 1 + mask2[1][col] += 1 + + optimizer_grouped_parameters_1 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask1 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_2 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01, + 'exp_avg_mask': mask2 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + optimizer_grouped_parameters_3 = [{ + 'params': [param_optimizer[0][1]], + 'weight_decay': 0.01 + }, + { + 'params': [param_optimizer[1][1]], + 'weight_decay': 0.01 + }] + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing(mask1, mask2, args, model, hidden_dim): + model_1, optimizer_1, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_1) + data_loader = random_dataloader(model=model_1, + total_samples=10, + hidden_dim=hidden_dim, + device=model_1.device) + for n, batch in enumerate(data_loader): + loss = model_1(batch[0], batch[1]) + model_1.backward(loss) + model_1.step() + # Test whether momentum mask still exist after saving checkpoint + assert optimizer_1.optimizer.lamb_freeze_key is True + mask1 = mask1.to(device=optimizer_1.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Incorrect momentum mask" + scaling_coeff_1 = [] + for v in optimizer_1.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_1.append(v['scaling_coeff']) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + model_1.save_checkpoint(save_folder, tag=None) + assert torch.allclose(optimizer_1.param_groups[0]['exp_avg_mask'], mask1, atol=1e-07), f"Momentum mask should not change after saving checkpoint" + + + model_2, optimizer_2, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_2) + # Test whether momentum mask stays the same after loading checkpoint + mask2 = mask2.to(device=optimizer_2.param_groups[0]['exp_avg_mask'].device) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Incorrect momentum mask" + model_2.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert torch.allclose(optimizer_2.param_groups[0]['exp_avg_mask'], mask2, atol=1e-07), f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_2.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_2.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs is loaded correctly + scaling_coeff_2 = [] + for v in optimizer_2.state.values(): + assert 'scaling_coeff' in v, f"Incorrect scaling_coeff" + scaling_coeff_2.append(v['scaling_coeff']) + assert list(sorted(scaling_coeff_2)) == list(sorted(scaling_coeff_1)), f"Incorrect scaling_coeffs" + assert optimizer_2.optimizer.lamb_freeze_key is True + + model_3, optimizer_3, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=optimizer_grouped_parameters_3) + optimizer_3.optimizer.freeze_step = 20 + data_loader = random_dataloader(model=model_3, + total_samples=50, + hidden_dim=hidden_dim, + device=model_3.device) + for n, batch in enumerate(data_loader): + loss = model_3(batch[0], batch[1]) + model_3.backward(loss) + model_3.step() + assert optimizer_3.optimizer.lamb_freeze_key is True + # Test whether momentum mask stays the same after loading checkpoint + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Incorrect momentum mask" + model_3.load_checkpoint(save_folder, + tag=None, + load_optimizer_states=True, + load_lr_scheduler_states=True) + assert 'exp_avg_mask' not in optimizer_3.param_groups[0], f"Momentum mask should not change after loading checkpoint" + # Test whether worker&server error is resetted + assert len(optimizer_3.optimizer.worker_errors) == 0, f"Incorrect worker error" + assert len(optimizer_3.optimizer.server_errors) == 0, f"Incorrect server error" + # Test whether scaling_coeffs, lamb_coeff_freeze, last_factor are resetted + for v in optimizer_3.state.values(): + assert v['lamb_coeff_freeze'] == 0.0, f"Incorrect lamb_coeff_freeze" + assert v['last_factor'] == 1.0, f"Incorrect last_factor" + assert 'scaling_coeff' not in v, f"Incorrect scaling_coeff" + assert optimizer_3.optimizer.lamb_freeze_key is False + + _test_onebitlamb_checkpointing(mask1, + mask2, + args=args, + model=model, + hidden_dim=hidden_dim) + + +def test_onebitlamb_checkpointing_overflow(tmpdir): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00015, + "weight_decay": 0.01, + "max_coeff": 0.3, + "min_coeff": 0.01, + "freeze_step": 2, + "cuda_aware": False, + "comm_backend_name": "nccl", + "coeff_beta": 0.9, + "factor_max": 1.0, + "factor_min": 0.5, + "factor_threshold": 0.1 + } + }, + "gradient_clipping": 1.0, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + } + } + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + @distributed_test(world_size=[2]) + def _test_onebitlamb_checkpointing_overflow(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + data_loader = random_dataloader(model=model, + total_samples=100, + hidden_dim=hidden_dim, + device=model.device) + save_folder = os.path.join(tmpdir, 'saved_checkpoint') + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + if dist.get_rank() == 0 and n >= 10: + loss = loss * 1000000.0 + model.backward(loss) + dist.barrier() + model.step() + dist.barrier() + model.save_checkpoint(save_folder, tag=None) + + _test_onebitlamb_checkpointing_overflow(args=args, + model=model, + hidden_dim=hidden_dim) + + +@pytest.mark.parametrize('topo', + [ + PipeTopo(num_pp=1, + num_dp=4), + PipeTopo(num_pp=2, + num_dp=2), + PipeTopo(num_pp=4, + num_dp=1), + ]) +def test_onebitlamb_fp16_pipeline(topo, tmpdir): + config_dict = { + "train_batch_size": 16, + "train_micro_batch_size_per_gpu": 4, + "steps_per_print": 20, + "optimizer": { + "type": "OneBitLamb", + "params": { + "lr": 0.00001, + "betas": [0.9, + 0.999], + "eps": 1e-8, + "weight_decay": 3e-7, + "freeze_step": 200, + "cuda_aware": False, + "comm_backend_name": "nccl" + } + }, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": 0 + }, + "fp16": { + "enabled": True, + "loss_scale": 0, + "initial_scale_power": 16 + }, + "pipeline": { + "seed_layers": True, + "activation_checkpoint_interval": 1 + } + } + args = args_from_dict(tmpdir, config_dict) + + # Allocate model for consistent initial weights. + init_net = AlexNetPipe() + + @distributed_test(world_size=4) + def _helper(topo, tmpdir, steps=500): + assert steps >= 100 + + test_net = copy.deepcopy(init_net) + test_model = PipelineModule(layers=test_net.to_layers(), + topology=topo, + loss_fn=nn.CrossEntropyLoss()) + + test_losses = train_cifar(test_model, + args, + num_steps=steps, + fp16=config_dict['fp16']['enabled']) + + _helper(topo, tmpdir) + + +def test_compressed_allreduce_basic(tmpdir): + @distributed_test(world_size=[1, 2]) + def _test_compressed_allreduce_basic(): + from deepspeed.runtime.comm.nccl import NcclBackend + size = dist.get_world_size() + rank = dist.get_rank() + backend = NcclBackend() + local_rank = dist.get_rank() + device = torch.device("cuda", dist.get_rank()) + + # A simulated compression function using torch.distributed + def torch_sim(a): + a_sign = a.sign().add_(1).bool().float().add_(-0.5).mul_(2.0) + scale = a.norm() / np.sqrt(a.numel()) + a_compressed = scale * a_sign + a_sign = None + worker_error = a - a_compressed + dist.all_reduce(a_compressed) + a_compressed.mul_(1 / dist.get_world_size()) + a_server_sign = a_compressed.sign().add_(1).bool().float().add_(-0.5).mul_( + 2.0) + a_list = torch.chunk(a_compressed, chunks=dist.get_world_size()) + server_scale = [ + chunk_a.norm() / np.sqrt(chunk_a.numel()) for chunk_a in a_list + ] + a_sign_list = torch.chunk(a_server_sign, dist.get_world_size()) + a_server_compressed = torch.cat( + [server_scale[i] * a_sign_list[i] for i in range(dist.get_world_size())]) + rank = dist.get_rank() + server_error = a_list[rank] - server_scale[rank] * a_sign_list[rank] + torch.cuda.synchronize() + torch.distributed.barrier() + return a_server_compressed, worker_error, server_error + + tensor_size = 300 * 2**20 + server_size = int(tensor_size / size) + if tensor_size % (8 * size) != 0: + right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) + else: + right_tensor_size = tensor_size + right_server_size = right_tensor_size // size + + # Adding bias to the initialization of the gradient we are communicating + # In order to get rid of the case where some elements in the gradient are too small + a = (torch.rand(tensor_size, device=device) - 0.5) + 0.01 * rank + + worker_error = torch.zeros(right_tensor_size, device=device) + server_error = torch.zeros(right_server_size, device=device) + + a_torch, worker_error_torch, server_error_torch = torch_sim(a) + torch.cuda.empty_cache() + + a_after = backend.compressed_allreduce(a, worker_error, server_error, local_rank) + + threshold = 1e-6 + magnitude_threshold = 1e-6 + diff_mask = (a_after - a_torch) > threshold + diff_server_mask = torch.chunk(diff_mask, size)[rank] + mpi_server = torch.chunk(a_after, size)[rank] + server_error + torch_server = torch.chunk(a_torch, size)[rank] + server_error_torch + + # If the number in the compensated_server_m is too small (e.g 1e-8), then calling sign() might be problematic + # The test would skip those numbers that are too small in compensated_server_m + check_mag_mask = mpi_server[diff_server_mask] > magnitude_threshold + if torch.sum(check_mag_mask) != 0: + print('Fails at {} of positions'.format(torch.sum(check_mag_mask))) + assert torch.sum(diff_server_mask) == 0 or torch.sum(check_mag_mask) == 0 + + _test_compressed_allreduce_basic() diff --git a/tests/unit/test_pld.py b/tests/unit/test_pld.py index 0672da9177b1..2c6674620471 100755 --- a/tests/unit/test_pld.py +++ b/tests/unit/test_pld.py @@ -1,117 +1,117 @@ -import numpy as np -import deepspeed -import pytest -from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from .common import distributed_test -from .simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict - - -@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) -def test_pld_schedule(tmpdir, theta): - gamma = 0.001 - - pld_scheduler = ProgressiveLayerDrop(theta, gamma) - for i in range(10): - pld_scheduler.update_state(i) - expected_theta = (1. - theta) * np.exp(-gamma * i) + theta - actual_theta = pld_scheduler.get_theta() - assert expected_theta == actual_theta - - -@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) -def test_pld_model(tmpdir, theta): - gamma = 0.001 - config_dict = { - "train_batch_size": 1, - "steps_per_print": 1, - "optimizer": { - "type": 'Adam', - "params": { - "lr": 0.0001 - } - }, - "fp16": { - "enabled": True - }, - "progressive_layer_drop": { - "enabled": True, - "theta": theta, - "gamma": gamma - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = PLD_SimpleModel(hidden_dim, empty_grad=False) - - @distributed_test(world_size=[1]) - def _test_pld_model(args, model, hidden_dim, theta, gamma): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - - data_loader = random_dataloader(model=model, - total_samples=50, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - expected_theta = (1. - theta) * np.exp(-gamma * i) + theta - actual_theta = model.get_pld_theta() - assert expected_theta == actual_theta - - _test_pld_model(args=args, - model=model, - hidden_dim=hidden_dim, - theta=theta, - gamma=gamma) - - -def test_non_pld_model(tmpdir): - gamma = 0.001 - theta = 0.5 - config_dict = { - "train_batch_size": 1, - "steps_per_print": 1, - "optimizer": { - "type": 'Adam', - "params": { - "lr": 0.0001 - } - }, - "fp16": { - "enabled": True - }, - "progressive_layer_drop": { - "enabled": True, - "theta": theta, - "gamma": gamma - } - } - - args = args_from_dict(tmpdir, config_dict) - hidden_dim = 10 - - model = SimpleModel(hidden_dim, empty_grad=False) - - @distributed_test(world_size=[1]) - def _test_non_pld_model(args, model, hidden_dim): - model, _, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) - - data_loader = random_dataloader(model=model, - total_samples=1, - hidden_dim=hidden_dim, - device=model.device) - - for i, batch in enumerate(data_loader): - with pytest.raises(TypeError): - loss = model(batch[0], batch[1]) - - _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim) +import numpy as np +import deepspeed +import pytest +from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop +from .common import distributed_test +from .simple_model import SimpleModel, PLD_SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_schedule(tmpdir, theta): + gamma = 0.001 + + pld_scheduler = ProgressiveLayerDrop(theta, gamma) + for i in range(10): + pld_scheduler.update_state(i) + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = pld_scheduler.get_theta() + assert expected_theta == actual_theta + + +@pytest.mark.parametrize('theta', [0, 0.1, 0.9, 1.0]) +def test_pld_model(tmpdir, theta): + gamma = 0.001 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = PLD_SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_pld_model(args, model, hidden_dim, theta, gamma): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=50, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + + expected_theta = (1. - theta) * np.exp(-gamma * i) + theta + actual_theta = model.get_pld_theta() + assert expected_theta == actual_theta + + _test_pld_model(args=args, + model=model, + hidden_dim=hidden_dim, + theta=theta, + gamma=gamma) + + +def test_non_pld_model(tmpdir): + gamma = 0.001 + theta = 0.5 + config_dict = { + "train_batch_size": 1, + "steps_per_print": 1, + "optimizer": { + "type": 'Adam', + "params": { + "lr": 0.0001 + } + }, + "fp16": { + "enabled": True + }, + "progressive_layer_drop": { + "enabled": True, + "theta": theta, + "gamma": gamma + } + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_non_pld_model(args, model, hidden_dim): + model, _, _, _ = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=1, + hidden_dim=hidden_dim, + device=model.device) + + for i, batch in enumerate(data_loader): + with pytest.raises(TypeError): + loss = model(batch[0], batch[1]) + + _test_non_pld_model(args=args, model=model, hidden_dim=hidden_dim) From b998206eb35d8a02eb397ce7256154752895b76d Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 29 Nov 2021 16:40:40 -0800 Subject: [PATCH 42/59] minor merge fixes --- .pre-commit-config.yaml | 1 - docs/_tutorials/zero-offload.md | 1 - 2 files changed, 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76f237aacdb7..d27dd0b41a5f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,6 @@ repos: exclude: "DeepSpeedExamples/" args: [--fix=lf] - - repo: https://github.com/pre-commit/mirrors-yapf rev: v0.29.0 hooks: diff --git a/docs/_tutorials/zero-offload.md b/docs/_tutorials/zero-offload.md index afc916d8fc33..404355090855 100644 --- a/docs/_tutorials/zero-offload.md +++ b/docs/_tutorials/zero-offload.md @@ -72,4 +72,3 @@ Finally, here is a screenshot of htop showing host CPU and memory activity durin Congratulations! You have completed the ZeRO-Offload tutorial. - From d6deecb3343601784d79f6aa593bd42547aaa7ba Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 30 Nov 2021 10:52:36 -0800 Subject: [PATCH 43/59] remove extra bfloat16_enabled definition --- deepspeed/runtime/engine.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c500b30f4b00..e59782d5b0a8 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -729,9 +729,6 @@ def bfloat16_enabled(self): def fp16_master_weights_and_gradients(self): return self._config.fp16_master_weights_and_gradients - def bfloat16_enabled(self): - return self._config.bfloat16_enabled - def amp_enabled(self): return self._config.amp_enabled From 2a4ef29aff9d415eb41743159388590158f6ebe0 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 30 Nov 2021 10:55:22 -0800 Subject: [PATCH 44/59] asserting params inflight for AllGatherHandle --- deepspeed/runtime/zero/partition_parameters.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index b359ec2a9f10..03537262d112 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -481,6 +481,9 @@ def _set_dtype(self, ds_config, dtype): class AllGatherHandle: def __init__(self, handle, param: Parameter) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + self.__handle = handle self.__param = param From 90182b66e0158335eb2104eef578fce01369a642 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 30 Nov 2021 12:42:45 -0800 Subject: [PATCH 45/59] remove get_cuda_mem_allocated_str --- deepspeed/runtime/zero/stage3.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index d8f79f69a577..a92426f44ef8 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -70,14 +70,6 @@ def debug_rank0(message: str) -> None: logger.debug(message) -def get_cuda_mem_allocated_str() -> str: - # this is really slow. when enabled the python process becomes slow - # to the point where it can't keep the GPU fed with work, so only enable - # for memory debugging. - # return f"{torch.cuda.memory_allocated() / 1024 ** 3:.2f}GB" - return "xGB" - - def move_to_cpu(tensor_list): for tensor in tensor_list: tensor.data = tensor.data.cpu() @@ -308,7 +300,6 @@ def fetch_sub_module(self, current_submodule: Module) -> None: "avail": f"{self.__n_available_params:.1e}", "queue_sz": f"{len(self.__param_queue or [])}", "inflight": [p.ds_id for p in self.__inflight_param_registry], - "allocated": get_cuda_mem_allocated_str() })) params_to_fetch = frozenset(iter_params(current_submodule)) From f590ba45c711d8815d96c77d6c027b02168d544c Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Wed, 8 Dec 2021 22:50:57 +0000 Subject: [PATCH 46/59] Format fixes --- .../runtime/comm/coalesced_collectives.py | 4 +-- .../runtime/zero/partition_parameters.py | 10 +++---- deepspeed/runtime/zero/stage3.py | 22 +++++--------- tests/unit/test_onebit.py | 1 + tests/unit/test_zero.py | 30 +++++++++---------- 5 files changed, 31 insertions(+), 36 deletions(-) diff --git a/deepspeed/runtime/comm/coalesced_collectives.py b/deepspeed/runtime/comm/coalesced_collectives.py index 1ac438734813..880a3cc46b89 100644 --- a/deepspeed/runtime/comm/coalesced_collectives.py +++ b/deepspeed/runtime/comm/coalesced_collectives.py @@ -41,8 +41,8 @@ def torch_reduce_scatter_fn(input_tensor: Tensor, output_tensor: Tensor, group): @instrument_w_nvtx @torch.no_grad() def reduce_scatter_coalesced( - tensors: List[Tensor], - group: ProcessGroup = None, + tensors: List[Tensor], + group: ProcessGroup = None, ) -> List[Tensor]: """simultaneously reduce-scatter a list of tensors - this can be done more efficiently than individual reduce scatter calls diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 8ad11b961729..42d5622704a9 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -494,11 +494,11 @@ def wait(self) -> None: class AllGatherCoalescedHandle: def __init__( - self, - allgather_handle, - params: List[Parameter], - partitions: List[Tensor], - world_size: int, + self, + allgather_handle, + params: List[Parameter], + partitions: List[Tensor], + world_size: int, ) -> None: self.__allgather_handle = allgather_handle self.__params = params diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 3a5b60fc7cd2..4675a6388551 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -189,12 +189,12 @@ class __ParamInTrace: step_id_last_used_at: int def __init__( - self, - prefetch_bucket_sz: int, - max_reuse_distance_in_numel: int, - max_available_parameters_in_numel: int, - allgather_stream: Stream, - prefetch_nvme: bool = False, + self, + prefetch_bucket_sz: int, + max_reuse_distance_in_numel: int, + max_available_parameters_in_numel: int, + allgather_stream: Stream, + prefetch_nvme: bool = False, ) -> None: # mapping of param -> handle for each param that is currently in flight self.__inflight_param_registry = __class__.__InflightParamRegistry() @@ -491,7 +491,6 @@ def __prefetch_nvme_param_partitions(self) -> None: swap_in_params[0].nvme_swapper.swap_in(swap_in_params, async_op=True) - class PreBackwardFunction(torch.autograd.Function): @staticmethod def forward(ctx, module, pre_backward_function, outputs): @@ -538,7 +537,6 @@ def backward(ctx, *args): return (None, None) + args - class FP16_DeepSpeedZeroOptimizer_Stage3(object): """ DeepSpeedZeroOptimizer designed to reduce the memory footprint @@ -1813,7 +1811,6 @@ def independent_gradient_partition_epilogue(self): def overlapping_partition_gradients_reduce_epilogue(self): self.independent_gradient_partition_epilogue() - def create_reduce_and_remove_grad_hooks(self): print_rank_0(f'[Begin] Create gradient reduction hooks') self.grad_accs = [] @@ -1962,7 +1959,6 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor] return grad_partitions_for_rank - def set_grad_positions(self): for i, group in enumerate(self.fp16_groups): current_offset = 0 @@ -2029,7 +2025,6 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params): return total_norm - @instrument_w_nvtx def __partition_grads(self, params_to_release: List[Parameter], @@ -2107,7 +2102,6 @@ def __partition_grads(self, gradient_offsets=offload_fp32_offsets[i], gradient_tensors=offload_fp32_gradients[i]) - def reduce_ready_partitions_and_remove_grads(self, param, i): #print_rank_0(f"Backward {debug_param2name_id_shape(param)}", force=True) self.reduce_independent_p_g_buckets_and_remove_grads(param, i) @@ -2597,8 +2591,8 @@ def _optimizer_states_and_gradient_swap_out(self, sub_group_id, timer_names=set( self.optimizer_swapper.swap_out_optimizer_state( parameter=self.fp32_partitioned_groups_flat[sub_group_id], - async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] is - not None) + async_swap=self.next_swappable_fp32_partitioned_groups[sub_group_id] + is not None) self.stop_timers([OPTIMIZER_SWAP_OUT_STATE]) see_memory_usage( diff --git a/tests/unit/test_onebit.py b/tests/unit/test_onebit.py index d428826d3a7a..a2034e8fd043 100644 --- a/tests/unit/test_onebit.py +++ b/tests/unit/test_onebit.py @@ -12,6 +12,7 @@ import time from deepspeed.runtime.pipe.topology import PipeDataParallelTopology, PipeModelDataParallelTopology + PipeTopo = PipeDataParallelTopology from deepspeed.runtime.pipe.module import PipelineModule, LayerSpec from .common import distributed_test diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 132ec4ab23d2..b0e7886f05ee 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -480,10 +480,10 @@ def forward(self, x: Tensor) -> Tensor: class EltwiseMultiplicationTestNetwork(Module): """used for testing purposes""" def __init__( - self, - weight1: Parameter, - weight2: Parameter, - weight3: Parameter, + self, + weight1: Parameter, + weight2: Parameter, + weight3: Parameter, ) -> None: super().__init__() self.__layer1 = EltwiseMultiplicationModule(weight1) @@ -547,12 +547,12 @@ def forward(self, x: Tensor, y: Tensor, prefetching: bool) -> Dict[str, Tensor]: @pytest.mark.parametrize("zero_grad", [True, False]) @pytest.mark.parametrize("iteration", list(range(1))) def test_zero3_param_partitioning_base( - param_persistence_threshold: int, - fp16_enabled: bool, - contiguous_gradients: bool, - offload_optimizer: bool, - zero_grad: bool, - iteration: int, + param_persistence_threshold: int, + fp16_enabled: bool, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, ) -> None: @distributed_test(world_size=[2]) def _test_zero3_param_partitioning(): @@ -991,11 +991,11 @@ def _distributed_test(): @pytest.mark.parametrize("zero_grad", [True]) @pytest.mark.parametrize("iteration", list(range(1))) def test_zero3_param_partitioning_base_bf16( - param_persistence_threshold: int, - contiguous_gradients: bool, - offload_optimizer: bool, - zero_grad: bool, - iteration: int, + param_persistence_threshold: int, + contiguous_gradients: bool, + offload_optimizer: bool, + zero_grad: bool, + iteration: int, ) -> None: @distributed_test(world_size=[2]) def _test_zero3_param_partitioning(): From 9db815fe3ae06157bb5fc0986257fbd044fc853b Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 8 Dec 2021 15:07:51 -0800 Subject: [PATCH 47/59] fix bfloat16 zero stage check (broken after merge commit) --- DeepSpeedExamples | 2 +- deepspeed/runtime/config.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/DeepSpeedExamples b/DeepSpeedExamples index 1fed12e8b375..174ae3bc8dbb 160000 --- a/DeepSpeedExamples +++ b/DeepSpeedExamples @@ -1 +1 @@ -Subproject commit 1fed12e8b375b0c54902827e7140d8266dfccd59 +Subproject commit 174ae3bc8dbb688cfaccb4afa15d6e2cdbe19ce5 diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 5cb6deb97078..ca40592cb80d 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -899,7 +899,7 @@ def _initialize_params(self, param_dict): self.fp16_enabled = get_fp16_enabled(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' - assert not (self.bfloat16_enabled and (self.zero_optimization_stage != 2)), 'bfloat16 mode is only enabled for Zero2 currently' + assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f'bfloat16 mode is only enabled for Zero 2 and 3 currently. got {self.zero_optimization_stage}' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( param_dict) self.amp_enabled = get_amp_enabled(param_dict) From 259ec153929f494bbfbd60b3628132b65787c38c Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Wed, 8 Dec 2021 23:15:59 +0000 Subject: [PATCH 48/59] +self.communication_data_type, -self.allreduce_always_fp32; delete dead code --- deepspeed/runtime/zero/stage3.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4675a6388551..8804d0c27acc 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1937,7 +1937,7 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor] dtype = get_only_unique_item(p.grad.dtype for p in params_to_reduce) full_grads_for_rank = [p.grad for p in params_to_reduce] - if self.allreduce_always_fp32: + if self.communication_data_type == torch.float32: full_grads_for_rank = [g.float() for g in full_grads_for_rank] if self.postscale_gradients and self.gradient_predivide_factor != 1.0: @@ -1954,7 +1954,7 @@ def __avg_scatter_grads(self, params_to_reduce: List[Parameter]) -> List[Tensor] g.mul(self.gradient_predivide_factor) for g in grad_partitions_for_rank ] - if self.allreduce_always_fp32: + if self.communication_data_type == torch.float32: grad_partitions_for_rank = [g.to(dtype) for g in grad_partitions_for_rank] return grad_partitions_for_rank @@ -2229,20 +2229,6 @@ def allreduce_no_retain(self, if len(small_bucket) > 0: self.allreduce_and_copy(small_bucket, rank=rank, log=log) - # allows using reduction of gradients instead of using all_reduce - def buffered_reduce_fallback(self, - rank, - grads, - elements_per_buffer=500000000, - log=None): - split_buckets = split_half_float_double(grads) - - for i, bucket in enumerate(split_buckets): - self.allreduce_no_retain(bucket, - numel_per_bucket=elements_per_buffer, - rank=rank, - log=log) - ############################################################################# ############################################################################# ############################################################################# From 96d224713a2cefa5b148ff0d9147b256ccb4c431 Mon Sep 17 00:00:00 2001 From: Tunji Ruwase Date: Thu, 9 Dec 2021 00:20:27 +0000 Subject: [PATCH 49/59] Add self.reduce_scatter --- deepspeed/runtime/zero/stage3.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 8804d0c27acc..d22141af44b9 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -678,6 +678,8 @@ def __init__(self, self.timers = timers + self.reduce_scatter = reduce_scatter + self.dp_process_group = dp_process_group self.partition_count = dist.get_world_size(group=self.dp_process_group) From 0f8affe3d507ea81593102c2375ea5bf0c0fb001 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 30 Dec 2021 13:46:00 +0000 Subject: [PATCH 50/59] Format fix --- tests/unit/test_zero.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index f712d44717ac..034b0362eea8 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -437,7 +437,6 @@ def _test_partition_nccl_alignment(model, hidden_dim): _test_partition_nccl_alignment(model=model, hidden_dim=hidden_dim) - def _ds_initialize_for_param_partitioning_testing(model: Module, cfg: dict) -> DeepSpeedEngine: ds_engine, _, _, _ = deepspeed.initialize( From 601d1f191cc74ef7865696110607d6c779b31f90 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Thu, 30 Dec 2021 19:10:09 +0000 Subject: [PATCH 51/59] Fix merge issues --- tests/unit/test_checkpointing.py | 2 +- tests/unit/test_zero_context.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_checkpointing.py b/tests/unit/test_checkpointing.py index 5b2e35d1e33f..9a438b96c960 100755 --- a/tests/unit/test_checkpointing.py +++ b/tests/unit/test_checkpointing.py @@ -15,7 +15,7 @@ from deepspeed.ops.op_builder import FusedLambBuilder, CPUAdamBuilder from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 -from util import required_torch_version +from .util import required_torch_version import argparse import pytest diff --git a/tests/unit/test_zero_context.py b/tests/unit/test_zero_context.py index 8e3b81776f41..66521e075ce1 100644 --- a/tests/unit/test_zero_context.py +++ b/tests/unit/test_zero_context.py @@ -7,7 +7,7 @@ import deepspeed from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus, partitioned_param_data_shape -from common import distributed_test, get_master_port +from .common import distributed_test, get_master_port def setup_serial_env(): From 31aecfca46427356050d01b16996763f0fab0b58 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Wed, 12 Jan 2022 11:47:59 -0800 Subject: [PATCH 52/59] iterate over params_to_fetch rather than make another iterator --- deepspeed/runtime/zero/stage3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index c93e3c4e4d2c..1c06ac1bad71 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -311,7 +311,7 @@ def fetch_sub_module(self, current_submodule: Module) -> None: self.__all_gather_params(params_to_fetch) # wait for parameters in the immediately needed submodule to become available - for param in iter_params(current_submodule): + for param in params_to_fetch: param.ds_active_sub_modules.add(current_submodule.id) debug_rank0(f"-wait: {param.ds_summary()}") if param in self.__inflight_param_registry: From 8736700eb9f0f77e298291af2107b03124962aa7 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Thu, 13 Jan 2022 16:18:08 -0800 Subject: [PATCH 53/59] add some TODOs --- deepspeed/runtime/zero/stage3.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 1c06ac1bad71..e432fb02c32a 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -233,6 +233,7 @@ def __init__( # because ideally in the future its replaced by an async allocation # mechanism which doesnt require any configuration by the user. self.__ongoing_fetch_events: Deque[Event] = collections.deque() + # TODO. make this configurable via JSON self.__max_ongoing_fetch_events: int = 2 """Tracing and Tracking @@ -801,6 +802,7 @@ def __init__(self, self.is_gradient_accumulation_boundary: bool = True self.__param_reduce_events: Deque[Event] = collections.deque() + # TODO. make this configurable via JSON self.__max_param_reduce_events: int = 2 if dist.get_rank() == 0: @@ -892,6 +894,7 @@ def __init__(self, if dist.get_rank(group=self.dp_process_group) == 0: see_memory_usage(f"After initializing ZeRO optimizer", force=False) + # TODO. factor out to a utility outside of stage3 @staticmethod def defragment(tensors: List[Tensor]) -> Tensor: """move provided tensors into a contiguous flat buffer, with some additional From 0bf7bcde3f5c1beac996b914e75f7e4e0fe761f5 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 18 Jan 2022 16:42:17 -0800 Subject: [PATCH 54/59] remove unnecessary division by micro_step_id --- deepspeed/runtime/zero/stage3.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index e432fb02c32a..6d567a74e1f0 100755 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2693,8 +2693,6 @@ def dump_post_step_gradients(self): @instrument_w_nvtx def unscale_and_clip_grads(self, sub_group_id, total_norm): - grad_groups_flat = [self.fp32_partitioned_groups_flat[sub_group_id].grad] - # compute combined scale factor for this group combined_scale = self.loss_scale if self.clip_grad > 0.: @@ -2702,16 +2700,8 @@ def unscale_and_clip_grads(self, sub_group_id, total_norm): clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.loss_scale - # to maintain behavior of averaging over accumulation steps - combined_scale *= self.micro_step_id + 1 - - for grad in grad_groups_flat: - if isinstance(grad, list): - sub_partitions = grad - for g in sub_partitions: - g.data.mul_(1. / combined_scale) - else: - grad.data.mul_(1. / combined_scale) + + self.fp32_partitioned_groups_flat[sub_group_id].grad.mul_(1. / combined_scale) def _check_overflow(self, partition_gradients=True): self.overflow = self.has_overflow(partition_gradients) From 43c00ff7488267cad169daa2686f1d29aeef3e18 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 18 Jan 2022 18:05:28 -0800 Subject: [PATCH 55/59] rename config keys "bfloat16" -> "bf16" --- deepspeed/runtime/config.py | 12 ++++++------ deepspeed/runtime/constants.py | 5 +++-- docs/_pages/config-json.md | 2 +- tests/unit/test_bf16.py | 12 ++++++------ tests/unit/test_zero.py | 2 +- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index ca40592cb80d..81db59b544ed 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -134,12 +134,12 @@ def get_fp16_enabled(param_dict): def get_bfloat16_enabled(param_dict): - if BFLOAT16 in param_dict.keys(): - return get_scalar_param(param_dict[BFLOAT16], - BFLOAT16_ENABLED, - BFLOAT16_ENABLED_DEFAULT) - else: - return False + for key in [BFLOAT16, BFLOAT16_OLD]: + if key in param_dict.keys(): + return get_scalar_param(param_dict[key], + BFLOAT16_ENABLED, + BFLOAT16_ENABLED_DEFAULT) + return False def get_fp16_master_weights_and_grads_enabled(param_dict): diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 8eeb3d5db513..2d16f39433c3 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -114,11 +114,12 @@ # Users can configure in ds_config.json as below example: BFLOAT16_FORMAT = ''' BFLOAT16 parameters should be of the format: -"bfloat16": { +"bf16": { "enabled": true } ''' -BFLOAT16 = "bfloat16" +BFLOAT16 = "bf16" +BFLOAT16_OLD = "bfloat16" # keeping for backwards compatibility BFLOAT16_ENABLED = "enabled" BFLOAT16_ENABLED_DEFAULT = False diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index d7b47de25f47..acea36c8f199 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -251,7 +251,7 @@ Example of **scheduler** | Configuration for using [bfloat16](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format) floating-point format as an alternative to FP16. BFLOAT16 requires hardware support (e.g., NVIDIA A100). An example, including the available dictionary keys is illustrated below. Training with bfloat16 does not require loss scaling. | None | ```json -"bfloat16": { +"bf16": { "enabled": true } ``` diff --git a/tests/unit/test_bf16.py b/tests/unit/test_bf16.py index 99d4a8c514ae..aa2ab132394c 100644 --- a/tests/unit/test_bf16.py +++ b/tests/unit/test_bf16.py @@ -45,7 +45,7 @@ def test_adam_bf16_zero_onecycle_compatibility(tmpdir, zero_stage, use_cpu_offlo "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -95,7 +95,7 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage, use_cpu_offload): "fp16": { "enabled": False, }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -139,7 +139,7 @@ def test_zero_empty_partition(tmpdir, zero_stage, use_cpu_offload): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "optimizer": { @@ -199,7 +199,7 @@ def test_zero_supported_client_optimizer(tmpdir, zero_stage, optimizer_construct "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { @@ -250,7 +250,7 @@ def test_zero2_reduce_scatter_off(tmpdir): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True } } @@ -290,7 +290,7 @@ def test_zero_empty_grad(tmpdir, stage): "fp16": { "enabled": False }, - "bfloat16": { + "bf16": { "enabled": True }, "zero_optimization": { diff --git a/tests/unit/test_zero.py b/tests/unit/test_zero.py index 034b0362eea8..c2ff33a14042 100755 --- a/tests/unit/test_zero.py +++ b/tests/unit/test_zero.py @@ -1015,7 +1015,7 @@ def _test_zero3_param_partitioning(): "lr": 1. } }, - "bfloat16": { + "bf16": { "enabled": True, "loss_scale": 1., } From 4574bc71a47314488782bf0cc0a4433fe22f2462 Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 18 Jan 2022 17:51:10 -0800 Subject: [PATCH 56/59] rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bit_weights_on_model_save --- .../config_templates/template_zero3.json | 2 +- deepspeed/runtime/engine.py | 23 +++++++++++-------- deepspeed/runtime/zero/config.py | 16 +++++++++---- deepspeed/runtime/zero/constants.py | 7 +++--- docs/_pages/config-json.md | 6 ++--- docs/_tutorials/zero.md | 6 ++--- docs/code-docs/source/training.rst | 2 +- 7 files changed, 37 insertions(+), 25 deletions(-) diff --git a/deepspeed/autotuning/config_templates/template_zero3.json b/deepspeed/autotuning/config_templates/template_zero3.json index e00f47f65560..620d7eb10e81 100644 --- a/deepspeed/autotuning/config_templates/template_zero3.json +++ b/deepspeed/autotuning/config_templates/template_zero3.json @@ -11,7 +11,7 @@ "stage3_max_reuse_distance": 1e9, "stage3_prefetch_bucket_size": 5e8, "stage3_param_persistence_threshold": 1e6, - "stage3_gather_fp16_weights_on_model_save": false, + "stage3_gather_16bit_weights_on_model_save": false, "sub_group_size": 1e12 } } diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d8f71ef01882..c3f68447f015 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -708,8 +708,8 @@ def zero_prefetch_bucket_size(self): def zero_param_persistence_threshold(self): return self._config.zero_config.param_persistence_threshold - def zero_gather_fp16_weights_on_model_save(self): - return self._config.zero_config.gather_fp16_weights_on_model_save + def zero_gather_16bit_weights_on_model_save(self): + return self._config.zero_config.gather_16bit_weights_on_model_save def zero_grad_hooks(self): return self._config.zero_config.grad_hooks @@ -2955,7 +2955,7 @@ def _save_zero_checkpoint(self, save_path, tag): self._copy_recovery_script(save_path) logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name)) - def _zero3_consolidated_fp16_state_dict(self): + def _zero3_consolidated_16bit_state_dict(self): """ Get a full non-partitioned state_dict with fp16 weights on cpu. @@ -3024,9 +3024,14 @@ def get_layer_state_dict(module, prefix=""): return state_dict def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): - r"""Save fp16 model weights + """has been renamed to save_16bit_model, keeping this around for backwards + compatibility""" + return self.save_16bit_model(save_dir, save_filename) - This method saves the fp16 model weights at the desired destination. + def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin"): + r"""Save 16bit model weights + + This method saves the 16bit model weights at the desired destination. Arguments: save_dir: Required. Directory for saving the model @@ -3034,7 +3039,7 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): Returns: ``True`` when a model has been saved, ``False`` otherwise. It will not be saved if - stage3_gather_fp16_weights_on_model_save is ``False``. + stage3_gather_16bit_weights_on_model_save is ``False``. Important: all processes must call this method and not just the process with rank 0. It is because the processes need to work in sync to gather the weights. This method will hang @@ -3045,13 +3050,13 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): path = os.path.join(save_dir, save_filename) if self.zero_optimization_partition_weights(): - if self.zero_gather_fp16_weights_on_model_save(): + if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default - state_dict = self._zero3_consolidated_fp16_state_dict() + state_dict = self._zero3_consolidated_16bit_state_dict() else: # the model will be bogus if not consolidated so don't confuse the user by saving it logger.info( - f"Did not save the model {path} because `stage3_gather_fp16_weights_on_model_save` is False" + f"Did not save the model {path} because `stage3_gather_16bit_weights_on_model_save` is False" ) return False else: diff --git a/deepspeed/runtime/zero/config.py b/deepspeed/runtime/zero/config.py index 782d4d9e39fd..3804fb50a371 100755 --- a/deepspeed/runtime/zero/config.py +++ b/deepspeed/runtime/zero/config.py @@ -36,7 +36,7 @@ def __init__(self, param_dict): self.param_persistence_threshold = None self.max_live_parameters = None self.max_reuse_distance = None - self.gather_fp16_weights_on_model_save = None + self.gather_16bit_weights_on_model_save = None self.ignore_unused_parameters = None self.round_robin_gradients = None @@ -171,10 +171,16 @@ def _initialize(self, zero_config_dict): ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT) - self.gather_fp16_weights_on_model_save = get_scalar_param( - zero_config_dict, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT) + # config key has been renamed to use "16bit" instead of "fp16." falling back + # to old config name in order to preserve backwards compatibility + self.gather_16bit_weights_on_model_save = ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT + for key in [ + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE, + ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE + ]: + if key in zero_config_dict: + self.gather_16bit_weights_on_model_save = zero_config_dict[key] + break self.ignore_unused_parameters = get_scalar_param( zero_config_dict, diff --git a/deepspeed/runtime/zero/constants.py b/deepspeed/runtime/zero/constants.py index e3b2dfc0c68f..13efe1af768a 100755 --- a/deepspeed/runtime/zero/constants.py +++ b/deepspeed/runtime/zero/constants.py @@ -113,7 +113,8 @@ # gathers params for saving a model - inefficient but is required in certain situations ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_fp16_weights_on_model_save' -ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False +ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE = 'stage3_gather_16bit_weights_on_model_save' +ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT = False # Now just used in stage2 complete_grad_norm_calculation_for_cpu_offload # Enable this option to avoid: @@ -161,8 +162,8 @@ ZERO_OPTIMIZATION_PREFETCH_BUCKET_SIZE_DEFAULT, ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD: ZERO_OPTIMIZATION_PARAM_PERSISTENCE_THRESHOLD_DEFAULT, - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE: - ZERO_OPTIMIZATION_GATHER_FP16_WEIGHTS_ON_MODEL_SAVE_DEFAULT, + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE: + ZERO_OPTIMIZATION_GATHER_16BIT_WEIGHTS_ON_MODEL_SAVE_DEFAULT, ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS: ZERO_OPTIMIZATION_IGNORE_UNUSED_PARAMETERS_DEFAULT, ZERO_OPTIMIZATION_LEGACY_STAGE1: diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index acea36c8f199..84a7cd1fce95 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -329,7 +329,7 @@ Enabling and configuring ZeRO memory optimizations "stage3_param_persistence_threshold" : 1e6, "sub_group_size" : 1e12, "elastic_checkpoint" : [true|false], - "stage3_gather_fp16_weights_on_model_save": [true|false], + "stage3_gather_16bit_weights_on_model_save": [true|false], "ignore_unused_parameters": [true|false] "round_robin_gradients": [true|false] } @@ -433,11 +433,11 @@ Enabling and configuring ZeRO memory optimizations | Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` | -***stage3_gather_fp16_weights_on_model_save***: [boolean] +***stage3_gather_16bit_weights_on_model_save***: [boolean] | Description | Default | |--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ------- | -| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | +| Consolidate the weights before saving the model by `save_16bit_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gathers the weights when this option is enabled and then saves the fp16 model weights. | `False` | ***cpu_offload***: [boolean] diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 411ddf3478e7..7721f45ece4f 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -252,19 +252,19 @@ If you need to take the pretrained weights out of Deepspeed here is what you can ```json "zero_optimization": { - "stage3_gather_fp16_weights_on_model_save": true + "stage3_gather_16bit_weights_on_model_save": true }, ``` And then save the model using: ```python if self.deepspeed: - self.deepspeed.save_fp16_model(output_dir, output_file) + self.deepspeed.save_16bit_model(output_dir, output_file) ``` Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. -Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them). +Note that if `stage3_gather_16bit_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them). You can use this method to save ZeRO-2 weights as well. If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index 52e124fc3b40..e3a7029aae50 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -35,7 +35,7 @@ Gradient Accumulation Model Saving ------------ -.. autofunction:: deepspeed.DeepSpeedEngine.save_fp16_model +.. autofunction:: deepspeed.DeepSpeedEngine.save_16bit_model Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file. From e04dc6a24f9d587f9d45cbd3b6873022951f49ee Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 18 Jan 2022 19:04:52 -0800 Subject: [PATCH 57/59] add unit test to check backwards compatibility for gather_16bit_weights --- tests/unit/test_config.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e66544833c8e..e6d93b91bc88 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -3,6 +3,9 @@ import pytest import json import argparse + +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig + from .common import distributed_test, get_test_path from .simple_model import SimpleModel, create_config_from_dict, random_dataloader import torch.distributed as dist @@ -114,6 +117,22 @@ def test_temp_config_json(tmpdir): assert 'train_batch_size' in config_json +@pytest.mark.parametrize("gather_weights_key", + [ + "stage3_gather_16bit_weights_on_model_save", + "stage3_gather_fp16_weights_on_model_save" + ]) +def test_gather_16bit_params_on_model_save(gather_weights_key): + config_dict = { + "zero_optimization": { + gather_weights_key: True, + }, + } + config = DeepSpeedZeroConfig(config_dict) + + assert config.gather_16bit_weights_on_model_save == True + + def test_deprecated_deepscale_config(tmpdir): config_dict = { "train_batch_size": 1, From 391cecf700416623c5b015c1b4158644a0eaba7c Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Tue, 18 Jan 2022 19:13:33 -0800 Subject: [PATCH 58/59] added test to confirm bf16 key bwd compatibility --- tests/unit/test_config.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index e6d93b91bc88..a88cb2931d95 100755 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -12,7 +12,7 @@ # A test on its own import deepspeed -from deepspeed.runtime.config import DeepSpeedConfig +from deepspeed.runtime.config import DeepSpeedConfig, get_bfloat16_enabled def test_cuda(): @@ -133,6 +133,16 @@ def test_gather_16bit_params_on_model_save(gather_weights_key): assert config.gather_16bit_weights_on_model_save == True +@pytest.mark.parametrize("bf16_key", ["bf16", "bfloat16"]) +def test_get_bfloat16_enabled(bf16_key): + cfg = { + bf16_key: { + "enabled": True, + }, + } + assert get_bfloat16_enabled(cfg) == True + + def test_deprecated_deepscale_config(tmpdir): config_dict = { "train_batch_size": 1, From 536d1718bd713fad509a145c1f0e1b368b207aeb Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 19 Jan 2022 18:13:19 +0000 Subject: [PATCH 59/59] Format fixes --- deepspeed/runtime/zero/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index b8f184ee8afe..0017213a9941 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -81,6 +81,7 @@ def assert_ints_same_as_other_ranks(ints: List[int]) -> None: if ints != rank0_ints: raise RuntimeError(f"disagreement between rank0 and rank{dist.get_rank()}: " f"rank0: {rank0_ints}, rank{dist.get_rank()}: {ints}") - + + class ZeRORuntimeException(Exception): pass