Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various ZeRO Stage3 Optimizations + Improvements (including bfloat16 support) #1453

Merged
merged 91 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
fe26423
Changes for bfloat16 Zero2
raamjad Aug 14, 2021
8864f91
ZeRO stage3 optimizations, with some bug fixes
Sep 29, 2021
e66aedc
fix import in ut
Oct 12, 2021
350a7a0
ran yapf
Oct 12, 2021
b37a4f0
Merge branch 'master' into s3-pr
tjruwase Oct 13, 2021
f383947
improvements to cache flush warn log
Oct 13, 2021
b2a1c95
backwards compatibility with older versions of pytorch
Oct 14, 2021
d8678fa
handle edge case where reduced tensor smaller than world size
Oct 14, 2021
a0faca0
moved event synchronization to allgather handle wait() call
Oct 14, 2021
bf20c90
removed unnecessary barrier call
Oct 14, 2021
a353017
Merge branch 'master' into s3-pr
jfc4050 Oct 14, 2021
c51ba46
formatting fix after resolving merge conflict
Oct 14, 2021
ff01f5c
skip nvme prefetch when trace not complete
Oct 14, 2021
13093eb
opportunistically avoid memory allocation in allgather coalesced wher…
Oct 15, 2021
3cdcbdf
Merge branch 'master' into s3-pr
tjruwase Oct 20, 2021
64d74d1
Merge branch 'master' into s3-pr
tjruwase Oct 21, 2021
e30e6cc
Merge branch 'master' into s3-pr
tjruwase Oct 22, 2021
f19593d
fix indentation after merge
Oct 22, 2021
f72bc78
fixes to account for parameter offload
Oct 22, 2021
660df05
accounting for torch.cuda.memory_stats not being available
Oct 22, 2021
4f9477f
moved partition_all_params to optimizer step
Oct 22, 2021
818651c
Merge branch 'master' into s3-pr
jeffra Oct 26, 2021
f681201
Merge branch 'master' into s3-pr
jfc4050 Oct 26, 2021
bb34f90
allgathering on params before item gets called
Oct 25, 2021
9f3b504
fix param status checks
Oct 25, 2021
1772d41
fix grad accumulation with optimizer offload
Oct 25, 2021
5f213d8
grad norm computation fix for optimizer offload
Oct 26, 2021
3198805
change post divide in reduce-scatter to pre divide
Oct 26, 2021
2225659
fix gradient race condition w/ optimizer offload
Oct 26, 2021
5aa9bd5
improve inf/nan gradient tracking
Oct 26, 2021
a1a60ed
don't prefetch when not in training mode
Oct 26, 2021
df41659
format fix after merging
Oct 26, 2021
ab3a82a
fix prefetching issue when using NVME offload
Oct 27, 2021
025a41e
Merge branch 'master' into s3-pr
tjruwase Oct 29, 2021
6f9415b
Merge branch 'master' into s3-pr
jfc4050 Nov 1, 2021
8d12281
Merge branch 'master' into s3-pr
jfc4050 Nov 2, 2021
a26d1fb
improved defragmentation for fp16 parameters
Oct 31, 2021
937f04e
relative imports for bf16 tests
Nov 2, 2021
e74f509
changes for bwd compatibility with pytorch 1.2
Nov 2, 2021
6ee558d
remove buffered_reduce_fallback
Nov 2, 2021
14e22a2
removed unused parameter offset bookkeeping
Nov 3, 2021
16281df
fixed tracking for multiple param groups
Nov 3, 2021
38af6b1
Merge branch 'master' into s3-pr
tjruwase Nov 3, 2021
cc7011e
unbroke bfloat16 config after merge conflict
Nov 3, 2021
806b072
using base allgather params when only 1 param
Nov 3, 2021
bf0dd66
cleanup/fixes for fp16 partition defragmentation
Nov 3, 2021
73207ae
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
d3ecb1f
Merge branch 'master' into s3-pr
tjruwase Nov 5, 2021
812fe67
Merge branch 'master' into s3-pr
tjruwase Nov 11, 2021
6dc21a6
switch to CRLF
jeffra Nov 18, 2021
2a38302
convert to same new-line style as master
jeffra Nov 18, 2021
16f1d21
align new line with master
jeffra Nov 18, 2021
11d590a
Merge branch 'master' into s3-pr
tjruwase Nov 23, 2021
2b5f6ea
Fix merge issues
tjruwase Nov 23, 2021
80b53d3
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
6dfe693
Merge branch 'master' into s3-pr
tjruwase Nov 24, 2021
912e6f0
switch to CRLF
jeffra Nov 29, 2021
4b0133b
fix to LF line endings
jeffra Nov 30, 2021
b998206
minor merge fixes
jeffra Nov 30, 2021
d6deecb
remove extra bfloat16_enabled definition
Nov 30, 2021
2a4ef29
asserting params inflight for AllGatherHandle
Nov 30, 2021
90182b6
remove get_cuda_mem_allocated_str
Nov 30, 2021
ad847ed
Merge branch 'master' into s3-pr
tjruwase Dec 8, 2021
f590ba4
Format fixes
tjruwase Dec 8, 2021
9db815f
fix bfloat16 zero stage check (broken after merge commit)
Dec 8, 2021
259ec15
+self.communication_data_type, -self.allreduce_always_fp32; delete de…
tjruwase Dec 8, 2021
96d2247
Add self.reduce_scatter
tjruwase Dec 9, 2021
2630b75
Merge branch 'master' into s3-pr
tjruwase Dec 9, 2021
79fd42c
Merge branch 'master' into s3-pr
tjruwase Dec 11, 2021
8565e04
Merge branch 'master' into s3-pr
jeffra Dec 14, 2021
06eab1a
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
0f8affe
Format fix
tjruwase Dec 30, 2021
3436422
Merge branch 'master' into s3-pr
tjruwase Dec 30, 2021
601d1f1
Fix merge issues
tjruwase Dec 30, 2021
5dcee36
Merge branch 's3-pr' of github.com:jfc4050/DeepSpeed into s3-pr
tjruwase Dec 30, 2021
580d25e
Merge branch 'master' into s3-pr
tjruwase Jan 3, 2022
872f451
Merge branch 'master' into s3-pr
jeffra Jan 7, 2022
e236293
Merge branch 'master' into s3-pr
tjruwase Jan 10, 2022
43b3b83
Merge branch 'master' into s3-pr
tjruwase Jan 11, 2022
83905ac
Merge branch 'master' into s3-pr
tjruwase Jan 12, 2022
31aecfc
iterate over params_to_fetch rather than make another iterator
Jan 12, 2022
8736700
add some TODOs
Jan 14, 2022
516379d
Merge branch 'master' into s3-pr
tjruwase Jan 14, 2022
0bf7bcd
remove unnecessary division by micro_step_id
Jan 19, 2022
43c00ff
rename config keys "bfloat16" -> "bf16"
Jan 19, 2022
4574bc7
rename stage3_gather_fp16_weights_on_model_save -> stage3_gather_16bi…
Jan 19, 2022
e04dc6a
add unit test to check backwards compatibility for gather_16bit_weights
Jan 19, 2022
391cecf
added test to confirm bf16 key bwd compatibility
Jan 19, 2022
3d26469
Merge branch 'master' into s3-pr
tjruwase Jan 19, 2022
536d171
Format fixes
tjruwase Jan 19, 2022
19f3538
Merge branch 'master' into s3-pr
tjruwase Jan 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions csrc/includes/type_shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

