diff --git a/dashboard/modules/job/tests/test_job_agent.py b/dashboard/modules/job/tests/test_job_agent.py index 37d51161ed77..b6e1a838c378 100644 --- a/dashboard/modules/job/tests/test_job_agent.py +++ b/dashboard/modules/job/tests/test_job_agent.py @@ -532,6 +532,7 @@ def test_agent_logs_not_streamed_to_drivers(): err_str = proc.stderr.read().decode("ascii") print(out_str, err_str) + assert "(raylet)" not in out_str assert "(raylet)" not in err_str diff --git a/python/ray/_private/gcs_aio_client.py b/python/ray/_private/gcs_aio_client.py index d5502c6dcea1..f794fd7c7eee 100644 --- a/python/ray/_private/gcs_aio_client.py +++ b/python/ray/_private/gcs_aio_client.py @@ -56,7 +56,9 @@ def __init__( self._nums_reconnect_retry = nums_reconnect_retry def _connect(self): + print("vct connecting") self._gcs_client._connect() + print("vct connected") @property def address(self): diff --git a/python/ray/_private/node.py b/python/ray/_private/node.py index cdf4dd079734..181396049f4e 100644 --- a/python/ray/_private/node.py +++ b/python/ray/_private/node.py @@ -648,7 +648,11 @@ def _init_gcs_client(self): last_ex = None try: gcs_address = self.gcs_address - client = GcsClient(address=gcs_address) + client = GcsClient( + address=gcs_address, + cluster_id=self._ray_params.cluster_id, + ) + self.cluster_id = client.get_cluster_id() if self.head: # Send a simple request to make sure GCS is alive # if it's a head node. @@ -664,19 +668,26 @@ def _init_gcs_client(self): time.sleep(1) if self._gcs_client is None: - with open(os.path.join(self._logs_dir, "gcs_server.err")) as err: - # Use " C " or " E " to exclude the stacktrace. - # This should work for most cases, especitally - # it's when GCS is starting. Only display last 10 lines of logs. - errors = [e for e in err.readlines() if " C " in e or " E " in e][-10:] - error_msg = "\n" + "".join(errors) + "\n" - raise RuntimeError( - f"Failed to {'start' if self.head else 'connect to'} GCS. " - f" Last {len(errors)} lines of error files:" - f"{error_msg}." - f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}" - " for details" - ) + if hasattr(self, "_logs_dir"): + with open(os.path.join(self._logs_dir, "gcs_server.err")) as err: + # Use " C " or " E " to exclude the stacktrace. + # This should work for most cases, especitally + # it's when GCS is starting. Only display last 10 lines of logs. + errors = [e for e in err.readlines() if " C " in e or " E " in e][ + -10: + ] + error_msg = "\n" + "".join(errors) + "\n" + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS. " + f" Last {len(errors)} lines of error files:" + f"{error_msg}." + f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}" + " for details" + ) + else: + raise RuntimeError( + f"Failed to {'start' if self.head else 'connect to'} GCS." + ) ray.experimental.internal_kv._initialize_internal_kv(self._gcs_client) @@ -1064,6 +1075,7 @@ def start_raylet( self._ray_params.node_manager_port, self._raylet_socket_name, self._plasma_store_socket_name, + self.cluster_id, self._ray_params.worker_path, self._ray_params.setup_worker_path, self._ray_params.storage, diff --git a/python/ray/_private/parameter.py b/python/ray/_private/parameter.py index d383f76370a4..538fe09a505d 100644 --- a/python/ray/_private/parameter.py +++ b/python/ray/_private/parameter.py @@ -127,6 +127,7 @@ class RayParams: env_vars: Override environment variables for the raylet. session_name: The name of the session of the ray cluster. webui: The url of the UI. + cluster_id: The cluster ID. """ def __init__( @@ -188,6 +189,7 @@ def __init__( env_vars: Optional[Dict[str, str]] = None, session_name: Optional[str] = None, webui: Optional[str] = None, + cluster_id: Optional[str] = None, ): self.redis_address = redis_address self.gcs_address = gcs_address @@ -249,6 +251,7 @@ def __init__( self._enable_object_reconstruction = enable_object_reconstruction self.labels = labels self._check_usage() + self.cluster_id = cluster_id # Set the internal config options for object reconstruction. if enable_object_reconstruction: diff --git a/python/ray/_private/services.py b/python/ray/_private/services.py index 96671793928b..996a2289cfa2 100644 --- a/python/ray/_private/services.py +++ b/python/ray/_private/services.py @@ -1357,6 +1357,7 @@ def start_raylet( node_manager_port: int, raylet_name: str, plasma_store_name: str, + cluster_id: str, worker_path: str, setup_worker_path: str, storage: str, @@ -1538,6 +1539,7 @@ def start_raylet( f"--session-name={session_name}", f"--temp-dir={temp_dir}", f"--webui={webui}", + f"--cluster-id={cluster_id}", ] ) @@ -1643,6 +1645,7 @@ def start_raylet( f"--gcs-address={gcs_address}", f"--session-name={session_name}", f"--labels={labels_json_str}", + f"--cluster-id={cluster_id}", ] if is_head_node: diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 6b053b43e49d..19e85bc97401 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -1581,7 +1581,7 @@ def init( spawn_reaper=False, connect_only=True, ) - except ConnectionError: + except (ConnectionError, RuntimeError): if gcs_address == ray._private.utils.read_ray_address(_temp_dir): logger.info( "Failed to connect to the default Ray cluster address at " @@ -1590,7 +1590,7 @@ def init( "address to connect to, run `ray stop` or restart Ray with " "`ray start`." ) - raise + raise ConnectionError # Log a message to find the Ray address that we connected to and the # dashboard URL. @@ -2262,6 +2262,7 @@ def connect( runtime_env_hash, startup_token, session_name, + node.cluster_id, "" if mode != SCRIPT_MODE else entrypoint, worker_launch_time_ms, worker_launched_time_ms, diff --git a/python/ray/_private/workers/default_worker.py b/python/ray/_private/workers/default_worker.py index 6000b4bd6fe2..6b49b685f707 100644 --- a/python/ray/_private/workers/default_worker.py +++ b/python/ray/_private/workers/default_worker.py @@ -17,6 +17,12 @@ parser = argparse.ArgumentParser( description=("Parse addresses for the worker to connect to.") ) +parser.add_argument( + "--cluster-id", + required=True, + type=str, + help="the auto-generated ID of the cluster", +) parser.add_argument( "--node-ip-address", required=True, @@ -207,6 +213,7 @@ gcs_address=args.gcs_address, session_name=args.session_name, webui=args.webui, + cluster_id=args.cluster_id, ) node = ray._private.node.Node( ray_params, diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 4099ac64645e..a465c15f952f 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -102,7 +102,6 @@ cdef class ObjectRef(BaseID): cdef CObjectID native(self) - cdef class ActorID(BaseID): cdef CActorID data diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index b4832b9302dd..f4f92fe8f612 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -129,8 +129,9 @@ from ray.includes.common cimport ( ) from ray.includes.unique_ids cimport ( CActorID, - CObjectID, + CClusterID, CNodeID, + CObjectID, CPlacementGroupID, ObjectIDIndexType, ) @@ -2335,16 +2336,35 @@ cdef class GcsClient: shared_ptr[CPythonGcsClient] inner object address object _nums_reconnect_retry + CClusterID cluster_id - def __cinit__(self, address, nums_reconnect_retry=5): + def __cinit__(self, address, nums_reconnect_retry=5, cluster_id=None): cdef GcsClientOptions gcs_options = GcsClientOptions.from_gcs_address(address) self.inner.reset(new CPythonGcsClient(dereference(gcs_options.native()))) self.address = address self._nums_reconnect_retry = nums_reconnect_retry - self._connect() + cdef c_string c_cluster_id + if cluster_id is None: + self.cluster_id = CClusterID.Nil() + else: + c_cluster_id = cluster_id + self.cluster_id = CClusterID.FromHex(c_cluster_id) + self._connect(5) - def _connect(self): - check_status(self.inner.get().Connect()) + def _connect(self, timeout_s=None): + cdef: + int64_t timeout_ms = round(1000 * timeout_s) if timeout_s else -1 + size_t num_retries = self._nums_reconnect_retry + with nogil: + status = self.inner.get().Connect(self.cluster_id, timeout_ms, num_retries) + + check_status(status) + if self.cluster_id.IsNil(): + self.cluster_id = self.inner.get().GetClusterId() + assert not self.cluster_id.IsNil() + + def get_cluster_id(self): + return self.cluster_id.Hex().decode() @property def address(self): @@ -2844,7 +2864,7 @@ cdef class CoreWorker: node_ip_address, node_manager_port, raylet_ip_address, local_mode, driver_name, stdout_file, stderr_file, serialized_job_config, metrics_agent_port, runtime_env_hash, - startup_token, session_name, entrypoint, + startup_token, session_name, cluster_id, entrypoint, worker_launch_time_ms, worker_launched_time_ms): self.is_local_mode = local_mode @@ -2896,6 +2916,7 @@ cdef class CoreWorker: options.runtime_env_hash = runtime_env_hash options.startup_token = startup_token options.session_name = session_name + options.cluster_id = CClusterID.FromHex(cluster_id) options.entrypoint = entrypoint options.worker_launch_time_ms = worker_launch_time_ms options.worker_launched_time_ms = worker_launched_time_ms diff --git a/python/ray/autoscaler/_private/monitor.py b/python/ray/autoscaler/_private/monitor.py index 9f22f283b54f..34db56f511f9 100644 --- a/python/ray/autoscaler/_private/monitor.py +++ b/python/ray/autoscaler/_private/monitor.py @@ -150,6 +150,7 @@ def __init__( gcs_channel ) worker = ray._private.worker.global_worker + # TODO: eventually plumb ClusterID through to here gcs_client = GcsClient(address=self.gcs_address) if monitor_ip: diff --git a/python/ray/autoscaler/v2/BUILD b/python/ray/autoscaler/v2/BUILD index ddc043706ab8..839f29587e74 100644 --- a/python/ray/autoscaler/v2/BUILD +++ b/python/ray/autoscaler/v2/BUILD @@ -80,7 +80,7 @@ py_test( py_test( name = "test_sdk", - size = "small", + size = "medium", srcs = ["tests/test_sdk.py"], tags = ["team:core", "exclusive"], deps = ["//:ray_lib", ":conftest"], diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 66c54b12f285..1f8c25ee525a 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -12,6 +12,7 @@ from ray.includes.optional cimport ( from ray.includes.unique_ids cimport ( CActorID, CJobID, + CClusterID, CWorkerID, CObjectID, CTaskID, @@ -367,8 +368,10 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil: cdef cppclass CPythonGcsClient "ray::gcs::PythonGcsClient": CPythonGcsClient(const CGcsClientOptions &options) - CRayStatus Connect() - + CRayStatus Connect( + const CClusterID &cluster_id, + int64_t timeout_ms, + size_t num_retries) CRayStatus CheckAlive( const c_vector[c_string] &raylet_addresses, int64_t timeout_ms, @@ -405,6 +408,7 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil: CRayStatus GetClusterStatus( int64_t timeout_ms, c_string &serialized_reply) + CClusterID GetClusterId() CRayStatus DrainNode( const c_string &node_id, int32_t reason, @@ -412,7 +416,6 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil: int64_t timeout_ms, c_bool &is_accepted) - cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil: unordered_map[c_string, double] PythonGetResourcesTotal( const CGcsNodeInfo& node_info) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 3afb811405d2..1f52bbea0af0 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -13,6 +13,7 @@ from libcpp.vector cimport vector as c_vector from ray.includes.unique_ids cimport ( CActorID, + CClusterID, CNodeID, CJobID, CTaskID, @@ -359,6 +360,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool connect_on_start int runtime_env_hash int startup_token + CClusterID cluster_id c_string session_name c_string entrypoint int64_t worker_launch_time_ms diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 2fb14e6322c0..cdb9b58d9188 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -7,6 +7,9 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod T FromBinary(const c_string &binary) + @staticmethod + T FromHex(const c_string &hex) + @staticmethod const T Nil() @@ -154,6 +157,17 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: CTaskID TaskId() const + cdef cppclass CClusterID "ray::ClusterID"(CUniqueID): + + @staticmethod + CClusterID FromHex(const c_string &hex_str) + + @staticmethod + CClusterID FromRandom() + + @staticmethod + const CClusterID Nil() + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): @staticmethod diff --git a/python/ray/serve/BUILD b/python/ray/serve/BUILD index e87ed72d76cc..12631511778c 100644 --- a/python/ray/serve/BUILD +++ b/python/ray/serve/BUILD @@ -683,7 +683,7 @@ py_test( py_test( name = "test_callback", - size = "small", + size = "medium", srcs = serve_tests_srcs, tags = ["exclusive", "team:serve"], deps = [":serve_lib"], diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index 6e962187372e..7c78c3c3da16 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -211,7 +211,6 @@ py_test_module_list( "test_unhandled_error.py", "test_utils.py", "test_widgets.py", - "test_node_labels.py", ], size = "small", tags = ["exclusive", "small_size_python_tests", "team:core"], @@ -222,6 +221,7 @@ py_test_module_list( files = [ "test_gcs_ha_e2e.py", "test_memory_pressure.py", + "test_node_labels.py", ], size = "medium", tags = ["exclusive", "team:core", "xcommit"], diff --git a/python/ray/tests/test_autoscaler.py b/python/ray/tests/test_autoscaler.py index 0b6b7210157e..861349774ec7 100644 --- a/python/ray/tests/test_autoscaler.py +++ b/python/ray/tests/test_autoscaler.py @@ -3514,6 +3514,9 @@ class FaultyAutoscaler: def __init__(self, *args, **kwargs): raise AutoscalerInitFailException + prev_port = os.environ.get("RAY_GCS_SERVER_PORT") + os.environ["RAY_GCS_SERVER_PORT"] = "12345" + ray.init() with patch("ray._private.utils.publish_error_to_driver") as mock_publish: with patch.multiple( "ray.autoscaler._private.monitor", @@ -3521,11 +3524,17 @@ def __init__(self, *args, **kwargs): _internal_kv_initialized=Mock(return_value=False), ): monitor = Monitor( - address="here:12345", autoscaling_config="", log_dir=self.tmpdir + address="localhost:12345", + autoscaling_config="", + log_dir=self.tmpdir, ) with pytest.raises(AutoscalerInitFailException): monitor.run() mock_publish.assert_called_once() + if prev_port is not None: + os.environ["RAY_GCS_SERVER_PORT"] = prev_port + else: + del os.environ["RAY_GCS_SERVER_PORT"] def testInitializeSDKArguments(self): # https://github.com/ray-project/ray/issues/23166 diff --git a/python/ray/tests/test_gcs_fault_tolerance.py b/python/ray/tests/test_gcs_fault_tolerance.py index 869c9ab648fe..fcd260a77307 100644 --- a/python/ray/tests/test_gcs_fault_tolerance.py +++ b/python/ray/tests/test_gcs_fault_tolerance.py @@ -433,13 +433,16 @@ def test_gcs_aio_client_reconnect( passed = [False] async def async_kv_get(): - gcs_aio_client = gcs_utils.GcsAioClient( - address=gcs_address, nums_reconnect_retry=20 if auto_reconnect else 0 - ) if not auto_reconnect: with pytest.raises(Exception): + gcs_aio_client = gcs_utils.GcsAioClient( + address=gcs_address, nums_reconnect_retry=0 + ) await gcs_aio_client.internal_kv_get(b"a", None) else: + gcs_aio_client = gcs_utils.GcsAioClient( + address=gcs_address, nums_reconnect_retry=20 + ) assert await gcs_aio_client.internal_kv_get(b"a", None) == b"b" return True diff --git a/python/ray/tests/test_ray_init_2.py b/python/ray/tests/test_ray_init_2.py index 55821eeebf1b..7118873288bf 100644 --- a/python/ray/tests/test_ray_init_2.py +++ b/python/ray/tests/test_ray_init_2.py @@ -271,7 +271,7 @@ def verify(): return True try: - wait_for_condition(verify, timeout=10, retry_interval_ms=2000) + wait_for_condition(verify, timeout=15, retry_interval_ms=2000) finally: proc.terminate() proc.wait() diff --git a/python/ray/train/BUILD b/python/ray/train/BUILD index f28992e4de16..2f3f8f87cff2 100644 --- a/python/ray/train/BUILD +++ b/python/ray/train/BUILD @@ -326,6 +326,14 @@ py_test( deps = [":train_lib", ":conftest"] ) +py_test( + name = "test_gpu_2", + size = "medium", + srcs = ["tests/test_gpu_2.py"], + tags = ["team:ml", "exclusive", "gpu_only"], + deps = [":train_lib", ":conftest"] +) + py_test( name = "test_gpu_auto_transfer", size = "medium", diff --git a/python/ray/train/tests/test_gpu.py b/python/ray/train/tests/test_gpu.py index 7ce50a9a153a..a5ac24e90c3e 100644 --- a/python/ray/train/tests/test_gpu.py +++ b/python/ray/train/tests/test_gpu.py @@ -3,7 +3,6 @@ from unittest.mock import patch import pytest -import numpy as np import torch import torchvision from torch.nn.parallel import DistributedDataParallel @@ -359,35 +358,6 @@ def train_fn(): assert isinstance(exc_info.value.__cause__, RayTaskError) -@pytest.mark.parametrize("use_gpu", (True, False)) -def test_torch_iter_torch_batches_auto_device(ray_start_4_cpus_2_gpus, use_gpu): - """ - Tests that iter_torch_batches in TorchTrainer worker function uses the - default device. - """ - - def train_fn(): - dataset = train.get_dataset_shard("train") - for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"): - assert str(batch["data"].device) == "cpu" - - # Autodetect - for batch in dataset.iter_torch_batches(dtypes=torch.float): - assert str(batch["data"].device) == str(train.torch.get_device()) - - dataset = ray.data.from_numpy(np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]).T) - # Test that this works outside a Train function - for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"): - assert str(batch["data"].device) == "cpu" - - trainer = TorchTrainer( - train_fn, - scaling_config=ScalingConfig(num_workers=2, use_gpu=use_gpu), - datasets={"train": dataset}, - ) - trainer.fit() - - if __name__ == "__main__": import sys diff --git a/python/ray/train/tests/test_gpu_2.py b/python/ray/train/tests/test_gpu_2.py new file mode 100644 index 000000000000..1889a7d3349e --- /dev/null +++ b/python/ray/train/tests/test_gpu_2.py @@ -0,0 +1,72 @@ +import pytest +import numpy as np +import torch + +import ray +import ray.data +from ray import tune + +import ray.train as train +from ray.air.config import ScalingConfig +from ray.train.examples.pytorch.torch_linear_example import LinearDataset +from ray.train.torch.torch_trainer import TorchTrainer + + +class LinearDatasetDict(LinearDataset): + """Modifies the LinearDataset to return a Dict instead of a Tuple.""" + + def __getitem__(self, index): + return {"x": self.x[index, None], "y": self.y[index, None]} + + +class NonTensorDataset(LinearDataset): + """Modifies the LinearDataset to also return non-tensor objects.""" + + def __getitem__(self, index): + return {"x": self.x[index, None], "y": 2} + + +# Currently in DataParallelTrainers we only report metrics from rank 0. +# For testing purposes here, we need to be able to report from all +# workers. +class TorchTrainerPatchedMultipleReturns(TorchTrainer): + def _report(self, training_iterator) -> None: + for results in training_iterator: + tune.report(results=results) + + +@pytest.mark.parametrize("use_gpu", (True, False)) +def test_torch_iter_torch_batches_auto_device(ray_start_4_cpus_2_gpus, use_gpu): + """ + Tests that iter_torch_batches in TorchTrainer worker function uses the + default device. + """ + + def train_fn(): + dataset = train.get_dataset_shard("train") + for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"): + assert str(batch["data"].device) == "cpu" + + # Autodetect + for batch in dataset.iter_torch_batches(dtypes=torch.float): + assert str(batch["data"].device) == str(train.torch.get_device()) + + dataset = ray.data.from_numpy(np.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5]]).T) + # Test that this works outside a Train function + for batch in dataset.iter_torch_batches(dtypes=torch.float, device="cpu"): + assert str(batch["data"].device) == "cpu" + + trainer = TorchTrainer( + train_fn, + scaling_config=ScalingConfig(num_workers=2, use_gpu=use_gpu), + datasets={"train": dataset}, + ) + trainer.fit() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", "-x", "-s", __file__])) diff --git a/python/ray/tune/tests/test_progress_reporter.py b/python/ray/tune/tests/test_progress_reporter.py index d1e8a0c93119..87855f66ca3c 100644 --- a/python/ray/tune/tests/test_progress_reporter.py +++ b/python/ray/tune/tests/test_progress_reporter.py @@ -689,7 +689,9 @@ def testEndToEndReporting(self): if os.environ.get("TUNE_NEW_EXECUTION") == "0": assert EXPECTED_END_TO_END_START in output assert EXPECTED_END_TO_END_END in output - assert "(raylet)" not in output, "Unexpected raylet log messages" + for line in output.splitlines(): + if "(raylet)" in line: + assert "cluster ID" in line, "Unexpected raylet log messages" except Exception: print("*** BEGIN OUTPUT ***") print(output) diff --git a/src/mock/ray/gcs/gcs_client/gcs_client.h b/src/mock/ray/gcs/gcs_client/gcs_client.h index e7b687d04e7d..cf232a712f51 100644 --- a/src/mock/ray/gcs/gcs_client/gcs_client.h +++ b/src/mock/ray/gcs/gcs_client/gcs_client.h @@ -31,7 +31,10 @@ namespace gcs { class MockGcsClient : public GcsClient { public: - MOCK_METHOD(Status, Connect, (instrumented_io_context & io_service), (override)); + MOCK_METHOD(Status, + Connect, + (instrumented_io_context & io_service, const ClusterID &cluster_id), + (override)); MOCK_METHOD(void, Disconnect, (), (override)); MOCK_METHOD((std::pair), GetGcsServerAddress, (), (const, override)); MOCK_METHOD(std::string, DebugString, (), (const, override)); diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 7c5f430830ab..bca6c66492ef 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -395,6 +395,7 @@ std::ostream &operator<<(std::ostream &os, const PlacementGroupID &id); type() : UniqueID() {} \ static type FromRandom() { return type(UniqueID::FromRandom()); } \ static type FromBinary(const std::string &binary) { return type(binary); } \ + static type FromHex(const std::string &hex) { return type(UniqueID::FromHex(hex)); } \ static type Nil() { return type(UniqueID::Nil()); } \ static constexpr size_t Size() { return kUniqueIDSize; } \ \ @@ -414,27 +415,6 @@ std::ostream &operator<<(std::ostream &os, const PlacementGroupID &id); // Restore the compiler alignment to default (8 bytes). #pragma pack(pop) -struct SafeClusterID { - private: - mutable absl::Mutex m_; - ClusterID id_ GUARDED_BY(m_); - - public: - SafeClusterID(const ClusterID &id) : id_(id) {} - - const ClusterID load() const { - absl::MutexLock l(&m_); - return id_; - } - - ClusterID exchange(const ClusterID &newId) { - absl::MutexLock l(&m_); - ClusterID old = id_; - id_ = newId; - return old; - } -}; - template BaseID::BaseID() { // Using const_cast to directly change data is dangerous. The cached diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 65737cac0793..c9ac1752063f 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -229,7 +229,7 @@ CoreWorker::CoreWorker(const CoreWorkerOptions &options, const WorkerID &worker_ gcs_client_ = std::make_shared(options_.gcs_options, GetWorkerID()); - RAY_CHECK_OK(gcs_client_->Connect(io_service_)); + RAY_CHECK_OK(gcs_client_->Connect(io_service_, options_.cluster_id)); RegisterToGcs(options_.worker_launch_time_ms, options_.worker_launched_time_ms); // Initialize the task state event buffer. diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index be1b5e002775..05623bb25d36 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -92,6 +92,7 @@ struct CoreWorkerOptions { metrics_agent_port(-1), connect_on_start(true), runtime_env_hash(0), + cluster_id(ClusterID::Nil()), session_name(""), entrypoint(""), worker_launch_time_ms(-1), @@ -182,6 +183,8 @@ struct CoreWorkerOptions { /// may not have the same pid as the process the worker pool /// starts (due to shim processes). StartupToken startup_token{0}; + /// Cluster ID associated with the core worker. + ClusterID cluster_id; /// The function to allocate a new object for the memory store. /// This allows allocating the objects in the language frontend's memory. /// For example, for the Java worker, we can allocate the objects in the JVM heap diff --git a/src/ray/gcs/gcs_client/gcs_client.cc b/src/ray/gcs/gcs_client/gcs_client.cc index dfc4b2c5cfd9..4fcc2fad191f 100644 --- a/src/ray/gcs/gcs_client/gcs_client.cc +++ b/src/ray/gcs/gcs_client/gcs_client.cc @@ -14,6 +14,8 @@ #include "ray/gcs/gcs_client/gcs_client.h" +#include +#include #include #include "ray/common/ray_config.h" @@ -81,9 +83,10 @@ void GcsSubscriberClient::PubsubCommandBatch( GcsClient::GcsClient(const GcsClientOptions &options, UniqueID gcs_client_id) : options_(options), gcs_client_id_(gcs_client_id) {} -Status GcsClient::Connect(instrumented_io_context &io_service) { +Status GcsClient::Connect(instrumented_io_context &io_service, + const ClusterID &cluster_id) { // Connect to gcs service. - client_call_manager_ = std::make_unique(io_service); + client_call_manager_ = std::make_unique(io_service, cluster_id); gcs_rpc_client_ = std::make_shared( options_.gcs_address_, options_.gcs_port_, *client_call_manager_); @@ -143,9 +146,7 @@ std::pair GcsClient::GetGcsServerAddress() const { return gcs_rpc_client_->GetAddress(); } -PythonGcsClient::PythonGcsClient(const GcsClientOptions &options) : options_(options) {} - -Status PythonGcsClient::Connect() { +PythonGcsClient::PythonGcsClient(const GcsClientOptions &options) : options_(options) { channel_ = rpc::GcsRpcClient::CreateGcsChannel(options_.gcs_address_, options_.gcs_port_); kv_stub_ = rpc::InternalKVGcsService::NewStub(channel_); @@ -153,27 +154,59 @@ Status PythonGcsClient::Connect() { node_info_stub_ = rpc::NodeInfoGcsService::NewStub(channel_); job_info_stub_ = rpc::JobInfoGcsService::NewStub(channel_); autoscaler_stub_ = rpc::autoscaler::AutoscalerStateService::NewStub(channel_); - return Status::OK(); } +namespace { Status HandleGcsError(rpc::GcsStatus status) { - RAY_CHECK(status.code() != static_cast(StatusCode::OK)); + RAY_CHECK_NE(status.code(), static_cast(StatusCode::OK)); return Status::Invalid(status.message() + " [GCS status code: " + std::to_string(status.code()) + "]"); } +} // namespace -void GrpcClientContextWithTimeoutMs(grpc::ClientContext &context, int64_t timeout_ms) { - if (timeout_ms != -1) { - context.set_deadline(std::chrono::system_clock::now() + - std::chrono::milliseconds(timeout_ms)); +Status PythonGcsClient::Connect(const ClusterID &cluster_id, + int64_t timeout_ms, + size_t num_retries) { + if (cluster_id.IsNil()) { + size_t tries = num_retries + 1; + RAY_CHECK(tries > 0) << "Expected positive retries, but got " << tries; + + RAY_LOG(DEBUG) << "Retrieving cluster ID from GCS server."; + rpc::GetClusterIdRequest request; + rpc::GetClusterIdReply reply; + + Status connect_status; + for (; tries > 0; tries--) { + grpc::ClientContext context; + PrepareContext(context, timeout_ms); + connect_status = + GrpcStatusToRayStatus(node_info_stub_->GetClusterId(&context, request, &reply)); + + if (connect_status.ok()) { + cluster_id_ = ClusterID::FromBinary(reply.cluster_id()); + RAY_LOG(DEBUG) << "Received cluster ID from GCS server: " << cluster_id_; + RAY_CHECK(!cluster_id_.IsNil()); + break; + } else if (!connect_status.IsGrpcError()) { + return HandleGcsError(reply.status()); + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + RAY_RETURN_NOT_OK(connect_status); + } else { + cluster_id_ = cluster_id; + RAY_LOG(DEBUG) << "Client initialized with provided cluster ID: " << cluster_id_; } + + RAY_CHECK(!cluster_id_.IsNil()) << "Unexpected nil cluster ID."; + return Status::OK(); } Status PythonGcsClient::CheckAlive(const std::vector &raylet_addresses, int64_t timeout_ms, std::vector &result) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::CheckAliveRequest request; for (const auto &address : raylet_addresses) { @@ -199,7 +232,7 @@ Status PythonGcsClient::InternalKVGet(const std::string &ns, int64_t timeout_ms, std::string &value) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVGetRequest request; request.set_namespace_(ns); @@ -226,7 +259,7 @@ Status PythonGcsClient::InternalKVMultiGet( int64_t timeout_ms, std::unordered_map &result) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVMultiGetRequest request; request.set_namespace_(ns); @@ -258,7 +291,7 @@ Status PythonGcsClient::InternalKVPut(const std::string &ns, int64_t timeout_ms, int &added_num) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVPutRequest request; request.set_namespace_(ns); @@ -285,7 +318,7 @@ Status PythonGcsClient::InternalKVDel(const std::string &ns, int64_t timeout_ms, int &deleted_num) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVDelRequest request; request.set_namespace_(ns); @@ -310,7 +343,7 @@ Status PythonGcsClient::InternalKVKeys(const std::string &ns, int64_t timeout_ms, std::vector &results) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVKeysRequest request; request.set_namespace_(ns); @@ -334,7 +367,7 @@ Status PythonGcsClient::InternalKVExists(const std::string &ns, int64_t timeout_ms, bool &exists) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::InternalKVExistsRequest request; request.set_namespace_(ns); @@ -357,7 +390,7 @@ Status PythonGcsClient::PinRuntimeEnvUri(const std::string &uri, int expiration_s, int64_t timeout_ms) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::PinRuntimeEnvURIRequest request; request.set_uri(uri); @@ -385,7 +418,7 @@ Status PythonGcsClient::PinRuntimeEnvUri(const std::string &uri, Status PythonGcsClient::GetAllNodeInfo(int64_t timeout_ms, std::vector &result) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::GetAllNodeInfoRequest request; rpc::GetAllNodeInfoReply reply; @@ -405,7 +438,7 @@ Status PythonGcsClient::GetAllNodeInfo(int64_t timeout_ms, Status PythonGcsClient::GetAllJobInfo(int64_t timeout_ms, std::vector &result) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::GetAllJobInfoRequest request; rpc::GetAllJobInfoReply reply; @@ -427,7 +460,7 @@ Status PythonGcsClient::RequestClusterResourceConstraint( const std::vector> &bundles, const std::vector &count_array) { grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); rpc::autoscaler::RequestClusterResourceConstraintRequest request; rpc::autoscaler::RequestClusterResourceConstraintReply reply; @@ -458,7 +491,7 @@ Status PythonGcsClient::GetClusterStatus(int64_t timeout_ms, rpc::autoscaler::GetClusterStatusRequest request; rpc::autoscaler::GetClusterStatusReply reply; grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); grpc::Status status = autoscaler_stub_->GetClusterStatus(&context, request, &reply); @@ -484,7 +517,7 @@ Status PythonGcsClient::DrainNode(const std::string &node_id, rpc::autoscaler::DrainNodeReply reply; grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + PrepareContext(context, timeout_ms); grpc::Status status = autoscaler_stub_->DrainNode(&context, request, &reply); @@ -516,7 +549,10 @@ Status PythonCheckGcsHealth(const std::string &gcs_address, auto channel = rpc::GcsRpcClient::CreateGcsChannel(gcs_address, gcs_port); auto stub = rpc::NodeInfoGcsService::NewStub(channel); grpc::ClientContext context; - GrpcClientContextWithTimeoutMs(context, timeout_ms); + if (timeout_ms != -1) { + context.set_deadline(std::chrono::system_clock::now() + + std::chrono::milliseconds(timeout_ms)); + } rpc::CheckAliveRequest request; rpc::CheckAliveReply reply; grpc::Status status = stub->CheckAlive(&context, request, &reply); diff --git a/src/ray/gcs/gcs_client/gcs_client.h b/src/ray/gcs/gcs_client/gcs_client.h index 318ef69d6763..5773a0229b0b 100644 --- a/src/ray/gcs/gcs_client/gcs_client.h +++ b/src/ray/gcs/gcs_client/gcs_client.h @@ -84,7 +84,8 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { /// \param instrumented_io_context IO execution service. /// /// \return Status - virtual Status Connect(instrumented_io_context &io_service); + virtual Status Connect(instrumented_io_context &io_service, + const ClusterID &cluster_id = ClusterID::Nil()); /// Disconnect with GCS Service. Non-thread safe. virtual void Disconnect(); @@ -191,7 +192,7 @@ class RAY_EXPORT GcsClient : public std::enable_shared_from_this { class RAY_EXPORT PythonGcsClient { public: explicit PythonGcsClient(const GcsClientOptions &options); - Status Connect(); + Status Connect(const ClusterID &cluster_id, int64_t timeout_ms, size_t num_retries); Status CheckAlive(const std::vector &raylet_addresses, int64_t timeout_ms, @@ -241,7 +242,20 @@ class RAY_EXPORT PythonGcsClient { int64_t timeout_ms, bool &is_accepted); + const ClusterID &GetClusterId() const { return cluster_id_; } + private: + void PrepareContext(grpc::ClientContext &context, int64_t timeout_ms) { + if (timeout_ms != -1) { + context.set_deadline(std::chrono::system_clock::now() + + std::chrono::milliseconds(timeout_ms)); + } + if (!cluster_id_.IsNil()) { + context.AddMetadata(kClusterIdKey, cluster_id_.Hex()); + } + } + + ClusterID cluster_id_; GcsClientOptions options_; std::unique_ptr kv_stub_; std::unique_ptr runtime_env_stub_; diff --git a/src/ray/object_manager/test/ownership_based_object_directory_test.cc b/src/ray/object_manager/test/ownership_based_object_directory_test.cc index 3b7624a604a3..a326b2c9eb65 100644 --- a/src/ray/object_manager/test/ownership_based_object_directory_test.cc +++ b/src/ray/object_manager/test/ownership_based_object_directory_test.cc @@ -101,7 +101,8 @@ class MockGcsClient : public gcs::GcsClient { return *node_accessor_; } - MOCK_METHOD1(Connect, Status(instrumented_io_context &io_service)); + MOCK_METHOD2(Connect, + Status(instrumented_io_context &io_service, const ClusterID &cluster_id)); MOCK_METHOD0(Disconnect, void()); }; diff --git a/src/ray/raylet/main.cc b/src/ray/raylet/main.cc index 65d5f672c5aa..bd74db30f2a3 100644 --- a/src/ray/raylet/main.cc +++ b/src/ray/raylet/main.cc @@ -69,6 +69,8 @@ DEFINE_int32(ray_debugger_external, 0, "Make Ray debugger externally accessible. DEFINE_int64(object_store_memory, -1, "The initial memory of the object store."); DEFINE_string(node_name, "", "The user-provided identifier or name for this node."); DEFINE_string(session_name, "", "Session name (ClusterID) of the cluster."); +DEFINE_string(cluster_id, "", "ID of the cluster, separate from observability."); + #ifdef __linux__ DEFINE_string(plasma_directory, "/dev/shm", @@ -153,6 +155,10 @@ int main(int argc, char *argv[]) { const std::string session_name = FLAGS_session_name; const bool is_head_node = FLAGS_head; const std::string labels_json_str = FLAGS_labels; + + RAY_CHECK_NE(FLAGS_cluster_id, "") << "Expected cluster ID."; + ray::ClusterID cluster_id = ray::ClusterID::FromHex(FLAGS_cluster_id); + RAY_LOG(INFO) << "Setting cluster ID to: " << cluster_id; gflags::ShutDownCommandLineFlags(); // Configuration for the node manager. @@ -171,7 +177,7 @@ int main(int argc, char *argv[]) { ray::gcs::GcsClientOptions client_options(FLAGS_gcs_address); gcs_client = std::make_shared(client_options); - RAY_CHECK_OK(gcs_client->Connect(main_service)); + RAY_CHECK_OK(gcs_client->Connect(main_service, cluster_id)); std::unique_ptr raylet; RAY_CHECK_OK(gcs_client->Nodes().AsyncGetInternalConfig( diff --git a/src/ray/rpc/client_call.h b/src/ray/rpc/client_call.h index 98c1e519d3c5..b0f998aa9e8d 100644 --- a/src/ray/rpc/client_call.h +++ b/src/ray/rpc/client_call.h @@ -194,7 +194,7 @@ class ClientCallManager { const ClusterID &cluster_id = ClusterID::Nil(), int num_threads = 1, int64_t call_timeout_ms = -1) - : cluster_id_(ClusterID::Nil()), + : cluster_id_(cluster_id), main_service_(main_service), num_threads_(num_threads), shutdown_(false), @@ -249,7 +249,7 @@ class ClientCallManager { } auto call = std::make_shared>( - callback, cluster_id_.load(), std::move(stats_handle), method_timeout_ms); + callback, cluster_id_, std::move(stats_handle), method_timeout_ms); // Send request. // Find the next completion queue to wait for response. call->response_reader_ = (stub.*prepare_async_function)( @@ -267,14 +267,6 @@ class ClientCallManager { return call; } - void SetClusterId(const ClusterID &cluster_id) { - auto old_id = cluster_id_.exchange(ClusterID::Nil()); - if (!old_id.IsNil() && (old_id != cluster_id)) { - RAY_LOG(FATAL) << "Expected cluster ID to be Nil or " << cluster_id << ", but got" - << old_id; - } - } - /// Get the main service of this rpc. instrumented_io_context &GetMainService() { return main_service_; } @@ -328,7 +320,7 @@ class ClientCallManager { /// UUID of the cluster. Potential race between creating a ClientCall object /// and setting the cluster ID. - SafeClusterID cluster_id_; + ClusterID cluster_id_; /// The main event loop, to which the callback functions will be posted. instrumented_io_context &main_service_; diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index 7c602cebda0a..94b0ca4d1b9e 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -117,7 +117,7 @@ class Executor { } \ delete executor; \ } else { \ - /* In case of GCS failure, we queue the request and these requets will be */ \ + /* In case of GCS failure, we queue the request and these requests will be */ \ /* executed once GCS is back. */ \ gcs_is_down_ = true; \ auto request_bytes = request.ByteSizeLong(); \