Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support all early stopping options: LambdaEarlyStopping #6909

Closed
wants to merge 3 commits into from

Conversation

carmocca
Copy link
Contributor

@carmocca carmocca commented Apr 8, 2021

What does this PR do?

Fixes #6795

Pitch

inf = torch.tensor(np.Inf)  # our stop fn maximizes

def bounds_fn(score: float, _ = None) -> Tuple[bool, str]:
    return not (10 < score <= 100), "Out of bounds"

def stop_fn(score: float, best: float = inf) -> bool:
    # maximizing - did not improve
    return score < best

# stateful class example
class StdDevUnder:
    def __init__(self, threshold: float, num_values_to_keep: int) -> None:
        assert threshold > 0, "Standard deviation should be a positive value"
        assert num_values_to_keep > 1, (
            "The number of values to keep must be greater than 1 to compute "
            "the standard deviation"
        )
        self._threshold = threshold
        self._num_values_to_keep = num_values_to_keep
        self._values = []
        self._nval = 0

    def __call__(self, score: float, _ = None) -> Union[bool, Tuple[bool, str]]:
        # Add last value to the values
        if self._num_values_to_keep > len(self._values):
            self._values.append(value)
        else:
            self._values[self._nval] = value
            self._nval = (self._nval + 1) % self._num_values_to_keep

        # If not enough values are kept, return False
        if len(self._values) < self._num_values_to_keep:
            return False

        std = torch.std(torch.stack(self._values))
        if std < self._threshold:
            reason = f"Standard deviation {std} < Threshold {self._threshold}"
            return True, reason
        return False

Trainer(callbacks=[
    LambdaEarlyStopping(stop_fn),
    # stop immediately when the target cost is reached or the run failed completely
    LambdaEarlyStopping(bounds_fn, patience=0),
    # stop when the stddev over the last 10 values is lower than a threshold
    LambdaEarlyStopping(StdDevUnder(0.1, 10)),
])

There is no need to pass mode=min|max. The initial best value is provided by the user with the best keyword argument.

The Callable can optionally return a string with the stopping reason. The callable function name is also shown in the "report".

The Callable can be a class callable and keep a state.

Notes

This does not replace the PR in #6868 as we still want to provide that functionality out-of-the-box.

In a future PR

  • Refactor EarlyStopping so it subclasses LambdaEarlyStopping to reuse code: The objective of EarlyStopping will be to provide out-of-the-box stopping functions.

TODO

  • Code: Opened this PR to allow people to read the pitch
  • Docs
  • Tests

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

@carmocca carmocca added feature Is an improvement or enhancement callback labels Apr 8, 2021
@carmocca carmocca added this to the 1.3 milestone Apr 8, 2021
@carmocca carmocca self-assigned this Apr 8, 2021
@pep8speaks
Copy link

Hello @carmocca! Thanks for opening this PR.

Line 154:13: W503 line break before binary operator

Do see the Hitchhiker's guide to code style

@carmocca carmocca added _Will design Includes a design discussion and removed _Will labels Apr 8, 2021
@carmocca carmocca modified the milestones: v1.3, v1.4 May 3, 2021
@edenlightning edenlightning modified the milestones: v1.4, v1.5 Jul 9, 2021
@Borda
Copy link
Member

Borda commented Sep 23, 2021

seem this quite old PR with several handerts of commits behind the master, consider finishing it or closing as most likely the conflicts will make the PR challenging to finish... 🐰

@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 1, 2021
@carmocca carmocca removed this from the 1.6 milestone Mar 28, 2022
@carmocca carmocca closed this Nov 8, 2022
@carmocca carmocca deleted the lambda-early-stopping branch November 8, 2022 16:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
callback design Includes a design discussion feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add more early stopping options
6 participants