Skip to content

Commit

Permalink
[Tune; RLlib] Missing stopping criterion should not error (just warn). (
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored Jun 9, 2024
1 parent 7928ca5 commit 8b89a7b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 38 deletions.
20 changes: 12 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ray.tune.trainable.metadata import _TrainingRunMetadata
from ray.tune.utils import date_str, flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.util import log_once
from ray.util.annotations import Deprecated, DeveloperAPI

DEBUG_PRINT_INTERVAL = 5
Expand Down Expand Up @@ -851,18 +852,21 @@ def should_stop(self, result):
if result.get(DONE):
return True

for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
"Stopping criteria {} not provided in result dict. Keys "
"are {}.".format(criteria, list(result.keys()))
)
elif isinstance(criteria, dict):
for criterion, stop_value in self.stopping_criterion.items():
if isinstance(criterion, dict):
raise ValueError(
"Stopping criteria is now flattened by default. "
"Use forward slashes to nest values `key1/key2/key3`."
)
elif result[criteria] >= stop_value:
elif criterion not in result:
if log_once("tune_trial_stop_criterion_not_found"):
logger.warning(
f"Stopping criterion '{criterion}' not found in result dict! "
f"Available keys are {list(result.keys())}. If '{criterion}' is"
" never reported, the run will continue until training is "
"finished."
)
elif result[criterion] >= stop_value:
return True
return False

Expand Down
47 changes: 17 additions & 30 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,43 +423,23 @@ def f():

self.assertRaises(TuneError, f)

def testBadParams5(self):
def f():
run_experiments({"foo": {"run": "__fake", "stop": {"asdf": 1}}})

self.assertRaises(TuneError, f)

def testBadParams6(self):
def f():
run_experiments({"foo": {"run": "PPO", "resources_per_trial": {"asdf": 1}}})

self.assertRaises(TuneError, f)

def testBadStoppingReturn(self):
def train_fn(config):
train.report(dict(a=1))

register_trainable("f1", train_fn)

def f():
run_experiments(
{
"foo": {
"run": "f1",
"stop": {"time": 10},
}
}
)

self.assertRaises(TuneError, f)

def testNestedStoppingReturn(self):
def train_fn(config):
for i in range(10):
train.report(dict(test={"test1": {"test2": i}}))

with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"test": {"test1": {"test2": 6}}}).trials
[trial] = tune.run(train_fn, stop={"test": {"test1": {"test2": 6}}}).trials
self.assertTrue(
"test" in trial.last_result
and "test1" in trial.last_result["test"]
and "test2" in trial.last_result["test"]["test1"]
)
[trial] = tune.run(train_fn, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)

Expand Down Expand Up @@ -1636,10 +1616,17 @@ def train_fn(config):
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in algo.results)
)
with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"1/2/3": 20})
with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"test": 1}).trials
# Test, whether non-existent stop criteria do NOT cause an error anymore (just
# a warning).
[trial] = tune.run(train_fn, stop={"1/2/3": 20}).trials
self.assertFalse("1" in trial.last_result)
[trial] = tune.run(train_fn, stop={"test": 1}).trials
self.assertTrue(
"test" in trial.last_result
and "1" in trial.last_result["test"]
and "2" in trial.last_result["test"]["1"]
and "3" in trial.last_result["test"]["1"]["2"]
)

def testIterationCounter(self):
def train_fn(config):
Expand Down
36 changes: 36 additions & 0 deletions python/ray/tune/tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import sys

import pytest

from ray.exceptions import RayActorError, RayTaskError
from ray.tests.conftest import propagate_logs # noqa
from ray.train import Checkpoint
from ray.train._internal.session import _TrainingResult
from ray.train._internal.storage import StorageContext
Expand Down Expand Up @@ -116,5 +118,39 @@ def test_trial_logdir_length():
assert len(trial.storage.trial_dir_name) < 200


def test_should_stop(caplog, propagate_logs): # noqa
"""Test whether `Trial.should_stop()` works as expected given a result dict."""
trial = Trial(
"MockTrainable",
stub=True,
trial_id="abcd1234",
stopping_criterion={"a": 10.0, "b/c": 20.0},
)

# Criterion is not reached yet -> don't stop.
result = {"a": 9.999, "b/c": 0.0, "some_other_key": True}
assert not trial.should_stop(result)

# Criterion is exactly reached -> stop.
result = {"a": 10.0, "b/c": 0.0, "some_other_key": False}
assert trial.should_stop(result)

# Criterion is exceeded -> stop.
result = {"a": 10000.0, "b/c": 0.0, "some_other_key": False}
assert trial.should_stop(result)

# Test nested criterion.
result = {"a": 5.0, "b/c": 1000.0, "some_other_key": False}
assert trial.should_stop(result)

# Test criterion NOT found in result metrics.
result = {"b/c": 1000.0}
with caplog.at_level(logging.WARNING):
trial.should_stop(result)
assert (
"Stopping criterion 'a' not found in result dict! Available keys are ['b/c']."
) in caplog.text


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 8b89a7b

Please sign in to comment.