diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py index aaf8c31a1eb..f1b450cd5f1 100644 --- a/mmdet/models/detectors/base.py +++ b/mmdet/models/detectors/base.py @@ -198,6 +198,16 @@ def _parse_losses(self, losses): loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key) + # If the loss_vars has different length, GPUs will wait infinitely + if dist.is_available() and dist.is_initialized(): + log_var_length = torch.tensor(len(log_vars), device=loss.device) + dist.all_reduce(log_var_length) + message = (f'rank {dist.get_rank()}' + + f' len(log_vars): {len(log_vars)}' + ' keys: ' + + ','.join(log_vars.keys())) + assert log_var_length == len(log_vars) * dist.get_world_size(), \ + 'loss log variables are different across GPUs!\n' + message + log_vars['loss'] = loss for loss_name, loss_value in log_vars.items(): # reduce loss when distributed training