diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f9ae66bb0..5b3fda00c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `force` and `bert-level` keywords to `catalyst-data text2embedding` ([#917](https://github.com/catalyst-team/catalyst/pull/917)) - `OptunaCallback` to `catalyst.contrib` ([#915](https://github.com/catalyst-team/catalyst/pull/915)) +- Multi-scheduler support for multi-optimizer case ([#923](https://github.com/catalyst-team/catalyst/pull/923)) ### Changed diff --git a/catalyst/dl/experiment/config.py b/catalyst/dl/experiment/config.py index bf229c559c..caf57838f7 100644 --- a/catalyst/dl/experiment/config.py +++ b/catalyst/dl/experiment/config.py @@ -343,31 +343,31 @@ def get_optimizer( return optimizer @staticmethod - def _get_scheduler(*, optimizer, **params): + def _get_scheduler( + *, optimizer: Union[Optimizer, Dict[str, Optimizer]], **params: Any + ) -> Union[Scheduler, Dict[str, Scheduler]]: + optimizer_key = params.pop("_optimizer", None) + optimizer_ = optimizer[optimizer_key] if optimizer_key else optimizer + scheduler = SCHEDULERS.get_from_params(**params, optimizer=optimizer_) + + return scheduler + + def get_scheduler( + self, stage: str, optimizer: Union[Optimizer, Dict[str, Optimizer]] + ) -> Union[Scheduler, Dict[str, Scheduler]]: + """Returns the scheduler for a given stage.""" + params = self.stages_config[stage].get("scheduler_params", {}) key_value_flag = params.pop("_key_value", False) if key_value_flag: - scheduler = {} - for scheduler_key, scheduler_params in params.items(): - scheduler[ - scheduler_key - ] = ConfigExperiment._get_scheduler( # noqa: WPS437 + scheduler: Dict[str, Scheduler] = {} + for key, scheduler_params in params.items(): + scheduler[key] = self._get_scheduler( optimizer=optimizer, **scheduler_params ) else: - scheduler = SCHEDULERS.get_from_params( - **params, optimizer=optimizer - ) - return scheduler + scheduler = self._get_scheduler(optimizer=optimizer, **params) - def get_scheduler(self, stage: str, optimizer: Optimizer) -> Scheduler: - """Returns the scheduler for a given stage.""" - scheduler_params = self.stages_config[stage].get( - "scheduler_params", {} - ) - scheduler = self._get_scheduler( - optimizer=optimizer, **scheduler_params - ) return scheduler @staticmethod