diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 6b77166a5..dc3a03e4e 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -2728,6 +2728,12 @@ def ensure_process_created( ) return name + def set_env_vars_globally(self, env_vars: Dict): + if self.on_this_cluster(): + return obj_store.set_env_vars_globally(env_vars) + else: + return self.client.set_env_vars(env_vars) + def set_process_env_vars(self, name: str, env_vars: Dict): """Set the env vars for a process on the cluster. @@ -2738,9 +2744,7 @@ def set_process_env_vars(self, name: str, env_vars: Dict): if self.on_this_cluster(): return obj_store.set_process_env_vars(name, env_vars) else: - return self.client.set_process_env_vars( - process_name=name, env_vars=env_vars - ) + return self.client.set_env_vars(env_vars=env_vars, process_name=name) def install_package( self, diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 4084d1f99..a5dacfa8a 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -67,6 +67,7 @@ async def __init__( self._key_to_servlet_name: Dict[Any, str] = {} self._auth_cache: AuthCache = AuthCache() self._paths_to_prepend_in_new_processes = [] + self._env_vars_to_set_in_new_processes = {} self._node_servlet_names: List[str] = [] self._cluster_name = self._cluster_config.get("name", None) self._cluster_uri = ( @@ -133,6 +134,31 @@ async def aadd_path_to_prepend_in_new_processes(self, path: str): async def aget_paths_to_prepend_in_new_processes(self) -> List[str]: return self._paths_to_prepend_in_new_processes + ############################################## + # Env vars to set and get env vars methods + ############################################## + async def aset_env_vars_globally(self, env_vars: Dict[str, Any]): + + await asyncio.gather( + *[ + obj_store.acall_servlet_method( + servlet_name, + "aset_env_vars", + env_vars, + use_servlet_cache=False, + ) + for servlet_name in await self.aget_all_initialized_servlet_args() + ] + ) + + await self.aadd_env_vars_to_set_in_new_processes(env_vars) + + async def aadd_env_vars_to_set_in_new_processes(self, env_vars: Dict[str, Any]): + self._env_vars_to_set_in_new_processes.update(env_vars) + + async def aget_env_vars_to_set_in_new_processes(self) -> Dict[str, str]: + return self._env_vars_to_set_in_new_processes + ############################################## # Cluster config state storage methods ############################################## diff --git a/runhouse/servers/http/http_client.py b/runhouse/servers/http/http_client.py index cb7058e50..9d3750184 100644 --- a/runhouse/servers/http/http_client.py +++ b/runhouse/servers/http/http_client.py @@ -38,7 +38,7 @@ RenameObjectParams, RunBashParams, serialize_data, - SetProcessEnvVarsParams, + SetEnvVarsParams, ) from runhouse.utils import ClusterLogsFormatter, generate_default_name, thread_coroutine @@ -825,15 +825,15 @@ def create_process( json_dict=params.model_dump(), ) - def set_process_env_vars( + def set_env_vars( self, - process_name: str, env_vars: Dict[str, str], + process_name: Optional[str] = None, ): return self.request_json( - "/process_env_vars", + "/env_vars", req_type="post", - json_dict=SetProcessEnvVarsParams( + json_dict=SetEnvVarsParams( process_name=process_name, env_vars=env_vars ).model_dump(), ) diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index a88d32592..8f2c53f47 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -65,7 +65,7 @@ RunBashParams, serialize_data, ServerSettings, - SetProcessEnvVarsParams, + SetEnvVarsParams, ) from runhouse.servers.obj_store import ( ClusterServletSetupOption, @@ -454,11 +454,17 @@ async def kill_process(request: Request, params: KillProcessParams): ) @staticmethod - @app.post("/process_env_vars") + @app.post("/env_vars") @validate_cluster_access - async def set_process_env_vars(request: Request, params: SetProcessEnvVarsParams): + async def set_env_vars(request: Request, params: SetEnvVarsParams): try: - await obj_store.aset_process_env_vars(params.process_name, params.env_vars) + if params.process_name is not None: + await obj_store.aset_process_env_vars( + params.process_name, params.env_vars + ) + else: + await obj_store.aset_env_vars_globally(params.env_vars) + return Response(output_type=OutputType.SUCCESS) except Exception as e: return handle_exception_response( diff --git a/runhouse/servers/http/http_utils.py b/runhouse/servers/http/http_utils.py index 15347cb95..b94d1d04a 100644 --- a/runhouse/servers/http/http_utils.py +++ b/runhouse/servers/http/http_utils.py @@ -52,8 +52,8 @@ class CreateProcessParams(BaseModel): env_vars: Optional[Dict] = {} -class SetProcessEnvVarsParams(BaseModel): - process_name: str +class SetEnvVarsParams(BaseModel): + process_name: Optional[str] = None env_vars: Dict[str, str] diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 25c8f48f7..856b13a7a 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -271,7 +271,12 @@ async def ainitialize( if path not in sys.path: sys.path.insert(0, path) - # Set env vars that were passed in initialization + # Set env vars that need to be set on creation + env_vars_to_set = await self.aget_env_vars_to_set_in_new_processes() + self.set_process_env_vars_local(env_vars_to_set) + + # Set env vars that were passed in initialization, these should override + # any env vars that were set globally via the previous global set if create_process_params and create_process_params.env_vars: self.set_process_env_vars_local(create_process_params.env_vars) @@ -662,6 +667,11 @@ async def aset_process_env_vars(self, servlet_name: str, env_vars: Dict[str, str def set_process_env_vars(self, servlet_name: str, env_vars: Dict[str, str]): return sync_function(self.aset_process_env_vars)(servlet_name, env_vars) + async def aset_env_vars_globally(self, env_vars: Dict[str, str]): + return await self.acall_actor_method( + self.cluster_servlet, "aset_env_vars_globally", env_vars + ) + ############################################## # Cluster config state storage methods ############################################## @@ -825,6 +835,11 @@ async def aadd_path_to_prepend_in_new_processes(self, path: str): self.cluster_servlet, "aadd_path_to_prepend_in_new_processes", path ) + async def aget_env_vars_to_set_in_new_processes(self) -> Dict[str, str]: + return await self.acall_actor_method( + self.cluster_servlet, "aget_env_vars_to_set_in_new_processes" + ) + ############################################## # Remove Servlet ##############################################