-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #90 from taku-y/mlflow
Support Mlflow tracking
- Loading branch information
Showing
42 changed files
with
985 additions
and
327 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
[package] | ||
name = "border-candle-agent" | ||
version.workspace = true | ||
edition.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
keywords.workspace = true | ||
categories.workspace = true | ||
license.workspace = true | ||
readme = "README.md" | ||
|
||
[dependencies] | ||
border-core = { version = "0.0.6", path = "../border-core" } | ||
border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } | ||
serde = { workspace = true, features = ["derive"] } | ||
serde_yaml = { workspace = true } | ||
tensorboard-rs = { workspace = true } | ||
log = { workspace = true } | ||
thiserror = { workspace = true } | ||
anyhow = { workspace = true } | ||
chrono = { workspace = true } | ||
aquamarine = { workspace = true } | ||
candle-core = { workspace = true } | ||
fastrand = { workspace = true } | ||
segment-tree = { workspace = true } | ||
|
||
[dev-dependencies] | ||
tempdir = { workspace = true } | ||
|
||
# [package.metadata.docs.rs] | ||
# features = ["doc-only"] | ||
|
||
# [features] | ||
# doc-only = ["tch/doc-only"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
[package] | ||
name = "border-mlflow-tracking" | ||
version.workspace = true | ||
edition.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
keywords.workspace = true | ||
categories.workspace = true | ||
license.workspace = true | ||
readme = "README.md" | ||
|
||
[dependencies] | ||
border-core = { version = "0.0.6", path = "../border-core" } | ||
reqwest = { workspace = true } | ||
anyhow = { workspace = true } | ||
serde = { workspace = true, features = ["derive"] } | ||
log = { workspace = true } | ||
serde_json = { workspace = true } | ||
flatten-serde-json = "0.1.0" | ||
|
||
[dev-dependencies] | ||
env_logger = { workspace = true } | ||
|
||
[[example]] | ||
name = "tracking_basic" | ||
# test = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
Support [MLflow](https://mlflow.org) tracking to manage experiments. | ||
|
||
Before running the program using this crate, run a tracking server with the following command: | ||
|
||
```bash | ||
mlflow server --host 127.0.0.1 --port 8080 | ||
``` | ||
|
||
Then, training configurations and metrices can be logged to the tracking server. | ||
The following code is an example. Nested configuration parameters will be flattened, | ||
logged like `hyper_params.param1`, `hyper_params.param2`. | ||
|
||
```rust | ||
use anyhow::Result; | ||
use border_core::record::{Record, RecordValue, Recorder}; | ||
use border_mlflow_tracking::MlflowTrackingClient; | ||
use serde::Serialize; | ||
|
||
// Nested Configuration struct | ||
#[derive(Debug, Serialize)] | ||
struct Config { | ||
env_params: String, | ||
hyper_params: HyperParameters, | ||
} | ||
|
||
#[derive(Debug, Serialize)] | ||
struct HyperParameters { | ||
param1: i64, | ||
param2: Param2, | ||
param3: Param3, | ||
} | ||
|
||
#[derive(Debug, Serialize)] | ||
enum Param2 { | ||
Variant1, | ||
Variant2(f32), | ||
} | ||
|
||
#[derive(Debug, Serialize)] | ||
struct Param3 { | ||
dataset_name: String, | ||
} | ||
|
||
fn main() -> Result<()> { | ||
env_logger::init(); | ||
|
||
let config1 = Config { | ||
env_params: "env1".to_string(), | ||
hyper_params: HyperParameters { | ||
param1: 0, | ||
param2: Param2::Variant1, | ||
param3: Param3 { | ||
dataset_name: "a".to_string(), | ||
}, | ||
}, | ||
}; | ||
let config2 = Config { | ||
env_params: "env2".to_string(), | ||
hyper_params: HyperParameters { | ||
param1: 0, | ||
param2: Param2::Variant2(3.0), | ||
param3: Param3 { | ||
dataset_name: "a".to_string(), | ||
}, | ||
}, | ||
}; | ||
|
||
// Set experiment for runs | ||
let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; | ||
|
||
// Create recorders for logging | ||
let mut recorder_run1 = client.create_recorder("")?; | ||
let mut recorder_run2 = client.create_recorder("")?; | ||
recorder_run1.log_params(&config1)?; | ||
recorder_run2.log_params(&config2)?; | ||
|
||
// Logging while training | ||
for opt_steps in 0..100 { | ||
let opt_steps = opt_steps as f32; | ||
|
||
// Create a record | ||
let mut record = Record::empty(); | ||
record.insert("opt_steps", RecordValue::Scalar(opt_steps)); | ||
record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); | ||
|
||
// Log metrices in the record | ||
recorder_run1.write(record); | ||
} | ||
|
||
// Logging while training | ||
for opt_steps in 0..100 { | ||
let opt_steps = opt_steps as f32; | ||
|
||
// Create a record | ||
let mut record = Record::empty(); | ||
record.insert("opt_steps", RecordValue::Scalar(opt_steps)); | ||
record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); | ||
|
||
// Log metrices in the record | ||
recorder_run2.write(record); | ||
} | ||
|
||
Ok(()) | ||
} | ||
``` |
Oops, something went wrong.