diff --git a/ignite/handlers/state_param_scheduler.py b/ignite/handlers/state_param_scheduler.py index 18790c3c055..d76766e06f9 100644 --- a/ignite/handlers/state_param_scheduler.py +++ b/ignite/handlers/state_param_scheduler.py @@ -460,21 +460,48 @@ class MultiStepStateScheduler(StateParamScheduler): Examples: - .. code-block:: python + .. testsetup:: - ... - engine = Engine(train_step) + default_trainer = get_default_trainer() + + .. testcode:: param_scheduler = MultiStepStateScheduler( - param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6], + param_name="param", initial_value=1, gamma=0.9, milestones=[3, 6, 9, 12] ) - param_scheduler.attach(engine, Events.EPOCH_COMPLETED) + # parameter is param, initial_value sets param to 1, gamma is set as 0.9 + # Epoch 1 to 2, param does not change as milestone is 3 + # Epoch 3, param changes from 1 to 1*0.9, param = 0.9 + # Epoch 3 to 5, param does not change as milestone is 6 + # Epoch 6, param changes from 0.9 to 0.9*0.9, param = 0.81 + # Epoch 6 to 8, param does not change as milestone is 9 + # Epoch 9, param changes from 0.81 to 0.81*0.9, param = 0.729 + # Epoch 9 to 11, param does not change as milestone is 12 + # Epoch 12, param changes from 0.729 to 0.729*0.9, param = 0.6561 + + param_scheduler.attach(default_trainer, Events.EPOCH_COMPLETED) + + @default_trainer.on(Events.EPOCH_COMPLETED) + def print_param(): + print(default_trainer.state.param) + + default_trainer.run([0], max_epochs=12) - # basic handler to print scheduled state parameter - engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _ : print(engine.state.param)) + .. testoutput:: - engine.run([0] * 8, max_epochs=10) + 1.0 + 1.0 + 0.9 + 0.9 + 0.9 + 0.81 + 0.81 + 0.81 + 0.7290... + 0.7290... + 0.7290... + 0.6561 .. versionadded:: 0.5.0