Skip to content

Commit

Permalink
Merge branch 'project_packages_metadata' into package
Browse files Browse the repository at this point in the history
  • Loading branch information
ThierryCantin-Demers committed Sep 6, 2024
2 parents 1b7a374 + dfb0d12 commit 0cd24ef
Show file tree
Hide file tree
Showing 17 changed files with 549 additions and 183 deletions.
12 changes: 12 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# EditorConfig is awesome: https://EditorConfig.org
# top-most EditorConfig file
root = true

# Unix-style newlines with a newline ending every file
[*]
end_of_line = lf
insert_final_newline = true

[*.rs]
indent_style = space
indent_size = 4
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@ derive_more = { version = "0.99.18", features = ["display"], default-features =
dotenv = "0.15.0"
env_logger = "0.11.3"
log = "0.4.21"
once_cell = "1.19.0"
proc-macro2 = { version = "1.0.86" }
quote = "1.0.36"
rand = "0.8.5"
reqwest = "0.12.4"
regex = "1.10.5"
rmp-serde = "1.3.0"
rstest = "0.19.0"
serde = { version = "1.0.204", default-features = false, features = [
Expand Down
45 changes: 36 additions & 9 deletions crates/heat-sdk-cli/src/cli_commands/package/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::context::HeatCliContext;
use clap::Parser;
use heat_sdk::client::{HeatClient, HeatClientConfig, HeatCredentials};
use heat_sdk::{
client::{HeatClient, HeatClientConfig, HeatCredentials},
schemas::{HeatCodeMetadata, ProjectPath, RegisteredHeatFunction},
};
use quote::ToTokens;

#[derive(Parser, Debug)]
pub struct PackageArgs {
Expand All @@ -12,7 +16,7 @@ pub struct PackageArgs {
required = true,
help = "<required> The Heat project ID."
)]
project: String,
project_path: String,
/// The Heat API key
#[clap(
short = 'k',
Expand All @@ -31,25 +35,48 @@ pub struct PackageArgs {
pub heat_endpoint: String,
}

fn create_heat_client(api_key: &str, url: &str, project: &str) -> HeatClient {
fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient {
let creds = HeatCredentials::new(api_key.to_owned());
let client_config = HeatClientConfig::builder(creds, project)
.with_endpoint(url)
.with_num_retries(10)
.build();
let client_config = HeatClientConfig::builder(
creds,
ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."),
)
.with_endpoint(url)
.with_num_retries(10)
.build();
HeatClient::create(client_config)
.expect("Should connect to the Heat server and create a client")
}

pub(crate) fn handle_command(args: PackageArgs, context: HeatCliContext) -> anyhow::Result<()> {
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project);
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path);

let crates = crate::util::cargo::package::package(
&context.get_artifacts_dir_path(),
context.package_name(),
)?;

heat_client.upload_new_project_version(context.package_name(), crates)?;
let flags = crate::registry::get_flags();

let mut registered_functions = Vec::<RegisteredHeatFunction>::new();
for flag in flags {
// function token stream to readable string
let itemfn = syn_serde::json::from_slice::<syn::ItemFn>(flag.token_stream).expect("Should be able to parse token stream.");
let syn_tree: syn::File = syn::parse2(itemfn.into_token_stream()).expect("Should be able to parse token stream.");
let code_str = prettyplease::unparse(&syn_tree);
registered_functions.push(RegisteredHeatFunction {
mod_path: flag.mod_path.to_string(),
fn_name: flag.fn_name.to_string(),
proc_type: flag.proc_type.to_string(),
code: code_str,
});
}

let heat_metadata = HeatCodeMetadata {
functions: registered_functions,
};

heat_client.upload_new_project_version(context.package_name(), heat_metadata, crates)?;

Ok(())
}
57 changes: 46 additions & 11 deletions crates/heat-sdk-cli/src/cli_commands/run/remote/training.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use clap::Parser;
use heat_sdk::client::{HeatClient, HeatClientConfig, HeatCredentials};
use heat_sdk::{
client::{HeatClient, HeatClientConfig, HeatCredentials},
schemas::{HeatCodeMetadata, ProjectPath, RegisteredHeatFunction},
};
use quote::ToTokens;

use crate::{context::HeatCliContext, generation::backend::BackendType};

Expand All @@ -24,7 +28,7 @@ pub struct RemoteTrainingRunArgs {
required = true,
help = "<required> The Heat project ID."
)]
project: String,
project_path: String,
/// The Heat API key
#[clap(
short = 'k',
Expand All @@ -41,14 +45,25 @@ pub struct RemoteTrainingRunArgs {
default_value = "http://127.0.0.1:9001"
)]
pub heat_endpoint: String,
/// The runner group name
#[clap(
short = 'r',
long = "runner",
help = "The runner group name.",
required = true,
)]
pub runner: String,
}

