From 1597b8689ce76edf5d23146a2916ee1f6c8a65a4 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 6 Jul 2024 01:14:30 +0900 Subject: [PATCH 01/21] WIP: Add edge policy (#19) --- Cargo.toml | 1 + border-edge-policy/Cargo.toml | 24 +++++++++ border-edge-policy/src/lib.rs | 3 ++ border-edge-policy/src/mat.rs | 83 ++++++++++++++++++++++++++++++++ border-edge-policy/tests/test.rs | 24 +++++++++ 5 files changed, 135 insertions(+) create mode 100644 border-edge-policy/Cargo.toml create mode 100644 border-edge-policy/src/lib.rs create mode 100644 border-edge-policy/src/mat.rs create mode 100644 border-edge-policy/tests/test.rs diff --git a/Cargo.toml b/Cargo.toml index 1eff1f68..a7263122 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "border-derive", "border-atari-env", "border-async-trainer", + "border-edge-policy", "border", ] exclude = ["docker/"] diff --git a/border-edge-policy/Cargo.toml b/border-edge-policy/Cargo.toml new file mode 100644 index 00000000..4a96a84e --- /dev/null +++ b/border-edge-policy/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "border-edge-policy" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.6", path = "../border-core" } +serde = { workspace = true, features = ["derive"] } +log = { workspace = true } +anyhow = { workspace = true } +tch = { workspace = true, optional = true } + +[dev-dependencies] +tempdir = { workspace = true } +tch = { workspace = true } + +[features] +tch = ["dep:tch"] diff --git a/border-edge-policy/src/lib.rs b/border-edge-policy/src/lib.rs new file mode 100644 index 00000000..f99690cc --- /dev/null +++ b/border-edge-policy/src/lib.rs @@ -0,0 +1,3 @@ +mod mat; + +pub use mat::Mat; diff --git a/border-edge-policy/src/mat.rs b/border-edge-policy/src/mat.rs new file mode 100644 index 00000000..f3b019a6 --- /dev/null +++ b/border-edge-policy/src/mat.rs @@ -0,0 +1,83 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct Mat { + pub data: Vec, + pub shape: Vec, +} + +#[cfg(feature = "tch")] +impl From for Mat { + fn from(x: tch::Tensor) -> Self { + let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); + let (n, shape) = match shape.len() { + 1 => (shape[0] as usize, vec![shape[0], 1]), + 2 => ((shape[0] * shape[1]) as usize, shape), + _ => panic!("Invalid matrix size: {:?}", shape), + }; + let mut data: Vec = vec![0f32; n]; + x.f_copy_data(&mut data, n).unwrap(); + Self { data, shape } + } +} + +impl Mat { + pub fn matmul(&self, x: &Mat) -> Self { + let (m, l, n) = ( + self.shape[0] as usize, + self.shape[1] as usize, + x.shape[1] as usize, + ); + let mut data = vec![0.0f32; (m * n) as usize]; + for i in 0..m as usize { + for j in 0..n as usize { + let kk = i * n as usize + j; + for k in 0..l as usize { + data[kk] += self.data[i * l + k] * x.data[k * n + j]; + } + } + } + + Self { + shape: vec![m as _, n as _], + data, + } + } + + pub fn add(&self, x: &Mat) -> Self { + if self.shape[0] != x.shape[0] || self.shape[1] != x.shape[1] { + panic!( + "Trying to add matrices of different sizes: {:?}", + (&self.shape, &x.shape) + ); + } + + let data = self + .data + .iter() + .zip(x.data.iter()) + .map(|(a, b)| *a + *b) + .collect(); + + Mat { + data, + shape: self.shape.clone(), + } + } + + pub fn relu(&self) -> Self { + let data = self + .data + .iter() + .map(|a| match *a < 0. { + true => 0., + false => *a, + }) + .collect(); + + Self { + data, + shape: self.shape.clone(), + } + } +} diff --git a/border-edge-policy/tests/test.rs b/border-edge-policy/tests/test.rs new file mode 100644 index 00000000..f67b4e1d --- /dev/null +++ b/border-edge-policy/tests/test.rs @@ -0,0 +1,24 @@ +use tch::Tensor; +use border_edge_policy::Mat; + +#[test] +fn test_matmul() { + let x1 = Tensor::from_slice2(&[&[1.0f32, 2., 3.], &[4., 5., 6.]]); + let y1 = Tensor::from_slice(&[7.0f32, 8., 9.]); + let z1 = x1.matmul(&y1); + + let x2: Mat = x1.into(); + let y2: Mat = y1.into(); + let z2 = x2.matmul(&y2); + + let z3 = { + let mut data = vec![0.0f32; 2]; + z1.f_copy_data(&mut data, 2).unwrap(); + Mat { + shape: vec![2 as _, 1 as _], + data, + } + }; + + assert_eq!(z2, z3) +} From 258d4598aa5f572ec01e0e8d6bd25757890ebc4c Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 6 Jul 2024 11:28:20 +0900 Subject: [PATCH 02/21] Remove old dirs (#20) --- border/examples/model/.gitignore | 10 ------ .../tch/dqn_HeroNoFrameskip-v4/agent.yaml | 26 -------------- .../tch/dqn_HeroNoFrameskip-v4/model.yaml | 8 ----- .../dqn_HeroNoFrameskip-v4/replay_buffer.yaml | 4 --- .../tch/dqn_HeroNoFrameskip-v4/trainer.yaml | 9 ----- .../tch/dqn_PongNoFrameskip-v4/agent.yaml | 25 -------------- .../tch/dqn_PongNoFrameskip-v4/model.yaml | 8 ----- .../dqn_PongNoFrameskip-v4/replay_buffer.yaml | 4 --- .../tch/dqn_PongNoFrameskip-v4/trainer.yaml | 9 ----- .../dqn_PongNoFrameskip-v4_debug/agent.yaml | 26 -------------- .../replay_buffer.yaml | 4 --- .../dqn_PongNoFrameskip-v4_debug/trainer.yaml | 9 ----- .../tch/dqn_PongNoFrameskip-v4_per/agent.yaml | 24 ------------- .../replay_buffer.yaml | 8 ----- .../dqn_PongNoFrameskip-v4_per/trainer.yaml | 9 ----- .../tch/dqn_PongNoFrameskip-v4_vec/agent.yaml | 19 ----------- .../tch/dqn_PongNoFrameskip-v4_vec/model.yaml | 8 ----- .../dqn_PongNoFrameskip-v4_vec/trainer.yaml | 6 ---- .../tch/dqn_SeaquestNoFrameskip-v4/agent.yaml | 19 ----------- .../dqn_SeaquestNoFrameskip-v4/trainer.yaml | 6 ---- .../tch/iqn_PongNoFrameskip-v4/agent.yaml | 22 ------------ .../tch/iqn_PongNoFrameskip-v4/model.yaml | 15 -------- .../tch/iqn_PongNoFrameskip-v4/trainer.yaml | 6 ---- .../tch/iqn_SeaquestNoFrameskip-v4/agent.yaml | 34 ------------------- .../replay_buffer.yaml | 4 --- .../iqn_SeaquestNoFrameskip-v4/trainer.yaml | 9 ----- 26 files changed, 331 deletions(-) delete mode 100644 border/examples/model/.gitignore delete mode 100644 border/examples/model/tch/dqn_HeroNoFrameskip-v4/agent.yaml delete mode 100644 border/examples/model/tch/dqn_HeroNoFrameskip-v4/model.yaml delete mode 100644 border/examples/model/tch/dqn_HeroNoFrameskip-v4/replay_buffer.yaml delete mode 100644 border/examples/model/tch/dqn_HeroNoFrameskip-v4/trainer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4/agent.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4/model.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4/replay_buffer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4/trainer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/agent.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/trainer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_per/agent.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_per/trainer.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/agent.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/model.yaml delete mode 100644 border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/trainer.yaml delete mode 100644 border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/agent.yaml delete mode 100644 border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/trainer.yaml delete mode 100644 border/examples/model/tch/iqn_PongNoFrameskip-v4/agent.yaml delete mode 100644 border/examples/model/tch/iqn_PongNoFrameskip-v4/model.yaml delete mode 100644 border/examples/model/tch/iqn_PongNoFrameskip-v4/trainer.yaml delete mode 100644 border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/agent.yaml delete mode 100644 border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml delete mode 100644 border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/trainer.yaml diff --git a/border/examples/model/.gitignore b/border/examples/model/.gitignore deleted file mode 100644 index f59f4eff..00000000 --- a/border/examples/model/.gitignore +++ /dev/null @@ -1,10 +0,0 @@ -# !.gitignore -# */* -# * -*.pt -events* -*.csv -*.zip -*.gz -backup -dqn_cartpole_tch/** diff --git a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/agent.yaml b/border/examples/model/tch/dqn_HeroNoFrameskip-v4/agent.yaml deleted file mode 100644 index 9e9cf22d..00000000 --- a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,26 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/model.yaml b/border/examples/model/tch/dqn_HeroNoFrameskip-v4/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/tch/dqn_HeroNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 079309ba..00000000 --- a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 1048576 -seed: 42 -per_config: ~ diff --git a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/trainer.yaml b/border/examples/model/tch/dqn_HeroNoFrameskip-v4/trainer.yaml deleted file mode 100644 index c2c148c7..00000000 --- a/border/examples/model/tch/dqn_HeroNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 50000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_HeroNoFrameskip-v4" -opt_interval: 1 -eval_interval: 500000 -record_interval: 50000 -save_interval: 10000000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4/agent.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4/agent.yaml deleted file mode 100644 index fd66381b..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,25 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4/model.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 1e0ce1e7..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4/trainer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4/trainer.yaml deleted file mode 100644 index aa9b153c..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 3000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4" -opt_interval: 1 -eval_interval: 50000 -record_interval: 50000 -save_interval: 500000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/agent.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/agent.yaml deleted file mode 100644 index 9e9cf22d..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/agent.yaml +++ /dev/null @@ -1,26 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml deleted file mode 100644 index 1e0ce1e7..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/trainer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/trainer.yaml deleted file mode 100644 index 9fc0f649..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_debug/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 1000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4_debug" -opt_interval: 100 -eval_interval: 50000 -record_interval: 100 -save_interval: 500000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/agent.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/agent.yaml deleted file mode 100644 index 40b79f77..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/agent.yaml +++ /dev/null @@ -1,24 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -phantom: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml deleted file mode 100644 index 942608b4..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: - alpha: 0.6000000238418579 - beta_0: 0.4000000059604645 - beta_final: 1.0 - n_opts_final: 500000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/trainer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/trainer.yaml deleted file mode 100644 index fe02da59..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_per/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 3000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4_per" -opt_interval: 1 -eval_interval: 50000 -record_interval: 50000 -save_interval: 500000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/agent.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/agent.yaml deleted file mode 100644 index db95a127..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/agent.yaml +++ /dev/null @@ -1,19 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -replay_burffer_capacity: 50000 diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/model.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/trainer.yaml b/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/trainer.yaml deleted file mode 100644 index a8e92db7..00000000 --- a/border/examples/model/tch/dqn_PongNoFrameskip-v4_vec/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 3000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/dqn_PongNoFrameskip-v4_vec" diff --git a/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/agent.yaml b/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/agent.yaml deleted file mode 100644 index 55f99184..00000000 --- a/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,19 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -train: false -discount_factor: 0.99 -tau: 1.0 -replay_burffer_capacity: 1000000 -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 diff --git a/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/trainer.yaml b/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/trainer.yaml deleted file mode 100644 index 9ff6f869..00000000 --- a/border/examples/model/tch/dqn_SeaquestNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 50000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/dqn_SeaquestNoFrameskip-v4" diff --git a/border/examples/model/tch/iqn_PongNoFrameskip-v4/agent.yaml b/border/examples/model/tch/iqn_PongNoFrameskip-v4/agent.yaml deleted file mode 100644 index 01f6f654..00000000 --- a/border/examples/model/tch/iqn_PongNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,22 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -sample_percents_pred: Uniform8 -sample_percents_tgt: Uniform8 -sample_percents_act: Uniform32 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -replay_buffer_capacity: 50000 diff --git a/border/examples/model/tch/iqn_PongNoFrameskip-v4/model.yaml b/border/examples/model/tch/iqn_PongNoFrameskip-v4/model.yaml deleted file mode 100644 index c7258234..00000000 --- a/border/examples/model/tch/iqn_PongNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,15 +0,0 @@ ---- -feature_dim: 3136 -embed_dim: 64 -learning_rate: 0.0001 -f_config: - n_stack: 4 - feature_dim: 3136 -m_config: - in_dim: 3136 - hidden_dim: 512 - out_dim: 0 -phantom: ~ -opt_config: - Adam: - lr: 0.001 diff --git a/border/examples/model/tch/iqn_PongNoFrameskip-v4/trainer.yaml b/border/examples/model/tch/iqn_PongNoFrameskip-v4/trainer.yaml deleted file mode 100644 index c23fa53c..00000000 --- a/border/examples/model/tch/iqn_PongNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 5000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/iqn_PongNoFrameskip-v4" diff --git a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/agent.yaml b/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/agent.yaml deleted file mode 100644 index b400c5dd..00000000 --- a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,34 +0,0 @@ ---- -model_config: - feature_dim: 3136 - embed_dim: 64 - f_config: - n_stack: 4 - out_dim: 3136 - skip_linear: true - m_config: - in_dim: 3136 - units: - - 512 - out_dim: 0 - opt_config: - Adam: - lr: 0.00001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -sample_percents_pred: Uniform64 -sample_percents_tgt: Uniform64 -sample_percents_act: Uniform32 -device: ~ -phantom: ~ diff --git a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 079309ba..00000000 --- a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 1048576 -seed: 42 -per_config: ~ diff --git a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/trainer.yaml b/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/trainer.yaml deleted file mode 100644 index ebc92b01..00000000 --- a/border/examples/model/tch/iqn_SeaquestNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 50000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/iqn_SeaquestNoFrameskip-v4" -opt_interval: 1 -eval_interval: 500000 -record_interval: 500000 -save_interval: 10000000 From dde51f91e77ccac27c59af28de91d895a825a714 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 6 Jul 2024 17:53:45 +0900 Subject: [PATCH 03/21] WIP: Bump the version of clap (#20) --- Cargo.toml | 4 +- border-async-trainer/Cargo.toml | 4 +- border-atari-env/Cargo.toml | 2 +- border-candle-agent/Cargo.toml | 4 +- border-derive/Cargo.toml | 8 +- border-mlflow-tracking/Cargo.toml | 2 +- border-py-gym-env/Cargo.toml | 2 +- border-tch-agent/Cargo.toml | 4 +- border-tch-agent/src/iqn/base.rs | 4 +- border-tensorboard/Cargo.toml | 2 +- border/Cargo.toml | 20 +- .../examples/gym-robotics/sac_fetch_reach.rs | 277 +++++++++--------- border/examples/gym/dqn_cartpole.rs | 75 +++-- border/examples/gym/dqn_cartpole_tch.rs | 80 +++-- border/examples/gym/iqn_cartpole_tch.rs | 4 +- border/examples/gym/sac_lunarlander_cont.rs | 60 ++-- .../examples/gym/sac_lunarlander_cont_tch.rs | 60 ++-- border/examples/gym/sac_pendulum.rs | 73 ++--- border/examples/gym/sac_pendulum_tch.rs | 63 ++-- 19 files changed, 349 insertions(+), 399 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a7263122..9ea44c8d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ exclude = ["docker/"] [workspace.package] -version = "0.0.6" +version = "0.0.7" edition = "2018" rust-version = "1.70" description = "Reinforcement learning library" @@ -25,7 +25,7 @@ categories = ["science"] license = "MIT OR Apache-2.0" [workspace.dependencies] -clap = "2.33.3" +clap = { version = "4.5.8", features = ["derive"] } csv = "1.1.5" fastrand = "1.4.0" tch = "0.16.0" diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index 572817b7..be81813b 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -12,8 +12,8 @@ readme = "README.md" [dependencies] anyhow = { workspace = true } aquamarine = { workspace = true } -border-core = { version = "0.0.6", path = "../border-core" } -border-tensorboard = { version = "0.0.6", path = "../border-tensorboard" } +border-core = { version = "0.0.7", path = "../border-core" } +border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } serde = { workspace = true, features = ["derive"] } log = { workspace = true } tokio = { version = "1.14.0", features = ["full"] } diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index 727aa09a..a29fbc6b 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -14,7 +14,7 @@ anyhow = { workspace = true } pixels = { version = "0.2.0", optional = true } winit = { version = "0.24.0", optional = true } dirs = { workspace = true } -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } image = { workspace = true } tch = { workspace = true, optional = true } candle-core = { workspace = true, optional = true } diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml index f2549cbe..e06c2a16 100644 --- a/border-candle-agent/Cargo.toml +++ b/border-candle-agent/Cargo.toml @@ -10,8 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } -border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +border-core = { version = "0.0.7", path = "../border-core" } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } tensorboard-rs = { workspace = true } diff --git a/border-derive/Cargo.toml b/border-derive/Cargo.toml index 453a981a..53ba7fff 100644 --- a/border-derive/Cargo.toml +++ b/border-derive/Cargo.toml @@ -21,10 +21,10 @@ tch = { workspace = true, optional = true } candle-core = { workspace = true, optional = true } [dev-dependencies] -border-tch-agent = { version = "0.0.6", path = "../border-tch-agent" } -border-candle-agent = { version = "0.0.6", path = "../border-candle-agent" } -border-py-gym-env = { version = "0.0.6", path = "../border-py-gym-env" } -border-core = { version = "0.0.6", path = "../border-core" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } +border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } +border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } +border-core = { version = "0.0.7", path = "../border-core" } ndarray = { workspace = true } tch = { workspace = true } candle-core = { workspace = true } diff --git a/border-mlflow-tracking/Cargo.toml b/border-mlflow-tracking/Cargo.toml index 2fe1d2fb..08022a6e 100644 --- a/border-mlflow-tracking/Cargo.toml +++ b/border-mlflow-tracking/Cargo.toml @@ -10,7 +10,7 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } reqwest = { workspace = true } anyhow = { workspace = true } serde = { workspace = true, features = ["derive"] } diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index 97b7145c..a996eccb 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -10,7 +10,7 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index 46a9e276..b3ec2531 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -10,8 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } -border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +border-core = { version = "0.0.7", path = "../border-core" } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } tensorboard-rs = { workspace = true } diff --git a/border-tch-agent/src/iqn/base.rs b/border-tch-agent/src/iqn/base.rs index 9314323d..2525b67f 100644 --- a/border-tch-agent/src/iqn/base.rs +++ b/border-tch-agent/src/iqn/base.rs @@ -63,7 +63,9 @@ where let obs = obs.into(); let act = act.into().to(self.device); let next_obs = next_obs.into(); - let reward = Tensor::from_slice(&reward[..]).to(self.device).unsqueeze(-1); + let reward = Tensor::from_slice(&reward[..]) + .to(self.device) + .unsqueeze(-1); let is_terminated = Tensor::from_slice(&is_terminated[..]) .to(self.device) .unsqueeze(-1); diff --git a/border-tensorboard/Cargo.toml b/border-tensorboard/Cargo.toml index 8c4b2854..c7b709d2 100644 --- a/border-tensorboard/Cargo.toml +++ b/border-tensorboard/Cargo.toml @@ -10,5 +10,5 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } tensorboard-rs = { workspace = true } diff --git a/border/Cargo.toml b/border/Cargo.toml index 53bc2d44..25af06c9 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -13,13 +13,13 @@ readme = "README.md" aquamarine = { workspace = true } tch = { workspace = true, optional = true } candle-core = { workspace = true, optional = true } -border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } anyhow = { workspace = true } log = { workspace = true } dirs = { workspace = true } zip = "0.5.12" reqwest = { workspace = true } -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } [[example]] name = "dqn_cartpole" @@ -129,14 +129,14 @@ test = false clap = { workspace = true } csv = { workspace = true } tempdir = { workspace = true } -border-derive = { version = "0.0.6", path = "../border-derive" } -border-core = { version = "0.0.6", path = "../border-core" } -border-tensorboard = { version = "0.0.6", path = "../border-tensorboard" } -border-tch-agent = { version = "0.0.6", path = "../border-tch-agent" } -border-py-gym-env = { version = "0.0.6", path = "../border-py-gym-env" } -border-atari-env = { version = "0.0.6", path = "../border-atari-env" } -border-candle-agent = { version = "0.0.6", path = "../border-candle-agent" } -border-mlflow-tracking = { version = "0.0.6", path = "../border-mlflow-tracking" } +border-derive = { version = "0.0.7", path = "../border-derive" } +border-core = { version = "0.0.7", path = "../border-core" } +border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } +border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } +border-atari-env = { version = "0.0.7", path = "../border-atari-env" } +border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } +border-mlflow-tracking = { version = "0.0.7", path = "../border-mlflow-tracking" } serde = { workspace = true, features = ["derive"] } crossbeam-channel = { workspace = true } env_logger = { workspace = true } diff --git a/border/examples/gym-robotics/sac_fetch_reach.rs b/border/examples/gym-robotics/sac_fetch_reach.rs index f8766c74..8899f6e8 100644 --- a/border/examples/gym-robotics/sac_fetch_reach.rs +++ b/border/examples/gym-robotics/sac_fetch_reach.rs @@ -1,13 +1,15 @@ use anyhow::Result; use border_core::{ - record::{Record, RecordValue, Recorder}, - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, + record::{AggregateRecorder, Record, RecordValue}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; -use border_derive::SubBatch; +use border_derive::BatchBase; +use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ util::{arrayd_to_pyobj, arrayd_to_tensor, tensor_to_arrayd, ArrayType}, ArrayDictObsFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, @@ -20,12 +22,11 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; // use csv::WriterBuilder; use ndarray::ArrayD; use pyo3::PyObject; // use serde::Serialize; -use border_mlflow_tracking::MlflowTrackingClient; use std::convert::TryFrom; use tch::Tensor; @@ -45,15 +46,20 @@ const TAU: f64 = 0.02; const TARGET_ENTROPY: f64 = -(DIM_ACT as f64); const LR_ENT_COEF: f64 = 3e-4; const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR: &str = "./border/examples/gym-robotics/model/tch/sac_fetch_reach/"; -mod obs { +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod obs_act_types { use super::*; use border_py_gym_env::util::Array; #[derive(Clone, Debug)] pub struct Obs(Vec<(String, Array)>); - #[derive(Clone, SubBatch)] + #[derive(Clone, BatchBase)] pub struct ObsBatch(TensorBatch); impl border_core::Obs for Obs { @@ -89,10 +95,6 @@ mod obs { Self(TensorBatch::from_tensor(tensor)) } } -} - -mod act { - use super::*; #[derive(Clone, Debug)] pub struct Act(ArrayD); @@ -118,7 +120,7 @@ mod act { } } - #[derive(SubBatch)] + #[derive(BatchBase)] pub struct ActBatch(TensorBatch); impl From for ActBatch { @@ -157,101 +159,136 @@ mod act { (arrayd_to_pyobj(act_filt), record) } } -} -use act::{Act, ActBatch, ActFilter}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayDictObsFilter; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Evaluator = DefaultEvaluator>; - -fn create_agent(in_dim: i64, out_dim: i64) -> Sac { - let device = tch::Device::cuda_if_available(); - let actor_config = ActorConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) - .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, false)); - let critic_config = CriticConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, false)); - let sac_config = SacConfig::default() - .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .actor_config(actor_config) - .critic_config(critic_config) - .tau(TAU) - .critic_loss(CRITIC_LOSS) - .n_critics(N_CRITICS) - .ent_coef_mode(EntCoefMode::Auto(TARGET_ENTROPY, LR_ENT_COEF)) - .device(device); - Sac::build(sac_config) + pub type ObsFilter = ArrayDictObsFilter; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; } -fn env_config() -> GymEnvConfig { - GymEnvConfig::::default() - .name("FetchReach-v2".to_string()) - .obs_filter_config(ObsFilter::default_config().add_key_and_types(vec![ - ("observation", ArrayType::F32Array), - ("desired_goal", ArrayType::F32Array), - ("achieved_goal", ArrayType::F32Array), - ])) - .act_filter_config(ActFilter::default_config()) +use obs_act_types::*; + +mod config { + use super::*; + + pub fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("FetchReach-v2".to_string()) + .obs_filter_config(ObsFilter::default_config().add_key_and_types(vec![ + ("observation", ArrayType::F32Array), + ("desired_goal", ArrayType::F32Array), + ("achieved_goal", ArrayType::F32Array), + ])) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_trainer_config(max_opts: usize, model_dir: &str) -> TrainerConfig { + TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(EVAL_INTERVAL) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(N_TRANSITIONS_WARMUP) + .model_dir(model_dir) + } + + pub fn create_sac_config(dim_obs: i64, dim_act: i64, target_ent: f64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(dim_act) + .pi_config(MlpConfig::new(dim_obs, vec![64, 64], dim_act, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(dim_obs + dim_act, vec![64, 64], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .tau(TAU) + .critic_loss(CRITIC_LOSS) + .n_critics(N_CRITICS) + .ent_coef_mode(EntCoefMode::Auto(target_ent, LR_ENT_COEF)) + .device(device) + } } -fn create_recorder( - model_dir: &str, - mlflow: bool, - config: &TrainerConfig, -) -> Result> { - match mlflow { - true => { - let client = - MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; - let recorder_run = client.create_recorder("")?; - recorder_run.log_params(&config)?; - Ok(Box::new(recorder_run)) +mod utils { + use super::*; + + pub fn create_recorder( + model_dir: &str, + mlflow: bool, + config: &TrainerConfig, + ) -> Result> { + match mlflow { + true => { + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment_id("Fetch")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "reach")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } - false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } -fn train(max_opts: usize, model_dir: &str, eval_interval: usize, mlflow: bool) -> Result<()> { - let (mut trainer, config) = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() - .max_opts(max_opts) - .opt_interval(OPT_INTERVAL) - .eval_interval(eval_interval) - .record_interval(eval_interval) - .save_interval(eval_interval) - .model_dir(model_dir); - let trainer = Trainer::::build( - config.clone(), - env_config, - step_proc_config, - replay_buffer_config, - ); - - (trainer, config) - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = create_recorder(model_dir, mlflow, &config)?; - let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; +/// Train/eval SAC agent in fetch environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(max_opts: usize, model_dir: &str, mlflow: bool) -> Result<()> { + let trainer_config = config::create_trainer_config(max_opts, model_dir); + let env_config = config::env_config(); + let step_proc_config = SimpleStepProcessorConfig {}; + let sac_config = config::create_sac_config(DIM_OBS, DIM_ACT, TARGET_ENTROPY); + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut recorder = utils::create_recorder(model_dir, mlflow, &trainer_config)?; + let mut trainer = Trainer::build(trainer_config); + let mut agent = Sac::build(sac_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&config::env_config(), 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; Ok(()) } fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { let env_config = { - let mut env_config = env_config(); + let mut env_config = config::env_config(); if render { env_config = env_config .render_mode(Some("human".to_string())) @@ -260,7 +297,7 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { env_config }; let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); + let mut agent = Sac::build(config::create_sac_config(DIM_OBS, DIM_ACT, TARGET_ENTROPY)); agent.load(model_dir)?; agent.eval(); agent @@ -272,55 +309,19 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { Ok(()) } -fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_fetch_reach") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() -} - fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = create_matches(); - let do_train = matches.is_present("train"); - let do_eval = matches.is_present("eval"); - let mlflow = matches.is_present("mlflow"); - - if !do_train && !do_eval { - println!("You need to give either --train or --eval in the command line argument."); - return Ok(()); - } + let args = Args::parse(); - if do_train { - train( - MAX_OPTS, - "./border/examples/model/sac_fetch_reach", - EVAL_INTERVAL, - mlflow, - )?; - } - if do_eval { - eval(5, true, "./border/examples/model/sac_fetch_reach/best")?; + if args.train { + train(MAX_OPTS, MODEL_DIR, args.mlflow)?; + } else if args.eval { + eval(5, true, format!("{}/best", MODEL_DIR).as_str())?; + } else { + train(MAX_OPTS, MODEL_DIR, args.mlflow)?; + eval(5, true, format!("{}/best", MODEL_DIR).as_str())?; } Ok(()) @@ -337,7 +338,7 @@ mod test { let model_dir = TempDir::new("sac_fetch_reach")?; let model_dir = model_dir.path().to_str().unwrap(); - train(100, model_dir, 100, false)?; + train(100, model_dir, false)?; eval(1, false, (model_dir.to_string() + "/best").as_str())?; Ok(()) diff --git a/border/examples/gym/dqn_cartpole.rs b/border/examples/gym/dqn_cartpole.rs index 87a805c6..322106ba 100644 --- a/border/examples/gym/dqn_cartpole.rs +++ b/border/examples/gym/dqn_cartpole.rs @@ -22,7 +22,7 @@ use border_py_gym_env::{ }; use border_tensorboard::TensorboardRecorder; use candle_core::{Device, Tensor}; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; use ndarray::{ArrayD, IxDyn}; use serde::Serialize; @@ -234,11 +234,11 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, model_dir: &str, config: &DqnCartpoleConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -252,43 +252,30 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("dqn_cartpole") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() - } +/// Train/eval DQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train DQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } -fn train( - matches: &ArgMatches, - max_opts: usize, - model_dir: &str, - eval_interval: usize, -) -> Result<()> { +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); let step_proc_config = SimpleStepProcessorConfig {}; let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let mut recorder = utils::create_recorder(&matches, model_dir, &config)?; + let mut recorder = utils::create_recorder(&args, model_dir, &config)?; let mut trainer = Trainer::build(config.trainer_config.clone()); let env = Env::build(&config.env_config, 0)?; @@ -334,14 +321,15 @@ fn eval(model_dir: &str, render: bool) -> Result<()> { fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); // TODO: set seed - let matches = utils::create_matches(); - if matches.is_present("train") { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; - } else if matches.is_present("eval") { + let args = Args::parse(); + + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } else { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } @@ -350,8 +338,6 @@ fn main() -> Result<()> { #[cfg(test)] mod tests { - use crate::utils::create_matches; - use super::{eval, train}; use anyhow::Result; use tempdir::TempDir; @@ -363,7 +349,12 @@ mod tests { Some(s) => s, None => panic!("Failed to get string of temporary directory"), }; - train(&create_matches(), 100, model_dir, 100)?; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; eval(&(model_dir.to_owned() + "/best"), false)?; Ok(()) } diff --git a/border/examples/gym/dqn_cartpole_tch.rs b/border/examples/gym/dqn_cartpole_tch.rs index 36451818..688740d5 100644 --- a/border/examples/gym/dqn_cartpole_tch.rs +++ b/border/examples/gym/dqn_cartpole_tch.rs @@ -20,7 +20,7 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; use ndarray::{ArrayD, IxDyn}; use serde::Serialize; use std::convert::TryFrom; @@ -114,8 +114,8 @@ mod obs_act_types { impl From for Act { // `t` must be a 1-dimentional tensor of `f32` fn from(t: Tensor) -> Self { - let data = - Vec::::try_from(&t.flatten(0, -1)).expect("Failed to convert from Tensor to Vec"); + let data = Vec::::try_from(&t.flatten(0, -1)) + .expect("Failed to convert from Tensor to Vec"); let data = data.iter().map(|&e| e as i32).collect(); Act(data) } @@ -234,11 +234,11 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, model_dir: &str, config: &DqnCartpoleConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -252,43 +252,30 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("dqn_cartpole_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() - } +/// Train/eval DQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train DQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } -fn train( - matches: &ArgMatches, - max_opts: usize, - model_dir: &str, - eval_interval: usize, -) -> Result<()> { +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); let step_proc_config = SimpleStepProcessorConfig {}; let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let mut recorder = utils::create_recorder(&matches, model_dir, &config)?; + let mut recorder = utils::create_recorder(&args, model_dir, &config)?; let mut trainer = Trainer::build(config.trainer_config.clone()); let env = Env::build(&config.env_config, 0)?; @@ -335,14 +322,14 @@ fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("train") { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; - } else if matches.is_present("eval") { + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } else { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } @@ -351,9 +338,7 @@ fn main() -> Result<()> { #[cfg(test)] mod tests { - use crate::utils::create_matches; - - use super::{eval, train}; + use super::{eval, train, Args}; use anyhow::Result; use tempdir::TempDir; @@ -364,7 +349,12 @@ mod tests { Some(s) => s, None => panic!("Failed to get string of temporary directory"), }; - train(&create_matches(), 100, model_dir, 100)?; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; eval(&(model_dir.to_owned() + "/best"), false)?; Ok(()) } diff --git a/border/examples/gym/iqn_cartpole_tch.rs b/border/examples/gym/iqn_cartpole_tch.rs index 51bb53ef..c097daa2 100644 --- a/border/examples/gym/iqn_cartpole_tch.rs +++ b/border/examples/gym/iqn_cartpole_tch.rs @@ -118,8 +118,8 @@ mod obs_act_types { impl From for Act { // `t` must be a 1-dimentional tensor of `f32` fn from(t: Tensor) -> Self { - let data = - Vec::::try_from(&t.flatten(0, -1)).expect("Failed to convert from Tensor to Vec"); + let data = Vec::::try_from(&t.flatten(0, -1)) + .expect("Failed to convert from Tensor to Vec"); let data = data.iter().map(|&e| e as i32).collect(); Act(data) } diff --git a/border/examples/gym/sac_lunarlander_cont.rs b/border/examples/gym/sac_lunarlander_cont.rs index 87c4866e..fd46b8c4 100644 --- a/border/examples/gym/sac_lunarlander_cont.rs +++ b/border/examples/gym/sac_lunarlander_cont.rs @@ -20,7 +20,7 @@ use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; //use csv::WriterBuilder; use border_mlflow_tracking::MlflowTrackingClient; use candle_core::Tensor; @@ -205,10 +205,10 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, config: &config::SacLunarLanderConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -222,34 +222,26 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(MODEL_DIR))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_lunarlander_cont") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() - } +/// Train/eval SAC agent in lunarlander environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } -fn train(matches: ArgMatches, max_opts: usize) -> Result<()> { +fn train(args: &Args, max_opts: usize) -> Result<()> { let env_config = config::env_config(); let trainer_config = config::trainer_config(max_opts, EVAL_INTERVAL); let step_proc_config = SimpleStepProcessorConfig {}; @@ -260,7 +252,7 @@ fn train(matches: ArgMatches, max_opts: usize) -> Result<()> { replay_buffer_config: replay_buffer_config.clone(), trainer_config, }; - let mut recorder = utils::create_recorder(&matches, &config)?; + let mut recorder = utils::create_recorder(&args, &config)?; let mut trainer = Trainer::build(config.trainer_config.clone()); let env = Env::build(&env_config, 0)?; @@ -307,14 +299,14 @@ fn eval(render: bool) -> Result<()> { fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("eval") { + if args.eval { eval(true)?; - } else if matches.is_present("train") { - train(matches, MAX_OPTS)?; + } else if args.train { + train(&args, MAX_OPTS)?; } else { - train(matches, MAX_OPTS)?; + train(&args, MAX_OPTS)?; eval(true)?; } diff --git a/border/examples/gym/sac_lunarlander_cont_tch.rs b/border/examples/gym/sac_lunarlander_cont_tch.rs index 19f1ec96..ac810923 100644 --- a/border/examples/gym/sac_lunarlander_cont_tch.rs +++ b/border/examples/gym/sac_lunarlander_cont_tch.rs @@ -20,7 +20,7 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; //use csv::WriterBuilder; use border_mlflow_tracking::MlflowTrackingClient; use ndarray::{ArrayD, IxDyn}; @@ -206,10 +206,10 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, config: &config::SacLunarLanderConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -223,34 +223,26 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(MODEL_DIR))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_lunarlander_cont_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() - } +/// Train/eval SAC agent in lunarlander environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } -fn train(matches: ArgMatches, max_opts: usize) -> Result<()> { +fn train(args: &Args, max_opts: usize) -> Result<()> { let env_config = config::env_config(); let trainer_config = config::trainer_config(max_opts, EVAL_INTERVAL); let step_proc_config = SimpleStepProcessorConfig {}; @@ -261,7 +253,7 @@ fn train(matches: ArgMatches, max_opts: usize) -> Result<()> { replay_buffer_config: replay_buffer_config.clone(), trainer_config, }; - let mut recorder = utils::create_recorder(&matches, &config)?; + let mut recorder = utils::create_recorder(&args, &config)?; let mut trainer = Trainer::build(config.trainer_config.clone()); let env = Env::build(&env_config, 0)?; @@ -309,14 +301,14 @@ fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("eval") { + if args.eval { eval(true)?; - } else if matches.is_present("train") { - train(matches, MAX_OPTS)?; + } else if args.train { + train(&args, MAX_OPTS)?; } else { - train(matches, MAX_OPTS)?; + train(&args, MAX_OPTS)?; eval(true)?; } diff --git a/border/examples/gym/sac_pendulum.rs b/border/examples/gym/sac_pendulum.rs index d8016c24..8597fbbc 100644 --- a/border/examples/gym/sac_pendulum.rs +++ b/border/examples/gym/sac_pendulum.rs @@ -20,7 +20,7 @@ use border_py_gym_env::{ ArrayObsFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; // use csv::WriterBuilder; use border_mlflow_tracking::MlflowTrackingClient; use candle_core::{Device, Tensor}; @@ -229,6 +229,23 @@ fn create_recorder( } } +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + fn train(max_opts: usize, model_dir: &str, eval_interval: usize, mlflow: bool) -> Result<()> { let env_config = env_config(); let step_proc_config = SimpleStepProcessorConfig {}; @@ -298,56 +315,28 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { Ok(()) } -fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_pendulum") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() -} - fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - let matches = create_matches(); - let mlflow = matches.is_present("mlflow"); - - let do_train = (matches.is_present("train") && !matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - let do_eval = (!matches.is_present("train") && matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); + let args = Args::parse(); - if do_train { + if args.train { train( MAX_OPTS, - "./border/examples/gym/model/candle/sac_pendulum", + "./border/examples/gym/model/tch/sac_pendulum", EVAL_INTERVAL, - mlflow, + args.mlflow, )?; - } - if do_eval { - eval( - 5, - true, - "./border/examples/gym/model/candle/sac_pendulum/best", + } else if args.eval { + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + } else { + train( + MAX_OPTS, + "./border/examples/gym/model/tch/sac_pendulum", + EVAL_INTERVAL, + args.mlflow, )?; + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; } Ok(()) diff --git a/border/examples/gym/sac_pendulum_tch.rs b/border/examples/gym/sac_pendulum_tch.rs index f2538244..cdddaafb 100644 --- a/border/examples/gym/sac_pendulum_tch.rs +++ b/border/examples/gym/sac_pendulum_tch.rs @@ -20,7 +20,7 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; // use csv::WriterBuilder; use border_mlflow_tracking::MlflowTrackingClient; use ndarray::{ArrayD, IxDyn}; @@ -297,52 +297,45 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { Ok(()) } -fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_pendulum_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = create_matches(); - let mlflow = matches.is_present("mlflow"); + let args = Args::parse(); - let do_train = (matches.is_present("train") && !matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - let do_eval = (!matches.is_present("train") && matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - - if do_train { + if args.train { train( MAX_OPTS, "./border/examples/gym/model/tch/sac_pendulum", EVAL_INTERVAL, - mlflow, + args.mlflow, + )?; + } else if args.eval { + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + } else { + train( + MAX_OPTS, + "./border/examples/gym/model/tch/sac_pendulum", + EVAL_INTERVAL, + args.mlflow, )?; - } - if do_eval { eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; } From 1dc6a23d6e035925a67cb8f5d1ff3c8bb9f0a34f Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 7 Jul 2024 00:21:12 +0900 Subject: [PATCH 04/21] Bump the version of clap (#20) Bump the version of clap (#20) Bump the version of clap (#20) Bump the version of clap (#20) --- border-derive/src/act.rs | 8 +- border/Cargo.toml | 10 +- border/examples/atari/dqn_atari.rs | 205 +++++++++------------- border/examples/atari/dqn_atari_tch.rs | 206 +++++++++-------------- border/examples/atari/util_dqn_atari.rs | 2 +- border/examples/gym/dqn_cartpole.rs | 2 +- border/examples/gym/iqn_cartpole_tch.rs | 75 ++++----- border/examples/mujoco/sac_ant_async.rs | 2 +- border/examples/mujoco/sac_mujoco.rs | 126 ++++++-------- border/examples/mujoco/sac_mujoco_tch.rs | 126 ++++++-------- 10 files changed, 324 insertions(+), 438 deletions(-) diff --git a/border-derive/src/act.rs b/border-derive/src/act.rs index b4b10d0c..f18566f3 100644 --- a/border-derive/src/act.rs +++ b/border-derive/src/act.rs @@ -177,7 +177,13 @@ fn atari_env_act(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro impl From for #ident { fn from(t: tch::Tensor) -> Self { - let data: Vec = t.into(); + let data: Vec = { + let t = t.to_dtype(tch::Kind::Int64, false, true); + let n = t.numel(); + let mut data = vec![0i64; n]; + t.f_copy_data(&mut data, n).unwrap(); + data + }; // Non-vectorized environment #ident(BorderAtariAct::new(data[0] as u8)) } diff --git a/border/Cargo.toml b/border/Cargo.toml index 25af06c9..121e84b0 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -109,11 +109,11 @@ path = "examples/mujoco/sac_mujoco_tch.rs" required-features = ["tch"] test = false -[[example]] -name = "sac_ant_async" -path = "examples/ant/sac_ant_async.rs" -required-features = ["tch", "border-async-trainer"] -test = false +# [[example]] +# name = "sac_ant_async" +# path = "examples/mujoco/sac_ant_async.rs" +# required-features = ["tch", "border-async-trainer"] +# test = false # [[example]] # name = "make_cfg_dqn_atari" diff --git a/border/examples/atari/dqn_atari.rs b/border/examples/atari/dqn_atari.rs index 2ad682b4..46aae46a 100644 --- a/border/examples/atari/dqn_atari.rs +++ b/border/examples/atari/dqn_atari.rs @@ -10,24 +10,25 @@ use border_candle_agent::{ TensorBatch, }; use border_core::{ - record::AggregateRecorder, - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, DefaultEvaluator, Env as _, Evaluator as _, Policy, Trainer, TrainerConfig, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; -use border_derive::{Act, SubBatch}; +use border_derive::{Act, BatchBase}; use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; mod obs_act_types { use super::*; pub type Obs = BorderAtariObs; - #[derive(Clone, SubBatch)] + #[derive(Clone, BatchBase)] pub struct ObsBatch(TensorBatch); impl From for ObsBatch { @@ -37,7 +38,7 @@ mod obs_act_types { } } - #[derive(SubBatch)] + #[derive(BatchBase)] pub struct ActBatch(TensorBatch); impl From for ActBatch { @@ -123,8 +124,8 @@ mod config { Ok(config.into()) } - pub fn create_trainer_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_trainer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariTrainerConfig::default(); let path = model_dir + "/trainer.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -133,8 +134,8 @@ mod config { Ok(()) } - pub fn create_replay_buffer_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_replay_buffer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariReplayBufferConfig::default(); let path = model_dir + "/replay_buffer.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -143,8 +144,8 @@ mod config { Ok(()) } - pub fn create_agent_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_agent_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariAgentConfig::default(); let path = model_dir + "/agent.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -164,11 +165,8 @@ mod config { mod utils { use super::*; - pub fn model_dir(matches: &ArgMatches) -> String { - let name = matches - .value_of("name") - .expect("The name of the environment was not given") - .to_string(); + pub fn model_dir(args: &Args) -> String { + let name = &args.name; format!("./border/examples/atari/model/candle/dqn_{}", name) // let mut params = Params::default(); @@ -188,8 +186,8 @@ mod utils { // model_dir_(name, ¶ms) } - pub fn model_dir_for_play(matches: &ArgMatches) -> String { - matches.value_of("play").unwrap().to_string() + pub fn model_dir_for_eval(args: &Args) -> String { + model_dir(args) } pub fn n_actions(env_config: &EnvConfig) -> Result { @@ -197,13 +195,13 @@ mod utils { } pub fn create_recorder( - matches: &ArgMatches, + args: &Args, model_dir: &str, config: &DqnAtariConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { - let name = matches.value_of("name").unwrap(); + let name = &args.name; let client = MlflowTrackingClient::new("http://localhost:8080") .set_experiment_id("Atari")?; let recorder_run = client.create_recorder("")?; @@ -216,85 +214,45 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - - let matches = App::new("dqn_atari") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("name") - .long("name") - .takes_value(true) - .required(true) - .index(1) - .help("The name of the atari rom (e.g., pong)"), - ) - .arg( - Arg::with_name("play") - .long("play") - .takes_value(true) - .help("Play with the trained model of the given path"), - ) - .arg( - Arg::with_name("play-gdrive") - .long("play-gdrive") - .takes_value(false) - .help("Play with the trained model downloaded from google drive"), - ) - .arg( - Arg::with_name("create-config") - .long("create-config") - .help("Create config files"), - ) - // not supported yet - // .arg( - // Arg::with_name("per") - // .long("per") - // .takes_value(false) - // .help("Train/play with prioritized experience replay"), - // ) - // .arg( - // Arg::with_name("ddqn") - // .long("ddqn") - // .takes_value(false) - // .help("Train/play with double DQN"), - // ) - // .arg( - // Arg::with_name("debug") - // .long("debug") - // .takes_value(false) - // .help("Run with debug configuration"), - // ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when playing"), - ) - .arg( - Arg::with_name("show-config") - .long("show-config") - .takes_value(false) - .help("Showing configuration loaded from files"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .help("Logging with mlflow"), - ) - .get_matches(); - - matches - } +/// Train/eval DQN agent in atari environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the game + #[arg(long)] + name: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, } -fn train(matches: ArgMatches) -> Result<()> { +fn train(args: &Args) -> Result<()> { // Configurations - let name = matches.value_of("name").unwrap(); - let model_dir = utils::model_dir(&matches); + let name = &args.name; + let model_dir = utils::model_dir(&args); let env_config_train = config::env_config(name); let env_config_eval = config::env_config(name).eval(); let n_actions = utils::n_actions(&env_config_train)?; @@ -310,7 +268,7 @@ fn train(matches: ArgMatches) -> Result<()> { let step_proc_config = SimpleStepProcessorConfig {}; // Show configs or train - if matches.is_present("show-config") { + if args.show_config { config::show_config(&env_config_train, &agent_config, &trainer_config); } else { let config = DqnAtariConfig { @@ -318,25 +276,30 @@ fn train(matches: ArgMatches) -> Result<()> { replay_buffer: replay_buffer_config.clone(), agent: agent_config.clone(), }; + let mut trainer = Trainer::build(trainer_config); + let env = Env::build(&env_config_train, 0)?; + let step_proc = StepProc::build(&step_proc_config); let mut agent = Dqn::build(agent_config); - let mut recorder = utils::create_recorder(&matches, &model_dir, &config)?; + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(&args, &model_dir, &config)?; let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; - let mut trainer = Trainer::::build( - trainer_config, - env_config_train, - step_proc_config, - replay_buffer_config, - ); - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; } Ok(()) } -fn play(matches: ArgMatches) -> Result<()> { - let name = matches.value_of("name").unwrap(); - let model_dir = utils::model_dir_for_play(&matches); +fn eval(args: &Args) -> Result<()> { + let name = &args.name; + let model_dir = utils::model_dir_for_eval(&args); let (env_config, n_actions) = { let env_config = config::env_config(name).render(true); @@ -360,22 +323,22 @@ fn play(matches: ArgMatches) -> Result<()> { Ok(()) } -fn create_config(matches: ArgMatches) -> Result<()> { - config::create_trainer_config(&matches)?; - config::create_replay_buffer_config(&matches)?; - config::create_agent_config(&matches)?; +fn create_config(args: &Args) -> Result<()> { + config::create_trainer_config(&args)?; + config::create_replay_buffer_config(&args)?; + config::create_agent_config(&args)?; Ok(()) } fn main() -> Result<()> { - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("play") || matches.is_present("play-gdrive") { - play(matches)?; - } else if matches.is_present("create-config") { - create_config(matches)?; + if args.eval { + eval(&args)?; + } else if args.create_config { + create_config(&args)?; } else { - train(matches)?; + train(&args)?; } Ok(()) diff --git a/border/examples/atari/dqn_atari_tch.rs b/border/examples/atari/dqn_atari_tch.rs index 5d2b752a..edb5eb4e 100644 --- a/border/examples/atari/dqn_atari_tch.rs +++ b/border/examples/atari/dqn_atari_tch.rs @@ -5,14 +5,15 @@ use border_atari_env::{ BorderAtariObsRawFilter, }; use border_core::{ - record::AggregateRecorder, - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, Trainer, TrainerConfig, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; -use border_derive::{Act, SubBatch}; +use border_derive::{Act, BatchBase}; use border_mlflow_tracking::MlflowTrackingClient; use border_tch_agent::{ cnn::Cnn, @@ -20,14 +21,14 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; mod obs_act_types { use super::*; pub type Obs = BorderAtariObs; - #[derive(Clone, SubBatch)] + #[derive(Clone, BatchBase)] pub struct ObsBatch(TensorBatch); impl From for ObsBatch { @@ -37,7 +38,7 @@ mod obs_act_types { } } - #[derive(SubBatch)] + #[derive(BatchBase)] pub struct ActBatch(TensorBatch); impl From for ActBatch { @@ -123,8 +124,8 @@ mod config { Ok(config.into()) } - pub fn create_trainer_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_trainer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariTrainerConfig::default(); let path = model_dir + "/trainer.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -133,8 +134,8 @@ mod config { Ok(()) } - pub fn create_replay_buffer_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_replay_buffer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariReplayBufferConfig::default(); let path = model_dir + "/replay_buffer.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -143,8 +144,8 @@ mod config { Ok(()) } - pub fn create_agent_config(matches: &ArgMatches) -> Result<()> { - let model_dir = utils::model_dir(matches); + pub fn create_agent_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); let config = util_dqn_atari::DqnAtariAgentConfig::default(); let path = model_dir + "/agent.yaml"; let mut file = std::fs::File::create(path.clone())?; @@ -164,11 +165,8 @@ mod config { mod utils { use super::*; - pub fn model_dir(matches: &ArgMatches) -> String { - let name = matches - .value_of("name") - .expect("The name of the environment was not given") - .to_string(); + pub fn model_dir(args: &Args) -> String { + let name = &args.name; format!("./border/examples/atari/model/tch/dqn_{}", name) // let mut params = Params::default(); @@ -188,8 +186,8 @@ mod utils { // model_dir_(name, ¶ms) } - pub fn model_dir_for_play(matches: &ArgMatches) -> String { - matches.value_of("play").unwrap().to_string() + pub fn model_dir_for_eval(args: &Args) -> String { + model_dir(args) } pub fn n_actions(env_config: &EnvConfig) -> Result { @@ -197,13 +195,13 @@ mod utils { } pub fn create_recorder( - matches: &ArgMatches, + args: &Args, model_dir: &str, config: &DqnAtariConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { - let name = matches.value_of("name").unwrap(); + let name = &args.name; let client = MlflowTrackingClient::new("http://localhost:8080") .set_experiment_id("Atari")?; let recorder_run = client.create_recorder("")?; @@ -216,86 +214,45 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_atari_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("name") - .long("name") - .takes_value(true) - .required(true) - .index(1) - .help("The name of the atari rom (e.g., pong)"), - ) - .arg( - Arg::with_name("play") - .long("play") - .takes_value(true) - .help("Play with the trained model of the given path"), - ) - .arg( - Arg::with_name("play-gdrive") - .long("play-gdrive") - .takes_value(false) - .help("Play with the trained model downloaded from google drive"), - ) - .arg( - Arg::with_name("create-config") - .long("create-config") - .help("Create config files"), - ) - // not supported yet - // .arg( - // Arg::with_name("per") - // .long("per") - // .takes_value(false) - // .help("Train/play with prioritized experience replay"), - // ) - // .arg( - // Arg::with_name("ddqn") - // .long("ddqn") - // .takes_value(false) - // .help("Train/play with double DQN"), - // ) - // .arg( - // Arg::with_name("debug") - // .long("debug") - // .takes_value(false) - // .help("Run with debug configuration"), - // ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when playing"), - ) - .arg( - Arg::with_name("show-config") - .long("show-config") - .takes_value(false) - .help("Showing configuration loaded from files"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .help("Logging with mlflow"), - ) - .get_matches(); - - matches - } +/// Train/eval DQN agent in atari environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the game + #[arg(long)] + name: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, } -fn train(matches: ArgMatches) -> Result<()> { +fn train(args: &Args) -> Result<()> { // Configurations - let name = matches.value_of("name").unwrap(); - let model_dir = utils::model_dir(&matches); + let name = &args.name; + let model_dir = utils::model_dir(&args); let env_config_train = config::env_config(name); let env_config_eval = config::env_config(name).eval(); let n_actions = utils::n_actions(&env_config_train)?; @@ -311,7 +268,7 @@ fn train(matches: ArgMatches) -> Result<()> { let step_proc_config = SimpleStepProcessorConfig {}; // Show configs or train - if matches.is_present("show-config") { + if args.show_config { config::show_config(&env_config_train, &agent_config, &trainer_config); } else { let config = DqnAtariConfig { @@ -319,25 +276,30 @@ fn train(matches: ArgMatches) -> Result<()> { replay_buffer: replay_buffer_config.clone(), agent: agent_config.clone(), }; + let mut trainer = Trainer::build(trainer_config); + let env = Env::build(&env_config_train, 0)?; + let step_proc = StepProc::build(&step_proc_config); let mut agent = Dqn::build(agent_config); - let mut recorder = utils::create_recorder(&matches, &model_dir, &config)?; + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(&args, &model_dir, &config)?; let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; - let mut trainer = Trainer::::build( - trainer_config, - env_config_train, - step_proc_config, - replay_buffer_config, - ); - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; } Ok(()) } -fn play(matches: ArgMatches) -> Result<()> { - let name = matches.value_of("name").unwrap(); - let model_dir = utils::model_dir_for_play(&matches); +fn eval(args: &Args) -> Result<()> { + let name = &args.name; + let model_dir = utils::model_dir_for_eval(&args); let (env_config, n_actions) = { let env_config = config::env_config(name).render(true); @@ -361,22 +323,22 @@ fn play(matches: ArgMatches) -> Result<()> { Ok(()) } -fn create_config(matches: ArgMatches) -> Result<()> { - config::create_trainer_config(&matches)?; - config::create_replay_buffer_config(&matches)?; - config::create_agent_config(&matches)?; +fn create_config(args: &Args) -> Result<()> { + config::create_trainer_config(&args)?; + config::create_replay_buffer_config(&args)?; + config::create_agent_config(&args)?; Ok(()) } fn main() -> Result<()> { - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("play") || matches.is_present("play-gdrive") { - play(matches)?; - } else if matches.is_present("create-config") { - create_config(matches)?; + if args.eval { + eval(&args)?; + } else if args.create_config { + create_config(&args)?; } else { - train(matches)?; + train(&args)?; } Ok(()) diff --git a/border/examples/atari/util_dqn_atari.rs b/border/examples/atari/util_dqn_atari.rs index 8f5d50d3..7f048aa1 100644 --- a/border/examples/atari/util_dqn_atari.rs +++ b/border/examples/atari/util_dqn_atari.rs @@ -302,7 +302,7 @@ mod async_trainer_config { } mod replay_buffer_config { - use border_core::replay_buffer::{PerConfig, SimpleReplayBufferConfig}; + use border_core::generic_replay_buffer::{PerConfig, SimpleReplayBufferConfig}; use serde::{Deserialize, Serialize}; #[derive(Deserialize, Serialize)] diff --git a/border/examples/gym/dqn_cartpole.rs b/border/examples/gym/dqn_cartpole.rs index 322106ba..f052698c 100644 --- a/border/examples/gym/dqn_cartpole.rs +++ b/border/examples/gym/dqn_cartpole.rs @@ -338,7 +338,7 @@ fn main() -> Result<()> { #[cfg(test)] mod tests { - use super::{eval, train}; + use super::{eval, train, Args}; use anyhow::Result; use tempdir::TempDir; diff --git a/border/examples/gym/iqn_cartpole_tch.rs b/border/examples/gym/iqn_cartpole_tch.rs index c097daa2..f717cb79 100644 --- a/border/examples/gym/iqn_cartpole_tch.rs +++ b/border/examples/gym/iqn_cartpole_tch.rs @@ -19,7 +19,7 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; use ndarray::{ArrayD, IxDyn}; use serde::Serialize; use std::convert::TryFrom; @@ -149,7 +149,7 @@ mod obs_act_types { } } - // Required by Dqn + // Required by Iqn impl From for Tensor { fn from(act: ActBatch) -> Self { act.0.into() @@ -242,11 +242,11 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, model_dir: &str, config: &config::IqnCartpoleConfig, ) -> Result> { - match matches.is_present("mlflow") { + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -260,35 +260,27 @@ mod utils { false => Ok(Box::new(TensorboardRecorder::new(model_dir))), } } +} - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("iqn_cartpole_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .get_matches() - } +/// Train/eval IQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train IQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate IQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, } fn train( - matches: &ArgMatches, + args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize, @@ -297,7 +289,7 @@ fn train( config::IqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); let step_proc_config = SimpleStepProcessorConfig {}; let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let mut recorder = utils::create_recorder(&matches, model_dir, &config)?; + let mut recorder = utils::create_recorder(args, model_dir, &config)?; let mut trainer = Trainer::build(config.trainer_config.clone()); let env = Env::build(&config.env_config, 0)?; @@ -344,14 +336,14 @@ fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("train") { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; - } else if matches.is_present("eval") { + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } else { - train(&matches, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; eval(&(MODEL_DIR.to_owned() + "/best"), true)?; } @@ -360,20 +352,23 @@ fn main() -> Result<()> { #[cfg(test)] mod test { - use crate::utils::create_matches; - - use super::{eval, train}; + use super::{eval, train, Args}; use anyhow::Result; use tempdir::TempDir; #[test] - fn test_dqn_cartpole() -> Result<()> { + fn test_iqn_cartpole() -> Result<()> { let tmp_dir = TempDir::new("iqn_cartpole")?; let model_dir = match tmp_dir.as_ref().to_str() { Some(s) => s, None => panic!("Failed to get string of temporary directory"), }; - train(&create_matches(), 100, model_dir, 100)?; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; eval(&(model_dir.to_owned() + "/best"), false)?; Ok(()) } diff --git a/border/examples/mujoco/sac_ant_async.rs b/border/examples/mujoco/sac_ant_async.rs index 2c8dabd4..7647df9c 100644 --- a/border/examples/mujoco/sac_ant_async.rs +++ b/border/examples/mujoco/sac_ant_async.rs @@ -53,7 +53,7 @@ const TAU: f64 = 0.02; const TARGET_ENTROPY: f64 = -(DIM_ACT as f64); const LR_ENT_COEF: f64 = 3e-4; const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; -const MODEL_DIR: &str = "./border/examples/model/sac_ant_async"; +const MODEL_DIR: &str = "./border/examples/mujoco/model/tch/sac_ant_async"; mod obs { use super::*; diff --git a/border/examples/mujoco/sac_mujoco.rs b/border/examples/mujoco/sac_mujoco.rs index 06af9444..ce19292d 100644 --- a/border/examples/mujoco/sac_mujoco.rs +++ b/border/examples/mujoco/sac_mujoco.rs @@ -24,7 +24,7 @@ use border_py_gym_env::{ }; use border_tensorboard::TensorboardRecorder; use candle_core::Tensor; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; // use log::info; use ndarray::{ArrayD, IxDyn}; @@ -41,6 +41,7 @@ const N_CRITICS: usize = 2; const TAU: f64 = 0.02; const LR_ENT_COEF: f64 = 3e-4; const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR_BASE: &str = "./border/examples/mujoco/model/candle"; fn cuda_if_available() -> candle_core::Device { candle_core::Device::cuda_if_available(0).unwrap() @@ -188,12 +189,12 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, config: &config::SacAntConfig, ) -> Result> { - let env_name = matches.value_of("env").unwrap(); - let (_, _, _, _, model_dir) = env_params(matches); - match matches.is_present("mlflow") { + let env_name = &args.env; + let (_, _, _, _, model_dir) = env_params(&args); + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -208,59 +209,11 @@ mod utils { } } - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_ant") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("env") - .long("env") - .value_name("name") - .default_value("ant") - .takes_value(true) - .help("Environment name (ant, cheetah, walker, hopper)"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .value_name("path") - .default_value("") - .takes_value(true) - .help("Evaluation with the trained model of the given path"), - ) - // .arg( - // Arg::with_name("play-gdrive") - // .long("play-gdrive") - // .takes_value(false) - // .help("Play with the trained model downloaded from google drive"), - // ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when evaluation"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Training"), - ) - .get_matches() - } - /// Returns (dim_obs, dim_act, target_ent, env_name, model_dir) - pub fn env_params<'a>(matches: &ArgMatches) -> (i64, i64, f64, &'a str, String) { - let env_name = matches.value_of("env").unwrap(); - let model_dir = format!("./border/examples/mujoco/model/{}/candle", env_name); - match matches.value_of("env").unwrap() { + pub fn env_params<'a>(args: &Args) -> (i64, i64, f64, &'a str, String) { + let env_name = &args.env; + let model_dir = format!("{}/sac_{}", MODEL_DIR_BASE, env_name); + match args.env.as_str() { "ant" => (27, 8, -8., "Ant-v4", model_dir), "cheetah" => (17, 6, -6., "HalfCheetah-v4", model_dir), "walker" => (17, 6, -6., "Walker2d-v4", model_dir), @@ -270,8 +223,34 @@ mod utils { } } -fn train(matches: ArgMatches) -> Result<()> { - let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(&matches); +/// Train/eval SAC agent in Mujoco environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Environment name (ant, cheetah, walker, hopper) + #[arg(long)] + env: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + // #[arg(long, default_value_t = String::new())] + // eval: String, + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, +} + +fn train(args: &Args) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(&args); let env_config = config::env_config(env_name); let step_proc_config = SimpleStepProcessorConfig {}; let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); @@ -288,7 +267,7 @@ fn train(matches: ArgMatches) -> Result<()> { let step_proc = StepProc::build(&step_proc_config); let mut agent = Sac::build(agent_config); let mut buffer = ReplayBuffer::build(&replay_buffer_config); - let mut recorder = utils::create_recorder(&matches, &config)?; + let mut recorder = utils::create_recorder(&args, &config)?; let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; trainer.train( @@ -303,8 +282,8 @@ fn train(matches: ArgMatches) -> Result<()> { Ok(()) } -fn eval(matches: &ArgMatches, model_dir: &str, render: bool, wait: u64) -> Result<()> { - let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&matches); +fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&args); let env_config = { let mut env_config = config::env_config(&env_name); if render { @@ -331,16 +310,14 @@ fn eval(matches: &ArgMatches, model_dir: &str, render: bool, wait: u64) -> Resul Ok(()) } -fn eval1(matches: ArgMatches) -> Result<()> { +fn eval1(args: &Args) -> Result<()> { let model_dir = { - let model_dir = matches - .value_of("eval") - .expect("Failed to parse model directory"); - format!("{}{}", model_dir, "/best").to_owned() + let env_name = &args.env; + format!("{}/sac_{}/best", MODEL_DIR_BASE, env_name) }; let render = true; - let wait = matches.value_of("wait").unwrap().parse().unwrap(); - eval(&matches, &model_dir, render, wait) + let wait = args.wait; + eval(&args, &model_dir, render, wait) } // fn eval2(matches: ArgMatches) -> Result<()> { @@ -361,12 +338,15 @@ fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); fastrand::seed(42); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("train") { - train(matches)?; + if args.train { + train(&args)?; + } else if args.eval { + eval1(&args)?; } else { - eval1(matches)?; + train(&args)?; + eval1(&args)?; } // } else if matches.is_present("play-gdrive") { // eval2(matches)?; diff --git a/border/examples/mujoco/sac_mujoco_tch.rs b/border/examples/mujoco/sac_mujoco_tch.rs index 39d0022f..5c2f6921 100644 --- a/border/examples/mujoco/sac_mujoco_tch.rs +++ b/border/examples/mujoco/sac_mujoco_tch.rs @@ -23,7 +23,7 @@ use border_tch_agent::{ TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; +use clap::Parser; // use log::info; use ndarray::{ArrayD, IxDyn}; use std::convert::TryFrom; @@ -42,6 +42,7 @@ const N_CRITICS: usize = 2; const TAU: f64 = 0.02; const LR_ENT_COEF: f64 = 3e-4; const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR_BASE: &str = "./border/examples/mujoco/model/tch"; fn cuda_if_available() -> tch::Device { tch::Device::cuda_if_available() @@ -189,12 +190,12 @@ mod utils { use super::*; pub fn create_recorder( - matches: &ArgMatches, + args: &Args, config: &config::SacAntConfig, ) -> Result> { - let env_name = matches.value_of("env").unwrap(); - let (_, _, _, _, model_dir) = env_params(matches); - match matches.is_present("mlflow") { + let env_name = &args.env; + let (_, _, _, _, model_dir) = env_params(&args); + match args.mlflow { true => { let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; @@ -209,59 +210,11 @@ mod utils { } } - pub fn create_matches<'a>() -> ArgMatches<'a> { - App::new("sac_ant_tch") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("env") - .long("env") - .value_name("name") - .default_value("ant") - .takes_value(true) - .help("Environment name (ant, cheetah, walker, hopper)"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .value_name("path") - .default_value("") - .takes_value(true) - .help("Evaluation with the trained model of the given path"), - ) - // .arg( - // Arg::with_name("play-gdrive") - // .long("play-gdrive") - // .takes_value(false) - // .help("Play with the trained model downloaded from google drive"), - // ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when evaluation"), - ) - .arg( - Arg::with_name("mlflow") - .long("mlflow") - .takes_value(false) - .help("User mlflow tracking"), - ) - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Training"), - ) - .get_matches() - } - /// Returns (dim_obs, dim_act, target_ent, env_name, model_dir) - pub fn env_params<'a>(matches: &ArgMatches) -> (i64, i64, f64, &'a str, String) { - let env_name = matches.value_of("env").unwrap(); - let model_dir = format!("./border/examples/mujoco/model/{}/tch", env_name); - match matches.value_of("env").unwrap() { + pub fn env_params<'a>(args: &Args) -> (i64, i64, f64, &'a str, String) { + let env_name = &args.env; + let model_dir = format!("{}/sac_{}", MODEL_DIR_BASE, env_name); + match args.env.as_str() { "ant" => (27, 8, -8., "Ant-v4", model_dir), "cheetah" => (17, 6, -6., "HalfCheetah-v4", model_dir), "walker" => (17, 6, -6., "Walker2d-v4", model_dir), @@ -271,8 +224,34 @@ mod utils { } } -fn train(matches: ArgMatches) -> Result<()> { - let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(&matches); +/// Train/eval SAC agent in Mujoco environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Environment name (ant, cheetah, walker, hopper) + #[arg(long)] + env: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + // #[arg(long, default_value_t = String::new())] + // eval: String, + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, +} + +fn train(args: &Args) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(args); let env_config = config::env_config(env_name); let step_proc_config = SimpleStepProcessorConfig {}; let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); @@ -289,7 +268,7 @@ fn train(matches: ArgMatches) -> Result<()> { let step_proc = StepProc::build(&step_proc_config); let mut agent = Sac::build(agent_config); let mut buffer = ReplayBuffer::build(&replay_buffer_config); - let mut recorder = utils::create_recorder(&matches, &config)?; + let mut recorder = utils::create_recorder(args, &config)?; let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; trainer.train( @@ -304,8 +283,8 @@ fn train(matches: ArgMatches) -> Result<()> { Ok(()) } -fn eval(matches: &ArgMatches, model_dir: &str, render: bool, wait: u64) -> Result<()> { - let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&matches); +fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&args); let env_config = { let mut env_config = config::env_config(&env_name); if render { @@ -332,16 +311,14 @@ fn eval(matches: &ArgMatches, model_dir: &str, render: bool, wait: u64) -> Resul Ok(()) } -fn eval1(matches: ArgMatches) -> Result<()> { +fn eval1(args: &Args) -> Result<()> { let model_dir = { - let model_dir = matches - .value_of("eval") - .expect("Failed to parse model directory"); - format!("{}{}", model_dir, "/best").to_owned() + let env_name = &args.env; + format!("{}/sac_{}/best", MODEL_DIR_BASE, env_name) }; let render = true; - let wait = matches.value_of("wait").unwrap().parse().unwrap(); - eval(&matches, &model_dir, render, wait) + let wait = args.wait; + eval(&args, &model_dir, render, wait) } // fn eval2(matches: ArgMatches) -> Result<()> { @@ -363,12 +340,15 @@ fn main() -> Result<()> { tch::manual_seed(42); fastrand::seed(42); - let matches = utils::create_matches(); + let args = Args::parse(); - if matches.is_present("train") { - train(matches)?; + if args.train { + train(&args)?; + } else if args.eval { + eval1(&args)?; } else { - eval1(matches)?; + train(&args)?; + eval1(&args)?; } // } else if matches.is_present("play-gdrive") { // eval2(matches)?; From e36ac66015c701f26027d1a7eef8b5f98cc304d2 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 7 Jul 2024 10:27:55 +0900 Subject: [PATCH 05/21] Improve interface of config of agent (#21) Fix random_fetch_reach (#21) Fix examples (#21) Fix examples (#21) --- border-async-trainer/src/async_trainer/base.rs | 2 +- border-atari-env/examples/random_pong.rs | 3 ++- border-atari-env/src/util/test.rs | 9 +++++---- border-candle-agent/src/dqn/base.rs | 4 ++-- border-candle-agent/src/sac/base.rs | 4 ++-- border-core/src/base/agent.rs | 8 ++++---- border-core/src/base/policy.rs | 17 +++++++++++++++-- border-core/src/trainer.rs | 2 +- border-py-gym-env/examples/random_ant.rs | 3 ++- border-py-gym-env/examples/random_cartpole.rs | 4 ++-- .../examples/random_fetch_reach.rs | 3 ++- .../examples/random_lunarlander_cont.rs | 3 ++- border-tch-agent/src/dqn/base.rs | 4 ++-- border-tch-agent/src/iqn/base.rs | 4 ++-- border-tch-agent/src/sac/base.rs | 4 ++-- border/examples/atari/dqn_atari.rs | 2 +- border/examples/atari/dqn_atari_tch.rs | 2 +- border/examples/atari/iqn_atari.rs | 2 +- border/examples/backup/random_atari.rs | 2 +- border/examples/gym-robotics/sac_fetch_reach.rs | 2 +- border/examples/gym/dqn_cartpole.rs | 2 +- border/examples/gym/dqn_cartpole_tch.rs | 2 +- border/examples/gym/iqn_cartpole_tch.rs | 9 ++------- border/examples/gym/sac_lunarlander_cont.rs | 2 +- border/examples/gym/sac_lunarlander_cont_tch.rs | 2 +- border/examples/gym/sac_pendulum.rs | 2 +- border/examples/gym/sac_pendulum_tch.rs | 2 +- border/examples/mujoco/sac_mujoco.rs | 2 +- border/examples/mujoco/sac_mujoco_tch.rs | 2 +- 29 files changed, 61 insertions(+), 48 deletions(-) diff --git a/border-async-trainer/src/async_trainer/base.rs b/border-async-trainer/src/async_trainer/base.rs index d7324696..2a16f5de 100644 --- a/border-async-trainer/src/async_trainer/base.rs +++ b/border-async-trainer/src/async_trainer/base.rs @@ -151,7 +151,7 @@ where } fn save_model(agent: &A, model_dir: String) { - match agent.save(&model_dir) { + match agent.save_params(&model_dir) { Ok(()) => info!("Saved the model in {:?}.", &model_dir), Err(_) => info!("Failed to save model in {:?}.", &model_dir), } diff --git a/border-atari-env/examples/random_pong.rs b/border-atari-env/examples/random_pong.rs index 18bb9b29..07fcb88d 100644 --- a/border-atari-env/examples/random_pong.rs +++ b/border-atari-env/examples/random_pong.rs @@ -4,6 +4,7 @@ use border_atari_env::{ BorderAtariObsRawFilter, }; use border_core::{Configurable, DefaultEvaluator, Env as _, Evaluator, Policy}; +use serde::Deserialize; type Obs = BorderAtariObs; type Act = BorderAtariAct; @@ -12,7 +13,7 @@ type ActFilter = BorderAtariActRawFilter; type EnvConfig = BorderAtariEnvConfig; type Env = BorderAtariEnv; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig { pub n_acts: usize, } diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 15ecf22a..5c7c8064 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -9,6 +9,7 @@ use border_core::{ record::Record, Agent as Agent_, Configurable, Policy, ReplayBufferBase, }; +use serde::Deserialize; use std::ptr::copy; pub type Obs = BorderAtariObs; @@ -132,8 +133,8 @@ impl From for ActBatch { } } -#[derive(Clone)] -/// Configuration of [RandomAgent]. +#[derive(Clone, Deserialize)] +/// Configuration of [`RandomAgent``]. pub struct RandomAgentConfig { pub n_acts: usize, } @@ -182,12 +183,12 @@ impl Agent_ for RandomAgent { Record::empty() } - fn save>(&self, _path: T) -> Result<()> { + fn save_params>(&self, _path: T) -> Result<()> { println!("save() was invoked"); Ok(()) } - fn load>(&mut self, _path: T) -> Result<()> { + fn load_params>(&mut self, _path: T) -> Result<()> { println!("load() was invoked"); Ok(()) } diff --git a/border-candle-agent/src/dqn/base.rs b/border-candle-agent/src/dqn/base.rs index cf053335..a7f2d7fb 100644 --- a/border-candle-agent/src/dqn/base.rs +++ b/border-candle-agent/src/dqn/base.rs @@ -330,7 +330,7 @@ where record } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; self.qnet.save(&path.as_ref().join("qnet.pt").as_path())?; @@ -339,7 +339,7 @@ where Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { self.qnet.load(&path.as_ref().join("qnet.pt").as_path())?; self.qnet_tgt .load(&path.as_ref().join("qnet_tgt.pt").as_path())?; diff --git a/border-candle-agent/src/sac/base.rs b/border-candle-agent/src/sac/base.rs index 9e5a9b52..05e4ea9e 100644 --- a/border-candle-agent/src/sac/base.rs +++ b/border-candle-agent/src/sac/base.rs @@ -322,7 +322,7 @@ where self.opt_(buffer).expect("Failed in Sac::opt_()") } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() { @@ -335,7 +335,7 @@ where Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() { qnet.load(&path.as_ref().join(format!("qnet_{}.pt", i)).as_path())?; qnet_tgt.load(&path.as_ref().join(format!("qnet_tgt_{}.pt", i)).as_path())?; diff --git a/border-core/src/base/agent.rs b/border-core/src/base/agent.rs index cc0812c6..c6ab3ac0 100644 --- a/border-core/src/base/agent.rs +++ b/border-core/src/base/agent.rs @@ -26,12 +26,12 @@ pub trait Agent: Policy { /// Performs an optimization step and returns some information. fn opt_with_record(&mut self, buffer: &mut R) -> Record; - /// Save the agent in the given directory. + /// Save the parameters of the agent in the given directory. /// This method commonly creates a number of files consisting the agent /// in the directory. For example, the DQN agent in `border_tch_agent` crate saves /// two Q-networks corresponding to the original and target networks. - fn save>(&self, path: T) -> Result<()>; + fn save_params>(&self, path: T) -> Result<()>; - /// Load the agent from the given directory. - fn load>(&mut self, path: T) -> Result<()>; + /// Load the parameters of the agent from the given directory. + fn load_params>(&mut self, path: T) -> Result<()>; } diff --git a/border-core/src/base/policy.rs b/border-core/src/base/policy.rs index 4d5afce5..090cc0ba 100644 --- a/border-core/src/base/policy.rs +++ b/border-core/src/base/policy.rs @@ -1,6 +1,8 @@ //! Policy. use super::Env; -// use anyhow::Result; +use anyhow::Result; +use serde::de::DeserializeOwned; +use std::path::Path; /// A policy on an environment. /// @@ -14,8 +16,19 @@ pub trait Policy { /// A configurable object, having type parameter. pub trait Configurable { /// Configuration. - type Config: Clone; + type Config: Clone + DeserializeOwned; /// Builds the object. fn build(config: Self::Config) -> Self; + + /// Build the object with the configuration in the yaml file of the given path. + fn build_from_path(path: impl AsRef) -> Result + where + Self: Sized, + { + let file = std::fs::File::open(path)?; + let rdr = std::io::BufReader::new(file); + let config = serde_yaml::from_reader(rdr)?; + Ok(Self::build(config)) + } } diff --git a/border-core/src/trainer.rs b/border-core/src/trainer.rs index c543f64d..8d7ffdda 100644 --- a/border-core/src/trainer.rs +++ b/border-core/src/trainer.rs @@ -157,7 +157,7 @@ impl Trainer { A: Agent, R: ReplayBufferBase, { - match agent.save(&model_dir) { + match agent.save_params(&model_dir) { Ok(()) => info!("Saved the model in {:?}.", &model_dir), Err(_) => info!("Failed to save model in {:?}.", &model_dir), } diff --git a/border-py-gym-env/examples/random_ant.rs b/border-py-gym-env/examples/random_ant.rs index e8a4674a..12b27bbd 100644 --- a/border-py-gym-env/examples/random_ant.rs +++ b/border-py-gym-env/examples/random_ant.rs @@ -4,6 +4,7 @@ use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; use ndarray::{Array, ArrayD, IxDyn}; +use serde::Deserialize; use std::default::Default; mod obs { @@ -58,7 +59,7 @@ type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; diff --git a/border-py-gym-env/examples/random_cartpole.rs b/border-py-gym-env/examples/random_cartpole.rs index 7f22d8c9..11ac54d5 100644 --- a/border-py-gym-env/examples/random_cartpole.rs +++ b/border-py-gym-env/examples/random_cartpole.rs @@ -3,7 +3,7 @@ use border_core::{record::Record, Configurable, DefaultEvaluator, Evaluator as _ use border_py_gym_env::{ ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::convert::TryFrom; type PyObsDtype = f32; @@ -60,7 +60,7 @@ type ActFilter = DiscreteActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; diff --git a/border-py-gym-env/examples/random_fetch_reach.rs b/border-py-gym-env/examples/random_fetch_reach.rs index 7cf27ec7..c7f8b5f4 100644 --- a/border-py-gym-env/examples/random_fetch_reach.rs +++ b/border-py-gym-env/examples/random_fetch_reach.rs @@ -55,13 +55,14 @@ mod act { use act::Act; use obs::Obs; +use serde::Deserialize; type ObsFilter = ArrayDictObsFilter; type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; diff --git a/border-py-gym-env/examples/random_lunarlander_cont.rs b/border-py-gym-env/examples/random_lunarlander_cont.rs index 6536c921..8af6c5e7 100644 --- a/border-py-gym-env/examples/random_lunarlander_cont.rs +++ b/border-py-gym-env/examples/random_lunarlander_cont.rs @@ -51,13 +51,14 @@ mod act { use act::Act; use obs::Obs; +use serde::Deserialize; type ObsFilter = ArrayObsFilter; type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; diff --git a/border-tch-agent/src/dqn/base.rs b/border-tch-agent/src/dqn/base.rs index 82f233a5..b43663ca 100644 --- a/border-tch-agent/src/dqn/base.rs +++ b/border-tch-agent/src/dqn/base.rs @@ -341,7 +341,7 @@ where record } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; self.qnet @@ -351,7 +351,7 @@ where Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { self.qnet .load(&path.as_ref().join("qnet.pt.tch").as_path())?; self.qnet_tgt diff --git a/border-tch-agent/src/iqn/base.rs b/border-tch-agent/src/iqn/base.rs index 2525b67f..bd549d26 100644 --- a/border-tch-agent/src/iqn/base.rs +++ b/border-tch-agent/src/iqn/base.rs @@ -295,7 +295,7 @@ where self.opt_(buffer) } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; self.iqn.save(&path.as_ref().join("iqn.pt.tch").as_path())?; @@ -304,7 +304,7 @@ where Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { self.iqn.load(&path.as_ref().join("iqn.pt.tch").as_path())?; self.iqn_tgt .load(&path.as_ref().join("iqn_tgt.pt.tch").as_path())?; diff --git a/border-tch-agent/src/sac/base.rs b/border-tch-agent/src/sac/base.rs index f39ab29b..8483cae6 100644 --- a/border-tch-agent/src/sac/base.rs +++ b/border-tch-agent/src/sac/base.rs @@ -301,7 +301,7 @@ where self.opt_(buffer) } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() { @@ -319,7 +319,7 @@ where Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() { qnet.load(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?; qnet_tgt.load( diff --git a/border/examples/atari/dqn_atari.rs b/border/examples/atari/dqn_atari.rs index 46aae46a..a3bb05aa 100644 --- a/border/examples/atari/dqn_atari.rs +++ b/border/examples/atari/dqn_atari.rs @@ -312,7 +312,7 @@ fn eval(args: &Args) -> Result<()> { .out_dim(n_actions as _) .device(device); let mut agent = Dqn::build(agent_config); - agent.load(model_dir + "/best")?; + agent.load_params(model_dir + "/best")?; agent.eval(); agent }; diff --git a/border/examples/atari/dqn_atari_tch.rs b/border/examples/atari/dqn_atari_tch.rs index edb5eb4e..d88bc2cc 100644 --- a/border/examples/atari/dqn_atari_tch.rs +++ b/border/examples/atari/dqn_atari_tch.rs @@ -312,7 +312,7 @@ fn eval(args: &Args) -> Result<()> { .out_dim(n_actions as _) .device(device); let mut agent = Dqn::build(agent_config); - agent.load(model_dir + "/best")?; + agent.load_params(model_dir + "/best")?; agent.eval(); agent }; diff --git a/border/examples/atari/iqn_atari.rs b/border/examples/atari/iqn_atari.rs index 71244a25..78eada97 100644 --- a/border/examples/atari/iqn_atari.rs +++ b/border/examples/atari/iqn_atari.rs @@ -213,7 +213,7 @@ fn play(matches: ArgMatches) -> Result<()> { .out_dim(n_actions as _) .device(device); let mut agent = Iqn::build(agent_config); - agent.load(model_dir + "/best")?; + agent.load_params(model_dir + "/best")?; agent.eval(); agent }; diff --git a/border/examples/backup/random_atari.rs b/border/examples/backup/random_atari.rs index b4ffa39d..0b36ca3f 100644 --- a/border/examples/backup/random_atari.rs +++ b/border/examples/backup/random_atari.rs @@ -42,7 +42,7 @@ type ObsFilter = FrameStackFilter; type ActFilter = PyGymEnvDiscreteActRawFilter; type Env = PyGymEnv; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig { pub n_acts: usize, } diff --git a/border/examples/gym-robotics/sac_fetch_reach.rs b/border/examples/gym-robotics/sac_fetch_reach.rs index 8899f6e8..d4084f99 100644 --- a/border/examples/gym-robotics/sac_fetch_reach.rs +++ b/border/examples/gym-robotics/sac_fetch_reach.rs @@ -298,7 +298,7 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { }; let mut agent = { let mut agent = Sac::build(config::create_sac_config(DIM_OBS, DIM_ACT, TARGET_ENTROPY)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/dqn_cartpole.rs b/border/examples/gym/dqn_cartpole.rs index f052698c..ea0bf88c 100644 --- a/border/examples/gym/dqn_cartpole.rs +++ b/border/examples/gym/dqn_cartpole.rs @@ -308,7 +308,7 @@ fn eval(model_dir: &str, render: bool) -> Result<()> { }; let mut agent = { let mut agent = Dqn::build(create_agent_config(DIM_OBS, DIM_ACT)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/dqn_cartpole_tch.rs b/border/examples/gym/dqn_cartpole_tch.rs index 688740d5..b6604476 100644 --- a/border/examples/gym/dqn_cartpole_tch.rs +++ b/border/examples/gym/dqn_cartpole_tch.rs @@ -308,7 +308,7 @@ fn eval(model_dir: &str, render: bool) -> Result<()> { }; let mut agent = { let mut agent = Dqn::build(create_agent_config(DIM_OBS, DIM_ACT)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/iqn_cartpole_tch.rs b/border/examples/gym/iqn_cartpole_tch.rs index f717cb79..70b9be41 100644 --- a/border/examples/gym/iqn_cartpole_tch.rs +++ b/border/examples/gym/iqn_cartpole_tch.rs @@ -279,12 +279,7 @@ struct Args { mlflow: bool, } -fn train( - args: &Args, - max_opts: usize, - model_dir: &str, - eval_interval: usize, -) -> Result<()> { +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { let config = config::IqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); let step_proc_config = SimpleStepProcessorConfig {}; @@ -322,7 +317,7 @@ fn eval(model_dir: &str, render: bool) -> Result<()> { }; let mut agent = { let mut agent = Iqn::build(config::agent_config(DIM_OBS, DIM_ACT)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/sac_lunarlander_cont.rs b/border/examples/gym/sac_lunarlander_cont.rs index fd46b8c4..fa128e9c 100644 --- a/border/examples/gym/sac_lunarlander_cont.rs +++ b/border/examples/gym/sac_lunarlander_cont.rs @@ -286,7 +286,7 @@ fn eval(render: bool) -> Result<()> { }; let mut agent = { let mut agent = Sac::build(config::agent_config(DIM_OBS, DIM_ACT)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/sac_lunarlander_cont_tch.rs b/border/examples/gym/sac_lunarlander_cont_tch.rs index ac810923..afbcef35 100644 --- a/border/examples/gym/sac_lunarlander_cont_tch.rs +++ b/border/examples/gym/sac_lunarlander_cont_tch.rs @@ -287,7 +287,7 @@ fn eval(render: bool) -> Result<()> { }; let mut agent = { let mut agent = Sac::build(config::agent_config(DIM_OBS, DIM_ACT)); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/sac_pendulum.rs b/border/examples/gym/sac_pendulum.rs index 8597fbbc..6921a18f 100644 --- a/border/examples/gym/sac_pendulum.rs +++ b/border/examples/gym/sac_pendulum.rs @@ -296,7 +296,7 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { }; let mut agent = { let mut agent = create_agent(DIM_OBS, DIM_ACT)?; - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/gym/sac_pendulum_tch.rs b/border/examples/gym/sac_pendulum_tch.rs index cdddaafb..de543af7 100644 --- a/border/examples/gym/sac_pendulum_tch.rs +++ b/border/examples/gym/sac_pendulum_tch.rs @@ -278,7 +278,7 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { }; let mut agent = { let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; diff --git a/border/examples/mujoco/sac_mujoco.rs b/border/examples/mujoco/sac_mujoco.rs index ce19292d..50dff870 100644 --- a/border/examples/mujoco/sac_mujoco.rs +++ b/border/examples/mujoco/sac_mujoco.rs @@ -296,7 +296,7 @@ fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { let mut agent = { let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); let mut agent = Sac::build(agent_config); - match agent.load(model_dir) { + match agent.load_params(model_dir) { Ok(_) => {} Err(_) => println!("Failed to load model parameters from {:?}", model_dir), } diff --git a/border/examples/mujoco/sac_mujoco_tch.rs b/border/examples/mujoco/sac_mujoco_tch.rs index 5c2f6921..14afb7ac 100644 --- a/border/examples/mujoco/sac_mujoco_tch.rs +++ b/border/examples/mujoco/sac_mujoco_tch.rs @@ -297,7 +297,7 @@ fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { let mut agent = { let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); let mut agent = Sac::build(agent_config); - match agent.load(model_dir) { + match agent.load_params(model_dir) { Ok(_) => {} Err(_) => println!("Failed to load model parameters from {:?}", model_dir), } From 9b4904b11340538a69ddfe8100413bc050e39c83 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 8 Jul 2024 00:01:07 +0900 Subject: [PATCH 06/21] Fix doc.sh --- docker/aarch64_doc/doc.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/aarch64_doc/doc.sh b/docker/aarch64_doc/doc.sh index 4cf37c27..c823279f 100644 --- a/docker/aarch64_doc/doc.sh +++ b/docker/aarch64_doc/doc.sh @@ -3,4 +3,4 @@ docker run -it --rm \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ border_headless bash -l -c \ - "cd /home/ubuntu/border; CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 LD_LIBRARY_PATH=$HOME/venv/lib/python3.10/site-packages/torch/lib:$LD_LIBRARY_PATH CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." From ad665a26e7ff09828acba5e372dd1ce60e7878ad Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 8 Jul 2024 00:20:41 +0900 Subject: [PATCH 07/21] WIP: Add example of conversion from tch sac agent to the one w/o backend (#19) --- border-edge-policy/Cargo.toml | 5 +- border-edge-policy/src/lib.rs | 2 + border-edge-policy/src/mat.rs | 12 +- border-edge-policy/src/mlp.rs | 44 ++++ border-tch-agent/src/sac/actor/base.rs | 2 +- border-tch-agent/src/sac/base.rs | 4 + border/Cargo.toml | 8 + .../gym/convert_sac_policy_to_edge.rs | 208 ++++++++++++++++++ docker/aarch64_headless/Dockerfile | 2 +- 9 files changed, 282 insertions(+), 5 deletions(-) create mode 100644 border-edge-policy/src/mlp.rs create mode 100644 border/examples/gym/convert_sac_policy_to_edge.rs diff --git a/border-edge-policy/Cargo.toml b/border-edge-policy/Cargo.toml index 4a96a84e..0929f6b0 100644 --- a/border-edge-policy/Cargo.toml +++ b/border-edge-policy/Cargo.toml @@ -10,7 +10,8 @@ license.workspace = true readme = "README.md" [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent", optional = true } serde = { workspace = true, features = ["derive"] } log = { workspace = true } anyhow = { workspace = true } @@ -21,4 +22,4 @@ tempdir = { workspace = true } tch = { workspace = true } [features] -tch = ["dep:tch"] +border-tch-agent = ["dep:border-tch-agent", "dep:tch"] diff --git a/border-edge-policy/src/lib.rs b/border-edge-policy/src/lib.rs index f99690cc..35a6a583 100644 --- a/border-edge-policy/src/lib.rs +++ b/border-edge-policy/src/lib.rs @@ -1,3 +1,5 @@ mod mat; +mod mlp; pub use mat::Mat; +pub use mlp::Mlp; diff --git a/border-edge-policy/src/mat.rs b/border-edge-policy/src/mat.rs index f3b019a6..b1e86e02 100644 --- a/border-edge-policy/src/mat.rs +++ b/border-edge-policy/src/mat.rs @@ -6,7 +6,7 @@ pub struct Mat { pub shape: Vec, } -#[cfg(feature = "tch")] +#[cfg(feature = "border-tch-agent")] impl From for Mat { fn from(x: tch::Tensor) -> Self { let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); @@ -81,3 +81,13 @@ impl Mat { } } } + +impl From> for Mat { + fn from(x: Vec) -> Self { + let shape = vec![x.len() as i32, 1]; + Self { + shape, + data: x, + } + } +} diff --git a/border-edge-policy/src/mlp.rs b/border-edge-policy/src/mlp.rs new file mode 100644 index 00000000..8805a50d --- /dev/null +++ b/border-edge-policy/src/mlp.rs @@ -0,0 +1,44 @@ +use crate::Mat; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "border-tch-agent")] +use tch::nn::VarStore; + +#[derive(Clone, Debug, Deserialize, Serialize)] +/// Multilayer perceptron with ReLU activation function. +pub struct Mlp { + /// Weights of layers. + ws: Vec, + + /// Biases of layers. + bs: Vec, +} + +impl Mlp { + pub fn forward(&self, x: &Mat) -> Mat { + let n_layers = self.ws.len(); + let mut x = x.clone(); + for i in 0..n_layers { + x = self.ws[i].matmul(&x).add(&self.bs[i]); + if i != n_layers - 1 { + x = x.relu(); + } + } + x + } + + #[cfg(feature = "border-tch-agent")] + pub fn from_varstore(vs: &VarStore, w_names: &[&str], b_names: &[&str]) -> Self { + let vars = vs.variables(); + let ws: Vec = w_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + let bs: Vec = b_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + + Self { ws, bs } + } +} diff --git a/border-tch-agent/src/sac/actor/base.rs b/border-tch-agent/src/sac/actor/base.rs index 3cbf5129..69f5c610 100644 --- a/border-tch-agent/src/sac/actor/base.rs +++ b/border-tch-agent/src/sac/actor/base.rs @@ -36,7 +36,7 @@ where P: SubModel, P::Config: DeserializeOwned + Serialize + OutDim, { - /// Constructs [Actor]. + /// Constructs [`Actor`]. pub fn build(config: ActorConfig, device: Device) -> Result> { let pi_config = config.pi_config.context("pi_config is not set.")?; let out_dim = pi_config.get_out_dim(); diff --git a/border-tch-agent/src/sac/base.rs b/border-tch-agent/src/sac/base.rs index 8483cae6..03b4bffc 100644 --- a/border-tch-agent/src/sac/base.rs +++ b/border-tch-agent/src/sac/base.rs @@ -191,6 +191,10 @@ where ), ]) } + + pub fn get_policy_net(&self) -> &Actor

{ + &self.pi + } } impl Policy for Sac diff --git a/border/Cargo.toml b/border/Cargo.toml index 121e84b0..020b080e 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -109,6 +109,12 @@ path = "examples/mujoco/sac_mujoco_tch.rs" required-features = ["tch"] test = false +[[example]] +name = "convert_sac_policy_to_edge" +path = "examples/gym/convert_sac_policy_to_edge.rs" +required-features = ["border-tch-agent", "tch"] +test = false + # [[example]] # name = "sac_ant_async" # path = "examples/mujoco/sac_ant_async.rs" @@ -133,6 +139,7 @@ border-derive = { version = "0.0.7", path = "../border-derive" } border-core = { version = "0.0.7", path = "../border-core" } border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } +border-edge-policy = { version = "0.0.7", path = "../border-edge-policy" } border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } border-atari-env = { version = "0.0.7", path = "../border-atari-env" } border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } @@ -157,3 +164,4 @@ features = ["doc-only"] doc-only = ["tch/doc-only"] cuda = ["candle-core/cuda"] cudnn = ["candle-core/cudnn"] +border-tch-agent = [] diff --git a/border/examples/gym/convert_sac_policy_to_edge.rs b/border/examples/gym/convert_sac_policy_to_edge.rs new file mode 100644 index 00000000..07ec4421 --- /dev/null +++ b/border/examples/gym/convert_sac_policy_to_edge.rs @@ -0,0 +1,208 @@ +use anyhow::Result; +use border_core::{Agent, Configurable}; +use border_edge_policy::Mlp; +use border_tch_agent::{ + mlp, + model::ModelBase, + sac::{ActorConfig, CriticConfig, SacConfig}, +}; + +const DIM_OBS: i64 = 3; +const DIM_ACT: i64 = 1; + +// Dummy types +mod dummy { + use super::mlp::{Mlp, Mlp2}; + use border_tch_agent::sac::Sac as Sac_; + + #[derive(Clone, Debug)] + pub struct DummyObs; + + impl border_core::Obs for DummyObs { + fn dummy(_n: usize) -> Self { + unimplemented!(); + } + + fn len(&self) -> usize { + unimplemented!(); + } + } + + impl Into for DummyObs { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + #[derive(Clone, Debug)] + pub struct DummyAct; + + impl border_core::Act for DummyAct { + fn len(&self) -> usize { + unimplemented!(); + } + } + + impl Into for DummyAct { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + impl From for DummyAct { + fn from(_value: tch::Tensor) -> Self { + unimplemented!(); + } + } + + #[derive(Clone)] + pub struct DummyInnerBatch; + + impl Into for DummyInnerBatch { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + pub struct DummyBatch; + + impl border_core::TransitionBatch for DummyBatch { + type ObsBatch = DummyInnerBatch; + type ActBatch = DummyInnerBatch; + + fn len(&self) -> usize { + unimplemented!(); + } + + fn obs(&self) -> &Self::ObsBatch { + unimplemented!(); + } + + fn unpack( + self, + ) -> ( + Self::ObsBatch, + Self::ActBatch, + Self::ObsBatch, + Vec, + Vec, + Vec, + Option>, + Option>, + ) { + unimplemented!(); + } + } + + pub struct DummyReplayBuffer; + + impl border_core::ReplayBufferBase for DummyReplayBuffer { + type Batch = DummyBatch; + type Config = usize; + + fn batch(&mut self, _size: usize) -> anyhow::Result { + unimplemented!(); + } + + fn build(_config: &Self::Config) -> Self { + unimplemented!(); + } + + fn update_priority(&mut self, _ixs: &Option>, _td_err: &Option>) { + unimplemented!(); + } + } + + #[derive(Clone, Debug)] + pub struct DummyInfo; + + impl border_core::Info for DummyInfo {} + + pub struct DummyEnv; + + impl border_core::Env for DummyEnv { + type Config = usize; + type Act = DummyAct; + type Obs = DummyObs; + type Info = DummyInfo; + + fn build(_config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + unimplemented!(); + } + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + unimplemented!(); + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + unimplemented!(); + } + + fn step(&mut self, _a: &Self::Act) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + unimplemented!(); + } + + fn step_with_reset( + &mut self, + _a: &Self::Act, + ) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + unimplemented!(); + } + } + + pub type Env = DummyEnv; + pub type Sac = Sac_; +} + +use dummy::Sac; + +fn create_sac_config() -> SacConfig { + // Omit learning related parameters + let actor_config = ActorConfig::default() + .out_dim(DIM_ACT) + .pi_config(mlp::MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); + let critic_config = CriticConfig::default().q_config(mlp::MlpConfig::new( + DIM_OBS + DIM_ACT, + vec![64, 64], + 1, + false, + )); + SacConfig::default() + .actor_config(actor_config) + .critic_config(critic_config) + .device(tch::Device::Cpu) +} + +fn main() -> Result<()> { + let src_path = "./border/examples/gym/model/tch/sac_pendulum/best"; + let dest_path = ""; + + // Load Sac model + let sac = { + let config = create_sac_config(); + let mut sac = Sac::build(config); + sac.load_params(src_path)?; + sac + }; + + // Check variables in the VarStore + let mlp = { + let vs = sac.get_policy_net().get_var_store(); + let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; + let b_names = ["mlp.al0.bias", "mlp.al1.bias", "ml.bias"]; + Mlp::from_varstore(vs, &w_names, &b_names) + }; + + // println!("{:?}", mlp); + + Ok(()) +} diff --git a/docker/aarch64_headless/Dockerfile b/docker/aarch64_headless/Dockerfile index ace7c4a9..c3fc15b0 100644 --- a/docker/aarch64_headless/Dockerfile +++ b/docker/aarch64_headless/Dockerfile @@ -57,7 +57,7 @@ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y RUN cd /root && python3 -m venv venv RUN source /root/venv/bin/activate && pip3 install --upgrade pip RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /root/venv/bin/activate && pip3 install torch==1.12.0 +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 From 1000bc2292228d7aa90de06035e745c95a4c8c42 Mon Sep 17 00:00:00 2001 From: taku-y Date: Mon, 8 Jul 2024 09:45:30 +0900 Subject: [PATCH 08/21] WIP: Serialize MLP and save to a file (#21) --- Cargo.toml | 3 ++- border-edge-policy/Cargo.toml | 1 + border/Cargo.toml | 1 + border/examples/gym/convert_sac_policy_to_edge.rs | 13 ++++++++++--- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9ea44c8d..105b5c6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,4 +55,5 @@ itertools = "0.12.1" ordered-float = "4.2.0" reqwest = { version = "0.11.26", features = ["json", "blocking"] } xxhash-rust = { version = "0.8.10", features = ["xxh3"] } -candle-optimisers = "0.4.0" \ No newline at end of file +candle-optimisers = "0.4.0" +bincode = "1.3.3" diff --git a/border-edge-policy/Cargo.toml b/border-edge-policy/Cargo.toml index 0929f6b0..e5b58fb4 100644 --- a/border-edge-policy/Cargo.toml +++ b/border-edge-policy/Cargo.toml @@ -21,5 +21,6 @@ tch = { workspace = true, optional = true } tempdir = { workspace = true } tch = { workspace = true } + [features] border-tch-agent = ["dep:border-tch-agent", "dep:tch"] diff --git a/border/Cargo.toml b/border/Cargo.toml index 020b080e..9437b8c6 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -156,6 +156,7 @@ chrono = { workspace = true } tensorboard-rs = { workspace = true } thiserror = { workspace = true } serde_yaml = { workspace = true } +bincode = { workspace = true } [package.metadata.docs.rs] features = ["doc-only"] diff --git a/border/examples/gym/convert_sac_policy_to_edge.rs b/border/examples/gym/convert_sac_policy_to_edge.rs index 07ec4421..90af3d72 100644 --- a/border/examples/gym/convert_sac_policy_to_edge.rs +++ b/border/examples/gym/convert_sac_policy_to_edge.rs @@ -6,6 +6,7 @@ use border_tch_agent::{ model::ModelBase, sac::{ActorConfig, CriticConfig, SacConfig}, }; +use std::{fs, io::Write}; const DIM_OBS: i64 = 3; const DIM_ACT: i64 = 1; @@ -184,7 +185,7 @@ fn create_sac_config() -> SacConfig { fn main() -> Result<()> { let src_path = "./border/examples/gym/model/tch/sac_pendulum/best"; - let dest_path = ""; + let dest_path = "./border/examples/gym/model/edge/sac_pendulum/best/mlp.bincode"; // Load Sac model let sac = { @@ -194,7 +195,7 @@ fn main() -> Result<()> { sac }; - // Check variables in the VarStore + // Create Mlp let mlp = { let vs = sac.get_policy_net().get_var_store(); let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; @@ -202,7 +203,13 @@ fn main() -> Result<()> { Mlp::from_varstore(vs, &w_names, &b_names) }; - // println!("{:?}", mlp); + // Serialize to file + let encoded = bincode::serialize(&mlp)?; + let mut file = fs::OpenOptions::new() + .create(true) + .write(true) + .open(&dest_path)?; + file.write_all(&encoded)?; Ok(()) } From 74592ac2e6cd54e8b134946c5bb80db39b4d8366 Mon Sep 17 00:00:00 2001 From: taku-y Date: Tue, 9 Jul 2024 07:19:57 +0900 Subject: [PATCH 09/21] Add policy running on edge (#21) --- border-edge-policy/src/mat.rs | 21 ++- border-edge-policy/tests/test.rs | 4 +- border/Cargo.toml | 5 + border/examples/gym/pendulum_edge.rs | 190 +++++++++++++++++++++++++++ 4 files changed, 214 insertions(+), 6 deletions(-) create mode 100644 border/examples/gym/pendulum_edge.rs diff --git a/border-edge-policy/src/mat.rs b/border-edge-policy/src/mat.rs index b1e86e02..05184476 100644 --- a/border-edge-policy/src/mat.rs +++ b/border-edge-policy/src/mat.rs @@ -28,6 +28,7 @@ impl Mat { self.shape[1] as usize, x.shape[1] as usize, ); + // println!("{:?}", (m, l, x.shape[0], n)); let mut data = vec![0.0f32; (m * n) as usize]; for i in 0..m as usize { for j in 0..n as usize { @@ -80,14 +81,26 @@ impl Mat { shape: self.shape.clone(), } } + + pub fn empty() -> Self { + Self { + data: vec![], + shape: vec![0, 0], + } + } + + pub fn shape(&self) -> &Vec { + &self.shape + } + + pub fn new(data: Vec, shape: Vec) -> Self { + Self { data, shape } + } } impl From> for Mat { fn from(x: Vec) -> Self { let shape = vec![x.len() as i32, 1]; - Self { - shape, - data: x, - } + Self { shape, data: x } } } diff --git a/border-edge-policy/tests/test.rs b/border-edge-policy/tests/test.rs index f67b4e1d..5e07eda2 100644 --- a/border-edge-policy/tests/test.rs +++ b/border-edge-policy/tests/test.rs @@ -1,12 +1,12 @@ -use tch::Tensor; use border_edge_policy::Mat; +use tch::Tensor; #[test] fn test_matmul() { let x1 = Tensor::from_slice2(&[&[1.0f32, 2., 3.], &[4., 5., 6.]]); let y1 = Tensor::from_slice(&[7.0f32, 8., 9.]); let z1 = x1.matmul(&y1); - + let x2: Mat = x1.into(); let y2: Mat = y1.into(); let z2 = x2.matmul(&y2); diff --git a/border/Cargo.toml b/border/Cargo.toml index 9437b8c6..ca46e2b2 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -115,6 +115,11 @@ path = "examples/gym/convert_sac_policy_to_edge.rs" required-features = ["border-tch-agent", "tch"] test = false +[[example]] +name = "pendulum_edge" +path = "examples/gym/pendulum_edge.rs" +test = false + # [[example]] # name = "sac_ant_async" # path = "examples/mujoco/sac_ant_async.rs" diff --git a/border/examples/gym/pendulum_edge.rs b/border/examples/gym/pendulum_edge.rs new file mode 100644 index 00000000..7ac44794 --- /dev/null +++ b/border/examples/gym/pendulum_edge.rs @@ -0,0 +1,190 @@ +use anyhow::Result; +use border_core::{DefaultEvaluator, Evaluator as _}; +use border_edge_policy::{Mat, Mlp}; +use border_py_gym_env::{ + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use clap::Parser; +use ndarray::ArrayD; +use std::fs; + +type PyObsDtype = f32; + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + /// Observation type. + pub struct Obs(Mat); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(Mat::empty()) + } + + fn len(&self) -> usize { + self.0.shape()[0] as _ + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + let obs = obs.t().to_owned(); + let shape = obs.shape().iter().map(|e| *e as i32).collect(); + let data = obs.into_raw_vec(); + Self(Mat::new(data, shape)) + } + } + + impl From for Mat { + fn from(obs: Obs) -> Mat { + obs.0 + } + } + + #[derive(Clone, Debug)] + pub struct Act(Mat); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + // let shape: Vec<_> = value.0.shape.iter().map(|e| *e as usize).collect(); + let shape = vec![(value.0.shape[0] * value.0.shape[1]) as usize]; + // let data = value.0.data; + let data: Vec = value.0.data.iter().map(|e| 2f32 * *e).collect(); + let t = ArrayD::from_shape_vec(shape, data).unwrap(); + t + } + } + + impl Into for Mat { + fn into(self) -> Act { + Act(self) + } + } +} + +mod policy { + use std::{io::Read, path::Path}; + + use super::*; + use border_core::Policy; + + pub struct MlpPolicy { + mlp: Mlp, + } + + impl Policy for MlpPolicy { + fn sample(&mut self, obs: &Obs) -> Act { + self.mlp.forward(&obs.clone().into()).into() + } + } + + impl MlpPolicy { + pub fn from_serialized_path(path: impl AsRef) -> Result { + let mut file = fs::OpenOptions::new().read(true).open(&path)?; + let mut buf = Vec::::new(); + let _ = file.read_to_end(&mut buf).unwrap(); + let mlp: Mlp = bincode::deserialize(&buf[..])?; + Ok(Self { mlp }) + } + } +} + +use obs_act_types::*; +use policy::*; + +type ObsFilter = ArrayObsFilter; +type ActFilter = ContinuousActFilter; +type Env = GymEnv; +type Evaluator = DefaultEvaluator; + +fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("Pendulum-v1".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) +} + +fn eval(n_episodes: usize, render: bool) -> Result<()> { + let env_config = { + let mut env_config = env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + }; + env_config + }; + let mut policy = MlpPolicy::from_serialized_path( + "./border/examples/gym/model/edge/sac_pendulum/best/mlp.bincode", + )?; + + let _ = Evaluator::new(&env_config, 0, n_episodes)?.evaluate(&mut policy); + + Ok(()) +} + +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + let _ = eval(5, true)?; + + // let args = Args::parse(); + + // if args.train { + // train( + // MAX_OPTS, + // "./border/examples/gym/model/tch/sac_pendulum", + // EVAL_INTERVAL, + // args.mlflow, + // )?; + // } else if args.eval { + // eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + // } else { + // train( + // MAX_OPTS, + // "./border/examples/gym/model/tch/sac_pendulum", + // EVAL_INTERVAL, + // args.mlflow, + // )?; + // eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + // } + + Ok(()) +} + +// #[cfg(test)] +// mod test { +// use super::*; +// use tempdir::TempDir; + +// #[test] +// fn test_sac_pendulum() -> Result<()> { +// tch::manual_seed(42); + +// let model_dir = TempDir::new("sac_pendulum_tch")?; +// let model_dir = model_dir.path().to_str().unwrap(); +// train(100, model_dir, 100, false)?; +// eval(1, false, (model_dir.to_string() + "/best").as_str())?; + +// Ok(()) +// } +// } From 5d07749cb1dec130b2f6e558790d0b1afc687d48 Mon Sep 17 00:00:00 2001 From: taku-y Date: Thu, 11 Jul 2024 09:11:21 +0900 Subject: [PATCH 10/21] Tweaks docs --- border-core/src/lib.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index c14d4a8a..de679b5d 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -11,11 +11,11 @@ //! [`Env`] trait is an abstraction of environments. It has four associated types: //! `Config`, `Obs`, `Act` and `Info`. `Obs` and `Act` are concrete types of //! observation and action of the environment. -//! These must implement [`Obs`] and [`Act`] traits, respectively. +//! These types must implement [`Obs`] and [`Act`] traits, respectively. //! The environment that implements [`Env`] generates [`Step`] object //! at every environment interaction step with [`Env::step()`] method. //! -//! `Info` stores some information at every step of interactions of an agent and +//! [`Info`] stores some information at every step of interactions of an agent and //! the environment. It could be empty (zero-sized struct). `Config` represents //! configurations of the environment and is used to build. //! @@ -32,18 +32,18 @@ //! the agent's policy might be probabilistic for exploration, while in evaluation mode, //! the policy might be deterministic. //! -//! [`Agent::opt()`] method does a single optimization step. The definition of an -//! optimization step depends on each agent. It might be multiple stochastic gradient +//! The [`Agent::opt()`] method performs a single optimization step. The definition of an +//! optimization step varies for each agent. It might be multiple stochastic gradient //! steps in an optimization step. Samples for training are taken from //! [`R: ReplayBufferBase`][`ReplayBufferBase`]. //! -//! This trait also has methods for saving/loading the trained policy -//! in the given directory. +//! This trait also has methods for saving/loading parameters of the trained policy +//! in a directory. //! //! # Batch //! //! [`TransitionBatch`] is a trait of a batch of transitions `(o_t, r_t, a_t, o_t+1)`. -//! This is used to train [`Agent`]s with an RL algorithm. +//! This trait is used to train [`Agent`]s using an RL algorithm. //! //! # Replay buffer //! From f2a068e2512df606630c8ca679abbfaf282ad656 Mon Sep 17 00:00:00 2001 From: taku-y Date: Fri, 12 Jul 2024 08:59:32 +0900 Subject: [PATCH 11/21] Improve docstring Improve docstring in border-derive Tweak Improve docstring in border-async-trainer Improve docstring in border-async-trainer Improve docstring in border-atari-env Tweaks for docstring Improve docstring in border-candle-agent Improve docstring in border-tch-agent Improve docstring in border-tch-agent --- Cargo.toml | 2 +- border-async-trainer/src/actor/base.rs | 14 +- border-async-trainer/src/actor/stat.rs | 4 +- border-async-trainer/src/actor_manager.rs | 2 +- .../src/actor_manager/base.rs | 25 +- .../src/async_trainer/base.rs | 23 +- .../src/async_trainer/config.rs | 18 +- border-async-trainer/src/lib.rs | 388 +++++++++++++++--- .../src/replay_buffer_proxy.rs | 2 +- border-atari-env/src/act.rs | 8 +- border-atari-env/src/atari_env.rs | 1 + border-atari-env/src/env/config.rs | 2 +- border-atari-env/src/lib.rs | 89 ++-- border-atari-env/src/obs.rs | 4 +- border-atari-env/src/util.rs | 1 + border-candle-agent/src/dqn/base.rs | 8 +- border-candle-agent/src/dqn/model.rs | 6 + border-candle-agent/src/lib.rs | 2 + border-candle-agent/src/model.rs | 6 +- border-candle-agent/src/opt.rs | 2 + border-candle-agent/src/sac.rs | 166 +++++++- border-candle-agent/src/sac/config.rs | 2 +- border-candle-agent/src/tensor_batch.rs | 6 +- border-candle-agent/src/util.rs | 26 +- border-core/src/lib.rs | 252 +++++++++++- border-derive/Cargo.toml | 8 +- border-derive/examples/border_atari_act.rs | 8 + border-derive/examples/border_gym_cont_act.rs | 8 + border-derive/examples/border_gym_disc_act.rs | 8 + border-derive/examples/border_tensor_batch.rs | 8 + border-derive/examples/test1.rs | 40 -- border-derive/src/act.rs | 6 +- border-derive/src/lib.rs | 213 +++++++++- border-derive/src/obs.rs | 3 +- border-tch-agent/src/dqn/base.rs | 4 + border-tch-agent/src/dqn/model/base.rs | 10 +- border-tch-agent/src/dqn/model/config.rs | 2 +- border-tch-agent/src/iqn/model/config.rs | 2 +- border-tch-agent/src/mlp/base.rs | 2 +- border-tch-agent/src/model/base.rs | 42 +- border-tch-agent/src/opt.rs | 2 + border-tch-agent/src/sac.rs | 164 +++++++- border-tch-agent/src/sac/actor/base.rs | 3 +- border-tch-agent/src/sac/actor/config.rs | 2 +- border-tch-agent/src/sac/config.rs | 3 +- border-tch-agent/src/sac/critic/config.rs | 2 +- border-tch-agent/src/tensor_batch.rs | 6 +- border-tch-agent/src/util.rs | 9 +- 48 files changed, 1310 insertions(+), 304 deletions(-) create mode 100644 border-derive/examples/border_atari_act.rs create mode 100644 border-derive/examples/border_gym_cont_act.rs create mode 100644 border-derive/examples/border_gym_disc_act.rs create mode 100644 border-derive/examples/border_tensor_batch.rs delete mode 100644 border-derive/examples/test1.rs diff --git a/Cargo.toml b/Cargo.toml index 105b5c6e..96ab8951 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ segment-tree = "2.0.0" image = "0.23.14" candle-core = { version = "=0.4.1", feature = ["cuda", "cudnn"] } candle-nn = "0.4.1" -rand = "0.8.5" +rand = { version = "0.8.5", features = ["small_rng"] } itertools = "0.12.1" ordered-float = "4.2.0" reqwest = { version = "0.11.26", features = ["json", "blocking"] } diff --git a/border-async-trainer/src/actor/base.rs b/border-async-trainer/src/actor/base.rs index 297a6f4f..95d0fdb3 100644 --- a/border-async-trainer/src/actor/base.rs +++ b/border-async-trainer/src/actor/base.rs @@ -10,7 +10,7 @@ use std::{ }; #[cfg_attr(doc, aquamarine::aquamarine)] -/// Runs interaction between an [`Agent`] and an [`Env`], then generates transitions. +/// Generate transitions by running [`Agent`] in [`Env`]. /// /// ```mermaid /// flowchart TB @@ -23,15 +23,18 @@ use std::{ /// C-->|ReplayBufferBase::PushedItem|F[ReplayBufferProxy] /// ``` /// -/// This diagram shows interaction of [`Agent`], [`Env`] and [`StepProcessor`], -/// as shown in [`border_core::Trainer`]. However, this diagram also shows that +/// In [`Actor`], an [`Agent`] runs on an [`Env`] and generates [`Step`] objects. +/// These objects are processed with [`StepProcessor`] and sent to [`ReplayBufferProxy`]. /// The [`Agent`] in the [`Actor`] periodically synchronizes with the [`Agent`] in -/// [`AsyncTrainer`] via [`SyncModel::ModelInfo`], and the transitions generated by -/// [`StepProcessor`] are sent to the [`ReplayBufferProxy`]. +/// [`AsyncTrainer`] via [`SyncModel::ModelInfo`]. /// /// See also the diagram in [`AsyncTrainer`]. /// /// [`AsyncTrainer`]: crate::AsyncTrainer +/// [`Agent`]: border_core::Agent +/// [`Env`]: border_core::Env +/// [`StepProcessor`]: border_core::StepProcessor +/// [`Step`]: border_core::Step pub struct Actor where A: Agent + Configurable + SyncModel, @@ -70,6 +73,7 @@ where env_seed: i64, stats: Arc>>, ) -> Self { + log::info!("Create actor {}", id); Self { id, stop, diff --git a/border-async-trainer/src/actor/stat.rs b/border-async-trainer/src/actor/stat.rs index bc989ffc..3fb26199 100644 --- a/border-async-trainer/src/actor/stat.rs +++ b/border-async-trainer/src/actor/stat.rs @@ -1,12 +1,12 @@ use std::time::Duration; -/// Stats of sampling process in each [`Actor`](crate::Actor). +/// Stats of sampling process in an [`Actor`](crate::Actor). #[derive(Clone, Debug)] pub struct ActorStat { /// The number of steps for interaction between agent and env. pub env_steps: usize, - /// Duration of sampling loop in [`Actor`](crate::Actor). + /// Duration of sampling loop in the [`Actor`](crate::Actor). pub duration: Duration, } diff --git a/border-async-trainer/src/actor_manager.rs b/border-async-trainer/src/actor_manager.rs index 60304b11..546371ff 100644 --- a/border-async-trainer/src/actor_manager.rs +++ b/border-async-trainer/src/actor_manager.rs @@ -1,4 +1,4 @@ -//! A manager of [Actor]()s. +//! A manager of [`Actor`](crate::Actor)s. mod base; mod config; pub use base::ActorManager; diff --git a/border-async-trainer/src/actor_manager/base.rs b/border-async-trainer/src/actor_manager/base.rs index cbfee205..2170921b 100644 --- a/border-async-trainer/src/actor_manager/base.rs +++ b/border-async-trainer/src/actor_manager/base.rs @@ -15,9 +15,11 @@ use std::{ /// Manages [`Actor`]s. /// /// This struct handles the following requests: -/// * From the [LearnerManager]() for updating the latest model info, stored in this struct. +/// * From the [`AsyncTrainer`] for updating the latest model info, stored in this struct. /// * From the [`Actor`]s for getting the latest model info. /// * From the [`Actor`]s for pushing sample batch to the `LearnerManager`. +/// +/// [`AsyncTrainer`]: crate::AsyncTrainer pub struct ActorManager where A: Agent + Configurable + SyncModel, @@ -25,10 +27,10 @@ where P: StepProcessor, R: ExperienceBufferBase + ReplayBufferBase, { - /// Configurations of [Agent]s. + /// Configurations of [`Agent`]s. agent_configs: Vec, - /// Configuration of [Env]. + /// Configuration of [`Env`]. env_config: E::Config, /// Configuration of a `StepProcessor`. @@ -77,7 +79,7 @@ where R::Item: Send + 'static, A::ModelInfo: Send + 'static, { - /// Builds a [ActorManager]. + /// Builds a [`ActorManager`]. pub fn build( config: &ActorManagerConfig, agent_configs: &Vec, @@ -103,10 +105,10 @@ where } } - /// Runs threads for [Actor]s and a thread for sending samples into the replay buffer. + /// Runs threads for [`Actor`]s and a thread for sending samples into the replay buffer. /// - /// A thread will wait for the initial [SyncModel::ModelInfo] from [AsyncTrainer](crate::AsyncTrainer), - /// which blocks execution of [Actor] threads. + /// Each thread is blocked until receiving the initial [`SyncModel::ModelInfo`] + /// from [`AsyncTrainer`](crate::AsyncTrainer). pub fn run(&mut self, guard_init_env: Arc>) { // Guard for sync of the initial model let guard_init_model = Arc::new(Mutex::new(true)); @@ -220,10 +222,11 @@ where // TODO: error handling, timeout // TODO: caching // TODO: stats - let msg = receiver.recv().unwrap(); - _n_samples += 1; - sender.try_send(msg).unwrap(); - // println!("{:?}", (_msg.id, n_samples)); + let msg = receiver.recv(); + if msg.is_ok() { + _n_samples += 1; + sender.try_send(msg.unwrap()).unwrap(); + } // Stop the loop if *stop.lock().unwrap() { diff --git a/border-async-trainer/src/async_trainer/base.rs b/border-async-trainer/src/async_trainer/base.rs index 2a16f5de..ffcc001d 100644 --- a/border-async-trainer/src/async_trainer/base.rs +++ b/border-async-trainer/src/async_trainer/base.rs @@ -33,25 +33,26 @@ use std::{ /// end /// ``` /// -/// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type -/// [`ReplayBufferBase::Item`], in parallel and push the transitions into -/// [`ReplayBufferProxy`]. It should be noted that [`ReplayBufferProxy`] has a -/// type parameter of [`ReplayBufferBase`] and the proxy accepts -/// [`ReplayBufferBase::Item`]. -/// * The proxy sends the transitions into the replay buffer, implementing -/// [`ReplayBufferBase`], in the [`AsyncTrainer`]. -/// * The [`Agent`] in [`AsyncTrainer`] trains its model parameters by using batches +/// * The [`Agent`] in [`AsyncTrainer`] (left) is trained with batches /// of type [`ReplayBufferBase::Batch`], which are taken from the replay buffer. /// * The model parameters of the [`Agent`] in [`AsyncTrainer`] are wrapped in /// [`SyncModel::ModelInfo`] and periodically sent to the [`Agent`]s in [`Actor`]s. -/// [`Agent`] must implement [`SyncModel`] to synchronize its model. +/// [`Agent`] must implement [`SyncModel`] to synchronize the model parameters. +/// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type +/// [`ReplayBufferBase::Item`], and push the transitions into +/// [`ReplayBufferProxy`]. +/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBufferBase`] and the proxy accepts +/// [`ReplayBufferBase::Item`]. +/// * The proxy sends the transitions into the replay buffer in the [`AsyncTrainer`]. /// /// [`ActorManager`]: crate::ActorManager /// [`Actor`]: crate::Actor /// [`ReplayBufferBase::Item`]: border_core::ReplayBufferBase::PushedItem +/// [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::PushedBatch /// [`ReplayBufferProxy`]: crate::ReplayBufferProxy /// [`ReplayBufferBase`]: border_core::ReplayBufferBase /// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo +/// [`Agent`]: border_core::Agent pub struct AsyncTrainer where A: Agent + Configurable + SyncModel, @@ -266,11 +267,8 @@ where }; let mut agent = A::build(self.agent_config.clone()); let mut buffer = R::build(&self.replay_buffer_config); - // let buffer = Arc::new(Mutex::new(R::build(&self.replay_buffer_config))); agent.train(); - // self.run_replay_buffer_thread(buffer.clone()); - let mut max_eval_reward = f32::MIN; let mut opt_steps = 0; let mut samples = 0; @@ -294,7 +292,6 @@ where // Add stats wrt computation cost if opt_steps % self.record_compute_cost_interval == 0 { - // record.insert("fps", Scalar(sampler.fps())); record.insert("opt_steps_per_sec", Scalar(self.opt_steps_per_sec())); } diff --git a/border-async-trainer/src/async_trainer/config.rs b/border-async-trainer/src/async_trainer/config.rs index 48498e36..95f5b115 100644 --- a/border-async-trainer/src/async_trainer/config.rs +++ b/border-async-trainer/src/async_trainer/config.rs @@ -6,7 +6,7 @@ use std::{ path::Path, }; -/// Configuration of [AsyncTrainer](crate::AsyncTrainer) +/// Configuration of [`AsyncTrainer`](crate::AsyncTrainer). #[derive(Clone, Debug, Deserialize, Serialize)] pub struct AsyncTrainerConfig { /// The maximum number of optimization steps. @@ -56,3 +56,19 @@ impl AsyncTrainerConfig { Ok(self) } } + +impl Default for AsyncTrainerConfig { + /// There is no special intention behind these initial values. + fn default() -> Self { + Self { + max_opts: 10, //000, + model_dir: None, + eval_interval: 5000, + flush_record_interval: 5000, + record_compute_cost_interval: 5000, + save_interval: 50000, + sync_interval: 100, + warmup_period: 10000, + } + } +} diff --git a/border-async-trainer/src/lib.rs b/border-async-trainer/src/lib.rs index d6747291..ce775418 100644 --- a/border-async-trainer/src/lib.rs +++ b/border-async-trainer/src/lib.rs @@ -2,61 +2,127 @@ //! //! The code might look like below. //! -//! ```ignore -//! fn train() { -//! let agent_configs: Vec<_> = vec![agent_config()]; -//! let env_config_train = env_config(name); -//! let env_config_eval = env_config(name).eval(); -//! let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?; -//! let step_proc_config = SimpleStepProcessorConfig::default(); -//! let actor_man_config = ActorManagerConfig::default(); -//! let async_trainer_config = load_async_trainer_config(model_dir.as_str())?; -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; -//! -//! // Shared flag to stop actor threads -//! let stop = Arc::new(Mutex::new(false)); +//! ``` +//! # use serde::{Deserialize, Serialize}; +//! # use border_core::test::{ +//! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, +//! # TestAct, TestActBatch +//! # }; +//! # use border_async_trainer::{ +//! # //test::{TestAgent, TestAgentConfig, TestEnv}, +//! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, +//! # }; +//! # use border_core::{ +//! # generic_replay_buffer::{ +//! # SimpleReplayBuffer, SimpleReplayBufferConfig, +//! # SimpleStepProcessorConfig, SimpleStepProcessor +//! # }, +//! # record::{AggregateRecorder, NullRecorder}, DefaultEvaluator, +//! # }; +//! # +//! # fn agent_config() -> TestAgentConfig { +//! # TestAgentConfig +//! # } +//! # +//! # fn env_config() -> usize { +//! # 0 +//! # } +//! +//! type Env = TestEnv; +//! type ObsBatch = TestObsBatch; +//! type ActBatch = TestActBatch; +//! type ReplayBuffer = SimpleReplayBuffer; +//! type StepProcessor = SimpleStepProcessor; +//! +//! // Create a new agent by wrapping the existing agent in order to implement SyncModel. +//! struct TestAgent2(TestAgent); +//! +//! impl border_core::Configurable for TestAgent2 { +//! type Config = TestAgentConfig; +//! +//! fn build(config: Self::Config) -> Self { +//! Self(TestAgent::build(config)) +//! } +//! } +//! +//! impl border_core::Agent for TestAgent2 { +//! // Boilerplate code to delegate the method calls to the inner agent. +//! fn train(&mut self) { +//! self.0.train(); +//! } //! -//! // Creates channels -//! let (item_s, item_r) = unbounded(); // items pushed to replay buffer -//! let (model_s, model_r) = unbounded(); // model_info +//! // For other methods ... +//! # fn is_train(&self) -> bool { +//! # self.0.is_train() +//! # } +//! # +//! # fn eval(&mut self) { +//! # self.0.eval(); +//! # } +//! # +//! # fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { +//! # self.0.opt_with_record(buffer) +//! # } +//! # +//! # fn save_params>(&self, path: T) -> anyhow::Result<()> { +//! # self.0.save_params(path) +//! # } +//! # +//! # fn load_params>(&mut self, path: T) -> anyhow::Result<()> { +//! # self.0.load_params(path) +//! # } +//! # +//! # fn opt(&mut self, buffer: &mut ReplayBuffer) { +//! # self.0.opt_with_record(buffer); +//! # } +//! } //! -//! // guard for initialization of envs in multiple threads -//! let guard_init_env = Arc::new(Mutex::new(true)); +//! impl border_core::Policy for TestAgent2 { +//! // Boilerplate code to delegate the method calls to the inner agent. +//! // ... +//! # fn sample(&mut self, obs: &TestObs) -> TestAct { +//! # self.0.sample(obs) +//! # } +//! } +//! +//! impl border_async_trainer::SyncModel for TestAgent2{ +//! // Self::ModelInfo shold include the model parameters. +//! type ModelInfo = usize; +//! //! -//! // Actor manager and async trainer -//! let mut actors = ActorManager::build( -//! &actor_man_config, -//! &agent_configs, -//! &env_config_train, -//! &step_proc_config, -//! item_s, -//! model_r, -//! stop.clone(), -//! ); -//! let mut trainer = AsyncTrainer::build( -//! &async_trainer_config, -//! &agent_config, -//! &env_config_eval, -//! &replay_buffer_config, -//! item_r, -//! model_s, -//! stop.clone(), -//! ); +//! fn model_info(&self) -> (usize, Self::ModelInfo) { +//! // Extracts the model parameters and returns them as Self::ModelInfo. +//! // The first element of the tuple is the number of optimization steps. +//! (0, 0) +//! } //! -//! // Set the number of threads -//! tch::set_num_threads(1); +//! fn sync_model(&mut self, _model_info: &Self::ModelInfo) { +//! // implements synchronization of the model based on the _model_info +//! } +//! } //! -//! // Starts sampling and training -//! actors.run(guard_init_env.clone()); -//! let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env); -//! println!("Stats of async trainer"); -//! println!("{}", stats.fmt()); +//! let agent_configs: Vec<_> = vec![agent_config()]; +//! let env_config_train = env_config(); +//! let env_config_eval = env_config(); +//! let replay_buffer_config = SimpleReplayBufferConfig::default(); +//! let step_proc_config = SimpleStepProcessorConfig::default(); +//! let actor_man_config = ActorManagerConfig::default(); +//! let async_trainer_config = AsyncTrainerConfig::default(); +//! let mut recorder: Box = Box::new(NullRecorder {}); +//! let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); //! -//! let stats = actors.stop_and_join(); -//! println!("Stats of generated samples in actors"); -//! println!("{}", actor_stats_fmt(&stats)); -//! } +//! border_async_trainer::util::train_async::<_, _, _, StepProcessor>( +//! &agent_config(), +//! &agent_configs, +//! &env_config_train, +//! &env_config_eval, +//! &step_proc_config, +//! &replay_buffer_config, +//! &actor_man_config, +//! &async_trainer_config, +//! &mut recorder, +//! &mut evaluator, +//! ); //! ``` //! //! Training process consists of the following two components: @@ -89,6 +155,7 @@ mod messages; mod replay_buffer_proxy; mod sync_model; pub mod util; + pub use actor::{actor_stats_fmt, Actor, ActorStat}; pub use actor_manager::{ActorManager, ActorManagerConfig}; pub use async_trainer::{AsyncTrainStat, AsyncTrainer, AsyncTrainerConfig}; @@ -96,3 +163,226 @@ pub use error::BorderAsyncTrainerError; pub use messages::PushedItemMessage; pub use replay_buffer_proxy::{ReplayBufferProxy, ReplayBufferProxyConfig}; pub use sync_model::SyncModel; + +/// Agent and Env for testing. +#[cfg(test)] +pub mod test { + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl border_core::Obs for TestObs { + fn dummy(_n: usize) -> Self { + Self { obs: 0 } + } + + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl border_core::generic_replay_buffer::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl border_core::Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl border_core::generic_replay_buffer::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl border_core::Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl border_core::Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset( + &mut self, + a: &Self::Act, + ) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = border_core::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, border_core::record::Record::empty()); + } + + fn step(&mut self, a: &Self::Act) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = border_core::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, border_core::record::Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = + border_core::generic_replay_buffer::SimpleReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl border_core::Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> border_core::record::Record { + border_core::record::Record::empty() + } + + fn save_params>(&self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + + fn load_params>(&mut self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + } + + impl border_core::Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl border_core::Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } + + impl crate::SyncModel for TestAgent { + type ModelInfo = usize; + + fn model_info(&self) -> (usize, Self::ModelInfo) { + (0, 0) + } + + fn sync_model(&mut self, _model_info: &Self::ModelInfo) { + // nothing to do + } + } +} diff --git a/border-async-trainer/src/replay_buffer_proxy.rs b/border-async-trainer/src/replay_buffer_proxy.rs index ccd263e0..263c5beb 100644 --- a/border-async-trainer/src/replay_buffer_proxy.rs +++ b/border-async-trainer/src/replay_buffer_proxy.rs @@ -9,7 +9,7 @@ use std::marker::PhantomData; pub struct ReplayBufferProxyConfig { /// Number of samples buffered until sent to the trainer. /// - /// Here, a sample corresponds to a `R::Item` for [`ReplayBufferProxy`]``. + /// A sample is a `R::Item` for [`ReplayBufferProxy`]``. pub n_buffer: usize, } diff --git a/border-atari-env/src/act.rs b/border-atari-env/src/act.rs index 9ee50316..6feae6bf 100644 --- a/border-atari-env/src/act.rs +++ b/border-atari-env/src/act.rs @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize}; use std::{default::Default, marker::PhantomData}; #[derive(Debug, Clone)] -/// Action for [BorderAtariEnv](crate::BorderAtariEnv) +/// Action for [`BorderAtariEnv`](crate::BorderAtariEnv). +/// +/// This action is a discrete action and denotes pushing a button. pub struct BorderAtariAct { pub act: u8, } @@ -28,7 +30,7 @@ impl From for BorderAtariAct { } } -/// Converts `A` to [`BorderAtariAct`]. +/// Converts action of type `A` to [`BorderAtariAct`]. pub trait BorderAtariActFilter { /// Configuration of the filter. type Config: Clone + Default; @@ -56,7 +58,7 @@ impl Default for BorderAtariActRawFilterConfig { } } -/// A filter without any processing. +/// A filter that performs no processing. pub struct BorderAtariActRawFilter { phantom: PhantomData, } diff --git a/border-atari-env/src/atari_env.rs b/border-atari-env/src/atari_env.rs index ec4d0758..b2731cd2 100644 --- a/border-atari-env/src/atari_env.rs +++ b/border-atari-env/src/atari_env.rs @@ -1,3 +1,4 @@ +//! Atari environment for reinforcement learning. pub mod ale; use std::path::Path; diff --git a/border-atari-env/src/env/config.rs b/border-atari-env/src/env/config.rs index 1f410644..9abb742b 100644 --- a/border-atari-env/src/env/config.rs +++ b/border-atari-env/src/env/config.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use std::{default::Default, env}; #[derive(Serialize, Deserialize, Debug)] -/// Configurations of [`BorderAtariEnv`](super::BorderAtariEnv). +/// Configuration of [`BorderAtariEnv`](super::BorderAtariEnv). pub struct BorderAtariEnvConfig where O: Obs, diff --git a/border-atari-env/src/lib.rs b/border-atari-env/src/lib.rs index 0f2d8277..7d47560e 100644 --- a/border-atari-env/src/lib.rs +++ b/border-atari-env/src/lib.rs @@ -1,13 +1,13 @@ -//! A thin wrapper of [atari-env](https://crates.io/crates/atari-env) for [Border](https://crates.io/crates/border). +//! A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). //! //! The code under [atari_env] is adapted from the -//! [atari-env](https://crates.io/crates/atari-env) crate +//! [`atari-env`](https://crates.io/crates/atari-env) crate //! (rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), //! because the crate registered in crates.io does not implement //! [`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. //! //! This environment applies some preprocessing to observation as in -//! [atari_wrapper.py](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). +//! [`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). //! //! You need to place Atari Rom directories under the directory specified by environment variable //! `ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) @@ -28,55 +28,50 @@ //! BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, //! BorderAtariObs, BorderAtariObsRawFilter, //! }; -//! use border_core::{util, Env as _, Policy, DefaultEvaluator, Evaluator as _}; -//! -//! type Obs = BorderAtariObs; -//! type Act = BorderAtariAct; -//! type ObsFilter = BorderAtariObsRawFilter; -//! type ActFilter = BorderAtariActRawFilter; -//! type EnvConfig = BorderAtariEnvConfig; -//! type Env = BorderAtariEnv; -//! -//! #[derive(Clone)] -//! struct RandomPolicyConfig { -//! pub n_acts: usize, -//! } -//! -//! struct RandomPolicy { -//! n_acts: usize, -//! } -//! -//! impl Policy for RandomPolicy { -//! type Config = RandomPolicyConfig; -//! -//! fn build(config: Self::Config) -> Self { -//! Self { -//! n_acts: config.n_acts, -//! } -//! } -//! -//! fn sample(&mut self, _: &Obs) -> Act { -//! fastrand::u8(..self.n_acts as u8).into() -//! } -//! } -//! -//! fn env_config(name: String) -> EnvConfig { -//! EnvConfig::default().name(name) -//! } -//! +//! use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _}; +//! +//! # type Obs = BorderAtariObs; +//! # type Act = BorderAtariAct; +//! # type ObsFilter = BorderAtariObsRawFilter; +//! # type ActFilter = BorderAtariActRawFilter; +//! # type EnvConfig = BorderAtariEnvConfig; +//! # type Env = BorderAtariEnv; +//! # +//! # #[derive(Clone)] +//! # struct RandomPolicyConfig { +//! # pub n_acts: usize, +//! # } +//! # +//! # struct RandomPolicy { +//! # n_acts: usize, +//! # } +//! # +//! # impl RandomPolicy { +//! # pub fn build(n_acts: usize) -> Self { +//! # Self { n_acts } +//! # } +//! # } +//! # +//! # impl Policy for RandomPolicy { +//! # fn sample(&mut self, _: &Obs) -> Act { +//! # fastrand::u8(..self.n_acts as u8).into() +//! # } +//! # } +//! # +//! # fn env_config(name: String) -> EnvConfig { +//! # EnvConfig::default().name(name) +//! # } +//! # //! fn main() -> Result<()> { -//! env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); -//! fastrand::seed(42); -//! +//! # env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); +//! # fastrand::seed(42); +//! # //! // Creates Pong environment //! let env_config = env_config("pong".to_string()); //! //! // Creates a random policy -//! let n_acts = 4; // number of actions; -//! let policy_config = RandomPolicyConfig { -//! n_acts: n_acts as _, -//! }; -//! let mut policy = RandomPolicy::build(policy_config); +//! let n_acts = 4; +//! let mut policy = RandomPolicy::build(n_acts); //! //! // Runs evaluation //! let env_config = env_config.render(true); diff --git a/border-atari-env/src/obs.rs b/border-atari-env/src/obs.rs index f37e4ab0..2a5e4b59 100644 --- a/border-atari-env/src/obs.rs +++ b/border-atari-env/src/obs.rs @@ -67,7 +67,7 @@ impl From for Tensor { } } -/// Converts [`BorderAtariObs`] to `O` with an arbitrary processing. +/// Converts [`BorderAtariObs`] to observation of type `O` with an arbitrary processing. pub trait BorderAtariObsFilter { /// Configuration of the filter. type Config: Clone + Default; @@ -98,7 +98,7 @@ impl Default for BorderAtariObsRawFilterConfig { } } -/// A filter without any processing. +/// A filter that performs no processing. pub struct BorderAtariObsRawFilter { phantom: PhantomData, } diff --git a/border-atari-env/src/util.rs b/border-atari-env/src/util.rs index 7b788c20..8788d89b 100644 --- a/border-atari-env/src/util.rs +++ b/border-atari-env/src/util.rs @@ -1 +1,2 @@ +//! Utility functions for testing. pub mod test; diff --git a/border-candle-agent/src/dqn/base.rs b/border-candle-agent/src/dqn/base.rs index a7f2d7fb..f67055b0 100644 --- a/border-candle-agent/src/dqn/base.rs +++ b/border-candle-agent/src/dqn/base.rs @@ -1,4 +1,4 @@ -//! DQN agent implemented with tch-rs. +//! DQN agent implemented with candle. use super::{config::DqnConfig, explorer::DqnExplorer, model::DqnModel}; use crate::{ model::SubModel1, @@ -17,7 +17,7 @@ use std::convert::TryFrom; use std::{fs, marker::PhantomData, path::Path}; #[allow(clippy::upper_case_acronyms, dead_code)] -/// DQN agent implemented with tch-rs. +/// DQN agent implemented with candle. pub struct Dqn where Q: SubModel1, @@ -330,6 +330,10 @@ where record } + /// Save model parameters in the given directory. + /// + /// The parameters of the model are saved as `qnet.pt`. + /// The parameters of the target model are saved as `qnet_tgt.pt`. fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; diff --git a/border-candle-agent/src/dqn/model.rs b/border-candle-agent/src/dqn/model.rs index c14f888d..90efee8b 100644 --- a/border-candle-agent/src/dqn/model.rs +++ b/border-candle-agent/src/dqn/model.rs @@ -79,6 +79,12 @@ where } } +/// Action value function model for DQN. +/// +/// The architecture of the model is defined by the type parameter `Q`, +/// which should implement [`SubModel1`]. +/// This takes [`SubModel1::Input`] as input and outputs a tensor. +/// The output tensor should have the same dimension as the number of actions. pub struct DqnModel where Q: SubModel1, diff --git a/border-candle-agent/src/lib.rs b/border-candle-agent/src/lib.rs index 2ae29440..9809a488 100644 --- a/border-candle-agent/src/lib.rs +++ b/border-candle-agent/src/lib.rs @@ -16,6 +16,8 @@ pub use tensor_batch::{TensorBatch, ZeroTensor}; /// Device for using candle. /// /// This enum is added because [`candle_core::Device`] does not support serialization. +/// +/// [`candle_core::Device`]: https://docs.rs/candle-core/0.4.1/candle_core/enum.Device.html pub enum Device { /// The main CPU device. Cpu, diff --git a/border-candle-agent/src/model.rs b/border-candle-agent/src/model.rs index 43f245c4..6fbacc6a 100644 --- a/border-candle-agent/src/model.rs +++ b/border-candle-agent/src/model.rs @@ -7,7 +7,7 @@ use candle_nn::VarBuilder; /// Neural network model not owing its [`VarMap`] internally. /// -/// [`VarMap`]: candle_nn::VarMap +/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html pub trait SubModel1 { /// Configuration from which [`SubModel1`] is constructed. type Config; @@ -19,6 +19,8 @@ pub trait SubModel1 { type Output; /// Builds [`SubModel1`] with [`VarBuilder`] and [`SubModel1::Config`]. + /// + /// [`VarBuilder`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_builder/type.VarBuilder.html fn build(vb: VarBuilder, config: Self::Config) -> Self; /// A generalized forward function. @@ -29,7 +31,7 @@ pub trait SubModel1 { /// /// The difference from [`SubModel1`] is that this trait takes two inputs. /// -/// [`VarMap`]: candle_nn::VarMap +/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html pub trait SubModel2 { /// Configuration from which [`SubModel2`] is constructed. type Config; diff --git a/border-candle-agent/src/opt.rs b/border-candle-agent/src/opt.rs index 48fe7cd4..0dff9522 100644 --- a/border-candle-agent/src/opt.rs +++ b/border-candle-agent/src/opt.rs @@ -113,6 +113,8 @@ impl Default for OptimizerConfig { /// Optimizers. /// /// This is a thin wrapper of [`candle_nn::optim::Optimizer`]. +/// +/// [`candle_nn::optim::Optimizer`]: https://docs.rs/candle-nn/0.4.1/candle_nn/optim/trait.Optimizer.html pub enum Optimizer { /// Adam optimizer. AdamW(AdamW), diff --git a/border-candle-agent/src/sac.rs b/border-candle-agent/src/sac.rs index bd2b31ea..89164dfd 100644 --- a/border-candle-agent/src/sac.rs +++ b/border-candle-agent/src/sac.rs @@ -1,10 +1,156 @@ //! SAC agent. //! -//! Here is an example in `border/examples/sac_pendulum.rs` +//! Here is an example of creating SAC agent: //! -//! ```rust,ignore +//! ```no_run +//! # use anyhow::Result; +//! use border_core::{ +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }, +//! # record::Record, +//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, +//! Configurable, +//! }; +//! use border_candle_agent::{ +//! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, +//! mlp::{Mlp, Mlp2, MlpConfig}, +//! opt::OptimizerConfig +//! }; +//! +//! # struct TestEnv(TestEnv_); +//! # #[derive(Clone, Debug)] +//! # struct TestObs(TestObs_); +//! # #[derive(Clone, Debug)] +//! # struct TestAct(TestAct_); +//! # struct TestObsBatch(TestObsBatch_); +//! # struct TestActBatch(TestActBatch_); +//! # +//! # impl Obs_ for TestObs { +//! # fn dummy(n: usize) -> Self { +//! # Self(TestObs_::dummy(n)) +//! # } +//! # +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl Into for TestObs { +//! # fn into(self) -> candle_core::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl BatchBase for TestObsBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestObsBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl BatchBase for TestActBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestActBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl Act_ for TestAct { +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl From for TestAct { +//! # fn from(t: candle_core::Tensor) -> Self { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Into for TestAct { +//! # fn into(self) -> candle_core::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Env_ for TestEnv { +//! # type Config = ::Config; +//! # type Obs = TestObs; +//! # type Act = TestAct; +//! # type Info = ::Info; +//! # +//! # fn build(config: &Self::Config, seed: i64) -> Result { +//! # Ok(Self(TestEnv_::build(&config, seed).unwrap())) +//! # } +//! # +//! # fn step(&mut self, act: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step(&act.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset(&mut self, is_done: Option<&Vec>) -> Result { +//! # Ok(TestObs(self.0.reset(is_done).unwrap())) +//! # } +//! # +//! # fn step_with_reset(&mut self, a: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step_with_reset(&a.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset_with_index(&mut self, ix: usize) -> Result { +//! # Ok(TestObs(self.0.reset_with_index(ix).unwrap())) +//! # } +//! # } +//! # +//! # type Env = TestEnv; +//! # type ObsBatch = TestObsBatch; +//! # type ActBatch = TestActBatch; +//! # type ReplayBuffer = SimpleReplayBuffer; +//! # +//! const DIM_OBS: i64 = 3; +//! const DIM_ACT: i64 = 1; +//! const LR_ACTOR: f64 = 1e-3; +//! const LR_CRITIC: f64 = 1e-3; +//! const BATCH_SIZE: usize = 256; +//! //! fn create_agent(in_dim: i64, out_dim: i64) -> Sac { -//! let device = tch::Device::cuda_if_available(); +//! let device = candle_core::Device::cuda_if_available(0).unwrap(); +//! //! let actor_config = ActorConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) //! .out_dim(out_dim) @@ -12,25 +158,13 @@ //! let critic_config = CriticConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) //! .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); -//! let sac_config = SacConfig::default() +//! let sac_config = SacConfig::::default() //! .batch_size(BATCH_SIZE) -//! .min_transitions_warmup(N_TRANSITIONS_WARMUP) //! .actor_config(actor_config) //! .critic_config(critic_config) //! .device(device); //! Sac::build(sac_config) //! } -//! -//! fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { -//! let trainer = //... -//! let mut agent = create_agent(DIM_OBS, DIM_ACT); -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; -//! -//! trainer.train(&mut agent, &mut recorder, &mut evaluator)?; -//! -//! Ok(()) -//! } //! ``` mod actor; mod base; diff --git a/border-candle-agent/src/sac/config.rs b/border-candle-agent/src/sac/config.rs index b789090f..3256740f 100644 --- a/border-candle-agent/src/sac/config.rs +++ b/border-candle-agent/src/sac/config.rs @@ -18,7 +18,7 @@ use std::{ path::Path, }; -/// Constructs [`Sac`](super::Sac). +/// Configuration of [`Sac`](super::Sac). #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq)] pub struct SacConfig diff --git a/border-candle-agent/src/tensor_batch.rs b/border-candle-agent/src/tensor_batch.rs index 21c031bb..410ac023 100644 --- a/border-candle-agent/src/tensor_batch.rs +++ b/border-candle-agent/src/tensor_batch.rs @@ -1,7 +1,9 @@ use border_core::generic_replay_buffer::BatchBase; use candle_core::{error::Result, DType, Device, Tensor}; -/// Adds capability of constructing [Tensor] with a static method. +/// Adds capability of constructing [`Tensor`] with a static method. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html pub trait ZeroTensor { /// Constructs zero tensor. fn zeros(shape: &[usize]) -> Result; @@ -28,6 +30,8 @@ impl ZeroTensor for i64 { /// A buffer consisting of a [`Tensor`]. /// /// The internal buffer is `Vec`. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html #[derive(Clone, Debug)] pub struct TensorBatch { buf: Vec, diff --git a/border-candle-agent/src/util.rs b/border-candle-agent/src/util.rs index dad828f6..1e967546 100644 --- a/border-candle-agent/src/util.rs +++ b/border-candle-agent/src/util.rs @@ -23,27 +23,10 @@ pub enum CriticLoss { SmoothL1, } -// /// Apply soft update on a model. -// /// -// /// Variables are identified by their names. -// pub fn track(dest: &mut M, src: &mut M, tau: f64) { -// let src = &mut src.get_var_store().variables(); -// let dest = &mut dest.get_var_store().variables(); -// debug_assert_eq!(src.len(), dest.len()); - -// let names = src.keys(); -// tch::no_grad(|| { -// for name in names { -// let src = src.get(name).unwrap(); -// let dest = dest.get_mut(name).unwrap(); -// dest.copy_(&(tau * src + (1.0 - tau) * &*dest)); -// } -// }); -// trace!("soft update"); -// } - -/// Apply soft update on model parameters. +/// Apply soft update on variables. /// +/// Variables are identified by their names. +/// /// dest = tau * src + (1.0 - tau) * dest pub fn track(dest: &VarMap, src: &VarMap, tau: f64) -> Result<()> { trace!("dest"); @@ -69,6 +52,7 @@ pub fn track(dest: &VarMap, src: &VarMap, tau: f64) -> Result<()> { // v // } +/// Interface for handling output dimensions. pub trait OutDim { /// Returns the output dimension. fn get_out_dim(&self) -> i64; @@ -141,6 +125,7 @@ pub fn smooth_l1_loss(x: &Tensor, y: &Tensor) -> Result f32 { t.broadcast_sub(&t.mean_all().unwrap()) .unwrap() @@ -154,6 +139,7 @@ pub fn std(t: &Tensor) -> f32 { .unwrap() } +/// Returns the mean and standard deviation of the parameters. pub fn param_stats(varmap: &VarMap) -> Record { let mut record = Record::empty(); diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index de679b5d..0f8efd09 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -4,7 +4,8 @@ //! # Observation and action //! //! [`Obs`] and [`Act`] traits are abstractions of observation and action in environments. -//! These traits can handle two or more samples for implementing vectorized environments. +//! These traits can handle two or more samples for implementing vectorized environments, +//! although there is currently no implementation of vectorized environment. //! //! # Environment //! @@ -14,16 +15,14 @@ //! These types must implement [`Obs`] and [`Act`] traits, respectively. //! The environment that implements [`Env`] generates [`Step`] object //! at every environment interaction step with [`Env::step()`] method. -//! //! [`Info`] stores some information at every step of interactions of an agent and //! the environment. It could be empty (zero-sized struct). `Config` represents //! configurations of the environment and is used to build. //! //! # Policy //! -//! [`Policy`] represents a policy, from which actions are sampled for -//! environment `E`. [`Policy::sample()`] takes `E::Obs` and emits `E::Act`. -//! It could be probabilistic or deterministic. +//! [`Policy`] represents a policy. [`Policy::sample()`] takes `E::Obs` and +//! generates `E::Act`. It could be probabilistic or deterministic. //! //! # Agent //! @@ -45,21 +44,23 @@ //! [`TransitionBatch`] is a trait of a batch of transitions `(o_t, r_t, a_t, o_t+1)`. //! This trait is used to train [`Agent`]s using an RL algorithm. //! -//! # Replay buffer +//! # Replay buffer and experience buffer //! -//! [`ReplayBufferBase`] trait is an abstraction of replay buffers. For handling samples, -//! there are two associated types: `Item` and `Batch`. `Item` is a type -//! representing samples pushed to the buffer. These samples might be generated from -//! [`Step`]. [`StepProcessor`] trait provides the interface -//! for converting [`Step`] into `Item`. +//! [`ReplayBufferBase`] trait is an abstraction of replay buffers. +//! One of the associated type [`ReplayBufferBase::Batch`] represents samples taken from +//! the buffer for training [`Agent`]s. Agents must implements [`Agent::opt()`] method, +//! where [`ReplayBufferBase::Batch`] has an appropriate type or trait bound(s) to train +//! the agent. //! -//! `Batch` is a type of samples taken from the buffer for training [`Agent`]s. -//! The user implements [`Agent::opt()`] method such that it handles `Batch` objects -//! for doing an optimization step. +//! As explained above, [`ReplayBufferBase`] trait has an ability to generates batches +//! of samples with which agents are trained. On the other hand, [`ExperienceBufferBase`] +//! trait has an ability to store samples. [`ExperienceBufferBase::push()`] is used to push +//! samples of type [`ExperienceBufferBase::Item`], which might be obtained via interaction +//! steps with an environment. //! //! ## A reference implementation //! -//! [`SimpleReplayBuffer`] implementats [`ReplayBufferBase`]. +//! [`SimpleReplayBuffer`] implementats both [`ReplayBufferBase`] and [`ExperienceBufferBase`]. //! This type has two parameters `O` and `A`, which are representation of //! observation and action in the replay buffer. `O` and `A` must implement //! [`BatchBase`], which has the functionality of storing samples, like `Vec`, @@ -74,10 +75,12 @@ //! # Trainer //! //! [`Trainer`] manages training loop and related objects. The [`Trainer`] object is -//! built with configurations of [`Env`], [`ReplayBufferBase`], [`StepProcessor`] -//! and some training parameters. Then, [`Trainer::train`] method starts training loop with -//! given [`Agent`] and [`Recorder`](crate::record::Recorder). -//! +//! built with configurations of training parameters such as the maximum number of +//! optimization steps, model directory to save parameters of the agent during training, etc. +//! [`Trainer::train`] method executes online training of an agent on an environment. +//! In the training loop of this method, the agent interacts with the environment to +//! take samples and perform optimization steps. Some metrices are recorded at the same time. +//! //! [`SimpleReplayBuffer`]: replay_buffer::SimpleReplayBuffer //! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer //! [`BatchBase`]: generic_replay_buffer::BatchBase @@ -98,3 +101,214 @@ pub use base::{ mod trainer; pub use evaluator::{DefaultEvaluator, Evaluator}; pub use trainer::{Sampler, Trainer, TrainerConfig}; + +// TODO: Consider to compile this module only for tests. +/// Agent and Env for testing. +pub mod test { + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl crate::Obs for TestObs { + fn dummy(_n: usize) -> Self { + Self { obs: 0 } + } + + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl crate::generic_replay_buffer::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl crate::Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl crate::generic_replay_buffer::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl crate::Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl crate::Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset( + &mut self, + a: &Self::Act, + ) -> (crate::Step, crate::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = crate::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, crate::record::Record::empty()); + } + + fn step(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = crate::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, crate::record::Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = + crate::generic_replay_buffer::SimpleReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl crate::Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record { + crate::record::Record::empty() + } + + fn save_params>(&self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + + fn load_params>(&mut self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + } + + impl crate::Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl crate::Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } +} diff --git a/border-derive/Cargo.toml b/border-derive/Cargo.toml index 53ba7fff..70a4a40f 100644 --- a/border-derive/Cargo.toml +++ b/border-derive/Cargo.toml @@ -25,16 +25,10 @@ border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } border-core = { version = "0.0.7", path = "../border-core" } +border-atari-env = { version = "0.0.7", path = "../border-atari-env" } ndarray = { workspace = true } tch = { workspace = true } candle-core = { workspace = true } -# [features] -# default = ["tch"] - -[[example]] -name = "test1" -required-features = ["tch"] - [package.metadata.docs.rs] features = ["doc-only"] diff --git a/border-derive/examples/border_atari_act.rs b/border-derive/examples/border_atari_act.rs new file mode 100644 index 00000000..4f7b002a --- /dev/null +++ b/border-derive/examples/border_atari_act.rs @@ -0,0 +1,8 @@ +use border_atari_env::BorderAtariAct; +use border_derive::Act; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(BorderAtariAct); + +fn main() {} diff --git a/border-derive/examples/border_gym_cont_act.rs b/border-derive/examples/border_gym_cont_act.rs new file mode 100644 index 00000000..9015aca0 --- /dev/null +++ b/border-derive/examples/border_gym_cont_act.rs @@ -0,0 +1,8 @@ +use border_derive::Act; +use border_py_gym_env::GymContinuousAct; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(GymContinuousAct); + +fn main() {} diff --git a/border-derive/examples/border_gym_disc_act.rs b/border-derive/examples/border_gym_disc_act.rs new file mode 100644 index 00000000..05d0ea07 --- /dev/null +++ b/border-derive/examples/border_gym_disc_act.rs @@ -0,0 +1,8 @@ +use border_derive::Act; +use border_py_gym_env::GymDiscreteAct; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(GymDiscreteAct); + +fn main() {} diff --git a/border-derive/examples/border_tensor_batch.rs b/border-derive/examples/border_tensor_batch.rs new file mode 100644 index 00000000..697e7b32 --- /dev/null +++ b/border-derive/examples/border_tensor_batch.rs @@ -0,0 +1,8 @@ +use border_derive::BatchBase; +use border_tch_agent::TensorBatch; + +#[allow(dead_code)] +#[derive(Clone, BatchBase)] +pub struct ObsBatch(TensorBatch); + +fn main() {} diff --git a/border-derive/examples/test1.rs b/border-derive/examples/test1.rs deleted file mode 100644 index 3551ff5a..00000000 --- a/border-derive/examples/test1.rs +++ /dev/null @@ -1,40 +0,0 @@ -use border_derive::{Act, SubBatch}; -use border_py_gym_env::GymDiscreteAct; -use border_tch_agent::TensorBatch; -use ndarray::ArrayD; -use std::convert::TryFrom; -use tch::Tensor; - -#[derive(Debug, Clone)] -struct Obs(ArrayD); - -#[derive(SubBatch)] -struct ObsBatch(TensorBatch); - -impl From for Tensor { - fn from(value: Obs) -> Self { - Tensor::try_from(&value.0).unwrap() - } -} - -impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorBatch::from_tensor(tensor)) - } -} - -#[derive(Clone, Debug, Act)] -struct Act(GymDiscreteAct); - -#[derive(SubBatch)] -struct ActBatch(TensorBatch); - -impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorBatch::from_tensor(tensor)) - } -} - -fn main() {} diff --git a/border-derive/src/act.rs b/border-derive/src/act.rs index f18566f3..1ddf9ea5 100644 --- a/border-derive/src/act.rs +++ b/border-derive/src/act.rs @@ -56,7 +56,8 @@ fn py_gym_env_cont_act( .iter() .map(|x| *x as usize) .collect::>(); - let act: Vec = t.into(); + use std::convert::TryInto; + let act: Vec = t.try_into().unwrap(); let act = ndarray::Array1::::from(act).into_shape(ndarray::IxDyn(&shape)).unwrap(); @@ -121,7 +122,8 @@ fn py_gym_env_disc_act( impl From for #ident { fn from(t: tch::Tensor) -> Self { - let data: Vec = t.into(); + use std::convert::TryInto; + let data: Vec = t.try_into().unwrap(); let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); #ident(GymDiscreteAct::new(data)) } diff --git a/border-derive/src/lib.rs b/border-derive/src/lib.rs index 020fc23d..e5874c76 100644 --- a/border-derive/src/lib.rs +++ b/border-derive/src/lib.rs @@ -1,9 +1,207 @@ -//! Derive macros for making newtypes of types that implements -//! `border_core::Obs`, `border_core::Act` and -//! `order_core::replay_buffer::SubBatch`. +//! Derive macros for implementing [`border_core::Act`] and +//! [`border_core::generic_replay_buffer::BatchBase`]. //! -//! These macros will implements some conversion traits for combining -//! interfaces of an environment and an agent. +//! # Examples +//! +//! ## Newtype for [`BorderAtariAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_atari_env::BorderAtariAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(BorderAtariAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_atari_env::BorderAtariAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(BorderAtariAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> BorderAtariAct { +//! self.0 +//! } +//! } +//! /// The following code is generated when features="tch" is enabled. +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = vec![act.0.act as i64]; +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! fn from(t: tch::Tensor) -> Self { +//! let data: Vec = { +//! let t = t.to_dtype(tch::Kind::Int64, false, true); +//! let n = t.numel(); +//! let mut data = vec![0i64; n]; +//! t.f_copy_data(&mut data, n).unwrap(); +//! data +//! }; +//! MyAct(BorderAtariAct::new(data[0] as u8)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`GymContinuousAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymContinuousAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(GymContinuousAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymContinuousAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(GymContinuousAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> GymContinuousAct { +//! self.0 +//! } +//! } +//! /// The following code is generated when features="tch" is enabled. +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = act.0.act.iter().map(|e| *e as f32).collect::>(); +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! /// `t` must be a 1-dimentional tensor of `f32`. +//! fn from(t: tch::Tensor) -> Self { +//! let shape = t.size()[1..].iter().map(|x| *x as usize).collect::>(); +//! use std::convert::TryInto; +//! let act: Vec = t.try_into().unwrap(); +//! let act = ndarray::Array1::::from(act) +//! .into_shape(ndarray::IxDyn(&shape)) +//! .unwrap(); +//! MyAct(GymContinuousAct::new(act)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`GymDiscreteAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymDiscreteAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(GymDiscreteAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymDiscreteAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(GymDiscreteAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> GymDiscreteAct { +//! self.0 +//! } +//! } +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = act.0.act.iter().map(|e| *e as i64).collect::>(); +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! fn from(t: tch::Tensor) -> Self { +//! use std::convert::TryInto; +//! let data: Vec = t.try_into().unwrap(); +//! let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); +//! MyAct(GymDiscreteAct::new(data)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`TensorBatch`] +//! +//! ``` +//! # use border_derive::BatchBase; +//! # use border_tch_agent::TensorBatch; +//! # +//! #[derive(Clone, BatchBase)] +//! struct MyBatch(TensorBatch); +//! ``` +//! +//! The above code will generate the following implementation: +//! +//! ``` +//! # use border_derive::BatchBase; +//! # use border_tch_agent::TensorBatch; +//! # +//! #[derive(Clone)] +//! struct ObsBatch(TensorBatch); +//! impl border_core::generic_replay_buffer::BatchBase for ObsBatch { +//! fn new(capacity: usize) -> Self { +//! Self(TensorBatch::new(capacity)) +//! } +//! fn push(&mut self, i: usize, data: Self) { +//! self.0.push(i, data.0) +//! } +//! fn sample(&self, ixs: &Vec) -> Self { +//! let buf = self.0.sample(ixs); +//! Self(buf) +//! } +//! } +//! impl From for ObsBatch { +//! fn from(obs: TensorBatch) -> Self { +//! ObsBatch(obs) +//! } +//! } +//! impl From for tch::Tensor { +//! fn from(b: ObsBatch) -> Self { +//! b.0.into() +//! } +//! } +//! ``` +//! +//! [`border_core::Obs`]: border_core::Obs +//! [`border_core::Act`]: border_core::Act +//! [`border_core::generic_replay_buffer::BatchBase`]: border_core::generic_replay_buffer::BatchBase +//! [`BorderAtariAct`]: border_atari_env::BorderAtariAct +//! [`GymContinuousAct`]: border_py_gym_env::GymContinuousAct +//! [`GymDiscreteAct`]: border_py_gym_env::GymDiscreteAct +//! [`TensorBatch`]: border_tch_agent::TensorBatch + mod act; mod obs; mod subbatch; @@ -11,18 +209,19 @@ use proc_macro::{self, TokenStream}; /// Implements `border_core::Obs` for the newtype that wraps /// PyGymEnvObs or BorderAtariObs. +#[deprecated] #[proc_macro_derive(Obs, attributes(my_trait))] pub fn derive1(input: TokenStream) -> TokenStream { obs::derive(input) } -/// Implements `border_core::generic_replay_buffer::BatchBase` for the newtype. +/// Implements [`border_core::generic_replay_buffer::BatchBase`] for the newtype. #[proc_macro_derive(BatchBase, attributes(my_trait))] pub fn derive2(input: TokenStream) -> TokenStream { subbatch::derive(input) } -/// Implements `border_core::Act` for the newtype. +/// Implements [`border_core::Act`] for the newtype. #[proc_macro_derive(Act, attributes(my_trait))] pub fn derive3(input: TokenStream) -> TokenStream { act::derive(input) diff --git a/border-derive/src/obs.rs b/border-derive/src/obs.rs index 8b45dda3..a63c4d7c 100644 --- a/border-derive/src/obs.rs +++ b/border-derive/src/obs.rs @@ -5,12 +5,11 @@ use syn::{parse_macro_input, DeriveInput}; pub fn derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input); - // let opts = Opts::from_derive_input(&input).expect("Wrong options"); let DeriveInput { ident, data, .. } = input; let field_type = get_field_type(data); let field_type_str = get_type_str( field_type.clone(), - "The item for deriving Obs must be a new type like MyObs(PyGymEnvObs)", + "The item for deriving Obs must be a new type like MyObs(BorderAtariObs)", ); // let output = if field_type_str == "PyGymEnvObs" { diff --git a/border-tch-agent/src/dqn/base.rs b/border-tch-agent/src/dqn/base.rs index b43663ca..3efa4e06 100644 --- a/border-tch-agent/src/dqn/base.rs +++ b/border-tch-agent/src/dqn/base.rs @@ -341,6 +341,10 @@ where record } + /// Save model parameters in the given directory. + /// + /// The parameters of the model are saved as `qnet.pt`. + /// The parameters of the target model are saved as `qnet_tgt.pt`. fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; diff --git a/border-tch-agent/src/dqn/model/base.rs b/border-tch-agent/src/dqn/model/base.rs index 4d67fa52..56180183 100644 --- a/border-tch-agent/src/dqn/model/base.rs +++ b/border-tch-agent/src/dqn/model/base.rs @@ -11,8 +11,12 @@ use serde::{de::DeserializeOwned, Serialize}; use std::{marker::PhantomData, path::Path}; use tch::{nn, Device, Tensor}; -#[allow(clippy::upper_case_acronyms)] -/// Represents value functions for DQN agents. +/// Action value function model for DQN. +/// +/// The architecture of the model is defined by the type parameter `Q`, +/// which should implement [`SubModel`]. +/// This takes [`SubModel::Input`] as input and outputs a tensor. +/// The output tensor should have the same dimension as the number of actions. pub struct DqnModel where Q: SubModel, @@ -75,7 +79,7 @@ where } } - /// Outputs the action-value given an observation. + /// Outputs the action-value given observation(s). pub fn forward(&self, x: &Q::Input) -> Tensor { let a = self.q.forward(&x); debug_assert_eq!(a.size().as_slice()[1], self.out_dim); diff --git a/border-tch-agent/src/dqn/model/config.rs b/border-tch-agent/src/dqn/model/config.rs index 554bfa26..ce31802c 100644 --- a/border-tch-agent/src/dqn/model/config.rs +++ b/border-tch-agent/src/dqn/model/config.rs @@ -8,7 +8,7 @@ use std::{ }; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [DqnModel](super::DqnModel). +/// Configuration of [`DqnModel`](super::DqnModel). pub struct DqnModelConfig where // Q: SubModel, diff --git a/border-tch-agent/src/iqn/model/config.rs b/border-tch-agent/src/iqn/model/config.rs index 19d11879..858e1dcd 100644 --- a/border-tch-agent/src/iqn/model/config.rs +++ b/border-tch-agent/src/iqn/model/config.rs @@ -25,7 +25,7 @@ where } #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [IqnModel](super::IqnModel). +/// Configuration of [`IqnModel`](super::IqnModel). /// /// The type parameter `F` represents a configuration struct of a feature extractor. /// The type parameter `M` represents a configuration struct of a model for merging diff --git a/border-tch-agent/src/mlp/base.rs b/border-tch-agent/src/mlp/base.rs index 674f426b..40e513d4 100644 --- a/border-tch-agent/src/mlp/base.rs +++ b/border-tch-agent/src/mlp/base.rs @@ -2,7 +2,7 @@ use super::{mlp, MlpConfig}; use crate::model::{SubModel, SubModel2}; use tch::{nn, nn::Module, Device, Tensor}; -/// Multilayer perceptron. +/// Multilayer perceptron with ReLU activation function. pub struct Mlp { config: MlpConfig, device: Device, diff --git a/border-tch-agent/src/model/base.rs b/border-tch-agent/src/model/base.rs index 0b375fca..99753100 100644 --- a/border-tch-agent/src/model/base.rs +++ b/border-tch-agent/src/model/base.rs @@ -51,53 +51,61 @@ pub trait Model2: ModelBase { fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output; } -/// Neural network model that can be initialized with [VarStore] and configuration. +/// Neural network model that can be initialized with [`VarStore`] and configuration. /// /// The purpose of this trait is for modularity of neural network models. -/// Modules, which consists a neural network, should share [VarStore]. -/// To do this, structs implementing this trait can be initialized with a given [VarStore]. -/// This trait also provide the ability to clone with a given [VarStore]. +/// Modules, which consists a neural network, should share [`VarStore`]. +/// To do this, structs implementing this trait can be initialized with a given [`VarStore`]. +/// This trait also provide the ability to clone with a given [`VarStore`]. /// The ability is useful when creating a target network, used in recent deep learning algorithms in common. +/// +/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html pub trait SubModel { - /// Configuration from which [SubModel] is constructed. + /// Configuration from which [`SubModel`] is constructed. type Config; - /// Input of the [SubModel]. + /// Input of the [`SubModel`]. type Input; - /// Output of the [SubModel]. + /// Output of the [`SubModel`]. type Output; - /// Builds [SubModel] with [VarStore] and [SubModel::Config]. + /// Builds [`SubModel`] with [`VarStore`] and [`SubModel::Config`]. fn build(var_store: &VarStore, config: Self::Config) -> Self; - /// Clones [SubModel] with [VarStore]. + /// Clones [`SubModel`] with [`VarStore`]. + /// + /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html fn clone_with_var_store(&self, var_store: &VarStore) -> Self; /// A generalized forward function. fn forward(&self, input: &Self::Input) -> Self::Output; } -/// Neural network model that can be initialized with [VarStore] and configuration. +/// Neural network model that can be initialized with [`VarStore`] and configuration. /// -/// The difference from [SubModel] is that this trait takes two inputs. +/// The difference from [`SubModel`] is that this trait takes two inputs. +/// +/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html pub trait SubModel2 { - /// Configuration from which [SubModel2] is constructed. + /// Configuration from which [`SubModel2`] is constructed. type Config; - /// Input of the [SubModel2]. + /// Input of the [`SubModel2`]. type Input1; - /// Input of the [SubModel2]. + /// Input of the [`SubModel2`]. type Input2; - /// Output of the [SubModel2]. + /// Output of the [`SubModel2`]. type Output; - /// Builds [SubModel2] with [VarStore] and [SubModel2::Config]. + /// Builds [`SubModel2`] with [VarStore] and [SubModel2::Config]. fn build(var_store: &VarStore, config: Self::Config) -> Self; - /// Clones [SubModel2] with [VarStore]. + /// Clones [`SubModel2`] with [`VarStore`]. + /// + /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html fn clone_with_var_store(&self, var_store: &VarStore) -> Self; /// A generalized forward function. diff --git a/border-tch-agent/src/opt.rs b/border-tch-agent/src/opt.rs index 7b48b9c4..6a077c69 100644 --- a/border-tch-agent/src/opt.rs +++ b/border-tch-agent/src/opt.rs @@ -60,6 +60,8 @@ impl OptimizerConfig { /// Optimizers. /// /// This is a thin wrapper of [tch::nn::Optimizer]. +/// +/// [tch::nn::Optimizer]: https://docs.rs/tch/0.16.0/tch/nn/struct.Optimizer.html pub enum Optimizer { /// Adam optimizer. Adam(Optimizer_), diff --git a/border-tch-agent/src/sac.rs b/border-tch-agent/src/sac.rs index bd2b31ea..bf3e8215 100644 --- a/border-tch-agent/src/sac.rs +++ b/border-tch-agent/src/sac.rs @@ -1,10 +1,156 @@ //! SAC agent. //! -//! Here is an example in `border/examples/sac_pendulum.rs` +//! Here is an example of creating SAC agent: //! -//! ```rust,ignore +//! ```no_run +//! # use anyhow::Result; +//! use border_core::{ +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }, +//! # record::Record, +//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, +//! Configurable, +//! }; +//! use border_tch_agent::{ +//! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, +//! mlp::{Mlp, Mlp2, MlpConfig}, +//! opt::OptimizerConfig +//! }; +//! +//! # struct TestEnv(TestEnv_); +//! # #[derive(Clone, Debug)] +//! # struct TestObs(TestObs_); +//! # #[derive(Clone, Debug)] +//! # struct TestAct(TestAct_); +//! # struct TestObsBatch(TestObsBatch_); +//! # struct TestActBatch(TestActBatch_); +//! # +//! # impl Obs_ for TestObs { +//! # fn dummy(n: usize) -> Self { +//! # Self(TestObs_::dummy(n)) +//! # } +//! # +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl Into for TestObs { +//! # fn into(self) -> tch::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl BatchBase for TestObsBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestObsBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl BatchBase for TestActBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestActBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl Act_ for TestAct { +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl From for TestAct { +//! # fn from(t: tch::Tensor) -> Self { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Into for TestAct { +//! # fn into(self) -> tch::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Env_ for TestEnv { +//! # type Config = ::Config; +//! # type Obs = TestObs; +//! # type Act = TestAct; +//! # type Info = ::Info; +//! # +//! # fn build(config: &Self::Config, seed: i64) -> Result { +//! # Ok(Self(TestEnv_::build(&config, seed).unwrap())) +//! # } +//! # +//! # fn step(&mut self, act: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step(&act.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset(&mut self, is_done: Option<&Vec>) -> Result { +//! # Ok(TestObs(self.0.reset(is_done).unwrap())) +//! # } +//! # +//! # fn step_with_reset(&mut self, a: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step_with_reset(&a.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset_with_index(&mut self, ix: usize) -> Result { +//! # Ok(TestObs(self.0.reset_with_index(ix).unwrap())) +//! # } +//! # } +//! # +//! # type Env = TestEnv; +//! # type ObsBatch = TestObsBatch; +//! # type ActBatch = TestActBatch; +//! # type ReplayBuffer = SimpleReplayBuffer; +//! # +//! const DIM_OBS: i64 = 3; +//! const DIM_ACT: i64 = 1; +//! const LR_ACTOR: f64 = 1e-3; +//! const LR_CRITIC: f64 = 1e-3; +//! const BATCH_SIZE: usize = 256; +//! //! fn create_agent(in_dim: i64, out_dim: i64) -> Sac { //! let device = tch::Device::cuda_if_available(); +//! //! let actor_config = ActorConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) //! .out_dim(out_dim) @@ -12,25 +158,13 @@ //! let critic_config = CriticConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) //! .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); -//! let sac_config = SacConfig::default() +//! let sac_config = SacConfig::::default() //! .batch_size(BATCH_SIZE) -//! .min_transitions_warmup(N_TRANSITIONS_WARMUP) //! .actor_config(actor_config) //! .critic_config(critic_config) //! .device(device); //! Sac::build(sac_config) //! } -//! -//! fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { -//! let trainer = //... -//! let mut agent = create_agent(DIM_OBS, DIM_ACT); -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; -//! -//! trainer.train(&mut agent, &mut recorder, &mut evaluator)?; -//! -//! Ok(()) -//! } //! ``` mod actor; mod base; diff --git a/border-tch-agent/src/sac/actor/base.rs b/border-tch-agent/src/sac/actor/base.rs index 69f5c610..756bbb77 100644 --- a/border-tch-agent/src/sac/actor/base.rs +++ b/border-tch-agent/src/sac/actor/base.rs @@ -10,8 +10,7 @@ use serde::{de::DeserializeOwned, Serialize}; use std::path::Path; use tch::{nn, Device, Tensor}; -#[allow(clippy::upper_case_acronyms)] -/// Represents a stochastic policy for SAC agents. +/// Stochastic policy for SAC agents. pub struct Actor

