diff --git a/databricks/sdk/mixins/compute.py b/databricks/sdk/mixins/compute.py index 61733a6f..46eae4f4 100644 --- a/databricks/sdk/mixins/compute.py +++ b/databricks/sdk/mixins/compute.py @@ -209,7 +209,7 @@ def select_node_type(self, return nt.node_type_id raise ValueError("cannot determine smallest node type") - def ensure_cluster_is_running(self, cluster_id: str) -> None: + def ensure_cluster_is_running(self, cluster_id: str) -> compute.ClusterDetails: """Ensures that given cluster is running, regardless of the current state""" timeout = datetime.timedelta(minutes=20) deadline = time.time() + timeout.total_seconds() @@ -218,17 +218,14 @@ def ensure_cluster_is_running(self, cluster_id: str) -> None: state = compute.State info = self.get(cluster_id) if info.state == state.RUNNING: - return + return info elif info.state == state.TERMINATED: - self.start(cluster_id).result() - return + return self.start(cluster_id).result() elif info.state == state.TERMINATING: self.wait_get_cluster_terminated(cluster_id) - self.start(cluster_id).result() - return + return self.start(cluster_id).result() elif info.state in (state.PENDING, state.RESIZING, state.RESTARTING): - self.wait_get_cluster_running(cluster_id) - return + return self.wait_get_cluster_running(cluster_id) elif info.state in (state.ERROR, state.UNKNOWN): raise RuntimeError(f'Cluster {info.cluster_name} is {info.state}: {info.state_message}') except OperationFailed as e: