diff --git a/CHANGELOG.md b/CHANGELOG.md index b5dcd0a8..8b601564 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,16 @@ # Changelog -## v0.0.6 (20??-??-??) +## v0.0.7 (20??-??-??) + +### Added + +Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2). + +### Changed + +* Take `self` in the signature of `push()` method of replay buffer (`border-core`) + +## v0.0.6 (2023-09-19) ### Added diff --git a/Cargo.toml b/Cargo.toml index 059177e3..06707687 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,8 +2,10 @@ members = [ "border-core", "border-tensorboard", + "border-mlflow-tracking", "border-py-gym-env", "border-tch-agent", + "border-candle-agent", "border-derive", "border-atari-env", "border-async-trainer", @@ -11,6 +13,16 @@ members = [ ] exclude = ["docker/"] +[workspace.package] +version = "0.0.6" +edition = "2018" +rust-version = "1.70" +description = "Reinforcement learning library" +repository = "https://github.com/taku-y/border" +keywords = ["rl"] +categories = ["science"] +license = "MIT OR Apache-2.0" + [workspace.dependencies] clap = "2.33.3" csv = "1.1.5" @@ -23,7 +35,8 @@ aquamarine = "0.1" log = "0.4" dirs = "3.0.2" thiserror = "1.0" -serde = "=1.0.126" +serde = "1.0.194" +serde_json = "^1.0.114" numpy = "0.14.1" env_logger = "0.8.2" tempdir = "0.3.7" @@ -34,3 +47,6 @@ ndarray = "0.15.1" chrono = "0.4" segment-tree = "2.0.0" image = "0.23.14" +candle-core = "0.2.2" +candle-nn = "0.2.2" +reqwest = { version = "0.11.26", features = ["json", "blocking"] } diff --git a/README.md b/README.md index e15b084b..a82b422e 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Border consists of the following crates: * [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). * [border-tch-agent](https://crates.io/crates/border-tch-agent) is a collection of RL agents based on [tch](https://crates.io/crates/tch), including Deep Q network (DQN), implicit quantile network (IQN), and soft actor critic (SAC). * [border-async-trainer](https://crates.io/crates/border-async-trainer) defines some traits and functions for asynchronous training of RL agents by multiple actors, which runs sampling processes in parallel. In each sampling process, an agent interacts with an environment to collect samples to be sent to a shared replay buffer. +* [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) support MLflow tracking to log metrices during training via REST API. You can use a part of these crates for your purposes, though [border-core](https://crates.io/crates/border-core) is mandatory. [This crate](https://crates.io/crates/border) is just a collection of examples. See [Documentation](https://docs.rs/border) for more details. diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index 51d5a128..572817b7 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-async-trainer" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Atari environment based on gym-rs" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] anyhow = { workspace = true } diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index e7abc683..27848894 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-atari-env" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Atari environment based on gym-rs" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "GPL-2.0-or-later" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +package.license = "GPL-2.0-or-later" readme = "README.md" -autoexamples = false [dependencies] anyhow = { workspace = true } diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 3e481f27..6732792f 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -45,7 +45,7 @@ impl SubBatch for ObsBatch { } #[inline] - fn push(&mut self, i: usize, data: &Self) { + fn push(&mut self, i: usize, data: Self) { unsafe { let src: *const u8 = &data.buf[0]; let dst: *mut u8 = &mut self.buf[i * self.m]; @@ -100,7 +100,7 @@ impl SubBatch for ActBatch { } #[inline] - fn push(&mut self, i: usize, data: &Self) { + fn push(&mut self, i: usize, data: Self) { unsafe { let src: *const u8 = &data.buf[0]; let dst: *mut u8 = &mut self.buf[i * self.m]; diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml new file mode 100644 index 00000000..32ed8d00 --- /dev/null +++ b/border-candle-agent/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "border-candle-agent" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.6", path = "../border-core" } +border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +serde = { workspace = true, features = ["derive"] } +serde_yaml = { workspace = true } +tensorboard-rs = { workspace = true } +log = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +chrono = { workspace = true } +aquamarine = { workspace = true } +candle-core = { workspace = true } +fastrand = { workspace = true } +segment-tree = { workspace = true } + +[dev-dependencies] +tempdir = { workspace = true } + +# [package.metadata.docs.rs] +# features = ["doc-only"] + +# [features] +# doc-only = ["tch/doc-only"] diff --git a/border-candle-agent/src/lib.rs b/border-candle-agent/src/lib.rs new file mode 100644 index 00000000..e69de29b diff --git a/border-core/Cargo.toml b/border-core/Cargo.toml index f8028f9f..208e866f 100644 --- a/border-core/Cargo.toml +++ b/border-core/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-core" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] serde = { workspace = true, features = ["derive"] } diff --git a/border-core/src/replay_buffer/base.rs b/border-core/src/replay_buffer/base.rs index 36c55c8b..cf360238 100644 --- a/border-core/src/replay_buffer/base.rs +++ b/border-core/src/replay_buffer/base.rs @@ -108,9 +108,9 @@ where fn push(&mut self, tr: Self::PushedItem) -> Result<()> { let len = tr.len(); // batch size let (obs, act, next_obs, reward, is_done, _, _) = tr.unpack(); - self.obs.push(self.i, &obs); - self.act.push(self.i, &act); - self.next_obs.push(self.i, &next_obs); + self.obs.push(self.i, obs); + self.act.push(self.i, act); + self.next_obs.push(self.i, next_obs); self.push_reward(self.i, &reward); self.push_is_done(self.i, &is_done); diff --git a/border-core/src/replay_buffer/subbatch.rs b/border-core/src/replay_buffer/subbatch.rs index 0fcfd321..3ce10308 100644 --- a/border-core/src/replay_buffer/subbatch.rs +++ b/border-core/src/replay_buffer/subbatch.rs @@ -6,7 +6,7 @@ pub trait SubBatch { fn new(capacity: usize) -> Self; /// Pushes the samples in `data`. - fn push(&mut self, i: usize, data: &Self); + fn push(&mut self, i: usize, data: Self); /// Takes samples in the batch. fn sample(&self, ixs: &Vec) -> Self; diff --git a/border-derive/Cargo.toml b/border-derive/Cargo.toml index 3c7ad2ef..6e1fafc7 100644 --- a/border-derive/Cargo.toml +++ b/border-derive/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-derive" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Derive macros for observation and action in RL environments of border" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" -# readme = "README.md" -autoexamples = false +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" [lib] proc-macro = true diff --git a/border-derive/src/subbatch.rs b/border-derive/src/subbatch.rs index 89fed9b8..ab660a0a 100644 --- a/border-derive/src/subbatch.rs +++ b/border-derive/src/subbatch.rs @@ -32,8 +32,8 @@ fn tensor_sub_batch(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_ma Self(TensorSubBatch::new(capacity)) } - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) } fn sample(&self, ixs: &Vec) -> Self { diff --git a/border-mlflow-tracking/Cargo.toml b/border-mlflow-tracking/Cargo.toml new file mode 100644 index 00000000..2fe1d2fb --- /dev/null +++ b/border-mlflow-tracking/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "border-mlflow-tracking" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.6", path = "../border-core" } +reqwest = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true, features = ["derive"] } +log = { workspace = true } +serde_json = { workspace = true } +flatten-serde-json = "0.1.0" + +[dev-dependencies] +env_logger = { workspace = true } + +[[example]] +name = "tracking_basic" +# test = true diff --git a/border-mlflow-tracking/README.md b/border-mlflow-tracking/README.md new file mode 100644 index 00000000..7a78a9a3 --- /dev/null +++ b/border-mlflow-tracking/README.md @@ -0,0 +1,105 @@ +Support [MLflow](https://mlflow.org) tracking to manage experiments. + +Before running the program using this crate, run a tracking server with the following command: + +```bash +mlflow server --host 127.0.0.1 --port 8080 +``` + +Then, training configurations and metrices can be logged to the tracking server. +The following code is an example. Nested configuration parameters will be flattened, +logged like `hyper_params.param1`, `hyper_params.param2`. + +```rust +use anyhow::Result; +use border_core::record::{Record, RecordValue, Recorder}; +use border_mlflow_tracking::MlflowTrackingClient; +use serde::Serialize; + +// Nested Configuration struct +#[derive(Debug, Serialize)] +struct Config { + env_params: String, + hyper_params: HyperParameters, +} + +#[derive(Debug, Serialize)] +struct HyperParameters { + param1: i64, + param2: Param2, + param3: Param3, +} + +#[derive(Debug, Serialize)] +enum Param2 { + Variant1, + Variant2(f32), +} + +#[derive(Debug, Serialize)] +struct Param3 { + dataset_name: String, +} + +fn main() -> Result<()> { + env_logger::init(); + + let config1 = Config { + env_params: "env1".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant1, + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + let config2 = Config { + env_params: "env2".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant2(3.0), + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + + // Set experiment for runs + let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + + // Create recorders for logging + let mut recorder_run1 = client.create_recorder("")?; + let mut recorder_run2 = client.create_recorder("")?; + recorder_run1.log_params(&config1)?; + recorder_run2.log_params(&config2)?; + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run1.write(record); + } + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run2.write(record); + } + + Ok(()) +} +``` diff --git a/border-mlflow-tracking/examples/tracking_basic.rs b/border-mlflow-tracking/examples/tracking_basic.rs new file mode 100644 index 00000000..98d22f9d --- /dev/null +++ b/border-mlflow-tracking/examples/tracking_basic.rs @@ -0,0 +1,91 @@ +use anyhow::Result; +use border_core::record::{Record, RecordValue, Recorder}; +use border_mlflow_tracking::MlflowTrackingClient; +use serde::Serialize; + +// Nested Configuration struct +#[derive(Debug, Serialize)] +struct Config { + env_params: String, + hyper_params: HyperParameters, +} + +#[derive(Debug, Serialize)] +struct HyperParameters { + param1: i64, + param2: Param2, + param3: Param3, +} + +#[derive(Debug, Serialize)] +enum Param2 { + Variant1, + Variant2(f32), +} + +#[derive(Debug, Serialize)] +struct Param3 { + dataset_name: String, +} + +fn main() -> Result<()> { + env_logger::init(); + + let config1 = Config { + env_params: "env1".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant1, + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + let config2 = Config { + env_params: "env2".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant2(3.0), + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + + // Set experiment for runs + let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + + // Create recorders for logging + let mut recorder_run1 = client.create_recorder("")?; + let mut recorder_run2 = client.create_recorder("")?; + recorder_run1.log_params(&config1)?; + recorder_run2.log_params(&config2)?; + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run1.write(record); + } + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run2.write(record); + } + + Ok(()) +} diff --git a/border-mlflow-tracking/src/client.rs b/border-mlflow-tracking/src/client.rs new file mode 100644 index 00000000..cc458b32 --- /dev/null +++ b/border-mlflow-tracking/src/client.rs @@ -0,0 +1,150 @@ +// use anyhow::Result; +use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run}; +use anyhow::Result; +use log::info; +use reqwest::blocking::Client; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt::Display; + +#[derive(Debug, Deserialize)] +/// Internally used. +struct Experiment_ { + pub(crate) experiment: Experiment, +} + +#[derive(Debug, Deserialize)] +/// Internally used. +struct Run_ { + run: Run, +} + +#[derive(Debug, Clone)] +pub struct GetExperimentIdError; + +impl Display for GetExperimentIdError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to get experiment ID") + } +} + +impl Error for GetExperimentIdError {} + +#[derive(Debug, Serialize)] +/// Parameters adapted from https://mlflow.org/docs/latest/rest-api.html#id74 +/// +/// TODO: Support parameters in API, if required. +struct CreateRunParams { + experiment_id: String, + start_time: i64, + run_name: String, +} + +/// Provides access to a MLflow tracking server via REST API. +/// +/// Support Mlflow API version 2.0. +pub struct MlflowTrackingClient { + client: Client, + + /// Base URL. + base_url: String, + + /// Current experiment ID. + experiment_id: Option, +} + +impl MlflowTrackingClient { + pub fn new(base_url: impl AsRef) -> Self { + Self { + client: Client::new(), + base_url: base_url.as_ref().to_string(), + experiment_id: None, + } + } + + /// Set the experiment ID to this struct. + /// + /// Get the ID from the tracking server by `name`. + pub fn set_experiment_id(self, name: impl AsRef) -> Result { + let experiment_id = { + self.get_experiment(name.as_ref()) + .expect(format!("Failed to get experiment: {:?}", name.as_ref()).as_str()) + .experiment_id + }; + + info!( + "For experiment '{}', id={} is set in MlflowTrackingClient", + name.as_ref(), + experiment_id + ); + + Ok(Self { + client: self.client, + base_url: self.base_url, + experiment_id: Some(experiment_id), + }) + + // let experiment_id = self.get_experiment_id(&name); + // match experiment_id { + // None => Err(GetExperimentIdError), + // Some(experiment_id) => Ok(Self { + // client: self.client, + // base_url: self.base_url, + // experiment_id: Some(experiment_id), + // }), + // } + } + + /// Get [`Experiment`] by name from the tracking server. + /// + /// TODO: Better error handling + pub fn get_experiment(&self, name: impl AsRef) -> Option { + let url = format!("{}/api/2.0/mlflow/experiments/get-by-name", self.base_url); + let resp = self + .client + .get(url) + .query(&[("experiment_name", name.as_ref())]) + .send() + .unwrap(); + let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap(); + + Some(experiment.experiment) + } + + /// Create [`MlflowTrackingRecorder`] corresponding to a run. + /// + /// If `name` is empty (`""`), a run name is given by the tracking server. + /// + /// Need to call [`MlflowTrackingClient::set_experiment_id()`] before calling this method. + pub fn create_recorder(&self, run_name: impl AsRef) -> Result { + let not_given_name = run_name.as_ref().len() == 0; + let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id"); + let url = format!("{}/api/2.0/mlflow/runs/create", self.base_url); + let params = CreateRunParams { + experiment_id: experiment_id.to_string(), + start_time: system_time_as_millis() as i64, + run_name: run_name.as_ref().to_string(), + }; + let resp = self + .client + .post(url) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // println!("{:?}", resp); + // println!("{:?}", resp.text()); + // TODO: Check the response from the tracking server here + let run = { + let run: Run_ = + serde_json::from_str(&resp.text().unwrap()).expect("Failed to deserialize Run"); + run.run + }; + if not_given_name { + info!( + "Run name '{}' has been automatically generated", + run.info.run_name + ); + } + MlflowTrackingRecorder::new(&self.base_url, &experiment_id, &run) + } +} diff --git a/border-mlflow-tracking/src/experiment.rs b/border-mlflow-tracking/src/experiment.rs new file mode 100644 index 00000000..73851fe7 --- /dev/null +++ b/border-mlflow-tracking/src/experiment.rs @@ -0,0 +1,19 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct ExperimentTag { + pub key: String, + pub value: String, +} + +#[derive(Debug, Deserialize)] +/// all fields taken from https://mlflow.org/docs/latest/rest-api.html#mlflowexperiment +pub struct Experiment { + pub experiment_id: String, + pub name: String, + pub artifact_location: String, + pub lifecycle_stage: String, + pub last_update_time: i64, + pub creation_time: i64, + pub tags: Option>, +} diff --git a/border-mlflow-tracking/src/lib.rs b/border-mlflow-tracking/src/lib.rs new file mode 100644 index 00000000..30a36250 --- /dev/null +++ b/border-mlflow-tracking/src/lib.rs @@ -0,0 +1,122 @@ +//! Support [MLflow](https://mlflow.org) tracking to manage experiments. +//! +//! Before running the program using this crate, run a tracking server with the following command: +//! +//! ```bash +//! mlflow server --host 127.0.0.1 --port 8080 +//! ``` +//! +//! Then, training configurations and metrices can be logged to the tracking server. +//! The following code is an example. Nested configuration parameters will be flattened, +//! logged like `hyper_params.param1`, `hyper_params.param2`. +//! +//! ```no_run +//! use anyhow::Result; +//! use border_core::record::{Record, RecordValue, Recorder}; +//! use border_mlflow_tracking::MlflowTrackingClient; +//! use serde::Serialize; +//! +//! // Nested Configuration struct +//! #[derive(Debug, Serialize)] +//! struct Config { +//! env_params: String, +//! hyper_params: HyperParameters, +//! } +//! +//! #[derive(Debug, Serialize)] +//! struct HyperParameters { +//! param1: i64, +//! param2: Param2, +//! param3: Param3, +//! } +//! +//! #[derive(Debug, Serialize)] +//! enum Param2 { +//! Variant1, +//! Variant2(f32), +//! } +//! +//! #[derive(Debug, Serialize)] +//! struct Param3 { +//! dataset_name: String, +//! } +//! +//! fn main() -> Result<()> { +//! env_logger::init(); +//! +//! let config1 = Config { +//! env_params: "env1".to_string(), +//! hyper_params: HyperParameters { +//! param1: 0, +//! param2: Param2::Variant1, +//! param3: Param3 { +//! dataset_name: "a".to_string(), +//! }, +//! }, +//! }; +//! let config2 = Config { +//! env_params: "env2".to_string(), +//! hyper_params: HyperParameters { +//! param1: 0, +//! param2: Param2::Variant2(3.0), +//! param3: Param3 { +//! dataset_name: "a".to_string(), +//! }, +//! }, +//! }; +//! +//! // Set experiment for runs +//! let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; +//! +//! // Create recorders for logging +//! let mut recorder_run1 = client.create_recorder("")?; +//! let mut recorder_run2 = client.create_recorder("")?; +//! recorder_run1.log_params(&config1)?; +//! recorder_run2.log_params(&config2)?; +//! +//! // Logging while training +//! for opt_steps in 0..100 { +//! let opt_steps = opt_steps as f32; +//! +//! // Create a record +//! let mut record = Record::empty(); +//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); +//! record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); +//! +//! // Log metrices in the record +//! recorder_run1.write(record); +//! } +//! +//! // Logging while training +//! for opt_steps in 0..100 { +//! let opt_steps = opt_steps as f32; +//! +//! // Create a record +//! let mut record = Record::empty(); +//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); +//! record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); +//! +//! // Log metrices in the record +//! recorder_run2.write(record); +//! } +//! +//! Ok(()) +//! } +//! ``` +mod client; +mod experiment; +mod run; +mod writer; +pub use client::{GetExperimentIdError, MlflowTrackingClient}; +use experiment::Experiment; +pub use run::Run; +use std::time::{SystemTime, UNIX_EPOCH}; +pub use writer::MlflowTrackingRecorder; + +/// Code adapted from https://stackoverflow.com/questions/26593387 +fn system_time_as_millis() -> u128 { + let time = SystemTime::now(); + time.duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() +} diff --git a/border-mlflow-tracking/src/run.rs b/border-mlflow-tracking/src/run.rs new file mode 100644 index 00000000..4179d8ea --- /dev/null +++ b/border-mlflow-tracking/src/run.rs @@ -0,0 +1,58 @@ +use serde::Deserialize; + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct Run { + pub info: RunInfo, + data: Option, + inputs: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct RunInfo { + pub run_id: String, + pub run_name: String, + experiment_id: String, + status: Option, + start_time: i64, + end_time: Option, + artifact_uri: Option, + lifecycle_stage: Option +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct RunData { + metrics: Option>, + params: Option>, + tags: Option>, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +/// TODO: implement +struct RunInputs {} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct RunTag { + key: String, + value: String, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Param { + key: String, + value: String, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Metric { + key: String, + value: String, + timestamp: i64, + step: i64, +} diff --git a/border-mlflow-tracking/src/writer.rs b/border-mlflow-tracking/src/writer.rs new file mode 100644 index 00000000..419fe7ce --- /dev/null +++ b/border-mlflow-tracking/src/writer.rs @@ -0,0 +1,149 @@ +use crate::{system_time_as_millis, Run}; +use anyhow::Result; +use border_core::record::{RecordValue, Recorder}; +use reqwest::blocking::Client; +use serde::Serialize; +use serde_json::Value; + +#[derive(Debug, Serialize)] +struct LogParamParams<'a> { + run_id: &'a String, + key: &'a String, + value: String, +} + +#[derive(Debug, Serialize)] +struct LogMetricParams<'a> { + run_id: &'a String, + key: &'a String, + value: f64, + timestamp: i64, + step: i64, +} + +#[derive(Debug, Serialize)] +struct UpdateRunParams<'a> { + run_id: &'a String, + status: String, + end_time: i64, + run_name: &'a String, +} + +#[allow(dead_code)] +/// Record metrics to the MLflow tracking server during training. +/// +/// Before training, you can use [`MlflowTrackingRecorder::log_params()`] to log parameters +/// of the run like hyperparameters of the algorithm, the name of environment on which the +/// agent is trained, etc. +/// +/// [`MlflowTrackingRecorder::write()`] method logs [`RecordValue::Scalar`] values in the record +/// as metrics. As an exception, `opt_steps` is treated as the `step` field of Mlflow's metric data +/// (https://mlflow.org/docs/latest/rest-api.html#metric). +/// +/// Other types of values like [`RecordValue::Array1`] will be ignored. +/// +/// When dropped, this struct updates run's status to "FINISHED" +/// (https://mlflow.org/docs/latest/rest-api.html#mlflowrunstatus). +/// +/// [`RecordValue::Scalar`]: border_core::record::RecordValue::Scalar +/// [`RecordValue::Array1`]: border_core::record::RecordValue::Array1 +pub struct MlflowTrackingRecorder { + client: Client, + base_url: String, + experiment_id: String, + run_id: String, + run_name: String, +} + +impl MlflowTrackingRecorder { + pub fn new(base_url: &String, experiment_id: &String, run: &Run) -> Result { + let client = Client::new(); + Ok(Self { + client, + base_url: base_url.clone(), + experiment_id: experiment_id.to_string(), + run_id: run.info.run_id.clone(), + run_name: run.info.run_name.clone(), + }) + } + + pub fn log_params(&self, params: impl Serialize) -> Result<()> { + let url = format!("{}/api/2.0/mlflow/runs/log-parameter", self.base_url); + let flatten_map = { + let map = match serde_json::to_value(params).unwrap() { + Value::Object(map) => map, + _ => panic!("Failed to parse object"), + }; + flatten_serde_json::flatten(&map) + }; + for (key, value) in flatten_map.iter() { + let params = LogParamParams { + run_id: &self.run_id, + key, + value: value.to_string(), + }; + let _resp = self + .client + .post(&url) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } + + Ok(()) + } +} + +impl Recorder for MlflowTrackingRecorder { + fn write(&mut self, record: border_core::record::Record) { + let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url); + let timestamp = system_time_as_millis() as i64; + let step = record.get_scalar("opt_steps").unwrap() as i64; + + for (key, value) in record.iter() { + if *key != "opt_steps" { + match value { + RecordValue::Scalar(v) => { + let value = *v as f64; + let params = LogMetricParams { + run_id: &self.run_id, + key, + value, + timestamp, + step, + }; + let _resp = self + .client + .post(&url) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } + _ => {} // ignore record value + } + } + } + } +} + +impl Drop for MlflowTrackingRecorder { + fn drop(&mut self) { + let end_time = system_time_as_millis() as i64; + let url = format!("{}/api/2.0/mlflow/runs/update", self.base_url); + let params = UpdateRunParams { + run_id: &self.run_id, + status: "FINISHED".to_string(), + end_time, + run_name: &self.run_name, + }; + let _resp = self + .client + .post(&url) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } +} diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index b1592ee2..f9184eee 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-py-gym-env" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] border-core = { version = "0.0.6", path = "../border-core" } diff --git a/border-py-gym-env/examples/random_ant_pybullet.rs b/border-py-gym-env/examples/backup/random_ant_pybullet.rs similarity index 100% rename from border-py-gym-env/examples/random_ant_pybullet.rs rename to border-py-gym-env/examples/backup/random_ant_pybullet.rs diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index 230775ee..46a9e276 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-tch-agent" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] border-core = { version = "0.0.6", path = "../border-core" } diff --git a/border-tch-agent/src/tensor_batch.rs b/border-tch-agent/src/tensor_batch.rs index ea7805f5..7b47edef 100644 --- a/border-tch-agent/src/tensor_batch.rs +++ b/border-tch-agent/src/tensor_batch.rs @@ -83,7 +83,7 @@ impl SubBatch for TensorSubBatch { /// /// If the internal buffer is empty, it will be initialized with the shape /// `[capacity, data.buf.size()[1..]]`. - fn push(&mut self, index: usize, data: &Self) { + fn push(&mut self, index: usize, data: Self) { if data.buf.is_none() { return; } diff --git a/border-tensorboard/Cargo.toml b/border-tensorboard/Cargo.toml index 2bb44c7f..8c4b2854 100644 --- a/border-tensorboard/Cargo.toml +++ b/border-tensorboard/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-tensorboard" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] border-core = { version = "0.0.6", path = "../border-core" } diff --git a/border/Cargo.toml b/border/Cargo.toml index bb57f4e8..4f4ee177 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "GPL-2.0-or-later" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +package.license = "GPL-2.0-or-later" readme = "README.md" -autoexamples = false [dependencies] aquamarine = { workspace = true } @@ -21,7 +17,7 @@ anyhow = { workspace = true } log = { workspace = true } dirs = { workspace = true } zip = "0.5.12" -reqwest = { version = "0.11.3", features = ["blocking"] } +reqwest = { workspace = true } border-core = { version = "0.0.6", path = "../border-core" } [[example]] @@ -67,10 +63,10 @@ path = "examples/gym-robotics/sac_fetch_reach.rs" required-features = ["tch"] test = false -[[example]] -name = "iqn_atari_rs" -required-features = ["tch"] -test = false +# [[example]] +# name = "iqn_atari_rs" +# required-features = ["tch"] +# test = false [[example]] name = "sac_ant" @@ -82,15 +78,15 @@ name = "sac_ant_async" required-features = ["tch", "border-async-trainer"] test = false -[[example]] -name = "make_cfg_dqn_atari" -required-features = ["border-async-trainer"] -test = false +# [[example]] +# name = "make_cfg_dqn_atari" +# required-features = ["border-async-trainer"] +# test = false -[[example]] -name = "make_cfg_iqn_atari" -required-features = ["border-async-trainer"] -test = false +# [[example]] +# name = "make_cfg_iqn_atari" +# required-features = ["border-async-trainer"] +# test = false [dev-dependencies] clap = { workspace = true } diff --git a/border/examples/iqn_atari_rs.rs b/border/examples/atari/iqn_atari.rs similarity index 100% rename from border/examples/iqn_atari_rs.rs rename to border/examples/atari/iqn_atari.rs diff --git a/border/examples/make_cfg_dqn_atari.rs b/border/examples/atari/make_cfg_dqn_atari.rs similarity index 98% rename from border/examples/make_cfg_dqn_atari.rs rename to border/examples/atari/make_cfg_dqn_atari.rs index e8a992b4..3467bb64 100644 --- a/border/examples/make_cfg_dqn_atari.rs +++ b/border/examples/atari/make_cfg_dqn_atari.rs @@ -65,7 +65,6 @@ fn make_trainer_config(env_name: String, params: &Params) -> Result Result IqnConfig { let hidden_dim = params.hidden_dim; let f_config = CnnConfig::new(n_stack, feature_dim) .skip_linear(true); - let m_config = MlpConfig::new(feature_dim, vec![hidden_dim], out_dim); + let m_config = MlpConfig::new(feature_dim, vec![hidden_dim], out_dim, false); let model_config = IqnModelConfig::default() .feature_dim(feature_dim) .embed_dim(params.embed_dim) @@ -69,7 +69,6 @@ fn make_trainer_config(env_name: String, params: &Params) -> Result) -> Self { @@ -130,8 +130,8 @@ mod act { Self(TensorSubBatch::new(capacity)) } - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) } fn sample(&self, ixs: &Vec) -> Self { diff --git a/border/examples/iqn_cartpole.rs b/border/examples/iqn_cartpole.rs index 1f20fda0..92116d01 100644 --- a/border/examples/iqn_cartpole.rs +++ b/border/examples/iqn_cartpole.rs @@ -82,8 +82,8 @@ mod obs { Self(TensorSubBatch::new(capacity)) } - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) } fn sample(&self, ixs: &Vec) -> Self { @@ -136,8 +136,8 @@ mod act { Self(TensorSubBatch::new(capacity)) } - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) } fn sample(&self, ixs: &Vec) -> Self { diff --git a/border/examples/sac_ant_async.rs b/border/examples/sac_ant_async.rs index 60ce75b8..289208e0 100644 --- a/border/examples/sac_ant_async.rs +++ b/border/examples/sac_ant_async.rs @@ -3,22 +3,18 @@ use border_async_trainer::{ actor_stats_fmt, ActorManager as ActorManager_, ActorManagerConfig, AsyncTrainer as AsyncTrainer_, AsyncTrainerConfig, }; -use border_atari_env::{ - BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, - BorderAtariObsRawFilter, -}; use border_core::{ replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - DefaultEvaluator, Env as _, + DefaultEvaluator, }; -use border_derive::{Act, Obs, SubBatch}; +use border_derive::SubBatch; use border_py_gym_env::{ util::{arrayd_to_tensor, tensor_to_arrayd}, - ArrayObsFilter, ContinuousActFilter, GymActFilter, GymContinuousAct, GymEnv, GymEnvConfig, - GymObs, GymObsFilter, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, + GymObsFilter, }; use border_tch_agent::{ mlp::{Mlp, Mlp2, MlpConfig}, diff --git a/border/examples/test_async_trainer.rs b/border/examples/test_async_trainer.rs deleted file mode 100644 index 5b7ce326..00000000 --- a/border/examples/test_async_trainer.rs +++ /dev/null @@ -1,104 +0,0 @@ -use super::{ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel}; -use border_atari_env::util::test::*; -use border_core::{ - record::BufferedRecorder, - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - Env as _, -}; -use crossbeam_channel::unbounded; -use log::info; -use std::sync::{Arc, Mutex}; -use test_log::test; - -fn replay_buffer_config() -> SimpleReplayBufferConfig { - SimpleReplayBufferConfig::default() -} - -fn actor_man_config() -> ActorManagerConfig { - ActorManagerConfig::default() -} - -fn async_trainer_config() -> AsyncTrainerConfig { - AsyncTrainerConfig { - model_dir: Some("".to_string()), - record_interval: 5, - eval_interval: 105, - max_train_steps: 15, - save_interval: 5, - sync_interval: 5, - eval_episodes: 1, - } -} - -impl SyncModel for RandomAgent { - type ModelInfo = (); - - fn model_info(&self) -> (usize, Self::ModelInfo) { - info!("Returns the current model info"); - (self.n_opts_steps(), ()) - } - - fn sync_model(&mut self, _model_info: &Self::ModelInfo) { - info!("Sync model"); - } -} - -#[test] -fn test_async_trainer() { - type Agent = RandomAgent; - type StepProc = SimpleStepProcessor; - type ReplayBuffer = SimpleReplayBuffer; - type ActorManager_ = ActorManager; - type AsyncTrainer_ = AsyncTrainer; - - let env_config = env_config("pong".to_string()); - let env = Env::build(&env_config, 0).unwrap(); - let n_acts = env.get_num_actions_atari() as _; - let agent_config = RandomAgentConfig { n_acts }; - let step_proc_config = SimpleStepProcessorConfig::default(); - let replay_buffer_config = replay_buffer_config(); - let actor_man_config = actor_man_config(); - let async_trainer_config = async_trainer_config(); - let agent_configs = vec![agent_config.clone(); 2]; - - let mut recorder = BufferedRecorder::new(); - - // Shared flag to stop actor threads - let stop = Arc::new(Mutex::new(false)); - - // Pushed items into replay buffer - let (item_s, item_r) = unbounded(); - - // Synchronizing model - let (model_s, model_r) = unbounded(); - - // Prevents simlutaneous initialization of env - let guard_init_env = Arc::new(Mutex::new(true)); - - let mut actors = ActorManager_::build( - &actor_man_config, - &agent_configs, - &env_config, - &step_proc_config, - item_s, - model_r, - stop.clone(), - ); - let mut trainer = AsyncTrainer_::build( - &async_trainer_config, - &agent_config, - &env_config, - &replay_buffer_config, - item_r, - model_s, - stop, - ); - - actors.run(guard_init_env.clone()); - trainer.train(&mut recorder, guard_init_env); - - actors.stop_and_join(); -} diff --git a/border/tests/test_async_trainer.rs b/border/tests/test_async_trainer.rs new file mode 100644 index 00000000..52ad7d12 --- /dev/null +++ b/border/tests/test_async_trainer.rs @@ -0,0 +1,104 @@ +// use border_async_trainer::{ActorManagerConfig, AsyncTrainerConfig, SyncModel}; +// use border_atari_env::util::test::*; +// use border_core::{ +// replay_buffer::SimpleReplayBufferConfig, +// // record::BufferedRecorder, +// // replay_buffer::{ +// // SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, +// // SimpleStepProcessorConfig, +// // }, +// // Env as _, +// }; +// use log::info; +// // use std::sync::{Arc, Mutex}; +// // use test_log::test; + +// fn replay_buffer_config() -> SimpleReplayBufferConfig { +// SimpleReplayBufferConfig::default() +// } + +// fn actor_man_config() -> ActorManagerConfig { +// ActorManagerConfig::default() +// } + +// fn async_trainer_config() -> AsyncTrainerConfig { +// AsyncTrainerConfig { +// model_dir: Some("".to_string()), +// record_interval: 5, +// eval_interval: 105, +// max_train_steps: 15, +// save_interval: 5, +// sync_interval: 5, +// eval_episodes: 1, +// } +// } + +// impl SyncModel for RandomAgent { +// type ModelInfo = (); + +// fn model_info(&self) -> (usize, Self::ModelInfo) { +// info!("Returns the current model info"); +// (self.n_opts_steps(), ()) +// } + +// fn sync_model(&mut self, _model_info: &Self::ModelInfo) { +// info!("Sync model"); +// } +// } + +// #[test] +// fn test_async_trainer() { +// type Agent = RandomAgent; +// type StepProc = SimpleStepProcessor; +// type ReplayBuffer = SimpleReplayBuffer; +// type ActorManager_ = ActorManager; +// type AsyncTrainer_ = AsyncTrainer; + +// let env_config = env_config("pong".to_string()); +// let env = Env::build(&env_config, 0).unwrap(); +// let n_acts = env.get_num_actions_atari() as _; +// let agent_config = RandomAgentConfig { n_acts }; +// let step_proc_config = SimpleStepProcessorConfig::default(); +// let replay_buffer_config = replay_buffer_config(); +// let actor_man_config = actor_man_config(); +// let async_trainer_config = async_trainer_config(); +// let agent_configs = vec![agent_config.clone(); 2]; + +// let mut recorder = BufferedRecorder::new(); + +// // Shared flag to stop actor threads +// let stop = Arc::new(Mutex::new(false)); + +// // Pushed items into replay buffer +// let (item_s, item_r) = unbounded(); + +// // Synchronizing model +// let (model_s, model_r) = unbounded(); + +// // Prevents simlutaneous initialization of env +// let guard_init_env = Arc::new(Mutex::new(true)); + +// let mut actors = ActorManager_::build( +// &actor_man_config, +// &agent_configs, +// &env_config, +// &step_proc_config, +// item_s, +// model_r, +// stop.clone(), +// ); +// let mut trainer = AsyncTrainer_::build( +// &async_trainer_config, +// &agent_config, +// &env_config, +// &replay_buffer_config, +// item_r, +// model_s, +// stop, +// ); + +// actors.run(guard_init_env.clone()); +// trainer.train(&mut recorder, guard_init_env); + +// actors.stop_and_join(); +// } diff --git a/docker/aarch64/Dockerfile b/docker/aarch64/Dockerfile index 2a6ba1a8..819435f7 100644 --- a/docker/aarch64/Dockerfile +++ b/docker/aarch64/Dockerfile @@ -78,6 +78,7 @@ RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install mlflow==2.11.1 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 diff --git a/docker/aarch64_doc/Dockerfile b/docker/aarch64_doc/Dockerfile deleted file mode 100644 index 42455746..00000000 --- a/docker/aarch64_doc/Dockerfile +++ /dev/null @@ -1,51 +0,0 @@ -FROM ubuntu:focal-20221130 - -ENV DEBIAN_FRONTEND noninteractive -RUN echo "Set disable_coredump false" >> /etc/sudo.conf -RUN apt-get update -q && \ - apt-get upgrade -yq && \ - apt-get install -yq wget curl git build-essential vim sudo libssl-dev -RUN apt install -y -q libclang-dev zip cmake llvm pkg-config libx11-dev libxkbcommon-dev - -# # swig -# RUN apt install -y swig - -# # python -# RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip - -# # headers required for building libtorch -# RUN apt install -y libgoogle-glog-dev libgflags-dev - -# # llvm, mesa for robosuite -# RUN apt install -y llvm libosmesa6-dev - -# # Used for Mujoco -# RUN apt install -y patchelf libglfw3 libglfw3-dev - -# Cleanup -RUN rm -rf /var/lib/apt/lists/* - -# COPY test_mujoco_py.py /test_mujoco_py.py -# RUN chmod 777 /test_mujoco_py.py - -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers - -# Use bash -RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh - -# User settings -USER ubuntu - -# rustup -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - -RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc - -USER root -RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh - -USER ubuntu -WORKDIR /home/ubuntu/border diff --git a/docker/aarch64_doc/README.md b/docker/aarch64_doc/README.md index 2e3eeb1e..6a9ef17c 100644 --- a/docker/aarch64_doc/README.md +++ b/docker/aarch64_doc/README.md @@ -1,39 +1,7 @@ -# Docker container for training - -This directory contains scripts to build and run a docker container for training. - -## Build - -The following command creates a container image locally, named `border_headless`. +Run the following command to create docs in border/doc. ```bash -cd $REPO/docker/aarch64_headless -sh build.sh +# in border/aarch64_doc +sh doc.sh ``` -# Build document - -The following commands builds the document and places it as `$REPO/doc`. - -## Run - -The following commands runs a program for training an agent. -The trained model will be saved in `$REPO/border/examples/model` directory, -which is mounted in the container. - -### DQN - -* Cartpole - - ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && cargo run --example dqn_cartpole --features='tch' -- --train" - ``` - - * Use a directory, not mounted on the host, as a cargo target directory, - making compile faster on Mac, where access to mounted directories is slow. - - ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && CARGO_TARGET_DIR=/home/ubuntu/target cargo run --example dqn_cartpole --features='tch' -- --train" - ``` diff --git a/docker/aarch64_doc/build.sh b/docker/aarch64_doc/build.sh deleted file mode 100644 index 12e19a79..00000000 --- a/docker/aarch64_doc/build.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -docker build -t border_doc . diff --git a/docker/aarch64_doc/doc.sh b/docker/aarch64_doc/doc.sh index a5609c6b..4cf37c27 100644 --- a/docker/aarch64_doc/doc.sh +++ b/docker/aarch64_doc/doc.sh @@ -2,5 +2,5 @@ docker run -it --rm \ --name border_headless \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ - border_doc bash -l -c \ - "CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + border_headless bash -l -c \ + "cd /home/ubuntu/border; CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ."