From 46cddf875daf8e79467b05ca3f1ccc227fccd451 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 11 Aug 2022 21:42:37 +0000 Subject: [PATCH 1/7] Added test_engine_run_resume --- tests/ignite/engine/test_engine.py | 63 ++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 3bcdb522350..d945ab7cf9e 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1014,3 +1014,66 @@ def train_step(engine, batch): assert trainer.state.iteration == 20 * 10 assert trainer.state.epoch == 20 assert trainer.state.dataloader is None + + +@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) +def test_engine_run_resume(data, epoch_length): + # https://github.com/pytorch/ignite/wiki/Roadmap#runresume-logic-improvements + engine = Engine(lambda e, b: None) + real_epoch_length = len(data) if data is not None else epoch_length + + first_epoch_iter = [None, None] + + @engine.on(Events.STARTED, first_epoch_iter) + def check_iter_epoch(first_epoch_iter): + assert engine.state.epoch == first_epoch_iter[0] + assert engine.state.iteration == first_epoch_iter[1] + + # (re)start from 0 to 5 + first_epoch_iter[0], first_epoch_iter[1] = 0, 0 + # Engine run starting with max_epochs=5 => state.epoch=5 + engine.run(data, max_epochs=5, epoch_length=epoch_length) + assert engine.state.epoch == 5 + assert engine.state.iteration == 5 * real_epoch_length + + # continue from 5 to 7 + first_epoch_iter[0], first_epoch_iter[1] = 5, 5 * real_epoch_length + # Engine run resuming from iteration 50, epoch 5 until 7 epochs => state.epoch=7 + engine.run(data, max_epochs=7, epoch_length=epoch_length) + assert engine.state.epoch == 7 + assert engine.state.iteration == 7 * real_epoch_length + + # error + with pytest.raises(ValueError, match="Argument max_epochs should be larger than the start epoch"): + engine.run(data, max_epochs=4, epoch_length=epoch_length) + + # restart from 0 to 7 (As state.epoch == max_epochs(=7), + # this should be like that as we always do: evaluator.run(data) without any other instructions) + first_epoch_iter[0], first_epoch_iter[1] = 0, 0 + # Engine run starting with max_epochs=7 => state.epoch=7 + engine.run(data, max_epochs=7, epoch_length=epoch_length) + assert engine.state.epoch == 7 + assert engine.state.iteration == 7 * real_epoch_length + + # forced restart from 0 to 5 + engine.state.max_epochs = None + first_epoch_iter[0], first_epoch_iter[1] = 0, 0 + # Engine run starting with max_epochs=5 => state.epoch=5 + engine.run(data, max_epochs=5, epoch_length=epoch_length) + assert engine.state.epoch == 5 + assert engine.state.iteration == 5 * real_epoch_length + + # forced restart from 0 to 9, instead of continue from state.epoch=5 + engine.state.max_epochs = None + first_epoch_iter[0], first_epoch_iter[1] = 0, 0 + # Engine run starting with max_epochs=9 => state.epoch=9 + engine.run(data, max_epochs=9, epoch_length=epoch_length) + assert engine.state.epoch == 9 + assert engine.state.iteration == 9 * real_epoch_length + + # continue from 9 until 10 + first_epoch_iter[0], first_epoch_iter[1] = 9, 9 * real_epoch_length + # Engine run resuming from iteration 90, epoch 9 until 10 epochs => state.epoch=10 + engine.run(data, max_epochs=10, epoch_length=epoch_length) + assert engine.state.epoch == 10 + assert engine.state.iteration == 10 * real_epoch_length From f027a12db0f2b9fd6540d13c5e79269406141e3e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 11 Aug 2022 23:55:40 +0000 Subject: [PATCH 2/7] Terminate/Terminate Single Epoch work on all EPOCH/ITERATION events --- ignite/engine/engine.py | 32 +++++++---- tests/ignite/engine/test_engine.py | 87 ++++++++++++++++++++++++++---- 2 files changed, 97 insertions(+), 22 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 86e90763838..a6aaa7fba0b 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -757,31 +757,39 @@ def _internal_run(self) -> State: try: start_time = time.time() self._fire_event(Events.STARTED) + if self.should_terminate: + self._fire_event(Events.TERMINATE) + while not self._is_done(self.state) and not self.should_terminate: self.state.epoch += 1 + handlers_start_time = time.time() self._fire_event(Events.EPOCH_STARTED) + epoch_time_taken = time.time() - handlers_start_time + + if not self.should_terminate: + if self._dataloader_iter is None: + self._setup_engine() - if self._dataloader_iter is None: - self._setup_engine() + epoch_time_taken += self._run_once_on_dataset() + + # time is available for handlers but must be updated after fire + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - time_taken = self._run_once_on_dataset() - # time is available for handlers but must be update after fire - self.state.times[Events.EPOCH_COMPLETED.name] = time_taken handlers_start_time = time.time() if self.should_terminate: self._fire_event(Events.TERMINATE) else: self._fire_event(Events.EPOCH_COMPLETED) - time_taken += time.time() - handlers_start_time + epoch_time_taken += time.time() - handlers_start_time # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = time_taken - hours, mins, secs = _to_hours_mins_secs(time_taken) + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:02d}") if self.should_terminate: break time_taken = time.time() - start_time - # time is available for handlers but must be update after fire + # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken handlers_start_time = time.time() self._fire_event(Events.COMPLETED) @@ -856,8 +864,10 @@ def _run_once_on_dataset(self) -> float: self.state.iteration += 1 self._fire_event(Events.ITERATION_STARTED) - self.state.output = self._process_function(self, self.state.batch) - self._fire_event(Events.ITERATION_COMPLETED) + + if not (self.should_terminate or self.should_terminate_single_epoch): + self.state.output = self._process_function(self, self.state.batch) + self._fire_event(Events.ITERATION_COMPLETED) if self.should_terminate or self.should_terminate_single_epoch: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index d945ab7cf9e..b82ccbe95d4 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -103,11 +103,11 @@ def end_of_epoch_handler(engine): assert engine.should_terminate -@pytest.mark.parametrize("data", [None, [1, 2, 3]]) -def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data): +@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) +def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data, epoch_length): max_epochs = 5 epoch_to_terminate_on = 3 - epoch_length = 3 + real_epoch_length = epoch_length if data is None else len(data) engine = Engine(MagicMock(return_value=1)) @@ -124,14 +124,47 @@ def start_of_epoch_handler(engine): # epoch is not completed so counter is not incremented assert state.epoch == epoch_to_terminate_on assert engine.should_terminate - # completes first iteration - assert state.iteration == ((epoch_to_terminate_on - 1) * epoch_length) + 1 + assert state.iteration == ((epoch_to_terminate_on - 1) * real_epoch_length) + # Engine continue from epoch_to_terminate_on until max_epochs + first_epoch_iter = [None, None] -@pytest.mark.parametrize("data", [None, list(range(10))]) -def test_terminate_stops_run_mid_epoch(data): - num_iterations_per_epoch = len(data) if data is not None else 10 - iteration_to_stop = num_iterations_per_epoch + 3 + @engine.on(Events.STARTED) + def check_iter_epoch(): + assert engine.state.epoch == first_epoch_iter[0] + assert engine.state.iteration == first_epoch_iter[1] + + if data is not None: + expected_data_iter = iter(data) + expected_iter = state.iteration + + @engine.on(Events.ITERATION_STARTED) + def check_iter_and_data(): + nonlocal expected_data_iter, expected_iter + + expected_iter += 1 + assert engine.state.iteration == expected_iter + + try: + assert engine.state.batch == next(expected_data_iter) + except StopIteration: + expected_data_iter = iter(data) + assert engine.state.batch == next(expected_data_iter) + + first_epoch_iter[0], first_epoch_iter[1] = state.epoch, state.iteration + state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + + assert state.epoch == max_epochs + assert not engine.should_terminate + # As terminated epoch is skipped -> iterations are not incremented + assert state.iteration == real_epoch_length * (max_epochs - 1) + + +@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) +def test_terminate_stops_run_mid_epoch(data, epoch_length): + max_epochs = 5 + iteration_to_stop = 13 + real_epoch_length = epoch_length if data is None else len(data) engine = Engine(MagicMock(return_value=1)) @@ -140,10 +173,42 @@ def start_of_iteration_handler(engine): engine.terminate() engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) - state = engine.run(data, max_epochs=3, epoch_length=num_iterations_per_epoch) + state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) assert state.iteration == iteration_to_stop - assert state.epoch == np.ceil(iteration_to_stop / num_iterations_per_epoch) # it starts from 0 + assert state.epoch == np.ceil(iteration_to_stop / real_epoch_length) # it starts from 0 + + # Engine continue from epoch_to_terminate_on until max_epochs + first_epoch_iter = [None, None] + + @engine.on(Events.STARTED, first_epoch_iter) + def check_iter_epoch(first_epoch_iter): + assert engine.state.epoch == first_epoch_iter[0] + assert engine.state.iteration == first_epoch_iter[1] + + if data is not None: + expected_data_iter = iter(data) + expected_iter = state.iteration + + @engine.on(Events.ITERATION_STARTED) + def check_iter_and_data(): + nonlocal expected_data_iter, expected_iter + + expected_iter += 1 + assert engine.state.iteration == expected_iter + + try: + assert engine.state.batch == next(expected_data_iter) + except StopIteration: + expected_data_iter = iter(data) + assert engine.state.batch == next(expected_data_iter) + + first_epoch_iter[0], first_epoch_iter[1] = state.epoch, state.iteration + state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + + assert state.epoch == max_epochs + assert not engine.should_terminate + assert state.iteration == real_epoch_length * (max_epochs - 1) @pytest.mark.parametrize("data", [None, list(range(10))]) From 8c36bfaae7faeaf6be7709adb0be0ced10596a71 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 15 Aug 2022 23:04:14 +0000 Subject: [PATCH 3/7] - terminate() work on all events, called on catched _EngineTerminateException - terminate_epoch work on iteration-based events, called on catched _EngineTerminateSingleEpochExpection - Fixed issue when attaching handlers on Events.TERMINATE_SINGLE_EPOCH --- ignite/engine/engine.py | 105 ++++++++++++++++++--------- tests/ignite/engine/test_engine.py | 109 +++++++++++++++++++++++++---- 2 files changed, 165 insertions(+), 49 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index a6aaa7fba0b..1cf68b375dc 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -301,7 +301,12 @@ def execute_something(): self._assert_allowed_event(event_name) - event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else () + event_args = () # type: Tuple[Any, ...] + if event_name == Events.EXCEPTION_RAISED: + event_args += (Exception(),) + elif event_name == Events.TERMINATE_SINGLE_EPOCH: + event_args += (0,) + try: _check_signature(handler, "handler", self, *(event_args + args), **kwargs) self._event_handlers[event_name].append((handler, (self,) + args, kwargs)) @@ -755,38 +760,40 @@ def _internal_run(self) -> State: self.should_terminate = self.should_terminate_single_epoch = False self._init_timers(self.state) try: - start_time = time.time() - self._fire_event(Events.STARTED) - if self.should_terminate: - self._fire_event(Events.TERMINATE) + try: + start_time = time.time() + self._fire_event(Events.STARTED) + self._maybe_terminate() + + while not self._is_done(self.state) and not self.should_terminate: + self.state.epoch += 1 + handlers_start_time = time.time() + self._fire_event(Events.EPOCH_STARTED) + epoch_time_taken = time.time() - handlers_start_time + self._maybe_terminate() - while not self._is_done(self.state) and not self.should_terminate: - self.state.epoch += 1 - handlers_start_time = time.time() - self._fire_event(Events.EPOCH_STARTED) - epoch_time_taken = time.time() - handlers_start_time - - if not self.should_terminate: if self._dataloader_iter is None: self._setup_engine() epoch_time_taken += self._run_once_on_dataset() - # time is available for handlers but must be updated after fire - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + # time is available for handlers but must be updated after fire + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - handlers_start_time = time.time() - if self.should_terminate: - self._fire_event(Events.TERMINATE) - else: + handlers_start_time = time.time() self._fire_event(Events.EPOCH_COMPLETED) - epoch_time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken - hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) - self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:02d}") - if self.should_terminate: - break + epoch_time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken + self._maybe_terminate() + + hours, mins, secs = _to_hours_mins_secs(epoch_time_taken) + self.logger.info( + f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:02d}" + ) + + except _EngineTerminateException: + self._fire_event(Events.TERMINATE) time_taken = time.time() - start_time # time is available for handlers but must be updated after fire @@ -807,6 +814,13 @@ def _internal_run(self) -> State: self._dataloader_iter = None return self.state + def _maybe_terminate(self) -> None: + if self.should_terminate: + raise _EngineTerminateException() + + if self.should_terminate_single_epoch: + raise _EngineTerminateSingleEpochException() + def _run_once_on_dataset(self) -> float: start_time = time.time() @@ -826,8 +840,12 @@ def _run_once_on_dataset(self) -> float: # Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted if self.last_event_name != Events.DATALOADER_STOP_ITERATION: self._fire_event(Events.GET_BATCH_STARTED) + self._maybe_terminate() + self.state.batch = next(self._dataloader_iter) self._fire_event(Events.GET_BATCH_COMPLETED) + self._maybe_terminate() + iter_counter += 1 should_exit = False except StopIteration: @@ -856,24 +874,20 @@ def _run_once_on_dataset(self) -> float: break self._fire_event(Events.DATALOADER_STOP_ITERATION) - self._setup_dataloader_iter() + self._maybe_terminate() + self._setup_dataloader_iter() should_exit = True continue self.state.iteration += 1 self._fire_event(Events.ITERATION_STARTED) + self._maybe_terminate() - if not (self.should_terminate or self.should_terminate_single_epoch): - self.state.output = self._process_function(self, self.state.batch) - self._fire_event(Events.ITERATION_COMPLETED) - - if self.should_terminate or self.should_terminate_single_epoch: - self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) - self.should_terminate_single_epoch = False - self._setup_dataloader_iter() - break + self.state.output = self._process_function(self, self.state.batch) + self._fire_event(Events.ITERATION_COMPLETED) + self._maybe_terminate() if self.state.epoch_length is not None and iter_counter == self.state.epoch_length: break @@ -882,6 +896,11 @@ def _run_once_on_dataset(self) -> float: self.should_terminate = True break + except _EngineTerminateSingleEpochException: + self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) + self.should_terminate_single_epoch = False + self._setup_dataloader_iter() + except Exception as e: self.logger.error(f"Current run is terminating due to exception: {e}") self._handle_exception(e) @@ -893,3 +912,19 @@ def _get_none_data_iter(size: int) -> Iterator: # Sized iterator for data as None for _ in range(size): yield None + + +class _EngineTerminateSingleEpochException(Exception): + """ + Exception associated with Terminate Single Epoch event + """ + + pass + + +class _EngineTerminateException(Exception): + """ + Exception associated with Terminate event + """ + + pass diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index b82ccbe95d4..18058c14621 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -187,21 +187,15 @@ def check_iter_epoch(first_epoch_iter): assert engine.state.iteration == first_epoch_iter[1] if data is not None: - expected_data_iter = iter(data) expected_iter = state.iteration @engine.on(Events.ITERATION_STARTED) def check_iter_and_data(): - nonlocal expected_data_iter, expected_iter + nonlocal expected_iter expected_iter += 1 assert engine.state.iteration == expected_iter - - try: - assert engine.state.batch == next(expected_data_iter) - except StopIteration: - expected_data_iter = iter(data) - assert engine.state.batch == next(expected_data_iter) + assert engine.state.batch == data[(expected_iter - first_epoch_iter[1] - 1) % len(data)] first_epoch_iter[0], first_epoch_iter[1] = state.epoch, state.iteration state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) @@ -211,10 +205,58 @@ def check_iter_and_data(): assert state.iteration == real_epoch_length * (max_epochs - 1) -@pytest.mark.parametrize("data", [None, list(range(10))]) -def test_terminate_epoch_stops_mid_epoch(data): - num_iterations_per_epoch = len(data) if data is not None else 10 - iteration_to_stop = num_iterations_per_epoch + 4 +class RecordedEngine(Engine): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.called_events = [] + + def _fire_event(self, event_name, *event_args, **event_kwargs): + self.called_events.append((self.state.epoch, self.state.iteration, event_name.name)) + return super()._fire_event(event_name, *event_args, **event_kwargs) + + +@pytest.mark.parametrize( + "terminate_event, e, i", + [ + (Events.STARTED, 0, 0), + (Events.EPOCH_STARTED(once=2), 2, None), + (Events.EPOCH_COMPLETED(once=2), 2, None), + (Events.GET_BATCH_STARTED(once=12), None, 12), + (Events.GET_BATCH_COMPLETED(once=12), None, 12), + (Events.ITERATION_STARTED(once=14), None, 14), + (Events.ITERATION_COMPLETED(once=14), None, 14), + ], +) +def test_terminate_events_sequence(terminate_event, e, i): + engine = RecordedEngine(MagicMock(return_value=1)) + data = range(10) + max_epochs = 5 + + @engine.on(terminate_event) + def call_terminate(): + engine.terminate() + + engine.run(data, max_epochs=max_epochs) + + if i is None: + if terminate_event == Events.EPOCH_STARTED: + i = len(data) * (e - 1) + else: + i = len(data) * e + + if e is None: + e = i // len(data) + 1 + + assert engine.called_events[0] == (0, 0, Events.STARTED) + assert engine.called_events[-1] == (e, i, Events.COMPLETED) + assert engine.called_events[-2] == (e, i, Events.TERMINATE) + assert engine.called_events[-3] == (e, i, terminate_event) + + +@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) +def test_terminate_epoch_stops_mid_epoch(data, epoch_length): + real_epoch_length = epoch_length if data is None else len(data) + iteration_to_stop = real_epoch_length + 4 engine = Engine(MagicMock(return_value=1)) @@ -224,10 +266,49 @@ def start_of_iteration_handler(engine): max_epochs = 3 engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) - state = engine.run(data, max_epochs=max_epochs, epoch_length=num_iterations_per_epoch) + state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) - true_value = num_iterations_per_epoch * (max_epochs - 1) + iteration_to_stop % num_iterations_per_epoch + true_value = real_epoch_length * (max_epochs - 1) + iteration_to_stop % real_epoch_length assert state.iteration == true_value + assert state.epoch == max_epochs + + +@pytest.mark.parametrize( + "terminate_epoch_event, i", + [ + (Events.GET_BATCH_STARTED(once=12), 12), + (Events.GET_BATCH_COMPLETED(once=12), 12), + (Events.ITERATION_STARTED(once=14), 14), + (Events.ITERATION_COMPLETED(once=14), 14), + ], +) +def test_terminate_epoch_events_sequence(terminate_epoch_event, i): + engine = RecordedEngine(MagicMock(return_value=1)) + data = range(10) + max_epochs = 5 + + # TODO: Bug: Events.GET_BATCH_STARTED(once=12) is called twice ! + # prevent call_terminate_epoch to be called twice + call_count = 0 + + @engine.on(terminate_epoch_event) + def call_terminate_epoch(): + nonlocal call_count + if call_count < 1: + engine.terminate_epoch() + call_count += 1 + + @engine.on(Events.TERMINATE_SINGLE_EPOCH) + def check_previous_events(iter_counter): + e = i // len(data) + 1 + + print("engine.called_events:", engine.called_events) + + assert engine.called_events[0] == (0, 0, Events.STARTED) + assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + assert engine.called_events[-2] == (e, i, terminate_epoch_event) + + engine.run(data, max_epochs=max_epochs) def _create_mock_data_loader(epochs, batches_per_epoch): From 4dc0a3604c3f76ddd9ffa83305110e71285edba7 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 15 Aug 2022 23:11:34 +0000 Subject: [PATCH 4/7] Updated docstring --- ignite/engine/events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/engine/events.py b/ignite/engine/events.py index a8292921579..d71c1bd4501 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -299,7 +299,7 @@ class CustomEvents(EventEnum): """triggered when the run is about to end completely, after receiving terminate() call.""" TERMINATE_SINGLE_EPOCH = "terminate_single_epoch" """triggered when the run is about to end the current epoch, - after receiving a terminate_epoch() or terminate() call.""" + after receiving a terminate_epoch() call.""" def __or__(self, other: Any) -> "EventsList": return EventsList() | self | other From 24ab65fe839e884a2b2bff6500cdbebed78c2e26 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 16 Aug 2022 06:23:13 +0000 Subject: [PATCH 5/7] Fixed issue with max_iters handling --- ignite/engine/engine.py | 2 +- tests/ignite/engine/test_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 31fe6a7271f..f67e210694e 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -881,7 +881,7 @@ def _run_once_on_dataset(self) -> float: if self.state.max_iters is not None and self.state.iteration == self.state.max_iters: self.should_terminate = True - break + raise _EngineTerminateException() except _EngineTerminateSingleEpochException: self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 18058c14621..cfe637a37b0 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -1060,7 +1060,7 @@ def test_run_with_invalid_max_iters_and_max_epoch(): engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs) -def test_epoch_events_fired(): +def test_epoch_events_fired_max_iters(): max_iters = 32 engine = Engine(lambda e, b: 1) From 5365d42f89a4dabdf95413c1b6004872632dd0f8 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 16 Aug 2022 20:57:35 +0000 Subject: [PATCH 6/7] Fixed issue with _EngineTerminateException handled as a general exception --- ignite/engine/engine.py | 5 +++++ tests/ignite/engine/test_engine.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index f67e210694e..9bc10c7fccb 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -888,6 +888,11 @@ def _run_once_on_dataset(self) -> float: self.should_terminate_single_epoch = False self._setup_dataloader_iter() + except _EngineTerminateException as e: + # we need to reraise this exception such that it is not handled + # as a general exception by the code below + raise e + except Exception as e: self.logger.error(f"Current run is terminating due to exception: {e}") self._handle_exception(e) diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index cfe637a37b0..b038a3ce7ca 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -172,6 +172,10 @@ def start_of_iteration_handler(engine): if engine.state.iteration == iteration_to_stop: engine.terminate() + @engine.on(Events.EXCEPTION_RAISED) + def assert_no_exceptions(ee): + assert False, f"Engine should terminate without raising an exception, got '{type(ee)}'" + engine.add_event_handler(Events.ITERATION_STARTED, start_of_iteration_handler) state = engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) # completes the iteration but doesn't increment counter (this happens just before a new iteration starts) @@ -236,6 +240,10 @@ def test_terminate_events_sequence(terminate_event, e, i): def call_terminate(): engine.terminate() + @engine.on(Events.EXCEPTION_RAISED) + def assert_no_exceptions(ee): + assert False, f"Engine should terminate without raising an exception, got '{type(ee)}'" + engine.run(data, max_epochs=max_epochs) if i is None: From d350ff51e2f13c8192721200f25960c2e352b5d3 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 22 Aug 2022 20:43:07 +0000 Subject: [PATCH 7/7] Updated tests and docs --- ignite/engine/engine.py | 22 ++++++++++++++++++---- tests/ignite/engine/test_engine.py | 19 ++++++++++++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index ab9e137b59f..3bf935eea6d 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -438,14 +438,28 @@ def fire_event(self, event_name: Any) -> None: return self._fire_event(event_name) def terminate(self) -> None: - """Sends terminate signal to the engine, so that it terminates completely the run after - the current iteration.""" + """Sends terminate signal to the engine, so that it terminates completely the run. The run is + terminated after the event on which ``terminate`` method was called. The following events are triggered: + + - ... + - Terminating event + - :attr:`~ignite.engine.events.Events.TERMINATE` + - :attr:`~ignite.engine.events.Events.COMPLETED` + """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True def terminate_epoch(self) -> None: - """Sends terminate signal to the engine, so that it terminates the current epoch - after the current iteration.""" + """Sends terminate signal to the engine, so that it terminates the current epoch. The run + continues from the next epoch. The following events are triggered: + + - ... + - Event on which ``terminate_epoch`` method is called + - :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH` + - :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED` + - :attr:`~ignite.engine.events.Events.EPOCH_STARTED` + - ... + """ self.logger.info( "Terminate current epoch is signaled. " "Current epoch iteration will stop after current iteration is finished." diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index b038a3ce7ca..8d7d21f686f 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -104,7 +104,7 @@ def end_of_epoch_handler(engine): @pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)]) -def test_terminate_at_start_of_epoch_stops_run_after_completing_iteration(data, epoch_length): +def test_terminate_at_start_of_epoch(data, epoch_length): max_epochs = 5 epoch_to_terminate_on = 3 real_epoch_length = epoch_length if data is None else len(data) @@ -293,7 +293,7 @@ def start_of_iteration_handler(engine): def test_terminate_epoch_events_sequence(terminate_epoch_event, i): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) - max_epochs = 5 + max_epochs = 3 # TODO: Bug: Events.GET_BATCH_STARTED(once=12) is called twice ! # prevent call_terminate_epoch to be called twice @@ -310,14 +310,23 @@ def call_terminate_epoch(): def check_previous_events(iter_counter): e = i // len(data) + 1 - print("engine.called_events:", engine.called_events) - assert engine.called_events[0] == (0, 0, Events.STARTED) - assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) assert engine.called_events[-2] == (e, i, terminate_epoch_event) + assert engine.called_events[-1] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + + @engine.on(Events.EPOCH_COMPLETED) + def check_previous_events2(): + e = i // len(data) + 1 + if e == engine.state.epoch and i == engine.state.iteration: + assert engine.called_events[-3] == (e, i, terminate_epoch_event) + assert engine.called_events[-2] == (e, i, Events.TERMINATE_SINGLE_EPOCH) + assert engine.called_events[-1] == (e, i, Events.EPOCH_COMPLETED) engine.run(data, max_epochs=max_epochs) + assert engine.state.epoch == max_epochs + assert (max_epochs - 1) * len(data) < engine.state.iteration < max_epochs * len(data) + def _create_mock_data_loader(epochs, batches_per_epoch): batches = [MagicMock()] * batches_per_epoch