Skip to content

Commit

Permalink
Add utility to set_env_vars globally.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 committed Jan 29, 2025
1 parent 5cfd5b2 commit 937df4a
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 15 deletions.
10 changes: 7 additions & 3 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2728,6 +2728,12 @@ def ensure_process_created(
)
return name

def set_env_vars_globally(self, env_vars: Dict):
if self.on_this_cluster():
return obj_store.set_env_vars_globally(env_vars)
else:
return self.client.set_env_vars(env_vars)

def set_process_env_vars(self, name: str, env_vars: Dict):
"""Set the env vars for a process on the cluster.
Expand All @@ -2738,9 +2744,7 @@ def set_process_env_vars(self, name: str, env_vars: Dict):
if self.on_this_cluster():
return obj_store.set_process_env_vars(name, env_vars)
else:
return self.client.set_process_env_vars(
process_name=name, env_vars=env_vars
)
return self.client.set_env_vars(env_vars=env_vars, process_name=name)

def install_package(
self,
Expand Down
26 changes: 26 additions & 0 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ async def __init__(
self._key_to_servlet_name: Dict[Any, str] = {}
self._auth_cache: AuthCache = AuthCache()
self._paths_to_prepend_in_new_processes = []
self._env_vars_to_set_in_new_processes = {}
self._node_servlet_names: List[str] = []
self._cluster_name = self._cluster_config.get("name", None)
self._cluster_uri = (
Expand Down Expand Up @@ -133,6 +134,31 @@ async def aadd_path_to_prepend_in_new_processes(self, path: str):
async def aget_paths_to_prepend_in_new_processes(self) -> List[str]:
return self._paths_to_prepend_in_new_processes

##############################################
# Env vars to set and get env vars methods
##############################################
async def aset_env_vars_globally(self, env_vars: Dict[str, Any]):

await asyncio.gather(
*[
obj_store.acall_servlet_method(
servlet_name,
"aset_env_vars",
env_vars,
use_servlet_cache=False,
)
for servlet_name in await self.aget_all_initialized_servlet_args()
]
)

await self.aadd_env_vars_to_set_in_new_processes(env_vars)

async def aadd_env_vars_to_set_in_new_processes(self, env_vars: Dict[str, Any]):
self._env_vars_to_set_in_new_processes.update(env_vars)

async def aget_env_vars_to_set_in_new_processes(self) -> Dict[str, str]:
return self._env_vars_to_set_in_new_processes

##############################################
# Cluster config state storage methods
##############################################
Expand Down
10 changes: 5 additions & 5 deletions runhouse/servers/http/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
RenameObjectParams,
RunBashParams,
serialize_data,
SetProcessEnvVarsParams,
SetEnvVarsParams,
)

from runhouse.utils import ClusterLogsFormatter, generate_default_name, thread_coroutine
Expand Down Expand Up @@ -825,15 +825,15 @@ def create_process(
json_dict=params.model_dump(),
)

def set_process_env_vars(
def set_env_vars(
self,
process_name: str,
env_vars: Dict[str, str],
process_name: Optional[str] = None,
):
return self.request_json(
"/process_env_vars",
"/env_vars",
req_type="post",
json_dict=SetProcessEnvVarsParams(
json_dict=SetEnvVarsParams(
process_name=process_name, env_vars=env_vars
).model_dump(),
)
Expand Down
14 changes: 10 additions & 4 deletions runhouse/servers/http/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
RunBashParams,
serialize_data,
ServerSettings,
SetProcessEnvVarsParams,
SetEnvVarsParams,
)
from runhouse.servers.obj_store import (
ClusterServletSetupOption,
Expand Down Expand Up @@ -454,11 +454,17 @@ async def kill_process(request: Request, params: KillProcessParams):
)

@staticmethod
@app.post("/process_env_vars")
@app.post("/env_vars")
@validate_cluster_access
async def set_process_env_vars(request: Request, params: SetProcessEnvVarsParams):
async def set_env_vars(request: Request, params: SetEnvVarsParams):
try:
await obj_store.aset_process_env_vars(params.process_name, params.env_vars)
if params.process_name is not None:
await obj_store.aset_process_env_vars(
params.process_name, params.env_vars
)
else:
await obj_store.aset_env_vars_globally(params.env_vars)

return Response(output_type=OutputType.SUCCESS)
except Exception as e:
return handle_exception_response(
Expand Down
4 changes: 2 additions & 2 deletions runhouse/servers/http/http_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class CreateProcessParams(BaseModel):
env_vars: Optional[Dict] = {}


class SetProcessEnvVarsParams(BaseModel):
process_name: str
class SetEnvVarsParams(BaseModel):
process_name: Optional[str] = None
env_vars: Dict[str, str]


Expand Down
17 changes: 16 additions & 1 deletion runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,12 @@ async def ainitialize(
if path not in sys.path:
sys.path.insert(0, path)

# Set env vars that were passed in initialization
# Set env vars that need to be set on creation
env_vars_to_set = await self.aget_env_vars_to_set_in_new_processes()
self.set_process_env_vars_local(env_vars_to_set)

# Set env vars that were passed in initialization, these should override
# any env vars that were set globally via the previous global set
if create_process_params and create_process_params.env_vars:
self.set_process_env_vars_local(create_process_params.env_vars)

Expand Down Expand Up @@ -662,6 +667,11 @@ async def aset_process_env_vars(self, servlet_name: str, env_vars: Dict[str, str
def set_process_env_vars(self, servlet_name: str, env_vars: Dict[str, str]):
return sync_function(self.aset_process_env_vars)(servlet_name, env_vars)

async def aset_env_vars_globally(self, env_vars: Dict[str, str]):
return await self.acall_actor_method(
self.cluster_servlet, "aset_env_vars_globally", env_vars
)

##############################################
# Cluster config state storage methods
##############################################
Expand Down Expand Up @@ -825,6 +835,11 @@ async def aadd_path_to_prepend_in_new_processes(self, path: str):
self.cluster_servlet, "aadd_path_to_prepend_in_new_processes", path
)

async def aget_env_vars_to_set_in_new_processes(self) -> Dict[str, str]:
return await self.acall_actor_method(
self.cluster_servlet, "aget_env_vars_to_set_in_new_processes"
)

##############################################
# Remove Servlet
##############################################
Expand Down

0 comments on commit 937df4a

Please sign in to comment.