Skip to content

Commit

Permalink
[tune] Prefix global object registry with job ID to avoid conflicts i…
Browse files Browse the repository at this point in the history
…n multi tenancy (#33095)

In #32560, we documented a workaround for the multi tenancy issues in Ray Tune, e.g. described in #30091.

This PR fixes the root issue by prefixing the global registry with the core worker job ID, which is unique per driver process. This will avoid conflicts between parallel running tune trials.

To prove that it works, we modify the fix from #32560 to not require a workaround anymore.

To avoid cluttering the global key-value store with stale objects, we also de-register objects from the global KV store after finishing a Ray Tune run.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Kai Fricke <coding@kaifricke.com>
  • Loading branch information
krfricke authored Mar 7, 2023
1 parent 6c20bc6 commit 31a991f
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 60 deletions.
10 changes: 4 additions & 6 deletions doc/source/tune/faq.rst
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,11 @@ The reasons for this are:
3. Concurrent jobs are harder to debug. If a trial of job A fills the disk,
trials from job B on the same node are impacted. In practice, it's hard
to reason about these conditions from the logs if something goes wrong.
4. Some internal implementations in Ray Tune assume that you only have one job
running at a time. This can lead to conflicts.

The fourth reason is especially problematic when you run concurrent tuning jobs. For instance,
a symptom is when trials from job A use parameters specified in job B, leading to unexpected
results.
Previously, some internal implementations in Ray Tune assumed that you only have one job
running at a time. A symptom was when trials from job A used parameters specified in job B,
leading to unexpected results.

Please refer to
[this github issue](https://github.com/ray-project/ray/issues/30091#issuecomment-1431676976)
for more context and a workaround.
for more context and a workaround if you run into this issue.
69 changes: 63 additions & 6 deletions python/ray/tune/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import atexit
import logging
import uuid
from functools import partial
from types import FunctionType
from typing import Callable, Optional, Type, Union
Expand All @@ -10,6 +10,7 @@
_internal_kv_get,
_internal_kv_initialized,
_internal_kv_put,
_internal_kv_del,
)
from ray.tune.error import TuneError
from ray.util.annotations import DeveloperAPI
Expand Down Expand Up @@ -111,6 +112,10 @@ def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool =
_global_registry.register(TRAINABLE_CLASS, name, trainable)


def _unregister_trainables():
_global_registry.unregister_all(TRAINABLE_CLASS)


@DeveloperAPI
def register_env(name: str, env_creator: Callable):
"""Register a custom environment for use with RLlib.
Expand All @@ -128,6 +133,10 @@ def register_env(name: str, env_creator: Callable):
_global_registry.register(ENV_CREATOR, name, env_creator)


def _unregister_envs():
_global_registry.unregister_all(ENV_CREATOR)


@DeveloperAPI
def register_input(name: str, input_creator: Callable):
"""Register a custom input api for RLlib.
Expand All @@ -142,6 +151,10 @@ def register_input(name: str, input_creator: Callable):
_global_registry.register(RLLIB_INPUT, name, input_creator)


def _unregister_inputs():
_global_registry.unregister_all(RLLIB_INPUT)


@DeveloperAPI
def registry_contains_input(name: str) -> bool:
return _global_registry.contains(RLLIB_INPUT, name)
Expand All @@ -152,6 +165,12 @@ def registry_get_input(name: str) -> Callable:
return _global_registry.get(RLLIB_INPUT, name)


def _unregister_all():
_unregister_inputs()
_unregister_envs()
_unregister_trainables()


def _check_serializability(key, value):
_global_registry.register(TEST, key, value)

Expand Down Expand Up @@ -179,8 +198,29 @@ def _make_key(prefix: str, category: str, key: str):

class _Registry:
def __init__(self, prefix: Optional[str] = None):
"""If no prefix is given, use runtime context job ID."""
self._to_flush = {}
self._prefix = prefix or uuid.uuid4().hex[:8]
self._prefix = prefix
self._registered = set()
self._atexit_handler_registered = False

@property
def prefix(self):
if not self._prefix:
self._prefix = ray.get_runtime_context().get_job_id()
return self._prefix

def _register_atexit(self):
if self._atexit_handler_registered:
# Already registered
return

if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE:
# Only cleanup on the driver
return

atexit.register(_unregister_all)
self._atexit_handler_registered = True

def register(self, category, key, value):
"""Registers the value with the global registry.
Expand All @@ -198,16 +238,31 @@ def register(self, category, key, value):
if _internal_kv_initialized():
self.flush_values()

def unregister(self, category, key):
if _internal_kv_initialized():
_internal_kv_del(_make_key(self.prefix, category, key))
else:
self._to_flush.pop((category, key), None)

def unregister_all(self, category: Optional[str] = None):
remaining = set()
for (cat, key) in self._registered:
if category and category == cat:
self.unregister(cat, key)
else:
remaining.add((cat, key))
self._registered = remaining

def contains(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(self._prefix, category, key))
value = _internal_kv_get(_make_key(self.prefix, category, key))
return value is not None
else:
return (category, key) in self._to_flush

def get(self, category, key):
if _internal_kv_initialized():
value = _internal_kv_get(_make_key(self._prefix, category, key))
value = _internal_kv_get(_make_key(self.prefix, category, key))
if value is None:
raise ValueError(
"Registry value for {}/{} doesn't exist.".format(category, key)
Expand All @@ -217,14 +272,16 @@ def get(self, category, key):
return pickle.loads(self._to_flush[(category, key)])

def flush_values(self):
self._register_atexit()
for (category, key), value in self._to_flush.items():
_internal_kv_put(
_make_key(self._prefix, category, key), value, overwrite=True
_make_key(self.prefix, category, key), value, overwrite=True
)
self._registered.add((category, key))
self._to_flush.clear()


_global_registry = _Registry(prefix="global")
_global_registry = _Registry()
ray._private.worker._post_init_hooks.append(_global_registry.flush_values)


Expand Down
16 changes: 0 additions & 16 deletions python/ray/tune/tests/_test_multi_tenancy_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
# are tracked by the driver, not the trainable.
VALS = [int(os.environ["VAL_1"]), int(os.environ["VAL_2"])]

# If 1, use workaround, if 0, just run (and fail in job 1).
USE_WORKAROUND = bool(int(os.environ["WORKAROUND"]))

# Wait for HANG_RUN_MARKER
while HANG_RUN_MARKER and Path(HANG_RUN_MARKER).exists():
time.sleep(0.1)
Expand All @@ -56,13 +53,6 @@ def train_func(config):
session.report({"param": config["param"], "fixed": config["fixed"]})


# Workaround: Just use a unique name per trainer/trainable
if USE_WORKAROUND:
import uuid

DataParallelTrainer.__name__ = "DataParallelTrainer_" + uuid.uuid4().hex[:8]


trainer = DataParallelTrainer(
train_loop_per_worker=train_func,
train_loop_config={
Expand Down Expand Up @@ -97,9 +87,3 @@ def train_func(config):
# Put assertions last, so we don't finish early because of failures
assert sorted([result.metrics["param"] for result in results]) == VALS
assert [result.metrics["fixed"] for result in results] == [FIXED_VAL, FIXED_VAL]

if USE_WORKAROUND:
from ray.experimental.internal_kv import _internal_kv_del
from ray.tune.registry import _make_key, TRAINABLE_CLASS

_internal_kv_del(_make_key("global", TRAINABLE_CLASS, DataParallelTrainer.__name__))
8 changes: 2 additions & 6 deletions python/ray/tune/tests/test_experiment_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,7 @@ def testPickling(self):
self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max"))

ray.shutdown()
ray.tune.registry._global_registry = ray.tune.registry._Registry(
prefix="global"
)
ray.tune.registry._global_registry = ray.tune.registry._Registry()

with open(pickle_path, "rb") as f:
analysis = pickle.load(f)
Expand All @@ -406,9 +404,7 @@ def testFromPath(self):
self.assertTrue(analysis.get_best_trial(metric=self.metric, mode="max"))

ray.shutdown()
ray.tune.registry._global_registry = ray.tune.registry._Registry(
prefix="global"
)
ray.tune.registry._global_registry = ray.tune.registry._Registry()

analysis = ExperimentAnalysis(self.test_path)

Expand Down
30 changes: 4 additions & 26 deletions python/ray/tune/tests/test_multi_tenancy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@ def ray_start_4_cpus():
ray.shutdown()


@pytest.mark.parametrize("use_workaround", [False, True])
@pytest.mark.parametrize("exit_same", [False, True])
def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):
def test_registry_conflict(ray_start_4_cpus, tmpdir, exit_same):
"""Two concurrent Tune runs can conflict with each other when they
use a trainable with the same name.
This test starts two runs in parallel and asserts that a workaround used
in the docs can alleviate the problem.
This test starts two runs in parallel and asserts that our fix in
https://github.com/ray-project/ray/pull/33095 resolves the issue.
This is how we schedule the runs:
Expand All @@ -42,10 +41,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):
- Run 1 finally finishes, and we compare the expected results with the actual
results.
When you don't use the workaround, expect an assertion error (if ``exit_same=True``,
see below), otherwise a KeyError (because a trial failed).
When the workaround is used, we expect everything to run without error.
NOTE: Two errors can occur with registry conflicts. First,
the trainable can be overwritten and captured, for example, when a fixed value
is included in the trainable. The second trial of run 1 then has a wrong
Expand All @@ -57,10 +52,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):
removed already. Note that these objects are registered with
``tune.with_parameters()`` (not the global registry store).
We test both scenarios using the ``exit_same`` parameter.
NOTE: If we resolve the registry issue (for example, with unique keys)
you can remove the test that expects the assertion error. We can remove
the parametrization and the workaround and assert that no conflict occurs.
"""
# Create file markers
run_1_running = tmpdir / "run_1_running"
Expand All @@ -75,7 +66,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):

run_1_env = {
"RAY_ADDRESS": ray_address,
"WORKAROUND": str(int(use_workaround)),
"FIXED_VAL": str(1),
"VAL_1": str(2),
"VAL_2": str(3),
Expand All @@ -93,7 +83,6 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):

run_2_env = {
"RAY_ADDRESS": ray_address,
"WORKAROUND": str(int(use_workaround)),
"FIXED_VAL": str(4),
"VAL_1": str(5),
"VAL_2": str(6),
Expand Down Expand Up @@ -123,18 +112,7 @@ def test_registry_conflict(ray_start_4_cpus, tmpdir, use_workaround, exit_same):
print("Started run 2:", run_2.pid)

assert run_2.wait() == 0

if use_workaround:
assert run_1.wait() == 0
else:
assert run_1.wait() != 0

stderr = run_1.stderr.read().decode()

if not exit_same:
assert "OwnerDiedError" in stderr, stderr
else:
assert "AssertionError" in stderr, stderr
assert run_1.wait() == 0


if __name__ == "__main__":
Expand Down

0 comments on commit 31a991f

Please sign in to comment.