Expand All @@ -46,6 +51,11 @@
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}

Expand Down
88 changes: 88 additions & 0 deletions deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""batched collective operations for overhead amortization and better
bandwidth utilization"""

import math
from typing import List

import torch
from torch import Tensor
import torch.distributed
from torch.distributed import ProcessGroup
import torch.nn.functional

from deepspeed.utils import instrument_w_nvtx


@instrument_w_nvtx
@torch.no_grad()
def reduce_scatter_coalesced(
tensors: List[Tensor],
group: ProcessGroup = None,
) -> List[Tensor]:
"""simultaneously reduce-scatter a list of tensors - this can be done more
efficiently than individual reduce scatter calls

TODO. see if PyTorch team wants a c++ verson of this for ProcessGroupNCCL
"""
this_rank = torch.distributed.get_rank(group)
world_sz = torch.distributed.get_world_size(group)

partition_lst_for_each_tensor = tuple(
torch.chunk(tensor.view(-1),
world_sz) for tensor in tensors)
jfc4050 marked this conversation as resolved.
Show resolved Hide resolved
padded_partition_sz_for_each_tensor = tuple(
math.ceil(t.numel() / world_sz) for t in tensors)

if len(tensors) == 1 and tensors[0].numel() % world_sz == 0:
# if there's only one tensor being reduced and we don't need to pad
# we have an opportunity to avoid a memory allocation
tensor_partition_flat_buffer = tensors[0].view(-1)
else:
# interleave tensor partitions such that the correct reduced partitions of each tensor
# end up at each rank
tensor_partitions_lst_with_padding = []
for rank in range(world_sz):
for tensor_idx in range(len(tensors)):
# add tensor content
tensor_chunk = partition_lst_for_each_tensor[tensor_idx][rank]
tensor_partitions_lst_with_padding.append(tensor_chunk)

# add padding if necessary
padding_sz = padded_partition_sz_for_each_tensor[
tensor_idx] - tensor_chunk.numel()
if padding_sz > 0:
tensor_partitions_lst_with_padding.append(
torch.empty(padding_sz,
dtype=tensor_chunk.dtype,
device=tensor_chunk.device))

tensor_partition_flat_buffer = instrument_w_nvtx(
torch.cat)(tensor_partitions_lst_with_padding)

tensor_partition_buffer_for_each_rank: List[Tensor] = torch.chunk(
tensor_partition_flat_buffer,
world_sz)

# batched reduce-scatter call
instrument_w_nvtx(torch.distributed._reduce_scatter_base)(
tensor_partition_buffer_for_each_rank[this_rank],
tensor_partition_flat_buffer,
group=group,
)

# post-divide
tensor_partition_buffer_for_each_rank[this_rank].div_(world_sz)

# reverse procedure of the interleaving done previously, done on the
# result of the batched reduce-scatter
output_lst: List[Tensor] = [None] * len(tensors)
offset = 0
for tensor_idx in range(len(tensors)):
output_lst[tensor_idx] = tensor_partition_buffer_for_each_rank[this_rank].narrow(
0,
offset,
partition_lst_for_each_tensor[tensor_idx][this_rank].numel())

offset += padded_partition_sz_for_each_tensor[tensor_idx]

return output_lst
18 changes: 17 additions & 1 deletion deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,15 @@ def get_fp16_enabled(param_dict):
return False


def get_bfloat16_enabled(param_dict):
if BFLOAT16 in param_dict.keys():
return get_scalar_param(param_dict[BFLOAT16],
BFLOAT16_ENABLED,
BFLOAT16_ENABLED_DEFAULT)
else:
return False


def get_fp16_master_weights_and_grads_enabled(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16],
Expand All @@ -128,6 +137,8 @@ def get_loss_scale(param_dict):
return get_scalar_param(param_dict[FP16],
FP16_LOSS_SCALE,
FP16_LOSS_SCALE_DEFAULT)
elif get_bfloat16_enabled(param_dict):
return 1.0
else:
return FP16_LOSS_SCALE_DEFAULT

Expand All @@ -137,6 +148,8 @@ def get_initial_dynamic_scale(param_dict):
initial_scale_power = get_scalar_param(param_dict[FP16],
FP16_INITIAL_SCALE_POWER,
FP16_INITIAL_SCALE_POWER_DEFAULT)
elif get_bfloat16_enabled(param_dict):
initial_scale_power = 0
else:
initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT

Expand Down Expand Up @@ -791,6 +804,9 @@ def _initialize_params(self, param_dict):
self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(
param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
assert not (self.bfloat16_enabled and (self.zero_optimization_stage not in {2, 3})), f"bfloat16 mode is only enabled for ZeRO 2 and 3 currently, got {self.zero_optimization_stage}"
self.amp_enabled = get_amp_enabled(param_dict)
self.amp_params = get_amp_params(param_dict)
self.loss_scale = get_loss_scale(param_dict)
Expand Down Expand Up @@ -964,7 +980,7 @@ def _do_error_check(self):
assert self.zero_enabled and self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "Fp16_master_weights_and_grads is only supported with ZeRO Stage 2 for now."

def _do_warning_check(self):
fp16_enabled = self.fp16_enabled or self.zero_enabled
fp16_enabled = self.fp16_enabled

vocabulary_size = self._param_dict.get(VOCABULARY_SIZE, VOCABULARY_SIZE_DEFAULT)
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
Expand Down
16 changes: 16 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,22 @@
SPARSE_GRADIENTS = "sparse_gradients"
SPARSE_GRADIENTS_DEFAULT = False

#########################################
# BFLOAT16 support
#########################################
# BFLOAT16 feature. By default, this feature is not enabled.
# Users can configure in ds_config.json as below example:
BFLOAT16_FORMAT = '''
BFLOAT16 parameters should be of the format:
"bfloat16": {
"enabled": true
}
'''
BFLOAT16 = "bfloat16"
Copy link
Collaborator

@stas00 stas00 Oct 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: bfloat16 - I proposed some month back that we don't create too many config entries, but instead switch to using a new dtype block, where the user can flip from bf16 to fp16 to fp32

Especially since they are mutually-exclusive.

But that discussion wasn't concluded. Now it's a good time to bring it back

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds reasonable. note that this change will be published as part of a separate PR, i just made my changes on top of it so it ended up in here. will rebase once that PR makes it in

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please note that I'm just a contributor, so my suggestions are just that - suggestions. Therefore in order not to waste your time, please first secure an agreement from the Deepspeed team when it comes to changing APIs.

In particular this one as it'd require back-compat code to support the current config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@raamjad, in context of #1398, if you could add this dtype block, perhaps as a follow up PR, that would be great. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the current coverage of bf16 is very limited, I thought it may be better to do this refactoring of config later once there is higher coverage.
Do you prefer that this be done now?
@stas00 Can you point me to your comments/where you suggested about the dtype block so I know what shape of config changes were suggested

Copy link
Collaborator

@stas00 stas00 Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of my proposals was very simple: it's adding a new top-level config dtype and dropping enabled from fp16 config, and adding bf16 block.

{
    "dtype": "bf16",
    "fp16": {
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        [...]
    },    
}

This approach allows users to keep nuanced settings for each dtype in the same config, but dtype will enable one over the others, so one can easily switch the settings in only one place.

It hasn't been approved by the Deepspeed team (as in no decision has been made about it).

@tjruwase, if you have access to the Teams log from Apr-30 this is when we discussed this. But it won't show it me - the search only shows a snippet. search for 'dtype mockup'.


the other proposal is to have a single dtype block that will take over the fp16 block. This would be useful if many of the config options of bf16 and fp16 overlap. So:

{
    "dtype": {
        "enabled": "fp16",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
}

and for bf16:

{
    "dtype": {
        "enabled": "bf16",
        [...]
    },
}

Copy link
Collaborator

@stas00 stas00 Oct 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a maintainer of DS integration in HF Transformers, I find that it's the easiest when users can download ready-to-use config files, so in my experience having all the config sections already predefined in the config file makes it easier for the user. So the first approach would be preferable for that particular use case.

But of course there can be other ways...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00, thanks for reviving your awesome proposals. Sorry, mere mortals such as we can barely keep up with your genius :).

@stas00, @raamjad is it okay, if we continue this chat on #1398? I will add link to this thread and also post the referenced teams chat. Thanks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're too kind, @tjruwase - it's easy to come up with ideas, it's far from easy to make them a reality. So that's where your genius comes in ;)

Thank you for find that old discussion and repasting it here, as MSFT Teams won't let me access it.


BFLOAT16_ENABLED = "enabled"
BFLOAT16_ENABLED_DEFAULT = False

#########################################
# FP16 support
#########################################
Expand Down
50 changes: 29 additions & 21 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
import deepspeed.runtime.lr_schedules as lr_schedules
import deepspeed.utils.groups as groups
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils import logger, log_dist, init_distributed, instrument_w_nvtx
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
Expand Down Expand Up @@ -82,6 +82,7 @@ def split_half_float_double_csr(tensors):
"torch.cuda.HalfTensor",
"torch.cuda.FloatTensor",
"torch.cuda.DoubleTensor",
"torch.cuda.BFloat16Tensor",
CSRTensor.type()
]

Expand Down Expand Up @@ -527,6 +528,9 @@ def fp16_enabled(self):
def fp16_master_weights_and_gradients(self):
return self._config.fp16_master_weights_and_gradients

def bfloat16_enabled(self):
jfc4050 marked this conversation as resolved.
Show resolved Hide resolved
return self._config.bfloat16_enabled

def amp_enabled(self):
return self._config.amp_enabled

Expand Down Expand Up @@ -740,32 +744,32 @@ def is_replicated(p):
self.broadcast_src_rank,
group=self.data_parallel_group)

@staticmethod
def __check_params(model: Module, dtype: torch.dtype) -> None:
if not all(param.dtype == dtype
for param in model.parameters()) and dist.get_rank() == 0:
raise ValueError(
f"{dtype} is enabled but the following parameters have dtype that is "
f"not {dtype}: "
f"{[(n, p.dtype) for n, p in model.named_parameters() if p.dtype != dtype]}"
)

def _configure_distributed_model(self, model):
self.module = model
if self.fp16_enabled():
if self.zero_optimization_partition_weights() and any(
[hasattr(param,
'ds_id') for param in self.module.parameters()]):
if not all(
[param.dtype == torch.half for param in self.module.parameters()]):
names = [
n for n,
p in self.module.named_parameters() if p.dtype != torch.half
]
raise ValueError(
f"fp16 is enabled but the following parameters have dtype that is not fp16: {', '.join(names)}"
)
self.__check_params(self.module, torch.half)
self.module.half()
elif self.bfloat16_enabled():
if self.zero_optimization_partition_weights() and any(
hasattr(param,
'ds_id') for param in self.module.parameters()):
self.__check_params(self.module, torch.bfloat16)
self.module.bfloat16()
else:
if not all(
[param.dtype == torch.float for param in self.module.parameters()]):
names = [
n for n,
p in self.module.named_parameters() if p.dtype != torch.float
]
raise ValueError(
f"fp32 is enabled but the following parameters have dtype that is not fp32: {', '.join(names)}"
)
self.__check_params(self.module, torch.float)

if not self.dont_change_device:
self.module.to(self.device)
Expand Down Expand Up @@ -882,7 +886,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
)
self.optimizer = self._configure_zero_optimizer(basic_optimizer)
elif self.amp_enabled():
assert not self.fp16_enabled(), "Cannot enable both amp with (legacy) fp16 mode"
assert not (self.fp16_enabled() or self.bfloat16_enabled()), "Cannot enable both amp with (legacy) fp16 or bfloat16 mode"
amp_params = self.amp_params()
if self.global_rank == 0:
logger.info(f"Initializing AMP with these params: {amp_params}")
Expand Down Expand Up @@ -1278,6 +1282,7 @@ def _scale_loss_by_gas(self, prescaled_loss):

return scaled_loss

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation

Expand Down Expand Up @@ -1318,7 +1323,8 @@ def forward(self, *inputs, **kwargs):
if self.training_dataloader is None:
self.tput_timer.start()

loss = self.module(*inputs, **kwargs)
with torch.cuda.nvtx.range("DeepspeedEngine.forward::module_forward"):
loss = self.module(*inputs, **kwargs)

if self.zero_optimization_partition_weights():
# Reset the ZeRO-3 state if we are only doing forward-passes (ie evaluation).
Expand Down Expand Up @@ -1347,6 +1353,7 @@ def forward(self, *inputs, **kwargs):

return loss

@instrument_w_nvtx
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
Expand All @@ -1364,6 +1371,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False):
r"""Execute backward pass on the loss

Expand Down
9 changes: 9 additions & 0 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,3 +840,12 @@ def call_to_str(base, *args, **kwargs):
name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items())
name += ')'
return name


def get_only_unique_item(items):
item_set = set(items)
if len(item_set) != 1:
raise RuntimeError(f"expected there to be only one unique element in {items}")
unique_item, = item_set

return unique_item
Loading