Skip to content

Commit

Permalink
bf16 inference (#1917)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
tjruwase and jeffra authored Apr 29, 2022
1 parent 96c8bf3 commit af58f63
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 9 deletions.
25 changes: 16 additions & 9 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors,
clip_tensors_by_global_norm,
get_grad_norm,
clip_gradients,
DummyOptim,
align_dense_tensors,
all_gather_dp_groups,
bwc_tensor_model_parallel_rank,
Expand Down Expand Up @@ -85,6 +84,8 @@ def __init__(self,
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
self.optimizer = init_optimizer
self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

self.clip_grad = clip_grad
self.norm_type = norm_type
self.mpu = mpu
Expand All @@ -94,10 +95,6 @@ def __init__(self,
self.real_dp_process_group = [
dp_process_group for i in range(len(self.optimizer.param_groups))
]
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [
dp_world_size for i in range(len(self.optimizer.param_groups))
]

# Load pre-built or JIT compile (un)flatten ops
util_ops = UtilsBuilder().load()
Expand All @@ -124,6 +121,17 @@ def __init__(self,
self.step_count = 0
self.groups_padding = []

if self.using_real_optimizer:
self._setup_for_real_optimizer()

see_memory_usage('end bf16_optimizer', force=True)

def _setup_for_real_optimizer(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
self.partition_count = [
dp_world_size for i in range(len(self.optimizer.param_groups))
]

for i, param_group in enumerate(self.optimizer.param_groups):
see_memory_usage(f'before initializing group {i}', force=True)

Expand Down Expand Up @@ -210,8 +218,6 @@ def __init__(self,
# Need optimizer states initialized before linking lp to optimizer state
self._link_all_hp_params()

see_memory_usage('end bf16_optimizer', force=True)

def _link_all_hp_params(self):
dp_world_size = dist.get_world_size(group=self.dp_process_group)
for i, param_group in enumerate(self.optimizer.param_groups):
Expand Down Expand Up @@ -482,7 +488,8 @@ def load_state_dict(self,
src_tensor = _get_padded_tensor(saved, current.numel())
current.data.copy_(src_tensor.data)

self._link_all_hp_params()
if load_optimizer_states:
self._link_all_hp_params()

@property
def param_groups(self):
Expand Down
5 changes: 5 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def __init__(
elif self.zero_optimization():
# no optim selected but zero is enabled
self.optimizer = self._configure_zero_optimizer(optimizer=None)
elif self.bfloat16_enabled():
self.optimizer = self._configure_bf16_optimizer(optimizer=None)

self._get_model_parameters()

Expand Down Expand Up @@ -1280,6 +1282,9 @@ def _configure_fp16_optimizer(self, optimizer):
def _configure_bf16_optimizer(self, optimizer):
clip_grad = self.gradient_clipping()

if optimizer is None:
optimizer = DummyOptim(list(self.module.parameters()))

if self.global_rank == 0:
logger.info('Creating unfused BF16 optimizer')
timers = self.timers if self.wall_clock_breakdown() else None
Expand Down

0 comments on commit af58f63

Please sign in to comment.