fn create_heat_client(api_key: &str, url: &str, project: &str) -> HeatClient {
fn create_heat_client(api_key: &str, url: &str, project_path: &str) -> HeatClient {
let creds = HeatCredentials::new(api_key.to_owned());
let client_config = HeatClientConfig::builder(creds, project)
.with_endpoint(url)
.with_num_retries(10)
.build();
let client_config = HeatClientConfig::builder(
creds,
ProjectPath::try_from(project_path.to_string()).expect("Project path should be valid."),
)
.with_endpoint(url)
.with_num_retries(10)
.build();
HeatClient::create(client_config)
.expect("Should connect to the Heat server and create a client")
}
Expand All @@ -57,18 +72,38 @@ pub(crate) fn handle_command(
args: RemoteTrainingRunArgs,
context: HeatCliContext,
) -> anyhow::Result<()> {
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project);
let heat_client = create_heat_client(&args.key, &args.heat_endpoint, &args.project_path);

let crates = crate::util::cargo::package::package(
&context.get_artifacts_dir_path(),
context.package_name(),
)?;

let project_version = heat_client.upload_new_project_version(context.package_name(), crates)?;
let flags = crate::registry::get_flags();

let mut registered_functions = Vec::<RegisteredHeatFunction>::new();
for flag in flags {
// function token stream to readable string
let itemfn = syn_serde::json::from_slice::<syn::ItemFn>(flag.token_stream).expect("Should be able to parse token stream.");
let syn_tree: syn::File = syn::parse2(itemfn.into_token_stream()).expect("Should be able to parse token stream.");
let code_str = prettyplease::unparse(&syn_tree);
registered_functions.push(RegisteredHeatFunction {
mod_path: flag.mod_path.to_string(),
fn_name: flag.fn_name.to_string(),
proc_type: flag.proc_type.to_string(),
code: code_str,
});
}

let heat_metadata = HeatCodeMetadata {
functions: registered_functions,
};

let project_version = heat_client.upload_new_project_version(context.package_name(), heat_metadata, crates)?;

heat_client.start_remote_job(
args.runner,
project_version,
context.package_name().to_string(),
format!(
"run local training --functions {} --backends {} --configs {} --project {} --key {}",
args.functions.join(" "),
Expand All @@ -78,7 +113,7 @@ pub(crate) fn handle_command(
.collect::<Vec<_>>()
.join(" "),
args.configs.join(" "),
args.project,
args.project_path,
args.key
),
)?;
Expand Down
12 changes: 9 additions & 3 deletions crates/heat-sdk-cli/src/generation/crate_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,12 @@ fn generate_training_function(
quote! {
let client = create_heat_client(&key, &heat_endpoint, &project);
let training_config_str = std::fs::read_to_string(&config_path).expect("Config should be read");
let training_config: serde_json::Value = serde_json::from_str(&training_config_str).expect("Config should be deserialized");

let mut train_cmd_context = TrainCommandContext::new(client, vec![device], training_config_str);

let conf_ser = train_cmd_context.config().as_bytes().to_vec();
train_cmd_context.client()
.start_experiment(&conf_ser)
.start_experiment(&training_config)
.expect("Experiment should be started");

pub fn trigger<B: Backend, T, M: Module<B>, E: Into<Box<dyn std::error::Error>>, H: TrainCommandHandler<B, T, M, E>>(handler: H, context: TrainCommandContext<B>) -> Result<M, Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -339,7 +339,7 @@ fn generate_main_rs(main_backend: &BackendType) -> String {

fn create_heat_client(api_key: &str, url: &str, project: &str) -> tracel::heat::client::HeatClient {
let creds = tracel::heat::client::HeatCredentials::new(api_key.to_owned());
let client_config = tracel::heat::client::HeatClientConfig::builder(creds, project)
let client_config = tracel::heat::client::HeatClientConfig::builder(creds, tracel::heat::schemas::ProjectPath::try_from(project.to_string()).expect("Project path should be valid."))
.with_endpoint(url)
.with_num_retries(10)
.build();
Expand Down Expand Up @@ -400,6 +400,12 @@ pub fn create_crate(
None,
vec!["cargo".to_string()],
));
generated_crate.add_dependency(Dependency::new(
"serde_json".to_string(),
"*".to_string(),
None,
vec![],
));
find_required_dependencies(vec!["tracel", "burn"])
.drain(..)
.for_each(|mut dep| {
Expand Down
2 changes: 2 additions & 0 deletions crates/heat-sdk-cli/src/logging.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![allow(dead_code)]

use colored::{Colorize, CustomColor};

pub const BURN_ORANGE: CustomColor = CustomColor {
Expand Down
2 changes: 2 additions & 0 deletions crates/heat-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,5 @@ tracing-core = { version = "0.1.32" }
tracing-subscriber = { version = "0.3.18" }
tungstenite = { version = "0.21.0" }
uuid = { workspace = true }
regex = { workspace = true }
once_cell = { workspace = true }
Loading

0 comments on commit 0cd24ef

Please sign in to comment.