diff --git a/Cargo.lock b/Cargo.lock index a3088db..fd73ff7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1828,9 +1828,9 @@ checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd" [[package]] name = "gix" -version = "0.64.0" +version = "0.66.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d78414d29fcc82329080166077e0f7689f4016551fdb334d787c3d040fe2634f" +checksum = "9048b8d1ae2104f045cb37e5c450fc49d5d8af22609386bfc739c11ba88995eb" dependencies = [ "gix-actor", "gix-attributes", @@ -1850,7 +1850,6 @@ dependencies = [ "gix-ignore", "gix-index", "gix-lock", - "gix-macros", "gix-object", "gix-odb", "gix-pack", @@ -1876,9 +1875,9 @@ dependencies = [ [[package]] name = "gix-actor" -version = "0.31.5" +version = "0.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0e454357e34b833cc3a00b6efbbd3dd4d18b24b9fb0c023876ec2645e8aa3f2" +checksum = "fc19e312cd45c4a66cd003f909163dc2f8e1623e30a0c0c6df3776e89b308665" dependencies = [ "bstr", "gix-date", @@ -1890,9 +1889,9 @@ dependencies = [ [[package]] name = "gix-attributes" -version = "0.22.3" +version = "0.22.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e37ce99c7e81288c28b703641b6d5d119aacc45c1a6b247156e6249afa486257" +checksum = "ebccbf25aa4a973dd352564a9000af69edca90623e8a16dad9cbc03713131311" dependencies = [ "bstr", "gix-glob", @@ -1925,9 +1924,9 @@ dependencies = [ [[package]] name = "gix-command" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d76867867da891cbe32021ad454e8cae90242f6afb06762e4dd0d357afd1d7b" +checksum = "dff2e692b36bbcf09286c70803006ca3fd56551a311de450be317a0ab8ea92e7" dependencies = [ "bstr", "gix-path", @@ -1951,9 +1950,9 @@ dependencies = [ [[package]] name = "gix-config" -version = "0.38.0" +version = "0.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28f53fd03d1bf09ebcc2c8654f08969439c4556e644ca925f27cf033bc43e658" +checksum = "78e797487e6ca3552491de1131b4f72202f282fb33f198b1c34406d765b42bb0" dependencies = [ "bstr", "gix-config-value", @@ -1972,9 +1971,9 @@ dependencies = [ [[package]] name = "gix-config-value" -version = "0.14.7" +version = "0.14.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b328997d74dd15dc71b2773b162cb4af9a25c424105e4876e6d0686ab41c383e" +checksum = "03f76169faa0dec598eac60f83d7fcdd739ec16596eca8fb144c88973dbe6f8c" dependencies = [ "bitflags 2.6.0", "bstr", @@ -1985,21 +1984,21 @@ dependencies = [ [[package]] name = "gix-date" -version = "0.8.7" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9eed6931f21491ee0aeb922751bd7ec97b4b2fe8fbfedcb678e2a2dce5f3b8c0" +checksum = "35c84b7af01e68daf7a6bb8bb909c1ff5edb3ce4326f1f43063a5a96d3c3c8a5" dependencies = [ "bstr", "itoa", + "jiff", "thiserror", - "time", ] [[package]] name = "gix-diff" -version = "0.44.1" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1996d5c8a305b59709467d80617c9fde48d9d75fd1f4179ea970912630886c9d" +checksum = "92c9afd80fff00f8b38b1c1928442feb4cd6d2232a6ed806b6b193151a3d336c" dependencies = [ "bstr", "gix-hash", @@ -2009,9 +2008,9 @@ dependencies = [ [[package]] name = "gix-dir" -version = "0.6.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c975679aa00dd2d757bfd3ddb232e8a188c0094c3306400575a0813858b1365" +checksum = "0ed3a9076661359a1c5a27c12ad6c3ebe2dd96b8b3c0af6488ab7c128b7bdd98" dependencies = [ "bstr", "gix-discover", @@ -2029,9 +2028,9 @@ dependencies = [ [[package]] name = "gix-discover" -version = "0.33.0" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67662731cec3cb31ba3ed2463809493f76d8e5d6c6d245de8b0560438c13450e" +checksum = "0577366b9567376bc26e815fd74451ebd0e6218814e242f8e5b7072c58d956d2" dependencies = [ "bstr", "dunce", @@ -2064,9 +2063,9 @@ dependencies = [ [[package]] name = "gix-filter" -version = "0.11.3" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6547738da28275f4dff4e9f3a0f28509f53f94dd6bd822733c91cb306bca61a" +checksum = "4121790ae140066e5b953becc72e7496278138d19239be2e63b5067b0843119e" dependencies = [ "bstr", "encoding_rs", @@ -2085,9 +2084,9 @@ dependencies = [ [[package]] name = "gix-fs" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6adf99c27cdf17b1c4d77680c917e0d94d8783d4e1c73d3be0d1d63107163d7a" +checksum = "f2bfe6249cfea6d0c0e0990d5226a4cb36f030444ba9e35e0639275db8f98575" dependencies = [ "fastrand", "gix-features", @@ -2096,9 +2095,9 @@ dependencies = [ [[package]] name = "gix-glob" -version = "0.16.4" +version = "0.16.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7df15afa265cc8abe92813cd354d522f1ac06b29ec6dfa163ad320575cb447" +checksum = "74908b4bbc0a0a40852737e5d7889f676f081e340d5451a16e5b4c50d592f111" dependencies = [ "bitflags 2.6.0", "bstr", @@ -2129,9 +2128,9 @@ dependencies = [ [[package]] name = "gix-ignore" -version = "0.11.3" +version = "0.11.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6afb8f98e314d4e1adc822449389ada863c174b5707cedd327d67b84dba527" +checksum = "e447cd96598460f5906a0f6c75e950a39f98c2705fc755ad2f2020c9e937fab7" dependencies = [ "bstr", "gix-glob", @@ -2142,9 +2141,9 @@ dependencies = [ [[package]] name = "gix-index" -version = "0.33.1" +version = "0.35.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a9a44eb55bd84bb48f8a44980e951968ced21e171b22d115d1cdcef82a7d73f" +checksum = "0cd4203244444017682176e65fd0180be9298e58ed90bd4a8489a357795ed22d" dependencies = [ "bitflags 2.6.0", "bstr", @@ -2179,22 +2178,11 @@ dependencies = [ "thiserror", ] -[[package]] -name = "gix-macros" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "999ce923619f88194171a67fb3e6d613653b8d4d6078b529b15a765da0edcc17" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.75", -] - [[package]] name = "gix-object" -version = "0.42.3" +version = "0.44.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25da2f46b4e7c2fa7b413ce4dffb87f69eaf89c2057e386491f4c55cadbfe386" +checksum = "2f5b801834f1de7640731820c2df6ba88d95480dc4ab166a5882f8ff12b88efa" dependencies = [ "bstr", "gix-actor", @@ -2211,9 +2199,9 @@ dependencies = [ [[package]] name = "gix-odb" -version = "0.61.1" +version = "0.63.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20d384fe541d93d8a3bb7d5d5ef210780d6df4f50c4e684ccba32665a5e3bc9b" +checksum = "a3158068701c17df54f0ab2adda527f5a6aca38fd5fd80ceb7e3c0a2717ec747" dependencies = [ "arc-swap", "gix-date", @@ -2231,9 +2219,9 @@ dependencies = [ [[package]] name = "gix-pack" -version = "0.51.1" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e0594491fffe55df94ba1c111a6566b7f56b3f8d2e1efc750e77d572f5f5229" +checksum = "3223aa342eee21e1e0e403cad8ae9caf9edca55ef84c347738d10681676fd954" dependencies = [ "clru", "gix-chunk", @@ -2249,9 +2237,9 @@ dependencies = [ [[package]] name = "gix-packetline-blocking" -version = "0.17.4" +version = "0.17.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31d42378a3d284732e4d589979930d0d253360eccf7ec7a80332e5ccb77e14a" +checksum = "b9802304baa798dd6f5ff8008a2b6516d54b74a69ca2d3a2b9e2d6c3b5556b40" dependencies = [ "bstr", "faster-hex", @@ -2261,9 +2249,9 @@ dependencies = [ [[package]] name = "gix-path" -version = "0.10.10" +version = "0.10.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d5b8722112fa2fa87135298780bc833b0e9f6c56cc82795d209804b3a03484" +checksum = "ebfc4febd088abdcbc9f1246896e57e37b7a34f6909840045a1767c6dafac7af" dependencies = [ "bstr", "gix-trace", @@ -2274,9 +2262,9 @@ dependencies = [ [[package]] name = "gix-pathspec" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d307d1b8f84dc8386c4aa20ce0cf09242033840e15469a3ecba92f10cfb5c046" +checksum = "5d23bf239532b4414d0e63b8ab3a65481881f7237ed9647bb10c1e3cc54c5ceb" dependencies = [ "bitflags 2.6.0", "bstr", @@ -2300,9 +2288,9 @@ dependencies = [ [[package]] name = "gix-ref" -version = "0.45.0" +version = "0.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "636e96a0a5562715153fee098c217110c33a6f8218f08f4687ff99afde159bb5" +checksum = "ae0d8406ebf9aaa91f55a57f053c5a1ad1a39f60fdf0303142b7be7ea44311e5" dependencies = [ "gix-actor", "gix-features", @@ -2321,9 +2309,9 @@ dependencies = [ [[package]] name = "gix-refspec" -version = "0.23.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6868f8cd2e62555d1f7c78b784bece43ace40dd2a462daf3b588d5416e603f37" +checksum = "ebb005f82341ba67615ffdd9f7742c87787544441c88090878393d0682869ca6" dependencies = [ "bstr", "gix-hash", @@ -2335,9 +2323,9 @@ dependencies = [ [[package]] name = "gix-revision" -version = "0.27.2" +version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01b13e43c2118c4b0537ddac7d0821ae0dfa90b7b8dbf20c711e153fb749adce" +checksum = "ba4621b219ac0cdb9256883030c3d56a6c64a6deaa829a92da73b9a576825e1e" dependencies = [ "bstr", "gix-date", @@ -2349,9 +2337,9 @@ dependencies = [ [[package]] name = "gix-revwalk" -version = "0.13.2" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b030ccaab71af141f537e0225f19b9e74f25fefdba0372246b844491cab43e0" +checksum = "b41e72544b93084ee682ef3d5b31b1ba4d8fa27a017482900e5e044d5b1b3984" dependencies = [ "gix-commitgraph", "gix-date", @@ -2364,9 +2352,9 @@ dependencies = [ [[package]] name = "gix-sec" -version = "0.10.7" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1547d26fa5693a7f34f05b4a3b59a90890972922172653bcb891ab3f09f436df" +checksum = "0fe4d52f30a737bbece5276fab5d3a8b276dc2650df963e293d0673be34e7a5f" dependencies = [ "bitflags 2.6.0", "gix-path", @@ -2376,9 +2364,9 @@ dependencies = [ [[package]] name = "gix-submodule" -version = "0.12.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f2e0f69aa00805e39d39ec80472a7e9da20ed5d73318b27925a2cc198e854fd" +checksum = "529d0af78cc2f372b3218f15eb1e3d1635a21c8937c12e2dd0b6fc80c2ca874b" dependencies = [ "bstr", "gix-config", @@ -2407,15 +2395,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.9" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f924267408915fddcd558e3f37295cc7d6a3e50f8bd8b606cee0808c3915157e" +checksum = "6cae0e8661c3ff92688ce1c8b8058b3efb312aba9492bbe93661a21705ab431b" [[package]] name = "gix-traverse" -version = "0.39.2" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e499a18c511e71cf4a20413b743b9f5bcf64b3d9e81e9c3c6cd399eae55a8840" +checksum = "030da39af94e4df35472e9318228f36530989327906f38e27807df305fccb780" dependencies = [ "bitflags 2.6.0", "gix-commitgraph", @@ -2430,9 +2418,9 @@ dependencies = [ [[package]] name = "gix-url" -version = "0.27.4" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2eb9b35bba92ea8f0b5ab406fad3cf6b87f7929aa677ff10aa042c6da621156" +checksum = "fd280c5e84fb22e128ed2a053a0daeacb6379469be6a85e3d518a0636e160c89" dependencies = [ "bstr", "gix-features", @@ -2455,9 +2443,9 @@ dependencies = [ [[package]] name = "gix-validate" -version = "0.8.5" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82c27dd34a49b1addf193c92070bcbf3beaf6e10f16a78544de6372e146a0acf" +checksum = "81f2badbb64e57b404593ee26b752c26991910fd0d81fe6f9a71c1a8309b6c86" dependencies = [ "bstr", "thiserror", @@ -2465,9 +2453,9 @@ dependencies = [ [[package]] name = "gix-worktree" -version = "0.34.1" +version = "0.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26f7326ebe0b9172220694ea69d344c536009a9b98fb0f9de092c440f3efe7a6" +checksum = "c312ad76a3f2ba8e865b360d5cb3aa04660971d16dec6dd0ce717938d903149a" dependencies = [ "bstr", "gix-attributes", @@ -3095,6 +3083,31 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jiff" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a45489186a6123c128fdf6016183fcfab7113e1820eb813127e036e287233fb" +dependencies = [ + "jiff-tzdb-platform", + "windows-sys 0.59.0", +] + +[[package]] +name = "jiff-tzdb" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91335e575850c5c4c673b9bd467b0e025f164ca59d0564f69d0c2ee0ffad4653" + +[[package]] +name = "jiff-tzdb-platform" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9835f0060a626fe59f160437bc725491a6af23133ea906500027d1bd2f8f4329" +dependencies = [ + "jiff-tzdb", +] + [[package]] name = "jni-sys" version = "0.3.0" diff --git a/crates/heat-sdk-cli/Cargo.toml b/crates/heat-sdk-cli/Cargo.toml index bc94955..26b2807 100644 --- a/crates/heat-sdk-cli/Cargo.toml +++ b/crates/heat-sdk-cli/Cargo.toml @@ -33,11 +33,11 @@ flate2 = { version = "1.0.30", default-features = false, features = ["zlib"] } tar = { version = "0.4.40", default-features = false } walkdir = "2" ignore = "0.4.22" -gix = { version = "0.64.0", default-features = false, features = ["dirwalk"]} +gix = { version = "0.66.0", default-features = false, features = ["dirwalk"]} unicase = "2.7.0" itertools = "0.13.0" lazycell = "1.3.0" serde-untagged = "0.1.6" serde_ignored = "0.1.1" sha2 = "0.10" -pathdiff = "0.2.1" \ No newline at end of file +pathdiff = "0.2.1" diff --git a/crates/heat-sdk/src/client.rs b/crates/heat-sdk/src/client.rs index 2f42cd5..d7629cd 100644 --- a/crates/heat-sdk/src/client.rs +++ b/crates/heat-sdk/src/client.rs @@ -1,12 +1,16 @@ use std::path::PathBuf; use std::sync::mpsc; use std::sync::{Arc, RwLock}; +use std::thread; +use std::time::Duration; use burn::tensor::backend::Backend; +use reqwest::StatusCode; use serde::Serialize; use crate::errors::sdk::HeatSdkError; use crate::experiment::{Experiment, TempLogStore, WsMessage}; +use crate::http::error::HeatHttpError; use crate::http::{EndExperimentStatus, HttpClient}; use crate::schemas::{ CrateVersionMetadata, ExperimentPath, HeatCodeMetadata, PackagedCrateData, ProjectPath, @@ -40,6 +44,8 @@ pub struct HeatClientConfig { pub credentials: HeatCredentials, /// The number of retries to attempt when connecting to the Heat API. pub num_retries: u8, + /// The interval to wait between retries in seconds. + pub retry_interval: u64, /// The project ID to create the experiment in. pub project_path: ProjectPath, } @@ -66,6 +72,7 @@ impl HeatClientConfigBuilder { endpoint: "http://127.0.0.1:9001".into(), credentials: creds, num_retries: 3, + retry_interval: 3, project_path, }, } @@ -83,6 +90,12 @@ impl HeatClientConfigBuilder { self } + /// Set the interval to wait between retries in seconds + pub fn with_retry_interval(mut self, retry_interval: u64) -> HeatClientConfigBuilder { + self.config.retry_interval = retry_interval; + self + } + /// Build the HeatClientConfig pub fn build(self) -> HeatClientConfig { self.config @@ -126,10 +139,29 @@ impl HeatClient { match client.connect() { Ok(_) => break, Err(e) => { + println!("Failed to connect to the server: {}", e); + if i == client.config.num_retries { - return Err(HeatSdkError::ServerTimeoutError(e.to_string())); + return Err(HeatSdkError::CreateClientError( + "Server timeout".to_string(), + )); + } + + if let HeatSdkError::HttpError(HeatHttpError::HttpError( + StatusCode::UNPROCESSABLE_ENTITY, + msg, + )) = e + { + println!("Invalid API key. Please check your API key and try again."); + return Err(HeatSdkError::CreateClientError(format!( + "Invalid API key: {msg}" + ))); } - println!("Failed to connect to the server. Retrying..."); + println!( + "Failed to connect to the server. Retrying in {} seconds...", + client.config.retry_interval + ); + thread::sleep(Duration::from_secs(client.config.retry_interval)); } } } @@ -139,23 +171,27 @@ impl HeatClient { /// Start a new experiment. This will create a new experiment on the Heat backend and start it. pub fn start_experiment(&mut self, config: &impl Serialize) -> Result<(), HeatSdkError> { - let experiment = self.http_client.create_experiment( - self.config.project_path.owner_name(), - self.config.project_path.project_name(), - )?; + let experiment = self + .http_client + .create_experiment( + self.config.project_path.owner_name(), + self.config.project_path.project_name(), + ) + .map_err(HeatSdkError::HttpError)?; let experiment_path = ExperimentPath::try_from(format!( "{}/{}", self.config.project_path, experiment.experiment_num - )) - .map_err(HeatSdkError::ClientError)?; + ))?; - self.http_client.start_experiment( - self.config.project_path.owner_name(), - &experiment.project_name, - experiment.experiment_num, - config, - )?; + self.http_client + .start_experiment( + self.config.project_path.owner_name(), + &experiment.project_name, + experiment.experiment_num, + config, + ) + .map_err(HeatSdkError::HttpError)?; println!("Experiment num: {}", experiment.experiment_num); @@ -168,17 +204,9 @@ impl HeatClient { let mut ws_client = WebSocketClient::new(); ws_client.connect(ws_endpoint, self.http_client.get_session_cookie().unwrap())?; - let exp_log_store = TempLogStore::new(self.http_client.clone(), experiment_path); + let exp_log_store = TempLogStore::new(self.http_client.clone(), experiment_path.clone()); - let experiment = Experiment::new( - ExperimentPath::try_from(format!( - "{}/{}", - self.config.project_path, experiment.experiment_num - )) - .map_err(HeatSdkError::ClientError)?, - ws_client, - exp_log_store, - ); + let experiment = Experiment::new(experiment_path, ws_client, exp_log_store); let mut exp_guard = self .active_experiment .write() @@ -189,7 +217,7 @@ impl HeatClient { } /// Get the sender for the active experiment's WebSocket connection. - pub fn get_experiment_sender(&self) -> Result, HeatSdkError> { + pub(crate) fn get_experiment_sender(&self) -> Result, HeatSdkError> { let active_experiment = self .active_experiment .read() @@ -197,7 +225,7 @@ impl HeatClient { if let Some(w) = active_experiment.as_ref() { w.get_ws_sender() } else { - Err(HeatSdkError::ClientError( + Err(HeatSdkError::UnknownError( "No active experiment to get sender.".to_string(), )) } @@ -228,14 +256,14 @@ impl HeatClient { let time = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap() + .expect("Should be able to get time.") .as_millis(); self.http_client.upload_bytes_to_url(&url, checkpoint)?; let time_end = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) - .unwrap() + .expect("Should be able to get time.") .as_millis(); log::info!("Time to upload checkpoint: {}", time_end - time); @@ -273,7 +301,7 @@ impl HeatClient { .expect("Should be able to lock active_experiment as read."); if active_experiment.is_none() { - return Err(HeatSdkError::ClientError( + return Err(HeatSdkError::UnknownError( "No active experiment to upload final model.".to_string(), )); } @@ -307,7 +335,7 @@ impl HeatClient { let recorder = crate::record::RemoteRecorder::::final_model(self.clone()); let res = model.save_file("", &recorder); if let Err(e) = res { - return Err(HeatSdkError::ClientError(e.to_string())); + return Err(HeatSdkError::StopExperimentError(e.to_string())); } self.end_experiment_internal(EndExperimentStatus::Success) @@ -372,19 +400,18 @@ impl HeatClient { metadata, )?; - // assumes that the urls are returned in the same order as the names for (crate_name, file_path) in data.into_iter() { let url = urls .urls .get(&crate_name) - .ok_or(HeatSdkError::ClientError(format!( + .ok_or(HeatSdkError::UnknownError(format!( "No URL found for crate {}", crate_name )))?; let data = std::fs::read(file_path).map_err(|e| { - HeatSdkError::ClientError(format!( - "Failed to read crate data for {}: {}", + HeatSdkError::FileReadError(format!( + "Could not read crate data for crate {}: {}", crate_name, e )) })?; @@ -407,7 +434,9 @@ impl HeatClient { self.config.project_path.project_name(), project_version, command, - ) + )?; + + Ok(()) } } diff --git a/crates/heat-sdk/src/errors/sdk.rs b/crates/heat-sdk/src/errors/sdk.rs index 1845013..b9ae6dc 100644 --- a/crates/heat-sdk/src/errors/sdk.rs +++ b/crates/heat-sdk/src/errors/sdk.rs @@ -1,26 +1,41 @@ use thiserror::Error; -use crate::websocket::WebSocketError; +use crate::{http::error::HeatHttpError, websocket::WebSocketError}; #[derive(Error, Debug)] pub enum HeatSdkError { - #[error("Server Timeout Error: {0}")] - ServerTimeoutError(String), - #[error("Server Error: {0}")] - ServerError(String), - #[error("Client Error: {0}")] - ClientError(String), + #[error("Invalid experiment number: {0}")] + InvalidExperimentNumber(String), + #[error("Invalid experiment path: {0}")] + InvalidProjectPath(String), + #[error("Invalid experiment path: {0}")] + InvalidExperimentPath(String), #[error("Websocket Error: {0}")] WebSocketError(String), #[error("Macro Error: {0}")] MacroError(String), + #[error("Failed to start experiment: {0}")] + StartExperimentError(String), + #[error("Failed to stop experiment: {0}")] + StopExperimentError(String), + #[error("Failed to create client: {0}")] + CreateClientError(String), + #[error("Failed to create remote metric logger: {0}")] + CreateRemoteMetricLoggerError(String), + + #[error("File Read Error: {0}")] + FileReadError(String), + + #[error("HTTP Error: {0}")] + HttpError(HeatHttpError), + #[error("Unknown Error: {0}")] UnknownError(String), } impl From>> for HeatSdkError { fn from(error: std::sync::PoisonError>) -> Self { - HeatSdkError::ClientError(error.to_string()) + HeatSdkError::UnknownError(error.to_string()) } } @@ -29,3 +44,9 @@ impl From for HeatSdkError { HeatSdkError::WebSocketError(error.to_string()) } } + +impl From for HeatSdkError { + fn from(error: HeatHttpError) -> Self { + HeatSdkError::HttpError(error) + } +} diff --git a/crates/heat-sdk/src/experiment/base.rs b/crates/heat-sdk/src/experiment/base.rs index 9831089..cfe1dee 100644 --- a/crates/heat-sdk/src/experiment/base.rs +++ b/crates/heat-sdk/src/experiment/base.rs @@ -82,12 +82,12 @@ impl Experiment { &self.experiment_path } - pub fn get_ws_sender(&self) -> Result, HeatSdkError> { + pub(crate) fn get_ws_sender(&self) -> Result, HeatSdkError> { if let Some(handler) = &self.handler { Ok(handler.get_sender()) } else { - Err(HeatSdkError::ClientError( - "Experiment handling thread not started".to_string(), + Err(HeatSdkError::UnknownError( + "Experiment not started yet".to_string(), )) } } diff --git a/crates/heat-sdk/src/http/client.rs b/crates/heat-sdk/src/http/client.rs index 98f61ef..2c63310 100644 --- a/crates/heat-sdk/src/http/client.rs +++ b/crates/heat-sdk/src/http/client.rs @@ -2,10 +2,10 @@ use reqwest::header::{COOKIE, SET_COOKIE}; use reqwest::Url; use serde::Serialize; +use crate::http::error::HeatHttpError; use crate::schemas::HeatCodeMetadata; use crate::{ client::HeatCredentials, - errors::sdk::HeatSdkError, http::schemas::StartExperimentSchema, schemas::{CrateVersionMetadata, Experiment}, }; @@ -20,16 +20,25 @@ pub enum EndExperimentStatus { Fail(String), } -impl From for HeatSdkError { +impl From for HeatHttpError { fn from(error: reqwest::Error) -> Self { match error.status() { - Some(status) => match status { - reqwest::StatusCode::REQUEST_TIMEOUT => { - HeatSdkError::ServerTimeoutError(error.to_string()) - } - _ => HeatSdkError::ServerError(status.to_string()), - }, - None => HeatSdkError::ServerError(error.to_string()), + Some(status) => HeatHttpError::HttpError(status, error.to_string()), + None => HeatHttpError::UnknownError(error.to_string()), + } + } +} + +trait ResponseExt { + fn map_to_heat_err(self) -> Result; +} + +impl ResponseExt for reqwest::blocking::Response { + fn map_to_heat_err(self) -> Result { + if self.status().is_success() { + Ok(self) + } else { + Err(HeatHttpError::HttpError(self.status(), self.text()?)) } } } @@ -66,9 +75,9 @@ impl HttpClient { /// Check if the Heat server is reachable. #[allow(dead_code)] - pub fn health_check(&self) -> Result<(), HeatSdkError> { + pub fn health_check(&self) -> Result<(), HeatHttpError> { let url = format!("{}/health", self.base_url); - self.http_client.get(url).send()?.error_for_status()?; + self.http_client.get(url).send()?.map_to_heat_err()?; Ok(()) } @@ -78,7 +87,7 @@ impl HttpClient { } /// Log in to the Heat server with the given credentials. - pub fn login(&mut self, credentials: &HeatCredentials) -> Result<(), HeatSdkError> { + pub fn login(&mut self, credentials: &HeatCredentials) -> Result<(), HeatHttpError> { let url = format!("{}/login/api-key", self.base_url); let res = self .http_client @@ -86,22 +95,22 @@ impl HttpClient { .form::(credentials) .send()?; + let status = res.status(); + // store session cookie - if res.status().is_success() { + if status.is_success() { let cookie_header = res.headers().get(SET_COOKIE); if let Some(cookie) = cookie_header { let cookie_str = cookie .to_str() - .expect("Session cookie should be convert to str"); + .expect("Session cookie should be able to convert to str"); self.session_cookie = Some(cookie_str.to_string()); } else { - return Err(HeatSdkError::ClientError( - "Cannot connect to Heat server, bad session ID.".to_string(), - )); + return Err(HeatHttpError::BadSessionId); } } else { let error_message: String = format!("Cannot connect to Heat server({:?})", res.text()?); - return Err(HeatSdkError::ClientError(error_message)); + return Err(HeatHttpError::HttpError(status, error_message)); } Ok(()) @@ -133,7 +142,7 @@ impl HttpClient { &self, owner_name: &str, project_name: &str, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url = format!( @@ -148,7 +157,7 @@ impl HttpClient { .json(&serde_json::json!({})) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .send()? - .error_for_status()? + .map_to_heat_err()? .json::()?; let experiment = Experiment { @@ -173,7 +182,7 @@ impl HttpClient { project_name: &str, exp_num: i32, config: &impl Serialize, - ) -> Result<(), HeatSdkError> { + ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; let json = StartExperimentSchema { @@ -189,7 +198,7 @@ impl HttpClient { .header(COOKIE, self.session_cookie.as_ref().unwrap()) .json(&json) .send()? - .error_for_status()?; + .map_to_heat_err()?; Ok(()) } @@ -203,7 +212,7 @@ impl HttpClient { project_name: &str, exp_num: i32, end_status: EndExperimentStatus, - ) -> Result<(), HeatSdkError> { + ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; let url = format!( @@ -221,7 +230,7 @@ impl HttpClient { .header(COOKIE, self.session_cookie.as_ref().unwrap()) .json(&end_status) .send()? - .error_for_status()?; + .map_to_heat_err()?; Ok(()) } @@ -235,7 +244,7 @@ impl HttpClient { project_name: &str, exp_num: i32, file_name: &str, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url: String = format!( @@ -248,7 +257,7 @@ impl HttpClient { .post(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .send()? - .error_for_status()? + .map_to_heat_err()? .json::() .map(|res| res.url)?; @@ -264,7 +273,7 @@ impl HttpClient { project_name: &str, exp_num: i32, file_name: &str, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url: String = format!( @@ -277,7 +286,7 @@ impl HttpClient { .get(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .send()? - .error_for_status()? + .map_to_heat_err()? .json::() .map(|res| res.url)?; @@ -292,7 +301,7 @@ impl HttpClient { owner_name: &str, project_name: &str, exp_num: i32, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url = format!( @@ -305,7 +314,7 @@ impl HttpClient { .post(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .send()? - .error_for_status()? + .map_to_heat_err()? .json::()? .url; @@ -320,7 +329,7 @@ impl HttpClient { owner_name: &str, project_name: &str, exp_num: i32, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url = format!( @@ -333,7 +342,7 @@ impl HttpClient { .post(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) .send()? - .error_for_status()? + .map_to_heat_err()? .json::()? .url; @@ -341,34 +350,32 @@ impl HttpClient { } /// Generic method to upload bytes to the given URL. - pub fn upload_bytes_to_url(&self, url: &str, bytes: Vec) -> Result<(), HeatSdkError> { + pub fn upload_bytes_to_url(&self, url: &str, bytes: Vec) -> Result<(), HeatHttpError> { self.http_client .put(url) .body(bytes) .send()? - .error_for_status()?; + .map_to_heat_err()?; Ok(()) } /// Generic method to download bytes from the given URL. - pub fn download_bytes_from_url(&self, url: &str) -> Result, HeatSdkError> { + pub fn download_bytes_from_url(&self, url: &str) -> Result, HeatHttpError> { let data = self .http_client .get(url) .send()? - .error_for_status()? + .map_to_heat_err()? .bytes()? .to_vec(); Ok(data) } - fn validate_session_cookie(&self) -> Result<(), HeatSdkError> { + fn validate_session_cookie(&self) -> Result<(), HeatHttpError> { if self.session_cookie.is_none() { - return Err(HeatSdkError::ClientError( - "Cannot connect to Heat server, no session ID.".to_string(), - )); + return Err(HeatHttpError::BadSessionId); } Ok(()) } @@ -380,7 +387,7 @@ impl HttpClient { target_package_name: &str, heat_metadata: HeatCodeMetadata, crates_metadata: Vec, - ) -> Result { + ) -> Result { self.validate_session_cookie()?; let url = format!( @@ -388,7 +395,7 @@ impl HttpClient { self.base_url, owner_name, project_name ); - let upload_urls = self + let response = self .http_client .post(url) .header(COOKIE, self.session_cookie.as_ref().unwrap()) @@ -398,9 +405,9 @@ impl HttpClient { crates: crates_metadata, }) .send()? - .error_for_status()? - .json::()?; + .map_to_heat_err()?; + let upload_urls = response.json::()?; Ok(upload_urls) } @@ -411,7 +418,7 @@ impl HttpClient { project_name: &str, project_version: u32, command: String, - ) -> Result<(), HeatSdkError> { + ) -> Result<(), HeatHttpError> { self.validate_session_cookie()?; let url = format!( @@ -430,7 +437,7 @@ impl HttpClient { .header(COOKIE, self.session_cookie.as_ref().unwrap()) .json(&body) .send()? - .error_for_status()?; + .map_to_heat_err()?; Ok(()) } diff --git a/crates/heat-sdk/src/http/error.rs b/crates/heat-sdk/src/http/error.rs new file mode 100644 index 0000000..c291f77 --- /dev/null +++ b/crates/heat-sdk/src/http/error.rs @@ -0,0 +1,13 @@ +use reqwest::StatusCode; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum HeatHttpError { + #[error("Bad session id")] + BadSessionId, + #[error("Http Error {0}: {1}")] + HttpError(StatusCode, String), + + #[error("Unknown Error: {0}")] + UnknownError(String), +} diff --git a/crates/heat-sdk/src/http/mod.rs b/crates/heat-sdk/src/http/mod.rs index 85714b9..0c289f8 100644 --- a/crates/heat-sdk/src/http/mod.rs +++ b/crates/heat-sdk/src/http/mod.rs @@ -1,4 +1,5 @@ mod client; +pub(crate) mod error; mod schemas; pub use client::*; diff --git a/crates/heat-sdk/src/metrics/metric_logger.rs b/crates/heat-sdk/src/metrics/metric_logger.rs index 6245acc..18fbb4e 100644 --- a/crates/heat-sdk/src/metrics/metric_logger.rs +++ b/crates/heat-sdk/src/metrics/metric_logger.rs @@ -4,6 +4,7 @@ use burn::train::logger::MetricLogger; use burn::train::metric::{MetricEntry, NumericEntry}; use crate::client::HeatClientState; +use crate::errors::sdk::HeatSdkError; use crate::experiment::{Split, WsMessage}; /// The remote metric logger, used to send metric logs to Heat. @@ -17,19 +18,23 @@ impl RemoteMetricLogger { /// Create a new instance of the remote metric logger for `Training` with the given [HeatClientState]. pub fn new_train(client: HeatClientState) -> Self { Self::new(client, Split::Train) + .expect("RemoteMetricLogger should be created successfully for training split") } /// Create a new instance of the remote metric logger for `Validation` with the given [HeatClientState]. pub fn new_validation(client: HeatClientState) -> Self { Self::new(client, Split::Val) + .expect("RemoteMetricLogger should be created successfully for validation split") } - fn new(client: HeatClientState, split: Split) -> Self { - Self { - sender: client.get_experiment_sender().unwrap(), + fn new(client: HeatClientState, split: Split) -> Result { + Ok(Self { + sender: client + .get_experiment_sender() + .map_err(|e| HeatSdkError::CreateRemoteMetricLoggerError(e.to_string()))?, epoch: 1, split, - } + }) } } diff --git a/crates/heat-sdk/src/schemas.rs b/crates/heat-sdk/src/schemas.rs index 33cb4c4..efef5d7 100644 --- a/crates/heat-sdk/src/schemas.rs +++ b/crates/heat-sdk/src/schemas.rs @@ -4,6 +4,8 @@ use once_cell::sync::Lazy; use regex::Regex; use serde::{Deserialize, Serialize}; +use crate::errors::sdk::HeatSdkError; + #[derive(Debug, Deserialize, Serialize)] #[serde(rename_all = "lowercase")] pub enum DepKind { @@ -163,11 +165,11 @@ impl ProjectPath { } impl TryFrom for ProjectPath { - type Error = String; + type Error = HeatSdkError; fn try_from(path: String) -> Result { if !ProjectPath::validate_path(&path) { - return Err(format!("Invalid project path: {}", path)); + return Err(HeatSdkError::InvalidProjectPath(path)); } let parts: Vec<&str> = path.split('/').collect(); @@ -227,18 +229,18 @@ impl ExperimentPath { } impl TryFrom for ExperimentPath { - type Error = String; + type Error = HeatSdkError; fn try_from(path: String) -> Result { if !ExperimentPath::validate_path(&path) { - return Err(format!("Invalid experiment path: {}", path)); + return Err(HeatSdkError::InvalidExperimentPath(path)); } let parts: Vec<&str> = path.split('/').collect(); let project_path = ProjectPath::try_from(parts[0..2].join("/"))?; let experiment_num = parts[2] .parse::() - .map_err(|e| format!("Failed to parse experiment number: {}", e))?; + .map_err(|_| HeatSdkError::InvalidExperimentNumber(parts[2].to_string()))?; Ok(ExperimentPath { project_path,