Skip to content

Commit

Permalink
attempt at extracting pointers in .to
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jul 3, 2024
1 parent abdb7be commit 3f80796
Showing 1 changed file with 28 additions and 20 deletions.
48 changes: 28 additions & 20 deletions runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,20 @@ def __init__(
self._system = _get_cluster_from(
system or _current_cluster(key="config"), dryrun=dryrun
)
self._env = env
# self._env = env
is_builtin = hasattr(sys.modules["runhouse"], self.__class__.__qualname__)
if not pointers and not is_builtin:
# If there are no pointers and this isn't a builtin module, we assume this is a user-created subclass
# of rh.Module, and we need to do the factory constructor logic here.

# When creating a module as a subclass of rh.Module, we need to collect pointers here
if not self._env:
self._env = self._system.default_env if self._system else Env()
# if not self._env:
# self._env = self._system.default_env if self._system else Env()
# If we're creating pointers, we're also local to the class definition and package, so it should be
# set as the workdir (we can do this in a fancier way later)
pointers, req_to_add = Module._extract_pointers(
self.__class__, reqs=self._env.reqs
)
if req_to_add:
self._env.reqs = [req_to_add] + self._env.reqs
pointers, _ = Module._extract_pointers(self.__class__, reqs=[])
# if req_to_add:
# self._env.reqs = [req_to_add] + self._env.reqs
self._pointers = pointers
self._endpoint = endpoint
self._signature = signature
Expand Down Expand Up @@ -444,9 +442,21 @@ def to(
_get_cluster_from(system, dryrun=self.dryrun) if system else self.system
)

env = self.env if not env else env
if not isinstance(env, Env):
env = _get_env_from(env)
if not env:
env = _get_env_from(_default_env_if_on_cluster())
if not env:
env = Env()

cls_pointers, working_dir_to_add = Module._extract_pointers(
self.__class__, env.reqs
)

if working_dir_to_add is not None:
env.reqs = [str(working_dir_to_add)] + env.reqs

env = _get_env_from(env)
self._pointers = cls_pointers

if system:
system.check_server()
Expand Down Expand Up @@ -1217,7 +1227,8 @@ def __call__(
"__call__": __call__,
"_module_init_only": _module_init_only,
}
new_type = type(cls_pointers[2], (Module, cls), methods)
cls_name = getattr(cls, "__qualname__", cls.__name__)
new_type = type(cls_name, (Module, cls), methods)
return new_type


Expand Down Expand Up @@ -1324,16 +1335,13 @@ class (e.g. ``to``, ``fetch``, etc.). Properties and private methods are not int
"Use `.to(system)` or `.get_or_to(system)` after construction to send and run the Module on the system."
)

if not isinstance(env, Env):
env = _get_env_from(env)
if not env:
env = _get_env_from(_default_env_if_on_cluster())
if not env:
env = Env()
if env:
raise Exception(
"`env` argument is no longer supported in the module factory function. "
"Use `.to(env)` or `.get_or_to(env)` after construction to send and run the Module in the environment."
)

cls_pointers, working_dir_to_add = Module._extract_pointers(cls, env.reqs)
if working_dir_to_add is not None:
env.reqs = [str(working_dir_to_add)] + env.reqs
cls_pointers, _ = Module._extract_pointers(cls, [])

name = name or (
cls_pointers[2] if cls_pointers else _generate_default_name(prefix="module")
Expand Down

0 comments on commit 3f80796

Please sign in to comment.