Skip to content

Commit

Permalink
attempt at extracting pointers in .to
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jul 5, 2024
1 parent 4870134 commit 3677a0d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions runhouse/resources/functions/function_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def function(
env.reqs = [repo_package] + env.reqs

new_function = Function(fn_pointers=fn_pointers, name=name, dryrun=dryrun, env=env)
if callable(fn):
new_function._raw_cls = fn

if load_secrets and not dryrun:
new_function.send_secrets()
Expand Down
23 changes: 22 additions & 1 deletion runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
env: Optional[Env] = None,
dryrun: bool = False,
provenance: Optional[dict] = None,
raw_cls: Optional[Type] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(
self._dumb_signature_cache = None
self._resolve = False
self._openapi_spec = None
self._raw_cls = raw_cls

def config(self, condensed=True):
if not self.system:
Expand Down Expand Up @@ -446,7 +448,22 @@ def to(

env = self.env if not env else env

env = _get_env_from(env)
if not isinstance(env, Env):
env = _get_env_from(env)
if not env:
env = _get_env_from(_default_env_if_on_cluster())
if not env:
env = Env()

if self._raw_cls:
cls_pointers, working_dir_to_add = Module._extract_pointers(
self._raw_cls, env.reqs
)

if working_dir_to_add is not None:
env.reqs = [str(working_dir_to_add)] + env.reqs

self._pointers = cls_pointers

if system:
system.check_server()
Expand Down Expand Up @@ -486,6 +503,7 @@ def to(
"_env",
"_pointers",
"_resolve",
"_raw_cls",
]
state = {}
# We only send over state for instances, not classes
Expand Down Expand Up @@ -1157,6 +1175,7 @@ def __init__(
signature=None,
name=None,
provenance=None,
raw_cls=None,
**kwargs,
):
# args and kwargs are passed to the cls's __init__ method if this is being called on a cluster. They
Expand All @@ -1170,6 +1189,7 @@ def __init__(
env=env,
dryrun=dryrun,
provenance=provenance,
raw_cls=raw_cls,
)
# This allows a class which is already on the cluster to construct an instance of itself with a factory
# method, e.g. my_module = MyModuleCls.factory_constructor(*args, **kwargs)
Expand Down Expand Up @@ -1345,6 +1365,7 @@ class (e.g. ``to``, ``fetch``, etc.). Properties and private methods are not int
dryrun=dryrun,
pointers=cls_pointers,
name=name,
raw_cls=cls,
)


Expand Down

0 comments on commit 3677a0d

Please sign in to comment.