diff --git a/mmengine/_strategy/colossalai.py b/mmengine/_strategy/colossalai.py index ffeb591dcf..1bdc96c794 100644 --- a/mmengine/_strategy/colossalai.py +++ b/mmengine/_strategy/colossalai.py @@ -94,20 +94,25 @@ class ColossalAIOptimWrapper(OptimWrapper): def __init__(self, optimizer: torch.optim.Optimizer, - booster: Booster, + booster: Optional[Booster] = None, accumulative_counts: int = 1): super().__init__(optimizer, accumulative_counts=accumulative_counts) self.booster = booster @contextmanager def optim_context(self, model: nn.Module): + assert isinstance(self.booster, Booster), \ + 'Please set the booster attribute before using ' \ + '`ColossalAIOptimWrapper`.' if self.booster.plugin.support_no_sync(): - sync_context = self.booster.no_sync(model, self.optimizer) + no_sync_context = self.booster.no_sync(model, self.optimizer) else: yield return - if not self.should_sync(): - with sync_context: + if self.should_sync(): + yield + else: + with no_sync_context: yield def backward(self, loss: torch.Tensor, **kwargs) -> None: @@ -305,7 +310,6 @@ def prepare( # optim_wrapper is required by booster if optim_wrapper is not None and isinstance(optim_wrapper, dict): optim_wrapper.setdefault('type', 'ColossalAIOptimWrapper') - optim_wrapper.setdefault('booster', self.booster) optim_wrapper_type = OPTIM_WRAPPERS.get(optim_wrapper['type']) if optim_wrapper_type is None: raise ValueError(f'Failed to find {optim_wrapper["type"]} in ' @@ -318,6 +322,7 @@ def prepare( '`ColossalAIOptimWrapper` (or subclass), but got ' f'{optim_wrapper_type}') optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) + optim_wrapper.booster = self.booster # type: ignore if optim_wrapper is not None: self.model, self.optim_wrapper = self._wrap(