Skip to content

Commit

Permalink
Module distribute updates
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg committed Nov 6, 2024
1 parent 1ebaecb commit 9364e24
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
4 changes: 2 additions & 2 deletions runhouse/resources/distributed/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod

from runhouse.resources.module import Module, MODULE_ATTRS, MODULE_METHODS
from runhouse.resources.module import Module, MODULE_ATTRS, MODULE_METHODS, MODULE_METHODS_REMOTEABLE
from runhouse.utils import client_call_wrapper


Expand All @@ -20,7 +20,7 @@ def __getattribute__(self, item):
"""Override to allow for remote execution if system is a remote cluster. If not, the subclass's own
__getattr__ will be called."""
if (
item in MODULE_METHODS
(item in MODULE_METHODS and item not in MODULE_METHODS_REMOTEABLE)
or item in MODULE_ATTRS
or not hasattr(self, "_client")
):
Expand Down
66 changes: 64 additions & 2 deletions runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@
"_dumb_signature_cache",
]

# Module methods which should still run remotely when called on a remote module
MODULE_METHODS_REMOTEABLE = ["distribute"]

logger = get_logger(__name__)


Expand Down Expand Up @@ -603,7 +606,7 @@ def __getattribute__(self, item):
"""Override to allow for remote execution if system is a remote cluster. If not, the subclass's own
__getattr__ will be called."""
if (
item in MODULE_METHODS
(item in MODULE_METHODS and item not in MODULE_METHODS_REMOTEABLE)
or item in MODULE_ATTRS
or not hasattr(self, "_client")
):
Expand Down Expand Up @@ -702,6 +705,7 @@ def refresh(self):
def replicate(
self,
num_replicas: int = 1,
replicas_per_node: Optional[int] = None,
names: List[str] = None,
envs: List["Env"] = None,
parallel: bool = False,
Expand Down Expand Up @@ -747,11 +751,25 @@ def create_replica(i):
env = copy.copy(self.env)
env.name = f"{self.env.name}_replica_{i}"

# TODO remove
env.reqs = None

if replicas_per_node is not None:
if env.compute:
raise ValueError(
"Cannot specify replicas_per_node of other compute requirements for env "
"placement are specified."
)
env.compute = env.compute or {}
env.compute["node_idx"] = i // replicas_per_node

new_module = copy.copy(self)
new_module.name = name
new_module.env = None
new_module.system = None
new_module = new_module.to(self.system, env=env)

env.to(self.system)
new_module = new_module.to(self.system, env=env.name)
return new_module

if parallel:
Expand All @@ -762,6 +780,50 @@ def create_replica(i):

return [create_replica(i) for i in range(num_replicas)]

def distribute(
self,
distribution: str,
name: Optional[str] = None,
num_replicas: Optional[int] = 1,
replicas_per_node: Optional[int] = None,
replication_kwargs: Optional[dict] = {},
**distribution_kwargs,
):
"""Distribute the module on the cluster and return the distributed module.
Args:
distribution (str): The distribution method to use, e.g. "pool", "queue", "ray", "pytorch", or "tensorflow".
name (str, optional): The name to give to the distributed module, if applicable. Overwrites current module name by default. (Default: ``None``)
num_replicas (int, optional): The number of replicas to create. (Default: 1)
replicas_per_node (int, optional): The number of replicas to create per node. (Default: ``None``)
replication_kwargs: The keyword arguments to pass to the replicate method.
distribution_kwargs: The keyword arguments to pass to the distribution method.
"""
# TODO create the replicas remotely
if distribution == "pool":
if name:
raise ValueError("Cannot specify a name for a pool distribution.")
return [self] + self.replicate(
num_replicas=num_replicas,
replicas_per_node=replicas_per_node,
**replication_kwargs,
)
elif distribution == "queue":
from runhouse.resources.distributed.distributed_queue import (
DistributedQueue,
)

replicas = self.replicate(
num_replicas=num_replicas,
replicas_per_node=replicas_per_node,
**replication_kwargs,
)
name = name or f"distributed_{self.name}"
pooled_module = DistributedQueue(
**distribution_kwargs, name=name, replicas=replicas
).to(self.system, env=self.env.name)
return pooled_module

@property
def remote(self):
"""Helper property to allow for access to remote properties, both public and private. Returning functions
Expand Down

0 comments on commit 9364e24

Please sign in to comment.