Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add NoImprovementHandler #2574

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ignite/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from ignite.engine import Engine
from ignite.engine.events import Events
from ignite.handlers.checkpoint import Checkpoint, DiskSaver, ModelCheckpoint
from ignite.handlers.early_stopping import EarlyStopping
from ignite.handlers.early_stopping import EarlyStopping, NoImprovementHandler
from ignite.handlers.ema_handler import EMAHandler
from ignite.handlers.lr_finder import FastaiLRFinder
from ignite.handlers.param_scheduler import (
Expand Down Expand Up @@ -38,6 +38,7 @@
"Checkpoint",
"DiskSaver",
"Timer",
"NoImprovementHandler",
"EarlyStopping",
"TerminateOnNan",
"global_step_from_engine",
Expand Down
134 changes: 108 additions & 26 deletions ignite/handlers/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,57 @@
from ignite.engine import Engine
from ignite.utils import setup_logger

__all__ = ["EarlyStopping"]
__all__ = ["NoImprovementHandler", "EarlyStopping"]


class EarlyStopping(Serializable):
"""EarlyStopping handler can be used to stop the training if no improvement after a given number of events.

class NoImprovementHandler(Serializable):
"""NoImprovementHandler is a generalised version of Early stopping where you can define what should
happen if no improvement occurs after a given number of events.
Args:
patience: Number of events to wait if no improvement and then stop the training.
patience: Number of events to wait if no improvement and then call stop_function.
score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`
object, and return a score `float`. An improvement is considered if the score is higher.
pass_function: It should be a function taking a single argument, the trainer
object, and defines what to do in the case when the stopping condition is not met.
stop_function: It should be a function taking a single argument, the trainer
object, and defines what to do in the case when the stopping condition is met.
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.

Examples:
.. code-block:: python

from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping
#Example where if the score doesn't improve a user defined value `alpha` is doubled.

from ignite.engine import Engine, Events
from ignite.handlers import NoImprovementHandler
def score_function(engine):
val_loss = engine.state.metrics['nll']
return -val_loss
def pass_function(engine):
pass
def stop_function(trainer):
trainer.state.alpha *= 2

trainer = Engine(do_nothing_update_fn)
trainer.state_dict_user_keys.append("alpha")
trainer.state.alpha = 0.1

h = NoImprovementHandler(patience=3, score_function=score_function, pass_function=pass_function,
stop_function=stop_function, trainer=trainer)

handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)

"""

_state_dict_all_req_keys = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to save NoImprovementHandler's internal state we have to keep _state_dict_all_req_keys, IMO

"counter",
"best_score",
)

def __init__(
self,
patience: int,
score_function: Callable,
pass_function: Callable,
stop_function: Callable,
trainer: Engine,
min_delta: float = 0.0,
cumulative_delta: bool = False,
Expand All @@ -54,37 +64,50 @@ def __init__(
if not callable(score_function):
raise TypeError("Argument score_function should be a function.")

if not callable(pass_function):
raise TypeError("Argument pass_function should be a function.")

if not callable(stop_function):
raise TypeError("Argument stop_function should be a function.")

if not isinstance(trainer, Engine):
raise TypeError("Argument trainer should be an instance of Engine.")

if patience < 1:
raise ValueError("Argument patience should be positive integer.")

if min_delta < 0.0:
raise ValueError("Argument min_delta should not be a negative number.")

if not isinstance(trainer, Engine):
raise TypeError("Argument trainer should be an instance of Engine.")

self.score_function = score_function
self.patience = patience
self.min_delta = min_delta
self.cumulative_delta = cumulative_delta
self.score_function = score_function
self.pass_function = pass_function
self.stop_function = stop_function
self.trainer = trainer
self.counter = 0
self.best_score = None # type: Optional[float]
self.logger = setup_logger(__name__ + "." + self.__class__.__name__)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want also the logger for NoImprovementHandler ?

self.min_delta = min_delta
self.cumulative_delta = cumulative_delta

def __call__(self, engine: Engine) -> None:
score = self.score_function(engine)

self._update_state(score)
sdesrozis marked this conversation as resolved.
Show resolved Hide resolved
if self.counter >= self.patience:
self.stop_function(self.trainer)
else:
self.pass_function(self.trainer)

def _update_state(self, score: int) -> None:

if self.best_score is None:
self.best_score = score

elif score <= self.best_score + self.min_delta:
if not self.cumulative_delta and score > self.best_score:
self.best_score = score
self.counter += 1
self.logger.debug("EarlyStopping: %i / %i" % (self.counter, self.patience))
if self.counter >= self.patience:
self.logger.info("EarlyStopping: Stop training")
self.trainer.terminate()

else:
self.best_score = score
self.counter = 0
Expand All @@ -104,3 +127,62 @@ def load_state_dict(self, state_dict: Mapping) -> None:
super().load_state_dict(state_dict)
self.counter = state_dict["counter"]
self.best_score = state_dict["best_score"]


class EarlyStopping(NoImprovementHandler):
"""EarlyStopping handler can be used to stop the training if no improvement after a given number of events.
Args:
patience: Number of events to wait if no improvement and then stop the training.
score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`
object, and return a score `float`. An improvement is considered if the score is higher.
trainer: Trainer engine to stop the run if no improvement.
min_delta: A minimum increase in the score to qualify as an improvement,
i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
cumulative_delta: It True, `min_delta` defines an increase since the last `patience` reset, otherwise,
it defines an increase after the last event. Default value is False.
Examples:
.. code-block:: python
from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping
def score_function(engine):
val_loss = engine.state.metrics['nll']
return -val_loss
handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
# Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset).
evaluator.add_event_handler(Events.COMPLETED, handler)
"""

_state_dict_all_req_keys = (
"counter",
"best_score",
)

def __init__(
self,
patience: int,
score_function: Callable,
trainer: Engine,
min_delta: float = 0.0,
cumulative_delta: bool = False,
):
super(EarlyStopping, self).__init__(
patience=patience,
score_function=score_function,
pass_function=self.pass_function,
stop_function=self.stop_function,
trainer=trainer,
min_delta=min_delta,
cumulative_delta=cumulative_delta,
)

self.logger = setup_logger(__name__ + "." + self.__class__.__name__)

def __call__(self, engine: Engine) -> None:
super(EarlyStopping, self).__call__(engine)

def pass_function(self, trainer: Engine) -> None:
pass

def stop_function(self, trainer: Engine) -> None:
self.logger.info("EarlyStopping: Stop training")
trainer.terminate()