Skip to content

Commit

Permalink
Ray Distributor
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg committed Nov 13, 2024
1 parent 7688b08 commit 9257b5c
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 0 deletions.
54 changes: 54 additions & 0 deletions runhouse/resources/distributed/ray_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import multiprocessing
import sys

from runhouse.resources.distributed.supervisor import Supervisor

from runhouse.resources.module import Module


class RayDistributed(Supervisor):
def __init__(self, name, module: Module = None, ray_init_options=None, **kwargs):
super().__init__(name=name, **kwargs)
self._module = module
self._ray_init_options = ray_init_options or {}

def _compute_signature(self, rich=False):
return self.local._module.signature(rich=rich)

def forward(self, item, *args, **kwargs):
from runhouse.resources.distributed.utils import subprocess_ray_fn_call_helper

# TODO replace this with passing the filepath that this module is already writing to!
parent_conn, child_conn = multiprocessing.Pipe()

subproc_args = (
self._module.fn_pointers,
args,
kwargs,
child_conn,
self._ray_init_options,
)

# Check if start method is already spawn, because set_start_method will error if called again
if multiprocessing.get_start_method(allow_none=True) != "spawn":
multiprocessing.set_start_method("spawn")

with multiprocessing.Pool(processes=1) as pool:
result = pool.apply_async(subprocess_ray_fn_call_helper, args=subproc_args)
while True:
try:
(msg, output_stream) = parent_conn.recv()
if msg == EOFError:
break
print(
msg,
end="",
file=sys.stdout if output_stream == "stdout" else sys.stderr,
)
except EOFError:
break
res = result.get()
return res

def __call__(self, *args, **kwargs):
return self.call(*args, **kwargs)
46 changes: 46 additions & 0 deletions runhouse/resources/distributed/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import os
import sys
from pathlib import Path


def subprocess_ray_fn_call_helper(pointers, args, kwargs, conn, ray_opts={}):
def write_stdout(msg):
conn.send((msg, "stdout"))

def write_stderr(msg):
conn.send((msg, "stderr"))

sys.stdout.write = write_stdout
sys.stderr.write = write_stderr
abs_module_path = str(Path(pointers[0]).expanduser().resolve())

ray_opts["runtime_env"] = ray_opts.get("runtime_env", {})
ray_opts["runtime_env"]["env_vars"] = ray_opts["runtime_env"].get("env_vars", {})
ray_opts["runtime_env"]["env_vars"]["RH_LOG_LEVEL"] = os.environ.get(
"RH_LOG_LEVEL", "INFO"
)
if "PYTHONPATH" in ray_opts["runtime_env"]["env_vars"]:
pp = ray_opts["runtime_env"]["env_vars"]["PYTHONPATH"]
ray_opts["runtime_env"]["env_vars"]["PYTHONPATH"] = f"{abs_module_path}:{pp}"
else:
ray_opts["runtime_env"]["env_vars"]["PYTHONPATH"] = abs_module_path

import ray

ray.init(address="auto", **ray_opts)

from runhouse.resources.module import Module

(module_path, module_name, class_name) = pointers
orig_fn = Module._get_obj_from_pointers(
module_path, module_name, class_name, reload=False
)
res = orig_fn(*args, **kwargs)

ray.shutdown()

# Send an EOFError over the pipe because for some reason .close is hanging
conn.send((EOFError, None))
conn.close()

return res
8 changes: 8 additions & 0 deletions runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,14 @@ def distribute(
**distribution_kwargs, name=name, replicas=replicas
).to(self.system, env=self.env.name)
return pooled_module
elif distribution == "ray":
from runhouse.resources.distributed.ray_distributed import RayDistributed

name = name or f"ray_{self.name}"
ray_module = RayDistributed(
**distribution_kwargs, name=name, module=self
).to(self.system, self.env.name)
return ray_module

@property
def remote(self):
Expand Down

0 comments on commit 9257b5c

Please sign in to comment.