Skip to content

Commit

Permalink
[core] Add cluster ID to the Python layer (ray-project#37583)
Browse files Browse the repository at this point in the history
An earlier[ change](ray-project#37399) added a new RPC to each GCS client, causing some stress tests to fail. Instead of adding O(num_worker) RPCs, add O(num_node) RPCs by plumbing through the Python layer.

Signed-off-by: Shreyas Krishnaswamy <shrekris@anyscale.com>
  • Loading branch information
vitsai authored and shrekris-anyscale committed Aug 10, 2023
1 parent 69afa51 commit 548b136
Show file tree
Hide file tree
Showing 33 changed files with 297 additions and 129 deletions.
1 change: 1 addition & 0 deletions dashboard/modules/job/tests/test_job_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def test_agent_logs_not_streamed_to_drivers():
err_str = proc.stderr.read().decode("ascii")

print(out_str, err_str)

assert "(raylet)" not in out_str
assert "(raylet)" not in err_str

Expand Down
2 changes: 2 additions & 0 deletions python/ray/_private/gcs_aio_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def __init__(
self._nums_reconnect_retry = nums_reconnect_retry

def _connect(self):
print("vct connecting")
self._gcs_client._connect()
print("vct connected")

@property
def address(self):
Expand Down
40 changes: 26 additions & 14 deletions python/ray/_private/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,11 @@ def _init_gcs_client(self):
last_ex = None
try:
gcs_address = self.gcs_address
client = GcsClient(address=gcs_address)
client = GcsClient(
address=gcs_address,
cluster_id=self._ray_params.cluster_id,
)
self.cluster_id = client.get_cluster_id()
if self.head:
# Send a simple request to make sure GCS is alive
# if it's a head node.
Expand All @@ -664,19 +668,26 @@ def _init_gcs_client(self):
time.sleep(1)

if self._gcs_client is None:
with open(os.path.join(self._logs_dir, "gcs_server.err")) as err:
# Use " C " or " E " to exclude the stacktrace.
# This should work for most cases, especitally
# it's when GCS is starting. Only display last 10 lines of logs.
errors = [e for e in err.readlines() if " C " in e or " E " in e][-10:]
error_msg = "\n" + "".join(errors) + "\n"
raise RuntimeError(
f"Failed to {'start' if self.head else 'connect to'} GCS. "
f" Last {len(errors)} lines of error files:"
f"{error_msg}."
f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}"
" for details"
)
if hasattr(self, "_logs_dir"):
with open(os.path.join(self._logs_dir, "gcs_server.err")) as err:
# Use " C " or " E " to exclude the stacktrace.
# This should work for most cases, especitally
# it's when GCS is starting. Only display last 10 lines of logs.
errors = [e for e in err.readlines() if " C " in e or " E " in e][
-10:
]
error_msg = "\n" + "".join(errors) + "\n"
raise RuntimeError(
f"Failed to {'start' if self.head else 'connect to'} GCS. "
f" Last {len(errors)} lines of error files:"
f"{error_msg}."
f"Please check {os.path.join(self._logs_dir, 'gcs_server.out')}"
" for details"
)
else:
raise RuntimeError(
f"Failed to {'start' if self.head else 'connect to'} GCS."
)

ray.experimental.internal_kv._initialize_internal_kv(self._gcs_client)

Expand Down Expand Up @@ -1064,6 +1075,7 @@ def start_raylet(
self._ray_params.node_manager_port,
self._raylet_socket_name,
self._plasma_store_socket_name,
self.cluster_id,
self._ray_params.worker_path,
self._ray_params.setup_worker_path,
self._ray_params.storage,
Expand Down
3 changes: 3 additions & 0 deletions python/ray/_private/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ class RayParams:
env_vars: Override environment variables for the raylet.
session_name: The name of the session of the ray cluster.
webui: The url of the UI.
cluster_id: The cluster ID.
"""

def __init__(
Expand Down Expand Up @@ -188,6 +189,7 @@ def __init__(
env_vars: Optional[Dict[str, str]] = None,
session_name: Optional[str] = None,
webui: Optional[str] = None,
cluster_id: Optional[str] = None,
):
self.redis_address = redis_address
self.gcs_address = gcs_address
Expand Down Expand Up @@ -249,6 +251,7 @@ def __init__(
self._enable_object_reconstruction = enable_object_reconstruction
self.labels = labels
self._check_usage()
self.cluster_id = cluster_id

# Set the internal config options for object reconstruction.
if enable_object_reconstruction:
Expand Down
3 changes: 3 additions & 0 deletions python/ray/_private/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,7 @@ def start_raylet(
node_manager_port: int,
raylet_name: str,
plasma_store_name: str,
cluster_id: str,
worker_path: str,
setup_worker_path: str,
storage: str,
Expand Down Expand Up @@ -1538,6 +1539,7 @@ def start_raylet(
f"--session-name={session_name}",
f"--temp-dir={temp_dir}",
f"--webui={webui}",
f"--cluster-id={cluster_id}",
]
)

Expand Down Expand Up @@ -1643,6 +1645,7 @@ def start_raylet(
f"--gcs-address={gcs_address}",
f"--session-name={session_name}",
f"--labels={labels_json_str}",
f"--cluster-id={cluster_id}",
]

if is_head_node:
Expand Down
5 changes: 3 additions & 2 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,7 +1581,7 @@ def init(
spawn_reaper=False,
connect_only=True,
)
except ConnectionError:
except (ConnectionError, RuntimeError):
if gcs_address == ray._private.utils.read_ray_address(_temp_dir):
logger.info(
"Failed to connect to the default Ray cluster address at "
Expand All @@ -1590,7 +1590,7 @@ def init(
"address to connect to, run `ray stop` or restart Ray with "
"`ray start`."
)
raise
raise ConnectionError

# Log a message to find the Ray address that we connected to and the
# dashboard URL.
Expand Down Expand Up @@ -2262,6 +2262,7 @@ def connect(
runtime_env_hash,
startup_token,
session_name,
node.cluster_id,
"" if mode != SCRIPT_MODE else entrypoint,
worker_launch_time_ms,
worker_launched_time_ms,
Expand Down
7 changes: 7 additions & 0 deletions python/ray/_private/workers/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
parser = argparse.ArgumentParser(
description=("Parse addresses for the worker to connect to.")
)
parser.add_argument(
"--cluster-id",
required=True,
type=str,
help="the auto-generated ID of the cluster",
)
parser.add_argument(
"--node-ip-address",
required=True,
Expand Down Expand Up @@ -207,6 +213,7 @@
gcs_address=args.gcs_address,
session_name=args.session_name,
webui=args.webui,
cluster_id=args.cluster_id,
)
node = ray._private.node.Node(
ray_params,
Expand Down
1 change: 0 additions & 1 deletion python/ray/_raylet.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ cdef class ObjectRef(BaseID):

cdef CObjectID native(self)


cdef class ActorID(BaseID):
cdef CActorID data

Expand Down
33 changes: 27 additions & 6 deletions python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ from ray.includes.common cimport (
)
from ray.includes.unique_ids cimport (
CActorID,
CObjectID,
CClusterID,
CNodeID,
CObjectID,
CPlacementGroupID,
ObjectIDIndexType,
)
Expand Down Expand Up @@ -2335,16 +2336,35 @@ cdef class GcsClient:
shared_ptr[CPythonGcsClient] inner
object address
object _nums_reconnect_retry
CClusterID cluster_id

def __cinit__(self, address, nums_reconnect_retry=5):
def __cinit__(self, address, nums_reconnect_retry=5, cluster_id=None):
cdef GcsClientOptions gcs_options = GcsClientOptions.from_gcs_address(address)
self.inner.reset(new CPythonGcsClient(dereference(gcs_options.native())))
self.address = address
self._nums_reconnect_retry = nums_reconnect_retry
self._connect()
cdef c_string c_cluster_id
if cluster_id is None:
self.cluster_id = CClusterID.Nil()
else:
c_cluster_id = cluster_id
self.cluster_id = CClusterID.FromHex(c_cluster_id)
self._connect(5)

def _connect(self):
check_status(self.inner.get().Connect())
def _connect(self, timeout_s=None):
cdef:
int64_t timeout_ms = round(1000 * timeout_s) if timeout_s else -1
size_t num_retries = self._nums_reconnect_retry
with nogil:
status = self.inner.get().Connect(self.cluster_id, timeout_ms, num_retries)

check_status(status)
if self.cluster_id.IsNil():
self.cluster_id = self.inner.get().GetClusterId()
assert not self.cluster_id.IsNil()

def get_cluster_id(self):
return self.cluster_id.Hex().decode()

@property
def address(self):
Expand Down Expand Up @@ -2844,7 +2864,7 @@ cdef class CoreWorker:
node_ip_address, node_manager_port, raylet_ip_address,
local_mode, driver_name, stdout_file, stderr_file,
serialized_job_config, metrics_agent_port, runtime_env_hash,
startup_token, session_name, entrypoint,
startup_token, session_name, cluster_id, entrypoint,
worker_launch_time_ms, worker_launched_time_ms):
self.is_local_mode = local_mode

Expand Down Expand Up @@ -2896,6 +2916,7 @@ cdef class CoreWorker:
options.runtime_env_hash = runtime_env_hash
options.startup_token = startup_token
options.session_name = session_name
options.cluster_id = CClusterID.FromHex(cluster_id)
options.entrypoint = entrypoint
options.worker_launch_time_ms = worker_launch_time_ms
options.worker_launched_time_ms = worker_launched_time_ms
Expand Down
1 change: 1 addition & 0 deletions python/ray/autoscaler/_private/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
gcs_channel
)
worker = ray._private.worker.global_worker
# TODO: eventually plumb ClusterID through to here
gcs_client = GcsClient(address=self.gcs_address)

if monitor_ip:
Expand Down
2 changes: 1 addition & 1 deletion python/ray/autoscaler/v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ py_test(

py_test(
name = "test_sdk",
size = "small",
size = "medium",
srcs = ["tests/test_sdk.py"],
tags = ["team:core", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
Expand Down
9 changes: 6 additions & 3 deletions python/ray/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from ray.includes.optional cimport (
from ray.includes.unique_ids cimport (
CActorID,
CJobID,
CClusterID,
CWorkerID,
CObjectID,
CTaskID,
Expand Down Expand Up @@ -367,8 +368,10 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
cdef cppclass CPythonGcsClient "ray::gcs::PythonGcsClient":
CPythonGcsClient(const CGcsClientOptions &options)

CRayStatus Connect()

CRayStatus Connect(
const CClusterID &cluster_id,
int64_t timeout_ms,
size_t num_retries)
CRayStatus CheckAlive(
const c_vector[c_string] &raylet_addresses,
int64_t timeout_ms,
Expand Down Expand Up @@ -405,14 +408,14 @@ cdef extern from "ray/gcs/gcs_client/gcs_client.h" nogil:
CRayStatus GetClusterStatus(
int64_t timeout_ms,
c_string &serialized_reply)
CClusterID GetClusterId()
CRayStatus DrainNode(
const c_string &node_id,
int32_t reason,
const c_string &reason_message,
int64_t timeout_ms,
c_bool &is_accepted)


cdef extern from "ray/gcs/gcs_client/gcs_client.h" namespace "ray::gcs" nogil:
unordered_map[c_string, double] PythonGetResourcesTotal(
const CGcsNodeInfo& node_info)
Expand Down
2 changes: 2 additions & 0 deletions python/ray/includes/libcoreworker.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ from libcpp.vector cimport vector as c_vector

from ray.includes.unique_ids cimport (
CActorID,
CClusterID,
CNodeID,
CJobID,
CTaskID,
Expand Down Expand Up @@ -359,6 +360,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool connect_on_start
int runtime_env_hash
int startup_token
CClusterID cluster_id
c_string session_name
c_string entrypoint
int64_t worker_launch_time_ms
Expand Down
14 changes: 14 additions & 0 deletions python/ray/includes/unique_ids.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
@staticmethod
T FromBinary(const c_string &binary)

@staticmethod
T FromHex(const c_string &hex)

@staticmethod
const T Nil()

Expand Down Expand Up @@ -154,6 +157,17 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:

CTaskID TaskId() const

cdef cppclass CClusterID "ray::ClusterID"(CUniqueID):

@staticmethod
CClusterID FromHex(const c_string &hex_str)

@staticmethod
CClusterID FromRandom()

@staticmethod
const CClusterID Nil()

cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID):

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion python/ray/serve/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ py_test(

py_test(
name = "test_callback",
size = "small",
size = "medium",
srcs = serve_tests_srcs,
tags = ["exclusive", "team:serve"],
deps = [":serve_lib"],
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ py_test_module_list(
"test_unhandled_error.py",
"test_utils.py",
"test_widgets.py",
"test_node_labels.py",
],
size = "small",
tags = ["exclusive", "small_size_python_tests", "team:core"],
Expand All @@ -222,6 +221,7 @@ py_test_module_list(
files = [
"test_gcs_ha_e2e.py",
"test_memory_pressure.py",
"test_node_labels.py",
],
size = "medium",
tags = ["exclusive", "team:core", "xcommit"],
Expand Down
Loading

0 comments on commit 548b136

Please sign in to comment.