Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cluster client flow #1098

Merged
merged 1 commit into from
Aug 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 31 additions & 46 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import re
import subprocess
import threading
import time
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -126,15 +125,33 @@ def address(self, addr):
@property
def client(self):
if not self._http_client:
if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")
if not self._ping(retry=True):
# ping cluster, and refresh ips if ondemand cluster and first ping fails
raise Exception(
f"Could not reach cluster {self.name} ({self.ips}). Is it up?"
)

connect_call = threading.Thread(target=self.connect_server_client)
connect_call.start()
connect_call.join(timeout=5)
if connect_call.is_alive():
raise ConnectionError(
f"Could not connect to client. Please check that the cluster {self.name} is up."
)
if not self._http_client:
raise ConnectionError(
f"Error occured trying to form connection for cluster {self.name}."
)

try:
self._http_client.check_server()
except (
requests.exceptions.ConnectionError,
requests.exceptions.ReadTimeout,
requests.exceptions.ChunkedEncodingError,
ValueError,
) as e:
raise ConnectionError(f"Check server failed: {e}.")
return self._http_client

@property
Expand Down Expand Up @@ -625,51 +642,20 @@ def on_this_cluster(self):

# ----------------- RPC Methods ----------------- #

def call_client_method(self, method_name, *args, restart_server=True, **kwargs):
def check_and_call():
def call_client_method(self, method_name, *args, **kwargs):
method = getattr(self.client, method_name)
try:
return method(*args, **kwargs)
except ConnectionError:
try:
self.client.check_server()
self._http_client = None
method = getattr(self.client, method_name)
return method(*args, **kwargs)
except (
requests.exceptions.ConnectionError,
requests.exceptions.ReadTimeout,
requests.exceptions.ChunkedEncodingError,
ValueError,
) as e:
if isinstance(e, ValueError) and "Error checking server:" not in str(e):
raise e
raise ConnectionError(f"Check server failed: {e}.")

try:
return check_and_call()
except ConnectionError as e:
if not restart_server:
raise ConnectionError(f"Could not connect to server {self.name}: {e}")
elif "Check server failed: " not in str(e):
raise e

if not self._ping(retry=True):
raise Exception(f"Could not reach cluster {self.name}. Is it up?")

logger.info(
f"Cluster {self.name} is up, but the Runhouse API server may not be up."
)

self._http_client = None
self.restart_server()
for i in range(3):
logger.info(f"Checking server {self.name} again [{i + 1}/3]")
except:
raise ConnectionError("Could not connect to Runhouse server.")

try:
return self.call_client_method(
method_name, *args, restart_server=False, **kwargs
)
except ConnectionError as e:
if i == 2:
raise e
time.sleep(5)
return
return method(*args, **kwargs)
except Exception as e:
raise e

def connect_tunnel(self, force_reconnect=False):
if self._rpc_tunnel and force_reconnect:
Expand Down Expand Up @@ -748,7 +734,6 @@ def status(self, resource_address: str = None):
else:
status = self.call_client_method(
"status",
restart_server=False,
resource_address=resource_address or self.rns_address,
)
return status
Expand Down
Loading