Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tune; RLlib] Missing stopping criterion should not error (just warn). #45613

Merged
20 changes: 12 additions & 8 deletions python/ray/tune/experiment/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from ray.tune.trainable.metadata import _TrainingRunMetadata
from ray.tune.utils import date_str, flatten_dict
from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
from ray.util import log_once
from ray.util.annotations import Deprecated, DeveloperAPI

DEBUG_PRINT_INTERVAL = 5
Expand Down Expand Up @@ -851,18 +852,21 @@ def should_stop(self, result):
if result.get(DONE):
return True

for criteria, stop_value in self.stopping_criterion.items():
if criteria not in result:
raise TuneError(
"Stopping criteria {} not provided in result dict. Keys "
"are {}.".format(criteria, list(result.keys()))
)
elif isinstance(criteria, dict):
for criterion, stop_value in self.stopping_criterion.items():
if isinstance(criterion, dict):
raise ValueError(
"Stopping criteria is now flattened by default. "
"Use forward slashes to nest values `key1/key2/key3`."
)
elif result[criteria] >= stop_value:
elif criterion not in result:
if log_once("tune_trial_stop_criterion_not_found"):
logger.warning(
f"Stopping criterion '{criterion}' not found in result dict! "
f"Available keys are {list(result.keys())}. If '{criterion}' is"
" never reported, the run will continue until training is "
"finished."
)
elif result[criterion] >= stop_value:
return True
return False

Expand Down
47 changes: 17 additions & 30 deletions python/ray/tune/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,43 +423,23 @@ def f():

self.assertRaises(TuneError, f)

def testBadParams5(self):
def f():
run_experiments({"foo": {"run": "__fake", "stop": {"asdf": 1}}})

self.assertRaises(TuneError, f)

def testBadParams6(self):
def f():
run_experiments({"foo": {"run": "PPO", "resources_per_trial": {"asdf": 1}}})

self.assertRaises(TuneError, f)

def testBadStoppingReturn(self):
def train_fn(config):
train.report(dict(a=1))

register_trainable("f1", train_fn)

def f():
run_experiments(
{
"foo": {
"run": "f1",
"stop": {"time": 10},
}
}
)

self.assertRaises(TuneError, f)

def testNestedStoppingReturn(self):
def train_fn(config):
for i in range(10):
train.report(dict(test={"test1": {"test2": i}}))

with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"test": {"test1": {"test2": 6}}}).trials
[trial] = tune.run(train_fn, stop={"test": {"test1": {"test2": 6}}}).trials
self.assertTrue(
"test" in trial.last_result
and "test1" in trial.last_result["test"]
and "test2" in trial.last_result["test"]["test1"]
)
[trial] = tune.run(train_fn, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)

Expand Down Expand Up @@ -1636,10 +1616,17 @@ def train_fn(config):
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in algo.results)
)
with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"1/2/3": 20})
with self.assertRaises(TuneError):
[trial] = tune.run(train_fn, stop={"test": 1}).trials
# Test, whether non-existent stop criteria do NOT cause an error anymore (just
# a warning).
[trial] = tune.run(train_fn, stop={"1/2/3": 20}).trials
self.assertFalse("1" in trial.last_result)
[trial] = tune.run(train_fn, stop={"test": 1}).trials
self.assertTrue(
"test" in trial.last_result
and "1" in trial.last_result["test"]
and "2" in trial.last_result["test"]["1"]
and "3" in trial.last_result["test"]["1"]["2"]
)

def testIterationCounter(self):
def train_fn(config):
Expand Down
36 changes: 36 additions & 0 deletions python/ray/tune/tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import sys

import pytest

from ray.exceptions import RayActorError, RayTaskError
from ray.tests.conftest import propagate_logs # noqa
from ray.train import Checkpoint
from ray.train._internal.session import _TrainingResult
from ray.train._internal.storage import StorageContext
Expand Down Expand Up @@ -116,5 +118,39 @@ def test_trial_logdir_length():
assert len(trial.storage.trial_dir_name) < 200


def test_should_stop(caplog, propagate_logs): # noqa
"""Test whether `Trial.should_stop()` works as expected given a result dict."""
trial = Trial(
"MockTrainable",
stub=True,
trial_id="abcd1234",
stopping_criterion={"a": 10.0, "b/c": 20.0},
)

# Criterion is not reached yet -> don't stop.
result = {"a": 9.999, "b/c": 0.0, "some_other_key": True}
assert not trial.should_stop(result)

# Criterion is exactly reached -> stop.
result = {"a": 10.0, "b/c": 0.0, "some_other_key": False}
assert trial.should_stop(result)

# Criterion is exceeded -> stop.
result = {"a": 10000.0, "b/c": 0.0, "some_other_key": False}
assert trial.should_stop(result)

# Test nested criterion.
result = {"a": 5.0, "b/c": 1000.0, "some_other_key": False}
assert trial.should_stop(result)

# Test criterion NOT found in result metrics.
result = {"b/c": 1000.0}
with caplog.at_level(logging.WARNING):
trial.should_stop(result)
assert (
"Stopping criterion 'a' not found in result dict! Available keys are ['b/c']."
) in caplog.text


if __name__ == "__main__":
sys.exit(pytest.main(["-v", __file__]))
Loading