Skip to content

Commit

Permalink
Enable multi-node logging
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg authored and carolineechen committed Nov 8, 2024
1 parent dfda760 commit 7ec911c
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 112 deletions.
5 changes: 4 additions & 1 deletion runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def call_module_method(
logs_future = executor.submit(
thread_coroutine,
self._alogs_request(
key=key,
run_name=run_name,
serialization=serialization,
error_str=error_str,
Expand Down Expand Up @@ -565,6 +566,7 @@ async def _acall_request(

async def _alogs_request(
self,
key: str,
run_name: str,
serialization: str,
error_str: str,
Expand All @@ -579,7 +581,7 @@ async def _alogs_request(

async with client.stream(
"GET",
self._formatted_url(f"logs/{run_name}/{serialization}"),
self._formatted_url(f"logs/{key}/{run_name}/{serialization}"),
headers=self._request_headers,
) as res:
if res.status_code != 200:
Expand Down Expand Up @@ -649,6 +651,7 @@ async def acall_module_method(
)
alogs_request = asyncio.create_task(
self._alogs_request(
key=key,
run_name=run_name,
serialization=serialization,
error_str=error_str,
Expand Down
95 changes: 13 additions & 82 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
DEFAULT_SERVER_HOST,
DEFAULT_SERVER_PORT,
EMPTY_DEFAULT_ENV_NAME,
LOGGING_WAIT_TIME,
LOGS_TO_SHOW_UP_CHECK_TIME,
MAX_LOGS_TO_SHOW_UP_WAIT_TIME,
RH_LOGFILE_PATH,
)
from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
Expand Down Expand Up @@ -770,93 +766,33 @@ async def get_call(

# `/logs` POST endpoint that takes in request and LogParams
@staticmethod
@app.get("/logs/{run_name}/{serialization}")
@app.get("/logs/{key}/{run_name}/{serialization}")
@validate_cluster_access
async def get_logs(
request: Request,
key: str,
run_name: str,
serialization: str,
):
# This call could've been made fast enough that the future hasn't been stored yet
sleeps = 0
while run_name not in running_futures:
if sleeps * LOGS_TO_SHOW_UP_CHECK_TIME >= MAX_LOGS_TO_SHOW_UP_WAIT_TIME:
raise HTTPException(
status_code=404,
detail=f"Logs for call {run_name} not found.",
)
await asyncio.sleep(LOGS_TO_SHOW_UP_CHECK_TIME)

return StreamingResponse(
HTTPServer._get_results_and_logs_generator(
running_futures[run_name], run_name, serialization
),
HTTPServer._get_logs_generator(key, run_name, serialization),
media_type="application/json",
)

@staticmethod
def _get_logfiles(log_key, log_type=None):
if not log_key:
return None
key_logs_path = Path(RH_LOGFILE_PATH) / log_key
if key_logs_path.exists():
# Logs are like: `.rh/logs/key/key.[out|err]`
glob_pattern = (
"*.out"
if log_type == "stdout"
else "*.err"
if log_type == "stderr"
else "*.[oe][ur][tr]"
)
return [str(f.absolute()) for f in key_logs_path.glob(glob_pattern)]
else:
return None

@staticmethod
def open_new_logfiles(key, open_files):
logfiles = HTTPServer._get_logfiles(key)
if logfiles:
for f in logfiles:
if f not in [o.name for o in open_files]:
logger.info(f"Streaming logs from {f}")
open_files.append(open(f, "r"))
return open_files

@staticmethod
async def _get_results_and_logs_generator(fut, run_name, serialization=None):
async def _get_logs_generator(key, run_name, serialization=None):
logger.debug(f"Streaming logs for key {run_name}")
open_logfiles = []
waiting_for_results = True

try:
while waiting_for_results:
if fut.done():
waiting_for_results = False
del running_futures[run_name]
else:
await asyncio.sleep(LOGGING_WAIT_TIME)
# Grab all the lines written to all the log files since the last time we checked, including
# any new log files that have been created
open_logfiles = HTTPServer.open_new_logfiles(run_name, open_logfiles)
ret_lines = []
for i, f in enumerate(open_logfiles):
file_lines = f.readlines()
if file_lines:
# TODO [DG] handle .out vs .err, and multiple workers
# if len(logfiles) > 1:
# ret_lines.append(f"Process {i}:")
ret_lines += file_lines
if ret_lines:
logger.debug(f"Yielding logs for key {run_name}")
yield json.dumps(
jsonable_encoder(
Response(
data=ret_lines,
output_type=OutputType.STDOUT,
)
async for log_lines in obj_store.alogs_for_run_name(run_name, key=key):
logger.debug(f"Yielding logs for key {run_name}")
yield json.dumps(
jsonable_encoder(
Response(
data=log_lines,
output_type=OutputType.STDOUT,
)
) + "\n"

)
) + "\n"
except Exception as e:
logger.exception(e)
# NOTE: We do not convert the exception to an HTTPException here, because once we're inside this
Expand All @@ -878,11 +814,6 @@ async def _get_results_and_logs_generator(fut, run_name, serialization=None):
)
)
)
finally:
if not open_logfiles:
logger.warning(f"No logfiles found for call {run_name}")
for f in open_logfiles:
f.close()

################################################################################################
# Cluster status and metadata methods
Expand Down
52 changes: 29 additions & 23 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import ray
from pydantic import BaseModel

from runhouse.constants import LOGGING_WAIT_TIME, RH_LOGFILE_PATH
from runhouse.constants import (
LOGGING_WAIT_TIME,
LOGS_TO_SHOW_UP_CHECK_TIME,
MAX_LOGS_TO_SHOW_UP_WAIT_TIME,
RH_LOGFILE_PATH,
)
from runhouse.logger import get_logger

from runhouse.rns.defaults import req_ctx
Expand Down Expand Up @@ -1437,9 +1442,15 @@ async def acall(
# with the run_name and print them while the call runs
logs_task = None
if stream_logs:
logs_task = asyncio.create_task(
self.alogs_for_servlet_name(servlet_name_containing_key, run_name)
)

async def print_logs():
async for logs in self.alogs_for_run_name(
run_name, servlet_name=servlet_name_containing_key
):
for log in logs:
print(f"({key}): {log}", end="")

logs_task = asyncio.create_task(print_logs())

res = await self.acall_for_servlet_name(
servlet_name_containing_key,
Expand Down Expand Up @@ -1496,30 +1507,20 @@ def call(
remote,
)

async def alogs_for_servlet_name(
async def alogs_for_run_name(
self,
servlet_name: str,
run_name: Optional[str] = None,
print_stream: bool = True,
key: str = None,
servlet_name: str = None,
):
if not servlet_name:
servlet_name = await self.aget_servlet_name_for_key(key)
servlet = self.get_servlet(servlet_name)
# If stream_logs is True, print the logs as they come in. Otherwise just concatenate them together
# and return them as a long string.
printed_first_log = False
servlet = self.get_servlet(servlet_name)
full_logs = ""
async for log_ref in servlet.alogs_local.remote(run_name=run_name):
logs = await log_ref
for log in logs:
if print_stream:
if not printed_first_log:
print(f"---------------- Call {run_name} ----------------")
printed_first_log = True
print(log, end="")
full_logs += log

if print_stream and printed_first_log:
print(f"---------------- End Call {run_name} ------------")
return full_logs
yield logs

@staticmethod
def _get_logfiles(log_key, log_type=None):
Expand Down Expand Up @@ -1554,15 +1555,20 @@ async def alogs_local(self, run_name: Optional[str] = None):
open_logfiles = []

# Wait for a maximum of 5 seconds for the log files to be created
for _ in range(20):
sleeps = 0
while (sleeps * LOGS_TO_SHOW_UP_CHECK_TIME) <= MAX_LOGS_TO_SHOW_UP_WAIT_TIME:
open_logfiles = ObjStore.open_new_logfiles(run_name, open_logfiles)
if open_logfiles:
break
else:
await asyncio.sleep(LOGGING_WAIT_TIME)
await asyncio.sleep(LOGS_TO_SHOW_UP_CHECK_TIME)
sleeps += 1

if not open_logfiles:
logger.warning(f"No logfiles found for call {run_name}")
# raise ObjStoreError(
# f"Logs for call {run_name} not found."
# )

call_in_progress = True

Expand Down
21 changes: 19 additions & 2 deletions tests/test_resources/test_clusters/test_multinode_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import runhouse as rh
from runhouse.utils import capture_stdout

from tests.utils import get_pid_and_ray_node

Expand Down Expand Up @@ -96,8 +97,24 @@ def test_send_envs_to_specific_worker_node(self, cluster):
get_pid_2 = rh.function(get_pid_and_ray_node).to(
name="get_pid_2", system=cluster, env=env_2
)
assert get_pid_0()[1] != get_pid_1()[1]
assert get_pid_1()[1] == get_pid_2()[1]

with capture_stdout() as stdout_0:
pid_0, node_id_0 = get_pid_0()
assert str(pid_0) in str(stdout_0)
assert str(node_id_0) in str(stdout_0)

with capture_stdout() as stdout_1:
pid_1, node_id_1 = get_pid_1()
assert str(pid_1) in str(stdout_1)
assert str(node_id_1) in str(stdout_1)

with capture_stdout() as stdout_2:
pid_2, node_id_2 = get_pid_2()
assert str(pid_2) in str(stdout_2)
assert str(node_id_2) in str(stdout_2)

assert node_id_0 != node_id_1
assert node_id_1 == node_id_2

@pytest.mark.level("release")
def test_specifying_resources(self, cluster):
Expand Down
13 changes: 9 additions & 4 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ def get_ray_cluster_servlet(cluster_config=None):


def get_pid_and_ray_node(a=0):
import logging

import ray

return (
os.getpid(),
ray.runtime_context.RuntimeContext(ray.worker.global_worker).get_node_id(),
)
pid = os.getpid()
node_id = ray.runtime_context.RuntimeContext(ray.worker.global_worker).get_node_id()

print(f"PID: {pid}")
logging.info(f"Node ID: {node_id}")

return pid, node_id


def get_random_str(length: int = 8):
Expand Down

0 comments on commit 7ec911c

Please sign in to comment.