Skip to content

Commit

Permalink
Better error handling in the sdk (#30)
Browse files Browse the repository at this point in the history
More variants and separated sdk errors from http errors

Co-authored-by: Jonathan Richard <jwric@users.noreply.github.com>
  • Loading branch information
ThierryCantin-Demers and jwric authored Sep 27, 2024
1 parent 8efe30c commit f607eee
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 180 deletions.
167 changes: 90 additions & 77 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions crates/heat-sdk-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ flate2 = { version = "1.0.30", default-features = false, features = ["zlib"] }
tar = { version = "0.4.40", default-features = false }
walkdir = "2"
ignore = "0.4.22"
gix = { version = "0.64.0", default-features = false, features = ["dirwalk"]}
gix = { version = "0.66.0", default-features = false, features = ["dirwalk"]}
unicase = "2.7.0"
itertools = "0.13.0"
lazycell = "1.3.0"
serde-untagged = "0.1.6"
serde_ignored = "0.1.1"
sha2 = "0.10"
pathdiff = "0.2.1"
pathdiff = "0.2.1"
99 changes: 64 additions & 35 deletions crates/heat-sdk/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::path::PathBuf;
use std::sync::mpsc;
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::Duration;

use burn::tensor::backend::Backend;
use reqwest::StatusCode;
use serde::Serialize;

use crate::errors::sdk::HeatSdkError;
use crate::experiment::{Experiment, TempLogStore, WsMessage};
use crate::http::error::HeatHttpError;
use crate::http::{EndExperimentStatus, HttpClient};
use crate::schemas::{
CrateVersionMetadata, ExperimentPath, HeatCodeMetadata, PackagedCrateData, ProjectPath,
Expand Down Expand Up @@ -40,6 +44,8 @@ pub struct HeatClientConfig {
pub credentials: HeatCredentials,
/// The number of retries to attempt when connecting to the Heat API.
pub num_retries: u8,
/// The interval to wait between retries in seconds.
pub retry_interval: u64,
/// The project ID to create the experiment in.
pub project_path: ProjectPath,
}
Expand All @@ -66,6 +72,7 @@ impl HeatClientConfigBuilder {
endpoint: "http://127.0.0.1:9001".into(),
credentials: creds,
num_retries: 3,
retry_interval: 3,
project_path,
},
}
Expand All @@ -83,6 +90,12 @@ impl HeatClientConfigBuilder {
self
}

/// Set the interval to wait between retries in seconds
pub fn with_retry_interval(mut self, retry_interval: u64) -> HeatClientConfigBuilder {
self.config.retry_interval = retry_interval;
self
}

/// Build the HeatClientConfig
pub fn build(self) -> HeatClientConfig {
self.config
Expand Down Expand Up @@ -126,10 +139,29 @@ impl HeatClient {
match client.connect() {
Ok(_) => break,
Err(e) => {
println!("Failed to connect to the server: {}", e);

if i == client.config.num_retries {
return Err(HeatSdkError::ServerTimeoutError(e.to_string()));
return Err(HeatSdkError::CreateClientError(
"Server timeout".to_string(),
));
}

if let HeatSdkError::HttpError(HeatHttpError::HttpError(
StatusCode::UNPROCESSABLE_ENTITY,
msg,
)) = e
{
println!("Invalid API key. Please check your API key and try again.");
return Err(HeatSdkError::CreateClientError(format!(
"Invalid API key: {msg}"
)));
}
println!("Failed to connect to the server. Retrying...");
println!(
"Failed to connect to the server. Retrying in {} seconds...",
client.config.retry_interval
);
thread::sleep(Duration::from_secs(client.config.retry_interval));
}
}
}
Expand All @@ -139,23 +171,27 @@ impl HeatClient {

/// Start a new experiment. This will create a new experiment on the Heat backend and start it.
pub fn start_experiment(&mut self, config: &impl Serialize) -> Result<(), HeatSdkError> {
let experiment = self.http_client.create_experiment(
self.config.project_path.owner_name(),
self.config.project_path.project_name(),
)?;
let experiment = self
.http_client
.create_experiment(
self.config.project_path.owner_name(),
self.config.project_path.project_name(),
)
.map_err(HeatSdkError::HttpError)?;

let experiment_path = ExperimentPath::try_from(format!(
"{}/{}",
self.config.project_path, experiment.experiment_num
))
.map_err(HeatSdkError::ClientError)?;
))?;

self.http_client.start_experiment(
self.config.project_path.owner_name(),
&experiment.project_name,
experiment.experiment_num,
config,
)?;
self.http_client
.start_experiment(
self.config.project_path.owner_name(),
&experiment.project_name,
experiment.experiment_num,
config,
)
.map_err(HeatSdkError::HttpError)?;

println!("Experiment num: {}", experiment.experiment_num);

Expand All @@ -168,17 +204,9 @@ impl HeatClient {
let mut ws_client = WebSocketClient::new();
ws_client.connect(ws_endpoint, self.http_client.get_session_cookie().unwrap())?;

let exp_log_store = TempLogStore::new(self.http_client.clone(), experiment_path);
let exp_log_store = TempLogStore::new(self.http_client.clone(), experiment_path.clone());

let experiment = Experiment::new(
ExperimentPath::try_from(format!(
"{}/{}",
self.config.project_path, experiment.experiment_num
))
.map_err(HeatSdkError::ClientError)?,
ws_client,
exp_log_store,
);
let experiment = Experiment::new(experiment_path, ws_client, exp_log_store);
let mut exp_guard = self
.active_experiment
.write()
Expand All @@ -189,15 +217,15 @@ impl HeatClient {
}

/// Get the sender for the active experiment's WebSocket connection.
pub fn get_experiment_sender(&self) -> Result<mpsc::Sender<WsMessage>, HeatSdkError> {
pub(crate) fn get_experiment_sender(&self) -> Result<mpsc::Sender<WsMessage>, HeatSdkError> {
let active_experiment = self
.active_experiment
.read()
.expect("Should be able to lock active_experiment as read.");
if let Some(w) = active_experiment.as_ref() {
w.get_ws_sender()
} else {
Err(HeatSdkError::ClientError(
Err(HeatSdkError::UnknownError(
"No active experiment to get sender.".to_string(),
))
}
Expand Down Expand Up @@ -228,14 +256,14 @@ impl HeatClient {

let time = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.expect("Should be able to get time.")
.as_millis();

self.http_client.upload_bytes_to_url(&url, checkpoint)?;

let time_end = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.expect("Should be able to get time.")
.as_millis();

log::info!("Time to upload checkpoint: {}", time_end - time);
Expand Down Expand Up @@ -273,7 +301,7 @@ impl HeatClient {
.expect("Should be able to lock active_experiment as read.");

if active_experiment.is_none() {
return Err(HeatSdkError::ClientError(
return Err(HeatSdkError::UnknownError(
"No active experiment to upload final model.".to_string(),
));
}
Expand Down Expand Up @@ -307,7 +335,7 @@ impl HeatClient {
let recorder = crate::record::RemoteRecorder::<S>::final_model(self.clone());
let res = model.save_file("", &recorder);
if let Err(e) = res {
return Err(HeatSdkError::ClientError(e.to_string()));
return Err(HeatSdkError::StopExperimentError(e.to_string()));
}

self.end_experiment_internal(EndExperimentStatus::Success)
Expand Down Expand Up @@ -372,19 +400,18 @@ impl HeatClient {
metadata,
)?;

// assumes that the urls are returned in the same order as the names
for (crate_name, file_path) in data.into_iter() {
let url = urls
.urls
.get(&crate_name)
.ok_or(HeatSdkError::ClientError(format!(
.ok_or(HeatSdkError::UnknownError(format!(
"No URL found for crate {}",
crate_name
)))?;

let data = std::fs::read(file_path).map_err(|e| {
HeatSdkError::ClientError(format!(
"Failed to read crate data for {}: {}",
HeatSdkError::FileReadError(format!(
"Could not read crate data for crate {}: {}",
crate_name, e
))
})?;
Expand All @@ -407,7 +434,9 @@ impl HeatClient {
self.config.project_path.project_name(),
project_version,
command,
)
)?;

Ok(())
}
}

Expand Down
37 changes: 29 additions & 8 deletions crates/heat-sdk/src/errors/sdk.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,41 @@
use thiserror::Error;

use crate::websocket::WebSocketError;
use crate::{http::error::HeatHttpError, websocket::WebSocketError};

#[derive(Error, Debug)]
pub enum HeatSdkError {
#[error("Server Timeout Error: {0}")]
ServerTimeoutError(String),
#[error("Server Error: {0}")]
ServerError(String),
#[error("Client Error: {0}")]
ClientError(String),
#[error("Invalid experiment number: {0}")]
InvalidExperimentNumber(String),
#[error("Invalid experiment path: {0}")]
InvalidProjectPath(String),
#[error("Invalid experiment path: {0}")]
InvalidExperimentPath(String),
#[error("Websocket Error: {0}")]
WebSocketError(String),
#[error("Macro Error: {0}")]
MacroError(String),
#[error("Failed to start experiment: {0}")]
StartExperimentError(String),
#[error("Failed to stop experiment: {0}")]
StopExperimentError(String),
#[error("Failed to create client: {0}")]
CreateClientError(String),
#[error("Failed to create remote metric logger: {0}")]
CreateRemoteMetricLoggerError(String),

#[error("File Read Error: {0}")]
FileReadError(String),

#[error("HTTP Error: {0}")]
HttpError(HeatHttpError),

#[error("Unknown Error: {0}")]
UnknownError(String),
}

impl<T> From<std::sync::PoisonError<std::sync::MutexGuard<'_, T>>> for HeatSdkError {
fn from(error: std::sync::PoisonError<std::sync::MutexGuard<'_, T>>) -> Self {
HeatSdkError::ClientError(error.to_string())
HeatSdkError::UnknownError(error.to_string())
}
}

Expand All @@ -29,3 +44,9 @@ impl From<WebSocketError> for HeatSdkError {
HeatSdkError::WebSocketError(error.to_string())
}
}

impl From<HeatHttpError> for HeatSdkError {
fn from(error: HeatHttpError) -> Self {
HeatSdkError::HttpError(error)
}
}
6 changes: 3 additions & 3 deletions crates/heat-sdk/src/experiment/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ impl Experiment {
&self.experiment_path
}

pub fn get_ws_sender(&self) -> Result<mpsc::Sender<WsMessage>, HeatSdkError> {
pub(crate) fn get_ws_sender(&self) -> Result<mpsc::Sender<WsMessage>, HeatSdkError> {
if let Some(handler) = &self.handler {
Ok(handler.get_sender())
} else {
Err(HeatSdkError::ClientError(
"Experiment handling thread not started".to_string(),
Err(HeatSdkError::UnknownError(
"Experiment not started yet".to_string(),
))
}
}
Expand Down
Loading

0 comments on commit f607eee

Please sign in to comment.