where P: SubModel, diff --git a/border-tch-agent/src/sac/actor/config.rs b/border-tch-agent/src/sac/actor/config.rs index 8d026a7f..05aec6e8 100644 --- a/border-tch-agent/src/sac/actor/config.rs +++ b/border-tch-agent/src/sac/actor/config.rs @@ -9,7 +9,7 @@ use std::{ #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [Actor](super::Actor). +/// Configuration of [`Actor`](super::Actor). pub struct ActorConfig { pub pi_config: Option

, pub opt_config: OptimizerConfig, diff --git a/border-tch-agent/src/sac/config.rs b/border-tch-agent/src/sac/config.rs index fb998577..4686ee32 100644 --- a/border-tch-agent/src/sac/config.rs +++ b/border-tch-agent/src/sac/config.rs @@ -18,8 +18,7 @@ use std::{ }; use tch::Tensor; -/// Constructs [Sac](super::Sac). -#[allow(clippy::upper_case_acronyms)] +/// Configuration of [`Sac`](super::Sac). #[derive(Debug, Deserialize, Serialize, PartialEq)] pub struct SacConfig where diff --git a/border-tch-agent/src/sac/critic/config.rs b/border-tch-agent/src/sac/critic/config.rs index 5d0b2d8c..20045aa4 100644 --- a/border-tch-agent/src/sac/critic/config.rs +++ b/border-tch-agent/src/sac/critic/config.rs @@ -9,7 +9,7 @@ use std::{ #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [Critic](super::Critic). +/// Configuration of [`Critic`](super::Critic). pub struct CriticConfig { pub q_config: Option, pub opt_config: OptimizerConfig, diff --git a/border-tch-agent/src/tensor_batch.rs b/border-tch-agent/src/tensor_batch.rs index ebb257ee..e1c3d38c 100644 --- a/border-tch-agent/src/tensor_batch.rs +++ b/border-tch-agent/src/tensor_batch.rs @@ -1,7 +1,9 @@ use border_core::generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; -/// Adds capability of constructing [Tensor] with a static method. +/// Adds capability of constructing [`Tensor`] with a static method. +/// +/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html pub trait ZeroTensor { /// Constructs zero tensor. fn zeros(shape: &[i64]) -> Tensor; @@ -37,6 +39,8 @@ impl ZeroTensor for i64 { /// where `shape` is obtained from the data pushed at the first time via /// [`TensorBatch::push`] method. `[1..]` means that the first axis of the /// given data is ignored as it might be batch size. +/// +/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html pub struct TensorBatch { buf: Option, capacity: i64, diff --git a/border-tch-agent/src/util.rs b/border-tch-agent/src/util.rs index 795df663..dcc3d3e1 100644 --- a/border-tch-agent/src/util.rs +++ b/border-tch-agent/src/util.rs @@ -21,9 +21,11 @@ pub enum CriticLoss { SmoothL1, } -/// Apply soft update on a model. +/// Apply soft update on variables. /// /// Variables are identified by their names. +/// +/// dest = tau * src + (1.0 - tau) * dest pub fn track(dest: &mut M, src: &mut M, tau: f64) { let src = &mut src.get_var_store().variables(); let dest = &mut dest.get_var_store().variables(); @@ -47,15 +49,16 @@ pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec { v } -/// Returns the dimension of output vectors, i.e., the number of discrete outputs. +/// Interface for handling output dimensions. pub trait OutDim { - /// Returns the dimension of output vectors, i.e., the number of discrete outputs. + /// Returns the output dimension. fn get_out_dim(&self) -> i64; /// Sets the output dimension. fn set_out_dim(&mut self, v: i64); } +/// Returns the mean and standard deviation of the parameters. pub fn param_stats(var_store: &VarStore) -> Record { let mut record = Record::empty(); From 83df74852e65b8bb77f6b078e3de7625e0613fcd Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 3 Aug 2024 12:35:20 +0900 Subject: [PATCH 12/21] Add test in ci.yml Add test in ci.yml Add test in ci.yml Tweaks to ci.yml Add test in ci.yml Add test in ci.yml --- .github/workflows/ci.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00fe9056..c1102637 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,3 +101,8 @@ jobs: cargo test --example sac_pendulum_tch --features=tch cargo test --example dqn_cartpole --features=candle-core cargo test --example sac_pendulum --features=candle-core + cd border-async-trainer; cargo test; cd .. + cd border-atari-env; cargo test; cd .. + cd border-candle-agent; cargo test; cd .. + cd border-tch-agent; cargo test; cd .. + cd border-policy-no-backend; cargo test; cd .. From bbc66f3b1bd9ca5893f1efc348efd6eca1ca3f15 Mon Sep 17 00:00:00 2001 From: Taku Yoshioka Date: Sat, 3 Aug 2024 11:19:20 +0900 Subject: [PATCH 13/21] Update docker scripts --- docker/amd64/Dockerfile | 10 +--- docker/amd64/build.sh | 1 + docker/amd64/remove.sh | 1 + docker/amd64/run.sh | 22 ++++---- docker/amd64_headless/Dockerfile | 86 +++++++++++++------------------- docker/amd64_headless/build.sh | 1 + docker/amd64_headless/doc.sh | 9 +++- 7 files changed, 58 insertions(+), 72 deletions(-) diff --git a/docker/amd64/Dockerfile b/docker/amd64/Dockerfile index 744586a7..fdbecf3b 100644 --- a/docker/amd64/Dockerfile +++ b/docker/amd64/Dockerfile @@ -73,7 +73,7 @@ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y RUN cd /root && python3 -m venv venv RUN source /root/venv/bin/activate && pip3 install --upgrade pip RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /root/venv/bin/activate && pip3 install torch==1.13.1 +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu --timeout 300 RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 @@ -95,12 +95,6 @@ RUN cd $HOME && mkdir -p .border/model # # PyBulletGym # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 -# RUN source /home/ubuntu/venv/bin/activate && \ -# cd $HOME && \ -# git clone https://github.com/bulletphysics/bullet3.git && \ -# cd bullet3 && \ -# git checkout -b tmp 2c204c49e56ed15ec5fcfa71d199ab6d6570b3f5 && \ -# ./build_cmake_pybullet_double.sh # RUN cd $HOME && \ # git clone https://github.com/benelot/pybullet-gym.git && \ # cd pybullet-gym && \ @@ -121,7 +115,7 @@ RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc -RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc +# RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh diff --git a/docker/amd64/build.sh b/docker/amd64/build.sh index 0eb76e0e..0936264c 100644 --- a/docker/amd64/build.sh +++ b/docker/amd64/build.sh @@ -1,2 +1,3 @@ #!/bin/bash docker build -t border . +#podman build -t border . diff --git a/docker/amd64/remove.sh b/docker/amd64/remove.sh index e7a325bc..3872196d 100644 --- a/docker/amd64/remove.sh +++ b/docker/amd64/remove.sh @@ -1 +1,2 @@ docker rm -f border +#podman rm -f border diff --git a/docker/amd64/run.sh b/docker/amd64/run.sh index c02762ca..be2f43f3 100644 --- a/docker/amd64/run.sh +++ b/docker/amd64/run.sh @@ -1,13 +1,13 @@ #!/bin/bash -# nvidia-docker run -it --rm \ -# --env="DISPLAY" \ -# --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ -# --volume="/home/taku-y:/home/taku-y" \ -# --name my_pybullet my_pybullet bash +nvidia-docker run -it --rm \ + --env="DISPLAY" \ + --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ + --volume="/home/taku-y:/home/taku-y" \ + --name my_pybullet my_pybullet bash -docker run -td \ - --name border \ - -p 6080:6080 \ - --shm-size=512m \ - --volume="$(pwd)/../..:/root/border" \ - border +# podman run -td \ +# --name border \ +# -p 6080:6080 \ +# --shm-size=512m \ +# --volume="$(pwd)/../..:/root/border" \ +# border diff --git a/docker/amd64_headless/Dockerfile b/docker/amd64_headless/Dockerfile index d460674d..c61e3ea0 100644 --- a/docker/amd64_headless/Dockerfile +++ b/docker/amd64_headless/Dockerfile @@ -1,13 +1,16 @@ -FROM ubuntu:focal-20221130 +FROM --platform=linux/amd64 ubuntu:22.04 ENV DEBIAN_FRONTEND noninteractive RUN echo "Set disable_coredump false" >> /etc/sudo.conf RUN apt-get update -q && \ apt-get upgrade -yq && \ - apt-get install -yq wget curl git build-essential vim sudo libssl-dev - -# lsb-release locales bash-completion tzdata gosu && \ -# RUN rm -rf /var/lib/apt/lists/* + apt-get install -yq wget +RUN apt-get install -yq curl +RUN apt-get install -yq git +RUN apt-get install -yq build-essential +RUN apt-get install -yq vim +# RUN apt-get install -yq sudo +RUN apt-get install -yq libssl-dev # clang RUN apt install -y -q libclang-dev @@ -18,7 +21,7 @@ RUN apt update -y && \ DEBIAN_FRONTEND=noninteractive && \ apt install -y -q --no-install-recommends \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ - libsdl-dev libsdl-image1.2-dev + libsdl-image1.2-dev libsdl1.2-dev # zip RUN apt install -y zip @@ -27,7 +30,7 @@ RUN apt install -y zip RUN apt install -y swig # python -RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip +RUN apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip # cmake RUN apt install -y cmake @@ -44,32 +47,25 @@ RUN apt install -y patchelf libglfw3 libglfw3-dev # Cleanup RUN rm -rf /var/lib/apt/lists/* -# COPY test_mujoco_py.py /test_mujoco_py.py -# RUN chmod 777 /test_mujoco_py.py - -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers - # Use bash RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh -# User settings -USER ubuntu - # rustup RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # python -RUN cd /home/ubuntu && python3 -m venv venv -RUN source /home/ubuntu/venv/bin/activate && pip3 install --upgrade pip -RUN source /home/ubuntu/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /home/ubuntu/venv/bin/activate && pip3 install torch==1.12.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install ipython jupyterlab -RUN source /home/ubuntu/venv/bin/activate && pip3 install numpy==1.21.3 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN cd /root && python3 -m venv venv +RUN source /root/venv/bin/activate && pip3 install --upgrade pip +RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu --timeout 300 +RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab +RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 +RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 +RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 @@ -79,20 +75,6 @@ RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1. # border RUN cd $HOME && mkdir -p .border/model -# Mujoco amd64 binary -RUN cd $HOME && \ - mkdir .mujoco && \ - cd .mujoco && \ - wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -RUN cd $HOME/.mujoco && \ - tar zxf mujoco-2.1.1-linux-x86_64.tar.gz && \ - mkdir -p mujoco210/bin && \ - ln -sf $PWD/mujoco-2.1.1/lib/libmujoco.so.2.1.1 $PWD/mujoco210/bin/libmujoco210.so && \ - ln -sf $PWD/mujoco-2.1.1/lib/libglewosmesa.so $PWD/mujoco210/bin/libglewosmesa.so && \ - ln -sf $PWD/mujoco-2.1.1/include/ $PWD/mujoco210/include && \ - ln -sf $PWD/mujoco-2.1.1/model/ $PWD/mujoco210/model -# RUN cp /*.py $HOME - # # PyBulletGym # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 # # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 @@ -106,22 +88,22 @@ RUN cd $HOME/.mujoco && \ # RUN sed -i 's/return state, sum(self.rewards), bool(done), {}/return state, sum(self.rewards), bool(done), bool(done), {}/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/roboschool/envs/locomotion/walker_base_env.py # RUN sed -i 's/id='\''AntPyBulletEnv-v0'\'',/id='\''AntPyBulletEnv-v0'\'', order_enforce=False,/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/__init__.py -# Env vars -# RUN echo 'export LIBTORCH=$HOME/.local/lib/python3.8/site-packages/torch' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc -# RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +# .bashrc +RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc +RUN echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc -ENV LIBTORCH_CXX11_ABI 0 -ENV LIBTORCH /home/ubuntu/venv/lib/python3.8/site-packages/torch -ENV LD_LIBRARY_PATH $LIBTORCH/lib -ENV PYTHONPATH /home/ubuntu/border/border-py-gym-env/examples:$PYTHONPATH +RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc +RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc +RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc -USER root RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh -USER ubuntu -WORKDIR /home/ubuntu/border +# USER root +# WORKDIR /home/ubuntu/border # ENV USER ubuntu # CMD ["/bin/bash", "-l", "-c"] diff --git a/docker/amd64_headless/build.sh b/docker/amd64_headless/build.sh index 860261b3..86de24a2 100644 --- a/docker/amd64_headless/build.sh +++ b/docker/amd64_headless/build.sh @@ -1,2 +1,3 @@ #!/bin/bash docker build -t border_headless . +#podman build -t border_headless . diff --git a/docker/amd64_headless/doc.sh b/docker/amd64_headless/doc.sh index ac4d6098..ec32f92f 100644 --- a/docker/amd64_headless/doc.sh +++ b/docker/amd64_headless/doc.sh @@ -3,4 +3,11 @@ docker run -it --rm \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ border_headless bash -l -c \ - "CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + +# podman run -it --rm \ +# --name border_headless \ +# --shm-size=512m \ +# --volume="$(pwd)/../..:/home/ubuntu/border" \ +# border_headless bash -l -c \ +# "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." From aeb056eceef9a6630b1b5ac920a7d8d619e6a0ba Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 4 Aug 2024 16:56:59 +0900 Subject: [PATCH 14/21] Rename crate: border-policy-no-backend --- Cargo.toml | 2 +- {border-edge-policy => border-policy-no-backend}/Cargo.toml | 2 +- {border-edge-policy => border-policy-no-backend}/src/lib.rs | 1 + {border-edge-policy => border-policy-no-backend}/src/mat.rs | 1 + {border-edge-policy => border-policy-no-backend}/src/mlp.rs | 0 {border-edge-policy => border-policy-no-backend}/tests/test.rs | 2 +- border/Cargo.toml | 2 +- border/examples/gym/convert_sac_policy_to_edge.rs | 2 +- border/examples/gym/pendulum_edge.rs | 2 +- 9 files changed, 8 insertions(+), 6 deletions(-) rename {border-edge-policy => border-policy-no-backend}/Cargo.toml (95%) rename {border-edge-policy => border-policy-no-backend}/src/lib.rs (66%) rename {border-edge-policy => border-policy-no-backend}/src/mat.rs (99%) rename {border-edge-policy => border-policy-no-backend}/src/mlp.rs (100%) rename {border-edge-policy => border-policy-no-backend}/tests/test.rs (93%) diff --git a/Cargo.toml b/Cargo.toml index 96ab8951..551830af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ "border-derive", "border-atari-env", "border-async-trainer", - "border-edge-policy", + "border-policy-no-backend", "border", ] exclude = ["docker/"] diff --git a/border-edge-policy/Cargo.toml b/border-policy-no-backend/Cargo.toml similarity index 95% rename from border-edge-policy/Cargo.toml rename to border-policy-no-backend/Cargo.toml index e5b58fb4..53bf671d 100644 --- a/border-edge-policy/Cargo.toml +++ b/border-policy-no-backend/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "border-edge-policy" +name = "border-policy-no-backend" version.workspace = true edition.workspace = true description.workspace = true diff --git a/border-edge-policy/src/lib.rs b/border-policy-no-backend/src/lib.rs similarity index 66% rename from border-edge-policy/src/lib.rs rename to border-policy-no-backend/src/lib.rs index 35a6a583..93053528 100644 --- a/border-edge-policy/src/lib.rs +++ b/border-policy-no-backend/src/lib.rs @@ -1,3 +1,4 @@ +//! Policy with no backend. mod mat; mod mlp; diff --git a/border-edge-policy/src/mat.rs b/border-policy-no-backend/src/mat.rs similarity index 99% rename from border-edge-policy/src/mat.rs rename to border-policy-no-backend/src/mat.rs index 05184476..5a429cd8 100644 --- a/border-edge-policy/src/mat.rs +++ b/border-policy-no-backend/src/mat.rs @@ -1,3 +1,4 @@ +//! A matrix object. use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] diff --git a/border-edge-policy/src/mlp.rs b/border-policy-no-backend/src/mlp.rs similarity index 100% rename from border-edge-policy/src/mlp.rs rename to border-policy-no-backend/src/mlp.rs diff --git a/border-edge-policy/tests/test.rs b/border-policy-no-backend/tests/test.rs similarity index 93% rename from border-edge-policy/tests/test.rs rename to border-policy-no-backend/tests/test.rs index 5e07eda2..f1b66b07 100644 --- a/border-edge-policy/tests/test.rs +++ b/border-policy-no-backend/tests/test.rs @@ -1,4 +1,4 @@ -use border_edge_policy::Mat; +use border_policy_no_backend::Mat; use tch::Tensor; #[test] diff --git a/border/Cargo.toml b/border/Cargo.toml index ca46e2b2..16124eba 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -144,7 +144,7 @@ border-derive = { version = "0.0.7", path = "../border-derive" } border-core = { version = "0.0.7", path = "../border-core" } border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } -border-edge-policy = { version = "0.0.7", path = "../border-edge-policy" } +border-policy-no-backend = { version = "0.0.7", path = "../border-policy-no-backend" } border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } border-atari-env = { version = "0.0.7", path = "../border-atari-env" } border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } diff --git a/border/examples/gym/convert_sac_policy_to_edge.rs b/border/examples/gym/convert_sac_policy_to_edge.rs index 90af3d72..8a7294ca 100644 --- a/border/examples/gym/convert_sac_policy_to_edge.rs +++ b/border/examples/gym/convert_sac_policy_to_edge.rs @@ -1,6 +1,6 @@ use anyhow::Result; use border_core::{Agent, Configurable}; -use border_edge_policy::Mlp; +use border_policy_no_backend::Mlp; use border_tch_agent::{ mlp, model::ModelBase, diff --git a/border/examples/gym/pendulum_edge.rs b/border/examples/gym/pendulum_edge.rs index 7ac44794..e81b4f97 100644 --- a/border/examples/gym/pendulum_edge.rs +++ b/border/examples/gym/pendulum_edge.rs @@ -1,6 +1,6 @@ use anyhow::Result; use border_core::{DefaultEvaluator, Evaluator as _}; -use border_edge_policy::{Mat, Mlp}; +use border_policy_no_backend::{Mat, Mlp}; use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; From 35de283e4fcd8473669b209ce466b20f0fa4f403 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sat, 3 Aug 2024 12:35:20 +0900 Subject: [PATCH 15/21] Add test in ci.yml Add test in ci.yml Add test in ci.yml Tweaks to ci.yml Add test in ci.yml Add test in ci.yml Modify ci.yml --- .github/workflows/ci.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 00fe9056..c899d81b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -101,3 +101,8 @@ jobs: cargo test --example sac_pendulum_tch --features=tch cargo test --example dqn_cartpole --features=candle-core cargo test --example sac_pendulum --features=candle-core + cd border-async-trainer; cargo test; cd .. + cd border-atari-env; cargo test; cd .. + cd border-candle-agent; cargo test; cd .. + cd border-tch-agent; cargo test; cd .. + cd border-policy-no-backend; cargo test --features=border-tch-agent; cd .. From 15df04f4946b5f9a3be52bae65c87b934e87a7f0 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 4 Aug 2024 16:56:59 +0900 Subject: [PATCH 16/21] Rename crate: border-policy-no-backend --- Cargo.toml | 2 +- {border-edge-policy => border-policy-no-backend}/Cargo.toml | 2 +- {border-edge-policy => border-policy-no-backend}/src/lib.rs | 1 + {border-edge-policy => border-policy-no-backend}/src/mat.rs | 1 + {border-edge-policy => border-policy-no-backend}/src/mlp.rs | 0 {border-edge-policy => border-policy-no-backend}/tests/test.rs | 2 +- border/Cargo.toml | 2 +- border/examples/gym/convert_sac_policy_to_edge.rs | 2 +- border/examples/gym/pendulum_edge.rs | 2 +- 9 files changed, 8 insertions(+), 6 deletions(-) rename {border-edge-policy => border-policy-no-backend}/Cargo.toml (95%) rename {border-edge-policy => border-policy-no-backend}/src/lib.rs (66%) rename {border-edge-policy => border-policy-no-backend}/src/mat.rs (99%) rename {border-edge-policy => border-policy-no-backend}/src/mlp.rs (100%) rename {border-edge-policy => border-policy-no-backend}/tests/test.rs (93%) diff --git a/Cargo.toml b/Cargo.toml index 96ab8951..551830af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ members = [ "border-derive", "border-atari-env", "border-async-trainer", - "border-edge-policy", + "border-policy-no-backend", "border", ] exclude = ["docker/"] diff --git a/border-edge-policy/Cargo.toml b/border-policy-no-backend/Cargo.toml similarity index 95% rename from border-edge-policy/Cargo.toml rename to border-policy-no-backend/Cargo.toml index e5b58fb4..53bf671d 100644 --- a/border-edge-policy/Cargo.toml +++ b/border-policy-no-backend/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "border-edge-policy" +name = "border-policy-no-backend" version.workspace = true edition.workspace = true description.workspace = true diff --git a/border-edge-policy/src/lib.rs b/border-policy-no-backend/src/lib.rs similarity index 66% rename from border-edge-policy/src/lib.rs rename to border-policy-no-backend/src/lib.rs index 35a6a583..93053528 100644 --- a/border-edge-policy/src/lib.rs +++ b/border-policy-no-backend/src/lib.rs @@ -1,3 +1,4 @@ +//! Policy with no backend. mod mat; mod mlp; diff --git a/border-edge-policy/src/mat.rs b/border-policy-no-backend/src/mat.rs similarity index 99% rename from border-edge-policy/src/mat.rs rename to border-policy-no-backend/src/mat.rs index 05184476..5a429cd8 100644 --- a/border-edge-policy/src/mat.rs +++ b/border-policy-no-backend/src/mat.rs @@ -1,3 +1,4 @@ +//! A matrix object. use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] diff --git a/border-edge-policy/src/mlp.rs b/border-policy-no-backend/src/mlp.rs similarity index 100% rename from border-edge-policy/src/mlp.rs rename to border-policy-no-backend/src/mlp.rs diff --git a/border-edge-policy/tests/test.rs b/border-policy-no-backend/tests/test.rs similarity index 93% rename from border-edge-policy/tests/test.rs rename to border-policy-no-backend/tests/test.rs index 5e07eda2..f1b66b07 100644 --- a/border-edge-policy/tests/test.rs +++ b/border-policy-no-backend/tests/test.rs @@ -1,4 +1,4 @@ -use border_edge_policy::Mat; +use border_policy_no_backend::Mat; use tch::Tensor; #[test] diff --git a/border/Cargo.toml b/border/Cargo.toml index ca46e2b2..16124eba 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -144,7 +144,7 @@ border-derive = { version = "0.0.7", path = "../border-derive" } border-core = { version = "0.0.7", path = "../border-core" } border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } -border-edge-policy = { version = "0.0.7", path = "../border-edge-policy" } +border-policy-no-backend = { version = "0.0.7", path = "../border-policy-no-backend" } border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } border-atari-env = { version = "0.0.7", path = "../border-atari-env" } border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } diff --git a/border/examples/gym/convert_sac_policy_to_edge.rs b/border/examples/gym/convert_sac_policy_to_edge.rs index 90af3d72..8a7294ca 100644 --- a/border/examples/gym/convert_sac_policy_to_edge.rs +++ b/border/examples/gym/convert_sac_policy_to_edge.rs @@ -1,6 +1,6 @@ use anyhow::Result; use border_core::{Agent, Configurable}; -use border_edge_policy::Mlp; +use border_policy_no_backend::Mlp; use border_tch_agent::{ mlp, model::ModelBase, diff --git a/border/examples/gym/pendulum_edge.rs b/border/examples/gym/pendulum_edge.rs index 7ac44794..e81b4f97 100644 --- a/border/examples/gym/pendulum_edge.rs +++ b/border/examples/gym/pendulum_edge.rs @@ -1,6 +1,6 @@ use anyhow::Result; use border_core::{DefaultEvaluator, Evaluator as _}; -use border_edge_policy::{Mat, Mlp}; +use border_policy_no_backend::{Mat, Mlp}; use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; From 6cc609159d952aa453f3044317a74ec5bcd56aeb Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 4 Aug 2024 20:44:05 +0900 Subject: [PATCH 17/21] Remove obsoleted structs in border-py-gym-env --- border-py-gym-env/src/act_c.rs | 17 -- border-py-gym-env/src/act_c/base.rs | 28 -- border-py-gym-env/src/act_d.rs | 3 - border-py-gym-env/src/act_d/base.rs | 21 -- border-py-gym-env/src/lib.rs | 6 - border-py-gym-env/src/obs.rs | 5 - border-py-gym-env/src/obs/base.rs | 119 --------- .../src/obs/frame_stack_filter.rs | 250 ------------------ 8 files changed, 449 deletions(-) delete mode 100644 border-py-gym-env/src/act_c.rs delete mode 100644 border-py-gym-env/src/act_c/base.rs delete mode 100644 border-py-gym-env/src/act_d.rs delete mode 100644 border-py-gym-env/src/act_d/base.rs delete mode 100644 border-py-gym-env/src/obs/base.rs delete mode 100644 border-py-gym-env/src/obs/frame_stack_filter.rs diff --git a/border-py-gym-env/src/act_c.rs b/border-py-gym-env/src/act_c.rs deleted file mode 100644 index 574f1c13..00000000 --- a/border-py-gym-env/src/act_c.rs +++ /dev/null @@ -1,17 +0,0 @@ -//! Continuous action for [`GymEnv`](crate::GymEnv). -mod base; -pub use base::GymContinuousAct; -use ndarray::ArrayD; -use numpy::PyArrayDyn; -use pyo3::{IntoPy, PyObject}; - -/// Convert [`ArrayD`] to [`PyObject`]. -/// -/// This function does not support batch action. -pub fn to_pyobj(act: ArrayD) -> PyObject { - // let act = act.remove_axis(ndarray::Axis(0)); - pyo3::Python::with_gil(|py| { - let act = PyArrayDyn::::from_array(py, &act); - act.into_py(py) - }) -} diff --git a/border-py-gym-env/src/act_c/base.rs b/border-py-gym-env/src/act_c/base.rs deleted file mode 100644 index 4919bd58..00000000 --- a/border-py-gym-env/src/act_c/base.rs +++ /dev/null @@ -1,28 +0,0 @@ -use border_core::Act; -use ndarray::ArrayD; -use std::fmt::Debug; - -/// Represents an action. -#[derive(Clone, Debug)] -pub struct GymContinuousAct { - /// Stores an action. - pub act: ArrayD, -} - -impl GymContinuousAct { - /// Constructs an action. - pub fn new(act: ArrayD) -> Self { - Self { act } - } -} - -impl Act for GymContinuousAct { - fn len(&self) -> usize { - let shape = self.act.shape(); - if shape.len() == 1 { - 1 - } else { - shape[0] - } - } -} diff --git a/border-py-gym-env/src/act_d.rs b/border-py-gym-env/src/act_d.rs deleted file mode 100644 index 7cac4219..00000000 --- a/border-py-gym-env/src/act_d.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! Discrete action for [`GymEnv`](crate::GymEnv). -mod base; -pub use base::GymDiscreteAct; diff --git a/border-py-gym-env/src/act_d/base.rs b/border-py-gym-env/src/act_d/base.rs deleted file mode 100644 index 5afb829b..00000000 --- a/border-py-gym-env/src/act_d/base.rs +++ /dev/null @@ -1,21 +0,0 @@ -use border_core::Act; -use std::fmt::Debug; - -/// Represents action. -#[derive(Clone, Debug)] -pub struct GymDiscreteAct { - pub act: Vec, -} - -impl GymDiscreteAct { - /// Constructs a discrete action. - pub fn new(act: Vec) -> Self { - Self { act } - } -} - -impl Act for GymDiscreteAct { - fn len(&self) -> usize { - self.act.len() - } -} diff --git a/border-py-gym-env/src/lib.rs b/border-py-gym-env/src/lib.rs index c84f0612..ccdca9b9 100644 --- a/border-py-gym-env/src/lib.rs +++ b/border-py-gym-env/src/lib.rs @@ -30,8 +30,6 @@ //! Examples with `border-tch-agents`, which are collections of RL agents implemented with tch-rs, //! are in [here](https://github.com/taku-y/border/blob/main/border/examples). mod act; -mod act_c; -mod act_d; mod atari; mod base; mod config; @@ -41,14 +39,10 @@ mod vec; pub use act::{ ContinuousActFilter, ContinuousActFilterConfig, DiscreteActFilter, DiscreteActFilterConfig, }; -pub use act_c::{to_pyobj, GymContinuousAct}; -pub use act_d::GymDiscreteAct; pub use atari::AtariWrapper; pub use base::{GymActFilter, GymEnv, GymInfo, GymObsFilter}; pub use config::GymEnvConfig; #[allow(deprecated)] pub use obs::{ ArrayDictObsFilter, ArrayDictObsFilterConfig, ArrayObsFilter, ArrayObsFilterConfig, - FrameStackFilter, FrameStackFilterConfig, GymObs, }; -// pub use vec::{PyVecGymEnv, PyVecGymEnvConfig}; diff --git a/border-py-gym-env/src/obs.rs b/border-py-gym-env/src/obs.rs index 7da2730e..f40a5d3e 100644 --- a/border-py-gym-env/src/obs.rs +++ b/border-py-gym-env/src/obs.rs @@ -1,10 +1,5 @@ //! Observation for [`GymEnv`](crate::GymEnv). mod array_dict_filter; mod array_filter; -mod base; -mod frame_stack_filter; pub use array_dict_filter::{ArrayDictObsFilter, ArrayDictObsFilterConfig}; pub use array_filter::{ArrayObsFilter, ArrayObsFilterConfig}; -#[allow(deprecated)] -pub use base::GymObs; -pub use frame_stack_filter::{FrameStackFilter, FrameStackFilterConfig}; diff --git a/border-py-gym-env/src/obs/base.rs b/border-py-gym-env/src/obs/base.rs deleted file mode 100644 index b72fb680..00000000 --- a/border-py-gym-env/src/obs/base.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::util::pyobj_to_arrayd; -use border_core::Obs; -use ndarray::{ArrayD, IxDyn}; -use num_traits::cast::AsPrimitive; -use numpy::Element; -use pyo3::PyObject; -use std::fmt::Debug; -use std::marker::PhantomData; -#[cfg(feature = "tch")] -use {std::convert::TryFrom, tch::Tensor}; - -/// Observation represented by an [ndarray::ArrayD]. -/// -/// `S` is the shape of an observation, except for batch and process dimensions. -/// `T` is the dtype of ndarray in the Python gym environment. -/// For some reason, the dtype of observations in Python gym environments seems to -/// vary, f32 or f64. To get observations in Rust side, the dtype is specified as a -/// type parameter, instead of checking the dtype of Python array at runtime. -#[deprecated] -#[derive(Clone, Debug)] -pub struct GymObs -where - T1: Element + Debug, - T2: 'static + Copy, -{ - pub obs: ArrayD, - pub(crate) phantom: PhantomData, -} - -#[allow(deprecated)] -impl From> for GymObs -where - T1: Element + Debug, - T2: 'static + Copy, -{ - fn from(obs: ArrayD) -> Self { - Self { - obs, - phantom: PhantomData, - } - } -} - -#[allow(deprecated)] -impl Obs for GymObs -where - T1: Debug + Element, - T2: 'static + Copy + Debug + num_traits::Zero, -{ - fn dummy(_n_procs: usize) -> Self { - // let shape = &mut S::shape().to_vec(); - // shape.insert(0, n_procs as _); - // trace!("Shape of TchPyGymEnvObs: {:?}", shape); - let shape = vec![0]; - Self { - obs: ArrayD::zeros(IxDyn(&shape[..])), - phantom: PhantomData, - } - } - - fn len(&self) -> usize { - self.obs.shape()[0] - } -} - -/// Convert numpy array of Python into [`GymObs`]. -#[allow(deprecated)] -impl From for GymObs -where - T1: Element + AsPrimitive + std::fmt::Debug, - T2: 'static + Copy, -{ - fn from(obs: PyObject) -> Self { - Self { - obs: pyobj_to_arrayd::(obs), - phantom: PhantomData, - } - } -} - -// #[cfg(feature = "tch")] -// impl From> for Tensor -// where -// S: Shape, -// T1: Element + Debug, -// T2: 'static + Copy, -// { -// fn from(obs: PyGymEnvObs) -> Tensor { -// let tmp = &obs.obs; -// Tensor::try_from(tmp).unwrap() -// // Tensor::try_from(&obs.obs).unwrap() -// } -// } - -#[allow(deprecated)] -#[cfg(feature = "tch")] -impl From> for Tensor -where - T1: Element + Debug, -{ - fn from(obs: GymObs) -> Tensor { - let tmp = &obs.obs; - Tensor::try_from(tmp).unwrap() - // Tensor::try_from(&obs.obs).unwrap() - } -} - -#[allow(deprecated)] -#[cfg(feature = "tch")] -impl From> for Tensor -where - T1: Element + Debug, -{ - fn from(obs: GymObs) -> Tensor { - let tmp = &obs.obs; - Tensor::try_from(tmp).unwrap() - // Tensor::try_from(&obs.obs).unwrap() - } -} diff --git a/border-py-gym-env/src/obs/frame_stack_filter.rs b/border-py-gym-env/src/obs/frame_stack_filter.rs deleted file mode 100644 index 4f42f0de..00000000 --- a/border-py-gym-env/src/obs/frame_stack_filter.rs +++ /dev/null @@ -1,250 +0,0 @@ -//! An observation filter with stacking observations (frames). -#[allow(deprecated)] -use super::GymObs; -use crate::GymObsFilter; -use border_core::{ - record::{Record, RecordValue}, - Obs, -}; -use ndarray::{ArrayD, Axis, SliceInfoElem}; //, SliceOrIndex}; - // use ndarray::{stack, ArrayD, Axis, IxDyn, SliceInfo, SliceInfoElem}; -use num_traits::cast::AsPrimitive; -use numpy::{Element, PyArrayDyn}; -use pyo3::{PyAny, PyObject}; -// use pyo3::{types::PyList, Py, PyAny, PyObject}; -use serde::{Deserialize, Serialize}; -use std::{fmt::Debug, marker::PhantomData}; -// use std::{convert::TryFrom, fmt::Debug, marker::PhantomData}; - -#[allow(deprecated)] -#[derive(Debug, Serialize, Deserialize)] -/// Configuration of [FrameStackFilter]. -#[derive(Clone)] -pub struct FrameStackFilterConfig { - n_procs: i64, - n_stack: i64, - vectorized: bool, -} - -impl Default for FrameStackFilterConfig { - fn default() -> Self { - Self { - n_procs: 1, - n_stack: 4, - vectorized: false, - } - } -} - -/// An observation filter with stacking sequence of original observations. -/// -/// The first element of the shape `S` denotes the number of stacks (`n_stack`) and the following elements -/// denote the shape of the partial observation, which is the observation of each environment -/// in the vectorized environment. -#[allow(deprecated)] -#[derive(Debug)] -pub struct FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero, - U: Obs + From>, -{ - // Each element in the vector corresponds to a process. - buffers: Vec>>, - - #[allow(dead_code)] - n_procs: i64, - - n_stack: i64, - - shape: Option>, - - // Verctorized environment is not supported - vectorized: bool, - - phantom: PhantomData<(T1, U)>, -} - -#[allow(deprecated)] -impl FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero, - U: Obs + From>, -{ - /// Returns the default configuration. - pub fn default_config() -> FrameStackFilterConfig { - FrameStackFilterConfig::default() - } - - /// Create slice for a dynamic array: equivalent to arr[j:(j+1), ::] in numpy. - /// - /// See - fn s(shape: &Option>, j: usize) -> Vec { - // The first index of self.shape corresponds to stacking dimension, - // specific index. - let mut slicer = vec![SliceInfoElem::Index(j as isize)]; - - // For remaining dimensions, all elements will be taken. - let n = shape.as_ref().unwrap().len() - 1; - let (start, end, step) = (0, None, 1); - - slicer.extend(vec![SliceInfoElem::Slice { start, end, step }; n]); - slicer - } - - /// Update the buffer of the stacked observations. - /// - /// * `i` - Index of process. - fn update_buffer(&mut self, i: i64, obs: &ArrayD) { - let arr = if let Some(arr) = &mut self.buffers[i as usize] { - arr - } else { - let mut shape = obs.shape().to_vec(); - self.shape = Some(shape.clone()); - shape.insert(0, self.n_stack as _); - self.buffers[i as usize] = Some(ArrayD::zeros(shape)); - self.buffers[i as usize].as_mut().unwrap() - }; - - // Shift stacks frame(j) <- frame(j - 1) for j=1,..,(n_stack - 1) - for j in (1..self.n_stack as usize).rev() { - let dst_slice = Self::s(&self.shape, j); - let src_slice = Self::s(&self.shape, j - 1); - let (mut dst, src) = arr.multi_slice_mut((dst_slice.as_slice(), src_slice.as_slice())); - dst.assign(&src); - } - arr.slice_mut(Self::s(&self.shape, 0).as_slice()) - .assign(obs) - } - - /// Fill the buffer, invoked when resetting - fn fill_buffer(&mut self, i: i64, obs: &ArrayD) { - if let Some(arr) = &mut self.buffers[i as usize] { - for j in (0..self.n_stack as usize).rev() { - let mut dst = arr.slice_mut(Self::s(&self.shape, j).as_slice()); - dst.assign(&obs); - } - } else { - unimplemented!("fill_buffer() was called before receiving the first sample."); - } - } - - /// Get ndarray from pyobj - fn get_ndarray(o: &PyAny) -> ArrayD { - debug_assert_eq!(o.get_type().name().unwrap(), "ndarray"); - let o: &PyArrayDyn = o.extract().unwrap(); - let o = o.to_owned_array(); - let o = o.mapv(|elem| elem.as_()); - o - } -} - -#[allow(deprecated)] -impl GymObsFilter for FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero + Into, - U: Obs + From>, -{ - type Config = FrameStackFilterConfig; - - fn build(config: &Self::Config) -> anyhow::Result - where - Self: Sized, - { - Ok(FrameStackFilter { - buffers: vec![None; config.n_procs as usize], - n_procs: config.n_procs, - n_stack: config.n_stack, - shape: None, - vectorized: config.vectorized, - phantom: PhantomData, - }) - } - - fn filt(&mut self, obs: PyObject) -> (U, Record) { - if self.vectorized { - unimplemented!(); - // // Processes the input observation to update `self.buffer` - // pyo3::Python::with_gil(|py| { - // debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "list"); - - // let obs: Py = obs.extract(py).unwrap(); - - // for (i, o) in (0..self.n_procs).zip(obs.as_ref(py).iter()) { - // let o = Self::get_ndarray(o); - // self.update_buffer(i, &o); - // } - // }); - - // // Returned values - // let array_views: Vec<_> = self.buffer.iter().map(|a| a.view()).collect(); - // let obs = PyGymEnvObs::from(stack(Axis(0), array_views.as_slice()).unwrap()); - // let obs = U::from(obs); - - // // TODO: add contents in the record - // let record = Record::empty(); - - // (obs, record) - } else { - // Update the buffer with obs - pyo3::Python::with_gil(|py| { - debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "ndarray"); - let o = Self::get_ndarray(obs.as_ref(py)); - self.update_buffer(0, &o); - }); - - // Returns stacked observation in the buffer - // img.shape() = [1, 4, 1, 84, 84] - // [batch_size, n_stack, color_ch, width, height] - let img = self.buffers[0].clone().unwrap().insert_axis(Axis(0)); - let data = img.iter().map(|&e| e.into()).collect::>(); - let shape = [img.shape()[3] * self.n_stack as usize, img.shape()[4]]; - - let obs = GymObs::from(img); - let obs = U::from(obs); - - // TODO: add contents in the record - let mut record = Record::empty(); - record.insert("frame_stack_filter_out", RecordValue::Array2(data, shape)); - - (obs, record) - } - } - - fn reset(&mut self, obs: PyObject) -> U { - if self.vectorized { - unimplemented!(); - // pyo3::Python::with_gil(|py| { - // debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "list"); - - // let obs: Py = obs.extract(py).unwrap(); - - // for (i, o) in (0..self.n_procs).zip(obs.as_ref(py).iter()) { - // if o.get_type().name().unwrap() != "NoneType" { - // let o = Self::get_ndarray(o); - // self.fill_buffer(i, &o); - // } - // } - // }); - - // // Returned values - // let array_views: Vec<_> = self.buffer.iter().map(|a| a.view()).collect(); - // O::from(stack(Axis(0), array_views.as_slice()).unwrap()) - } else { - // Update the buffer if obs is not None, otherwise do nothing - pyo3::Python::with_gil(|py| { - if obs.as_ref(py).get_type().name().unwrap() != "NoneType" { - debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "ndarray"); - let o = Self::get_ndarray(obs.as_ref(py)); - self.fill_buffer(0, &o); - } - }); - - // Returns stacked observation in the buffer - let frames = self.buffers[0].clone().unwrap().insert_axis(Axis(0)); - U::from(GymObs::from(frames)) - } - } -} From 8af45012372ec843ae2932b51685d6e7f437befc Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 4 Aug 2024 23:48:23 +0900 Subject: [PATCH 18/21] Improve docstring in border-py-gym-env --- .../src/act/continuous_filter.rs | 2 +- border-py-gym-env/src/atari.rs | 1 + border-py-gym-env/src/base.rs | 6 +- border-py-gym-env/src/lib.rs | 63 +++++++++++-------- border-py-gym-env/src/obs/array_filter.rs | 2 + border-py-gym-env/src/util.rs | 1 + 6 files changed, 46 insertions(+), 29 deletions(-) diff --git a/border-py-gym-env/src/act/continuous_filter.rs b/border-py-gym-env/src/act/continuous_filter.rs index eb0350cf..68251c92 100644 --- a/border-py-gym-env/src/act/continuous_filter.rs +++ b/border-py-gym-env/src/act/continuous_filter.rs @@ -22,7 +22,7 @@ impl Default for ContinuousActFilterConfig { /// Raw filter for continuous actions. /// -/// Type `A` must implements `Into>` +/// Type `A` must implements `Into>`. #[derive(Clone, Debug)] pub struct ContinuousActFilter { // `true` indicates that this filter is used in a vectorized environment. diff --git a/border-py-gym-env/src/atari.rs b/border-py-gym-env/src/atari.rs index 6b64f1ad..cd3b130e 100644 --- a/border-py-gym-env/src/atari.rs +++ b/border-py-gym-env/src/atari.rs @@ -4,6 +4,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] /// Specifies training or evaluation mode. #[derive(Clone)] +// TODO: consider to remove this enum pub enum AtariWrapper { /// Training mode Train, diff --git a/border-py-gym-env/src/base.rs b/border-py-gym-env/src/base.rs index 11eb19f2..e0d09954 100644 --- a/border-py-gym-env/src/base.rs +++ b/border-py-gym-env/src/base.rs @@ -22,6 +22,8 @@ pub struct GymInfo {} impl Info for GymInfo {} /// Convert [`PyObject`] to [`GymEnv`]::Obs with a preprocessing. +/// +/// [`PyObject`]: https://docs.rs/pyo3/0.14.5/pyo3/type.PyObject.html pub trait GymObsFilter { /// Configuration. type Config: Clone + Default + Serialize + DeserializeOwned; @@ -50,7 +52,7 @@ pub trait GymObsFilter { /// Convert [`GymEnv`]::Act to [`PyObject`] with a preprocessing. /// -/// This trait should support vectorized environments. +/// [`PyObject`]: https://docs.rs/pyo3/0.14.5/pyo3/type.PyObject.html pub trait GymActFilter { /// Configuration. type Config: Clone + Default + Serialize + DeserializeOwned; @@ -79,7 +81,7 @@ pub trait GymActFilter { } } -/// An environment in [OpenAI gym](https://github.com/openai/gym). +/// An wrapper of [Gymnasium](https://gymnasium.farama.org). #[derive(Debug)] pub struct GymEnv where diff --git a/border-py-gym-env/src/lib.rs b/border-py-gym-env/src/lib.rs index ccdca9b9..80c4b101 100644 --- a/border-py-gym-env/src/lib.rs +++ b/border-py-gym-env/src/lib.rs @@ -4,31 +4,42 @@ //! It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and //! [Gymnasium-Robotics](https://robotics.farama.org) environments. //! -//! ```note -//! In a past, [`Atari`](https://gym.openai.com/envs/#atari), and -//! [`PyBullet`](https://github.com/benelot/pybullet-gym) environments were supported. -//! However, currently they are not tested. -//! ``` -//! -//! This wrapper accepts array-like observation and action -//! ([`Box`](https://github.com/openai/gym/blob/master/gym/spaces/box.py) spaces), and -//! discrete action. In order to interact with Python interpreter where gym is running, -//! [`GymObsFilter`] and [`GymActFilter`] provides interfaces for converting Python object -//! (numpy array) to/from ndarrays in Rust. [`GymObsFilter`], -//! [`ContinuousActFilter`] and [`DiscreteActFilter`] do the conversion for environments -//! where observation and action are arrays. In addition to the data conversion between Python and Rust, -//! we can implements arbitrary preprocessing in these filters. For example, [`FrameStackFilter`] keeps -//! four consevutive observation frames (images) and outputs a stack of these frames. -//! -//! For Atari environments, a tweaked version of -//! [`atari_wrapper.py`](https://github.com/taku-y/border/blob/main/examples/atari_wrappers.py) -//! is required to be in `PYTHONPATH`. The frame stacking preprocessing is implemented in -//! [`FrameStackFilter`] as an [`GymObsFilter`]. -//! -//! Examples with a random controller ([`Policy`](border_core::Policy)) are in -//! [`examples`](https://github.com/taku-y/border/blob/main/border-py-gym-env/examples) directory. -//! Examples with `border-tch-agents`, which are collections of RL agents implemented with tch-rs, -//! are in [here](https://github.com/taku-y/border/blob/main/border/examples). +//! In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. +//! +//! ## Observation +//! +//! Obsservation is created in Python and passed to Rust as a Python object. In order to convert +//! Python object to Rust object, this crate provides [`GymObsFilter`] trait. This trait has +//! [`GymObsFilter::filt`] method which converts Python object to Rust object. +//! The type of the Rust object after conversion corresponds to the type parameter `O` of the trait +//! and this is also the type of the observation in the environment, i.e., [`GymEnv`]`::Obs`. +//! +//! There are two built-in implementations of [`GymObsFilter`]: [`ArrayObsFilter`] and [`ArrayDictObsFilter`]. +//! [`ArrayObsFilter`] is for environments where observation is an array (e.g., CartPole). +//! Internally, the array is converted to [`ndarray::ArrayD`] from Python object. +//! Then, the array is converted to the type parameter `O` of the filter. +//! Since `O` must implement [`From`] by trait bound, the conversion is done +//! by calling `array.into()`. +//! +//! [`ArrayDictObsFilter`] is for environments where observation is a dictionary of arrays (e.g., FetchPickAndPlace). +//! Internally, the dictionary is converted to `Vec<(String, border_py_gym_env:util::Array)>` from Python object. +//! Then, `Vec<(String, border_py_gym_env:util::Array)>` is converted to `O` by calling `into()`. +//! +//! ## Action +//! +//! Action is created in [`Policy`] and passed to Python as a Python object. In order to convert +//! Rust object to Python object, this crate provides [`GymActFilter`] trait. This trait has +//! [`GymActFilter::filt`] method which converts Rust object of type `A`, which is the type parameter of +//! the trait, to Python object. +//! +//! There are two built-in implementations of [`GymActFilter`]: [`DiscreteActFilter`] and [`ContinuousActFilter`]. +//! [`DiscreteActFilter`] is for environments where action is discrete (e.g., CartPole). +//! This filter converts `A` to [`Vec`] and then to Python object. +//! [`ContinuousActFilter`] is for environments where action is continuous (e.g., Pendulum). +//! This filter converts `A` to [`ArrayD`] and then to Python object. +//! +//! [`Policy`]: border_core::Policy +//! [`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html mod act; mod atari; mod base; @@ -39,7 +50,7 @@ mod vec; pub use act::{ ContinuousActFilter, ContinuousActFilterConfig, DiscreteActFilter, DiscreteActFilterConfig, }; -pub use atari::AtariWrapper; +use atari::AtariWrapper; pub use base::{GymActFilter, GymEnv, GymInfo, GymObsFilter}; pub use config::GymEnvConfig; #[allow(deprecated)] diff --git a/border-py-gym-env/src/obs/array_filter.rs b/border-py-gym-env/src/obs/array_filter.rs index 16a5d899..0cb4023a 100644 --- a/border-py-gym-env/src/obs/array_filter.rs +++ b/border-py-gym-env/src/obs/array_filter.rs @@ -24,6 +24,8 @@ impl Default for ArrayObsFilterConfig { /// An observation filter that convertes PyObject of an numpy array. /// /// Type parameter `O` must implements [`From`]`` and [`border_core::Obs`]. +/// +/// [`border_core::Obs`]: border_core::Obs pub struct ArrayObsFilter { /// Marker. pub phantom: PhantomData<(T1, T2, O)>, diff --git a/border-py-gym-env/src/util.rs b/border-py-gym-env/src/util.rs index 15740b8a..117eed69 100644 --- a/border-py-gym-env/src/util.rs +++ b/border-py-gym-env/src/util.rs @@ -1,3 +1,4 @@ +//! Utility functions mainly for data conversion between Python and Rust. use ndarray::{concatenate, ArrayD, Axis}; use num_traits::cast::AsPrimitive; use numpy::{Element, PyArrayDyn}; From de421ff1f1f163d1b4c0512dd0992db93af4ae44 Mon Sep 17 00:00:00 2001 From: taku-y Date: Sun, 4 Aug 2024 23:49:04 +0900 Subject: [PATCH 19/21] Add test in ci.yml --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c899d81b..fa62b3c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,3 +106,4 @@ jobs: cd border-candle-agent; cargo test; cd .. cd border-tch-agent; cargo test; cd .. cd border-policy-no-backend; cargo test --features=border-tch-agent; cd .. + cd border-py-gym-env; cargo test; cd .. From fa0d92e130c4f5bf62d5e0ae86073db4f0475b1d Mon Sep 17 00:00:00 2001 From: taku-y Date: Wed, 7 Aug 2024 00:16:18 +0900 Subject: [PATCH 20/21] Update CHANGELOG.md --- CHANGELOG.md | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abba777b..d912549b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,33 +4,35 @@ ### Added -* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2) -* Add candle agent (`border-candle-agent`) (https://github.com/taku-y/border/issues/1) -* Split policy trait into two traits, one for sampling (`Policy`) and the other for configuration (`Configurable`) (https://github.com/taku-y/border/issues/12) +* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2). +* Add candle agent (`border-candle-agent`) (https://github.com/taku-y/border/issues/1). +* Add `Trainer::train_offline()` method for offline training (`border-core`) (https://github.com/taku-y/border/issues/18). +* Add crate `border-policy-no-backend`. ### Changed -* Take `self` in the signature of `push()` method of replay buffer (`border-core`) -* Fix a bug in `MlpConfig` (`border-tch-agent`) -* Bump the version of tch to 0.16.0 (`border-tch-agent`) -* Change the name of trait `StepProcessorBase` to `StepProcessor` (`border-core`) -* Change the environment API to include terminate/truncate flags (`border-core`) (https://github.com/taku-y/border/issues/10) +* Take `self` in the signature of `push()` method of replay buffer (`border-core`). +* Fix a bug in `MlpConfig` (`border-tch-agent`). +* Bump the version of tch to 0.16.0 (`border-tch-agent`). +* Change the name of trait `StepProcessorBase` to `StepProcessor` (`border-core`). +* Change the environment API to include terminate/truncate flags (`border-core`) (https://github.com/taku-y/border/issues/10). +* Split policy trait into two traits, one for sampling (`Policy`) and the other for configuration (`Configurable`) (https://github.com/taku-y/border/issues/12). ## v0.0.6 (2023-09-19) ### Added * Docker files (`border`). -* Singularity files (`border`) -* Script for GPUSOROBAN (#67) +* Singularity files (`border`). +* Script for GPUSOROBAN (#67). * `Evaluator` trait in `border-core` (#70). It can be used to customize evaluation logic in `Trainer`. * Example of asynchronous trainer for native Atari environment and DQN (`border/examples`). -* Move tensorboard recorder into a separate crate (`border-tensorboard`) +* Move tensorboard recorder into a separate crate (`border-tensorboard`). ### Changed * Bump the version of tch-rs to 0.8.0 (`border-tch-agent`). * Rename agents as following the convention in Rust (`border-tch-agent`). -* Bump the version of gym to 0.26 (`border-py-gym-env`) -* Remove the type parameter for array shape of gym environments (`border-py-gym-env`) -* Interface of Python-Gym interface (`border-py-gym-env`) +* Bump the version of gym to 0.26 (`border-py-gym-env`). +* Remove the type parameter for array shape of gym environments (`border-py-gym-env`). +* Interface of Python-Gym interface (`border-py-gym-env`). From 1dc32a0ff5e3a5061e96e68fb7513964201e77f4 Mon Sep 17 00:00:00 2001 From: taku-y Date: Wed, 7 Aug 2024 00:24:10 +0900 Subject: [PATCH 21/21] Tweak --- border/Cargo.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/border/Cargo.toml b/border/Cargo.toml index 3a6a8d9b..c888053c 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -109,19 +109,17 @@ path = "examples/mujoco/sac_mujoco_tch.rs" required-features = ["tch"] test = false -<<<<<<< HEAD [[example]] name = "convert_sac_policy_to_edge" path = "examples/gym/convert_sac_policy_to_edge.rs" required-features = ["border-tch-agent", "tch"] test = false -======= + # [[example]] # name = "sac_ant_async" # path = "examples/mujoco/sac_ant_async.rs" # required-features = ["tch", "border-async-trainer"] # test = false ->>>>>>> origin/dev_0_0_7 [[example]] name = "pendulum_edge" @@ -152,10 +150,7 @@ border-derive = { version = "0.0.7", path = "../border-derive" } border-core = { version = "0.0.7", path = "../border-core" } border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } -<<<<<<< HEAD border-policy-no-backend = { version = "0.0.7", path = "../border-policy-no-backend" } -======= ->>>>>>> origin/dev_0_0_7 border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } border-atari-env = { version = "0.0.7", path = "../border-atari-env" } border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" }