diff --git a/doc/source/tune/faq.rst b/doc/source/tune/faq.rst index bb02b09ddb06..bbefeef2f4cd 100644 --- a/doc/source/tune/faq.rst +++ b/doc/source/tune/faq.rst @@ -797,13 +797,11 @@ The reasons for this are: 3. Concurrent jobs are harder to debug. If a trial of job A fills the disk, trials from job B on the same node are impacted. In practice, it's hard to reason about these conditions from the logs if something goes wrong. -4. Some internal implementations in Ray Tune assume that you only have one job - running at a time. This can lead to conflicts. -The fourth reason is especially problematic when you run concurrent tuning jobs. For instance, -a symptom is when trials from job A use parameters specified in job B, leading to unexpected -results. +Previously, some internal implementations in Ray Tune assumed that you only have one job +running at a time. A symptom was when trials from job A used parameters specified in job B, +leading to unexpected results. Please refer to [this github issue](https://github.com/ray-project/ray/issues/30091#issuecomment-1431676976) -for more context and a workaround. +for more context and a workaround if you run into this issue. diff --git a/python/ray/tune/registry.py b/python/ray/tune/registry.py index 069689ae5d49..be868ce3cad7 100644 --- a/python/ray/tune/registry.py +++ b/python/ray/tune/registry.py @@ -1,5 +1,5 @@ +import atexit import logging -import uuid from functools import partial from types import FunctionType from typing import Callable, Optional, Type, Union @@ -10,6 +10,7 @@ _internal_kv_get, _internal_kv_initialized, _internal_kv_put, + _internal_kv_del, ) from ray.tune.error import TuneError from ray.util.annotations import DeveloperAPI @@ -111,6 +112,10 @@ def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = _global_registry.register(TRAINABLE_CLASS, name, trainable) +def _unregister_trainables(): + _global_registry.unregister_all(TRAINABLE_CLASS) + + @DeveloperAPI def register_env(name: str, env_creator: Callable): """Register a custom environment for use with RLlib. @@ -128,6 +133,10 @@ def register_env(name: str, env_creator: Callable): _global_registry.register(ENV_CREATOR, name, env_creator) +def _unregister_envs(): + _global_registry.unregister_all(ENV_CREATOR) + + @DeveloperAPI def register_input(name: str, input_creator: Callable): """Register a custom input api for RLlib. @@ -142,6 +151,10 @@ def register_input(name: str, input_creator: Callable): _global_registry.register(RLLIB_INPUT, name, input_creator) +def _unregister_inputs(): + _global_registry.unregister_all(RLLIB_INPUT) + + @DeveloperAPI def registry_contains_input(name: str) -> bool: return _global_registry.contains(RLLIB_INPUT, name) @@ -152,6 +165,12 @@ def registry_get_input(name: str) -> Callable: return _global_registry.get(RLLIB_INPUT, name) +def _unregister_all(): + _unregister_inputs() + _unregister_envs() + _unregister_trainables() + + def _check_serializability(key, value): _global_registry.register(TEST, key, value) @@ -179,8 +198,29 @@ def _make_key(prefix: str, category: str, key: str): class _Registry: def __init__(self, prefix: Optional[str] = None): + """If no prefix is given, use runtime context job ID.""" self._to_flush = {} - self._prefix = prefix or uuid.uuid4().hex[:8] + self._prefix = prefix + self._registered = set() + self._atexit_handler_registered = False + + @property + def prefix(self): + if not self._prefix: + self._prefix = ray.get_runtime_context().get_job_id() + return self._prefix + + def _register_atexit(self): + if self._atexit_handler_registered: + # Already registered + return + + if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE: + # Only cleanup on the driver + return + + atexit.register(_unregister_all) + self._atexit_handler_registered = True def register(self, category, key, value): """Registers the value with the global registry. @@ -198,16 +238,31 @@ def register(self, category, key, value): if _internal_kv_initialized(): self.flush_values() + def unregister(self, category, key): + if _internal_kv_initialized(): + _internal_kv_del(_make_key(self.prefix, category, key)) + else: + self._to_flush.pop((category, key), None) + + def unregister_all(self, category: Optional[str] = None): + remaining = set() + for (cat, key) in self._registered: + if category and category == cat: + self.unregister(cat, key) + else: + remaining.add((cat, key)) + self._registered = remaining + def contains(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(self._prefix, category, key)) + value = _internal_kv_get(_make_key(self.prefix, category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): - value = _internal_kv_get(_make_key(self._prefix, category, key)) + value = _internal_kv_get(_make_key(self.prefix, category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format(category, key) @@ -217,14 +272,16 @@ def get(self, category, key): return pickle.loads(self._to_flush[(category, key)]) def flush_values(self): + self._register_atexit() for (category, key), value in self._to_flush.items(): _internal_kv_put( - _make_key(self._prefix, category, key), value, overwrite=True + _make_key(self.prefix, category, key), value, overwrite=True ) + self._registered.add((category, key)) self._to_flush.clear() -_global_registry = _Registry(prefix="global") +_global_registry = _Registry() ray._private.worker._post_init_hooks.append(_global_registry.flush_values) diff --git a/python/ray/tune/tests/_test_multi_tenancy_run.py b/python/ray/tune/tests/_test_multi_tenancy_run.py index 2b42a96c0f12..74da15b28328 100644 --- a/python/ray/tune/tests/_test_multi_tenancy_run.py +++ b/python/ray/tune/tests/_test_multi_tenancy_run.py @@ -33,9 +33,6 @@ # are tracked by the driver, not the trainable. VALS = [int(os.environ["VAL_1"]), int(os.environ["VAL_2"])] -# If 1, use workaround, if 0, just run (and fail in job 1). -USE_WORKAROUND = bool(int(os.environ["WORKAROUND"])) - # Wait for HANG_RUN_MARKER while HANG_RUN_MARKER and Path(HANG_RUN_MARKER).exists(): time.sleep(0.1) @@ -56,13 +53,6 @@ def train_func(config): session.report({"param": config["param"], "fixed": config["fixed"]}) -# Workaround: Just use a unique name per trainer/trainable -if USE_WORKAROUND: - import uuid - - DataParallelTrainer.__name__ = "DataParallelTrainer_" + uuid.uuid4().hex[:8] - - trainer = DataParallelTrainer( train_loop_per_worker=train_func, train_loop_config={ @@ -97,9 +87,3 @@ def train_func(config): # Put assertions last, so we don't finish early because of failures assert sorted([result.metrics["param"] for result in results]) == VALS assert [result.metrics["fixed"] for result in results] == [FIXED_VAL, FIXED_VAL] - -if USE_WORKAROUND: - from ray.experimental.internal_kv import _internal_kv_del - from ray.tune.registry import _make_key, TRAINABLE_CLASS - - _internal_kv_del(_make_key("global", TRAINABLE_CLASS, DataParallelTrainer.__name__)) diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index c89202b50d2b..b99d3fbbce57 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -390,9 +390,7 @@ def testPickling(self): self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max")) ray.shutdown() - ray.tune.registry._global_registry = ray.tune.registry._Registry( - prefix="global" - ) + ray.tune.registry._global_registry = ray.tune.registry._Registry() with open(pickle_path, "rb") as f: analysis = pickle.load(f) @@ -406,9 +404,7 @@ def testFromPath(self): self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max")) ray.shutdown() - ray.tune.registry._global_registry = ray.tune.registry._Registry( - prefix="global" - ) + ray.tune.registry._global_registry = ray.tune.registry._Registry() analysis = ExperimentAnalysis(self.test_path) diff --git a/python/ray/tune/tests/test_multi_tenancy.py b/python/ray/tune/tests/test_multi_tenancy.py index 71062c8595cc..b363d7dff7b6 100644 --- a/python/ray/tune/tests/test_multi_tenancy.py +++ b/python/ray/tune/tests/test_multi_tenancy.py @@ -15,14 +15,13 @@ def ray_start_4_cpus(): ray.shutdown() -@pytest.mark.parametrize("use_workaround", [False, True]) @pytest.mark.parametrize("exit_same", [False, True]) -def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): +def test_registry_conflict(ray_start_4_cpus, tmpdir, exit_same): """Two concurrent Tune runs can conflict with each other when they use a trainable with the same name. - This test starts two runs in parallel and asserts that a workaround used - in the docs can alleviate the problem. + This test starts two runs in parallel and asserts that our fix in + https://github.com/ray-project/ray/pull/33095 resolves the issue. This is how we schedule the runs: @@ -42,10 +41,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): - Run 1 finally finishes, and we compare the expected results with the actual results. - When you don't use the workaround, expect an assertion error (if ``exit_same=True``, - see below), otherwise a KeyError (because a trial failed). - When the workaround is used, we expect everything to run without error. - NOTE: Two errors can occur with registry conflicts. First, the trainable can be overwritten and captured, for example, when a fixed value is included in the trainable. The second trial of run 1 then has a wrong @@ -57,10 +52,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): removed already. Note that these objects are registered with ``tune.with_parameters()`` (not the global registry store). We test both scenarios using the ``exit_same`` parameter. - - NOTE: If we resolve the registry issue (for example, with unique keys) - you can remove the test that expects the assertion error. We can remove - the parametrization and the workaround and assert that no conflict occurs. """ # Create file markers run_1_running = tmpdir / "run_1_running" @@ -75,7 +66,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): run_1_env = { "RAY_ADDRESS": ray_address, - "WORKAROUND": str(int(use_workaround)), "FIXED_VAL": str(1), "VAL_1": str(2), "VAL_2": str(3), @@ -93,7 +83,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): run_2_env = { "RAY_ADDRESS": ray_address, - "WORKAROUND": str(int(use_workaround)), "FIXED_VAL": str(4), "VAL_1": str(5), "VAL_2": str(6), @@ -123,18 +112,7 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same): print("Started run 2:", run_2.pid) assert run_2.wait() == 0 - - if use_workaround: - assert run_1.wait() == 0 - else: - assert run_1.wait() != 0 - - stderr = run_1.stderr.read().decode() - - if not exit_same: - assert "OwnerDiedError" in stderr, stderr - else: - assert "AssertionError" in stderr, stderr + assert run_1.wait() == 0 if __name__ == "__main__":