Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewritten Engine's terminate and terminate_epoch logic #2645

Merged
merged 13 commits into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 101 additions & 37 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,12 @@ def execute_something():

self._assert_allowed_event(event_name)

event_args = (Exception(),) if event_name == Events.EXCEPTION_RAISED else ()
event_args = () # type: Tuple[Any, ...]
if event_name == Events.EXCEPTION_RAISED:
event_args += (Exception(),)
elif event_name == Events.TERMINATE_SINGLE_EPOCH:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
event_args += (0,)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved

try:
_check_signature(handler, "handler", self, *(event_args + args), **kwargs)
self._event_handlers[event_name].append((handler, (self,) + args, kwargs))
Expand Down Expand Up @@ -433,14 +438,28 @@ def fire_event(self, event_name: Any) -> None:
return self._fire_event(event_name)

def terminate(self) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run after
the current iteration."""
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
terminated after the event on which ``terminate`` method was called. The following events are triggered:

- ...
- Terminating event
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch
after the current iteration."""
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
continues from the next epoch. The following events are triggered:

- ...
- Event on which ``terminate_epoch`` method is called
- :attr:`~ignite.engine.events.Events.TERMINATE_SINGLE_EPOCH`
- :attr:`~ignite.engine.events.Events.EPOCH_COMPLETED`
- :attr:`~ignite.engine.events.Events.EPOCH_STARTED`
- ...
"""
self.logger.info(
"Terminate current epoch is signaled. "
"Current epoch iteration will stop after current iteration is finished."
Expand Down Expand Up @@ -742,33 +761,43 @@ def _internal_run(self) -> State:
self.should_terminate = self.should_terminate_single_epoch = False
self._init_timers(self.state)
try:
start_time = time.time()
self._fire_event(Events.STARTED)
while not self._is_done(self.state) and not self.should_terminate:
self.state.epoch += 1
self._fire_event(Events.EPOCH_STARTED)

if self._dataloader_iter is None:
self._setup_engine()

time_taken = self._run_once_on_dataset()
# time is available for handlers but must be update after fire
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
handlers_start_time = time.time()
if self.should_terminate:
self._fire_event(Events.TERMINATE)
else:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
self._maybe_terminate()

while not self._is_done(self.state) and not self.should_terminate:
self.state.epoch += 1
handlers_start_time = time.time()
self._fire_event(Events.EPOCH_STARTED)
epoch_time_taken = time.time() - handlers_start_time
self._maybe_terminate()

if self._dataloader_iter is None:
self._setup_engine()

epoch_time_taken += self._run_once_on_dataset()

# time is available for handlers but must be updated after fire
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken

handlers_start_time = time.time()
self._fire_event(Events.EPOCH_COMPLETED)
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = time_taken
hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
if self.should_terminate:
break
epoch_time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
self._maybe_terminate()

hours, mins, secs = _to_hours_mins_secs(epoch_time_taken)
self.logger.info(
f"Epoch[{self.state.epoch}] Complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}"
)

except _EngineTerminateException:
self._fire_event(Events.TERMINATE)

time_taken = time.time() - start_time
# time is available for handlers but must be update after fire
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
Expand All @@ -786,6 +815,13 @@ def _internal_run(self) -> State:
self._dataloader_iter = None
return self.state

def _maybe_terminate(self) -> None:
if self.should_terminate:
raise _EngineTerminateException()

if self.should_terminate_single_epoch:
raise _EngineTerminateSingleEpochException()

def _run_once_on_dataset(self) -> float:
start_time = time.time()

Expand All @@ -805,8 +841,12 @@ def _run_once_on_dataset(self) -> float:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
if self.last_event_name != Events.DATALOADER_STOP_ITERATION:
self._fire_event(Events.GET_BATCH_STARTED)
self._maybe_terminate()

self.state.batch = next(self._dataloader_iter)
self._fire_event(Events.GET_BATCH_COMPLETED)
self._maybe_terminate()

iter_counter += 1
should_exit = False
except StopIteration:
Expand Down Expand Up @@ -835,29 +875,37 @@ def _run_once_on_dataset(self) -> float:
break

self._fire_event(Events.DATALOADER_STOP_ITERATION)
self._setup_dataloader_iter()
self._maybe_terminate()

self._setup_dataloader_iter()
should_exit = True

continue

self.state.iteration += 1
self._fire_event(Events.ITERATION_STARTED)
self._maybe_terminate()
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved

self.state.output = self._process_function(self, self.state.batch)
self._fire_event(Events.ITERATION_COMPLETED)

if self.should_terminate or self.should_terminate_single_epoch:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()
break
self._maybe_terminate()

if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
break

if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
self.should_terminate = True
break
raise _EngineTerminateException()

except _EngineTerminateSingleEpochException:
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
self.should_terminate_single_epoch = False
self._setup_dataloader_iter()

except _EngineTerminateException as e:
# we need to reraise this exception such that it is not handled
# as a general exception by the code below
raise e

except Exception as e:
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
self.logger.error(f"Current run is terminating due to exception: {e}")
Expand All @@ -870,3 +918,19 @@ def _get_none_data_iter(size: int) -> Iterator:
# Sized iterator for data as None
for _ in range(size):
yield None


class _EngineTerminateSingleEpochException(Exception):
"""
Exception associated with Terminate Single Epoch event
"""

pass


class _EngineTerminateException(Exception):
"""
Exception associated with Terminate event
"""

pass
2 changes: 1 addition & 1 deletion ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class CustomEvents(EventEnum):
"""triggered when the run is about to end completely, after receiving terminate() call."""
TERMINATE_SINGLE_EPOCH = "terminate_single_epoch"
"""triggered when the run is about to end the current epoch,
after receiving a terminate_epoch() or terminate() call."""
after receiving a terminate_epoch() call."""

def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other
Expand Down
Loading