diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e4564d04..fa62b3c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,8 +14,8 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - rust: [1.68.2] - python-version: [3.8] + rust: [1.76.0] + python-version: ["3.11"] steps: - uses: actions/checkout@v2 @@ -40,16 +40,22 @@ jobs: - if: matrix.os == 'ubuntu-latest' name: Install gym (Ubuntu) run: | + pip install --upgrade pip + pip install swig==4.2.1 pip install mujoco==2.3.7 - pip install gymnasium[box2d]==0.29.0 + pip install gymnasium==0.29.1 pip install gymnasium-robotics==1.2.2 pip install pybullet==3.2.5 + pip install torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu - if: matrix.os != 'ubuntu-latest' name: Install gym (Windows/Mac) run: | - pip install gymnasium[box2d]==0.29.0 + pip install --upgrade pip + pip install swig==4.2.1 + pip install gymnasium==0.29.1 pip install pybullet==3.2.5 + pip install torch==2.3.0 - if: matrix.os == 'ubuntu-latest' name: Install pybullet-gym @@ -76,10 +82,28 @@ jobs: run: cargo test -p border-py-gym-env - if: matrix.os == 'ubuntu-latest' - name: Test border + name: Check env vars + run: printenv + + - if: matrix.os == 'ubuntu-latest' + name: Test border examples + env: + LIBTORCH_USE_PYTORCH: 1 run: | + export LD_LIBRARY_PATH=`pip show torch | awk '/Location/ {print $2}'`/torch/lib:$LD_LIBRARY_PATH + printenv | grep LD_ sudo apt-get update sudo apt-get install -y --no-install-recommends --fix-missing \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ libsdl2-dev libsdl-image1.2-dev - cargo test -p border --features=tch + cargo test --example dqn_cartpole_tch --features=tch + cargo test --example iqn_cartpole_tch --features=tch + cargo test --example sac_pendulum_tch --features=tch + cargo test --example dqn_cartpole --features=candle-core + cargo test --example sac_pendulum --features=candle-core + cd border-async-trainer; cargo test; cd .. + cd border-atari-env; cargo test; cd .. + cd border-candle-agent; cargo test; cd .. + cd border-tch-agent; cargo test; cd .. + cd border-policy-no-backend; cargo test --features=border-tch-agent; cd .. + cd border-py-gym-env; cargo test; cd .. diff --git a/.gitignore b/.gitignore index 19098b1f..6c28e921 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,5 @@ __pycache__ .vscode/** doc/** +out/** +mlruns/** \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index b5dcd0a8..d912549b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,20 +1,38 @@ # Changelog -## v0.0.6 (20??-??-??) +## v0.0.7 (20??-??-??) + +### Added + +* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2). +* Add candle agent (`border-candle-agent`) (https://github.com/taku-y/border/issues/1). +* Add `Trainer::train_offline()` method for offline training (`border-core`) (https://github.com/taku-y/border/issues/18). +* Add crate `border-policy-no-backend`. + +### Changed + +* Take `self` in the signature of `push()` method of replay buffer (`border-core`). +* Fix a bug in `MlpConfig` (`border-tch-agent`). +* Bump the version of tch to 0.16.0 (`border-tch-agent`). +* Change the name of trait `StepProcessorBase` to `StepProcessor` (`border-core`). +* Change the environment API to include terminate/truncate flags (`border-core`) (https://github.com/taku-y/border/issues/10). +* Split policy trait into two traits, one for sampling (`Policy`) and the other for configuration (`Configurable`) (https://github.com/taku-y/border/issues/12). + +## v0.0.6 (2023-09-19) ### Added * Docker files (`border`). -* Singularity files (`border`) -* Script for GPUSOROBAN (#67) +* Singularity files (`border`). +* Script for GPUSOROBAN (#67). * `Evaluator` trait in `border-core` (#70). It can be used to customize evaluation logic in `Trainer`. * Example of asynchronous trainer for native Atari environment and DQN (`border/examples`). -* Move tensorboard recorder into a separate crate (`border-tensorboard`) +* Move tensorboard recorder into a separate crate (`border-tensorboard`). ### Changed * Bump the version of tch-rs to 0.8.0 (`border-tch-agent`). * Rename agents as following the convention in Rust (`border-tch-agent`). -* Bump the version of gym to 0.26 (`border-py-gym-env`) -* Remove the type parameter for array shape of gym environments (`border-py-gym-env`) -* Interface of Python-Gym interface (`border-py-gym-env`) +* Bump the version of gym to 0.26 (`border-py-gym-env`). +* Remove the type parameter for array shape of gym environments (`border-py-gym-env`). +* Interface of Python-Gym interface (`border-py-gym-env`). diff --git a/Cargo.toml b/Cargo.toml index 059177e3..5180ceb1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,20 +2,33 @@ members = [ "border-core", "border-tensorboard", + "border-mlflow-tracking", "border-py-gym-env", "border-tch-agent", + "border-candle-agent", "border-derive", "border-atari-env", "border-async-trainer", + "border-policy-no-backend", "border", ] exclude = ["docker/"] +[workspace.package] +version = "0.0.7" +edition = "2018" +rust-version = "1.76" +description = "Reinforcement learning library" +repository = "https://github.com/laboroai/border" +keywords = ["Reinforcement learning"] +categories = ["science"] +license = "MIT OR Apache-2.0" + [workspace.dependencies] -clap = "2.33.3" +clap = { version = "4.5.8", features = ["derive"] } csv = "1.1.5" fastrand = "1.4.0" -tch = "0.8.0" +tch = "0.16.0" anyhow = "1.0.38" crossbeam-channel = "0.5.1" serde_yaml = "0.8.7" @@ -23,14 +36,24 @@ aquamarine = "0.1" log = "0.4" dirs = "3.0.2" thiserror = "1.0" -serde = "=1.0.126" +serde = "1.0.194" +serde_json = "^1.0.114" numpy = "0.14.1" env_logger = "0.8.2" tempdir = "0.3.7" num-traits = "0.2.14" tensorboard-rs = "0.2.4" -pyo3 = { version = "=0.14.5", default-features=false } +pyo3 = { version = "=0.14.5", default-features = false } ndarray = "0.15.1" chrono = "0.4" segment-tree = "2.0.0" image = "0.23.14" +candle-core = { version = "=0.4.1", feature = ["cuda", "cudnn"] } +candle-nn = "0.4.1" +rand = { version = "0.8.5", features = ["small_rng"] } +itertools = "0.12.1" +ordered-float = "4.2.0" +reqwest = { version = "0.11.26", features = ["json", "blocking"] } +xxhash-rust = { version = "0.8.10", features = ["xxh3"] } +candle-optimisers = "0.4.0" +bincode = "1.3.3" diff --git a/README.md b/README.md index e15b084b..e37344ec 100644 --- a/README.md +++ b/README.md @@ -9,18 +9,19 @@ A reinforcement learning library in Rust. Border consists of the following crates: -* [border-core](https://crates.io/crates/border-core) provides basic traits and functions generic to environments and reinforcmenet learning (RL) agents. -* [border-tensorboard](https://crates.io/crates/border-tensorboard) has `TensorboardRecorder` struct to write records which can be shown in Tensorboard. It is based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). -* [border-py-gym-env](https://crates.io/crates/border-py-gym-env) is a wrapper of the [Gymnasium](https://gymnasium.farama.org) environments written in Python. -* [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). -* [border-tch-agent](https://crates.io/crates/border-tch-agent) is a collection of RL agents based on [tch](https://crates.io/crates/tch), including Deep Q network (DQN), implicit quantile network (IQN), and soft actor critic (SAC). -* [border-async-trainer](https://crates.io/crates/border-async-trainer) defines some traits and functions for asynchronous training of RL agents by multiple actors, which runs sampling processes in parallel. In each sampling process, an agent interacts with an environment to collect samples to be sent to a shared replay buffer. - -You can use a part of these crates for your purposes, though [border-core](https://crates.io/crates/border-core) is mandatory. [This crate](https://crates.io/crates/border) is just a collection of examples. See [Documentation](https://docs.rs/border) for more details. - -## News - -The owner of this repository will be changed from [taku-y](https://github.com/taku-y) to [laboroai](https://github.com/laboroai). +* Core and utility + * [border-core](https://crates.io/crates/border-core) provides basic traits and functions generic to environments and reinforcmenet learning (RL) agents. + * [border-tensorboard](https://crates.io/crates/border-tensorboard) has `TensorboardRecorder` struct to write records which can be shown in Tensorboard. It is based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). + * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) support MLflow tracking to log metrices during training via REST API. + * [border-async-trainer](https://crates.io/crates/border-async-trainer) defines some traits and functions for asynchronous training of RL agents by multiple actors, which runs sampling processes in parallel. In each sampling process, an agent interacts with an environment to collect samples to be sent to a shared replay buffer. + * [border](https://crates.io/crates/border) is just a collection of examples. +* Environment + * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) is a wrapper of the [Gymnasium](https://gymnasium.farama.org) environments written in Python. + * [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). +* Agent + * [border-tch-agent](https://crates.io/crates/border-tch-agent) includes RL agents based on [tch](https://crates.io/crates/tch), including Deep Q network (DQN), implicit quantile network (IQN), and soft actor critic (SAC). + * [border-candle-agent](https://crates.io/crates/border-candle-agent) includes RL agents based on [candle](https://crates.io/crates/candle-core) + * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) includes a policy that is independent of any deep learning backend, such as Torch. ## Status @@ -34,21 +35,17 @@ There are some example sctipts in `border/examples` directory. These are tested In `docker` directory, there are scripts for running a Docker container, in which you can try the examples described above. Currently, only `aarch64` is mainly used for the development. -## Tests - -The following command has been tested in the Docker container running on M2 Macbook air. - -```bash -cargo test --features=tch -``` - ## License -Crates | License -----------------------|------------------ -`border-core` | MIT OR Apache-2.0 -`border-py-gym-env` | MIT OR Apache-2.0 -`border-atari-env` | GPL-2.0-or-later -`border-tch-agent` | MIT OR Apache-2.0 -`border-async-trainer`| MIT OR Apache-2.0 -`border` | GPL-2.0-or-later +Crates | License +--------------------------|------------------ +`border-core` | MIT OR Apache-2.0 +`border-tensorboard` | MIT OR Apache-2.0 +`border-mlflow-tracking` | MIT OR Apache-2.0 +`border-async-trainer` | MIT OR Apache-2.0 +`border-py-gym-env` | MIT OR Apache-2.0 +`border-atari-env` | GPL-2.0-or-later +`border-tch-agent` | MIT OR Apache-2.0 +`border-candle-agent` | MIT OR Apache-2.0 +`border-policy-no-backend`| MIT OR Apache-2.0 +`border` | GPL-2.0-or-later diff --git a/border-async-trainer/Cargo.toml b/border-async-trainer/Cargo.toml index 51d5a128..be81813b 100644 --- a/border-async-trainer/Cargo.toml +++ b/border-async-trainer/Cargo.toml @@ -1,23 +1,19 @@ [package] name = "border-async-trainer" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Atari environment based on gym-rs" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] anyhow = { workspace = true } aquamarine = { workspace = true } -border-core = { version = "0.0.6", path = "../border-core" } -border-tensorboard = { version = "0.0.6", path = "../border-tensorboard" } +border-core = { version = "0.0.7", path = "../border-core" } +border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } serde = { workspace = true, features = ["derive"] } log = { workspace = true } tokio = { version = "1.14.0", features = ["full"] } diff --git a/border-async-trainer/src/actor.rs b/border-async-trainer/src/actor.rs index 9537a127..93cceb55 100644 --- a/border-async-trainer/src/actor.rs +++ b/border-async-trainer/src/actor.rs @@ -2,4 +2,4 @@ mod base; mod stat; pub use base::Actor; -pub use stat::{ActorStat, actor_stats_fmt}; +pub use stat::{actor_stats_fmt, ActorStat}; diff --git a/border-async-trainer/src/actor/base.rs b/border-async-trainer/src/actor/base.rs index daf9ee03..95d0fdb3 100644 --- a/border-async-trainer/src/actor/base.rs +++ b/border-async-trainer/src/actor/base.rs @@ -1,5 +1,7 @@ use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel}; -use border_core::{Agent, Env, ReplayBufferBase, StepProcessorBase, SyncSampler}; +use border_core::{ + Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, Sampler, StepProcessor, +}; use crossbeam_channel::Sender; use log::info; use std::{ @@ -8,7 +10,7 @@ use std::{ }; #[cfg_attr(doc, aquamarine::aquamarine)] -/// Runs interaction between an [`Agent`] and an [`Env`], then generates transitions. +/// Generate transitions by running [`Agent`] in [`Env`]. /// /// ```mermaid /// flowchart TB @@ -16,26 +18,29 @@ use std::{ /// subgraph D[Actor] /// A[Agent]-->|Env::Act|B[Env] /// B-->|Env::Obs|A -/// B-->|Step<E: Env>|C[StepProcessorBase] +/// B-->|Step<E: Env>|C[StepProcessor] /// end /// C-->|ReplayBufferBase::PushedItem|F[ReplayBufferProxy] /// ``` /// -/// This diagram shows interaction of [`Agent`], [`Env`] and [`StepProcessorBase`], -/// as shown in [`border_core::Trainer`]. However, this diagram also shows that +/// In [`Actor`], an [`Agent`] runs on an [`Env`] and generates [`Step`] objects. +/// These objects are processed with [`StepProcessor`] and sent to [`ReplayBufferProxy`]. /// The [`Agent`] in the [`Actor`] periodically synchronizes with the [`Agent`] in -/// [`AsyncTrainer`] via [`SyncModel::ModelInfo`], and the transitions generated by -/// [`StepProcessorBase`] are sent to the [`ReplayBufferProxy`]. +/// [`AsyncTrainer`] via [`SyncModel::ModelInfo`]. /// /// See also the diagram in [`AsyncTrainer`]. /// /// [`AsyncTrainer`]: crate::AsyncTrainer +/// [`Agent`]: border_core::Agent +/// [`Env`]: border_core::Env +/// [`StepProcessor`]: border_core::StepProcessor +/// [`Step`]: border_core::Step pub struct Actor where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, - P: StepProcessorBase, - R: ReplayBufferBase, + P: StepProcessor, + R: ExperienceBufferBase + ReplayBufferBase, { /// Stops sampling process if this field is set to `true`. id: usize, @@ -53,10 +58,10 @@ where impl Actor where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, - P: StepProcessorBase, - R: ReplayBufferBase, + P: StepProcessor, + R: ExperienceBufferBase + ReplayBufferBase, { pub fn build( id: usize, @@ -68,6 +73,7 @@ where env_seed: i64, stats: Arc>>, ) -> Self { + log::info!("Create actor {}", id); Self { id, stop, @@ -109,7 +115,7 @@ where /// When finishes, this method sets [ActorStat]. pub fn run( &mut self, - sender: Sender>, + sender: Sender>, model_info: Arc>, guard: Arc>, guard_init_model: Arc>, @@ -122,7 +128,7 @@ where let env = E::build(&self.env_config, self.env_seed).unwrap(); let step_proc = P::build(&self.step_proc_config); *tmp = true; - SyncSampler::new(env, step_proc) + Sampler::new(env, step_proc) }; info!("Starts actor {:?}", self.id); @@ -152,12 +158,10 @@ where // Stop sampling loop if *self.stop.lock().unwrap() { - *self.stats.lock().unwrap() = Some( - ActorStat { - env_steps, - duration: time.elapsed().unwrap(), - } - ); + *self.stats.lock().unwrap() = Some(ActorStat { + env_steps, + duration: time.elapsed().unwrap(), + }); break; } } diff --git a/border-async-trainer/src/actor/stat.rs b/border-async-trainer/src/actor/stat.rs index bc989ffc..3fb26199 100644 --- a/border-async-trainer/src/actor/stat.rs +++ b/border-async-trainer/src/actor/stat.rs @@ -1,12 +1,12 @@ use std::time::Duration; -/// Stats of sampling process in each [`Actor`](crate::Actor). +/// Stats of sampling process in an [`Actor`](crate::Actor). #[derive(Clone, Debug)] pub struct ActorStat { /// The number of steps for interaction between agent and env. pub env_steps: usize, - /// Duration of sampling loop in [`Actor`](crate::Actor). + /// Duration of sampling loop in the [`Actor`](crate::Actor). pub duration: Duration, } diff --git a/border-async-trainer/src/actor_manager.rs b/border-async-trainer/src/actor_manager.rs index 60304b11..546371ff 100644 --- a/border-async-trainer/src/actor_manager.rs +++ b/border-async-trainer/src/actor_manager.rs @@ -1,4 +1,4 @@ -//! A manager of [Actor]()s. +//! A manager of [`Actor`](crate::Actor)s. mod base; mod config; pub use base::ActorManager; diff --git a/border-async-trainer/src/actor_manager/base.rs b/border-async-trainer/src/actor_manager/base.rs index 409b80fe..2170921b 100644 --- a/border-async-trainer/src/actor_manager/base.rs +++ b/border-async-trainer/src/actor_manager/base.rs @@ -1,7 +1,9 @@ use crate::{ Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel, }; -use border_core::{Agent, Env, ReplayBufferBase, StepProcessorBase}; +use border_core::{ + Agent, Configurable, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor, +}; use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender}; use log::info; use std::{ @@ -13,20 +15,22 @@ use std::{ /// Manages [`Actor`]s. /// /// This struct handles the following requests: -/// * From the [LearnerManager]() for updating the latest model info, stored in this struct. +/// * From the [`AsyncTrainer`] for updating the latest model info, stored in this struct. /// * From the [`Actor`]s for getting the latest model info. /// * From the [`Actor`]s for pushing sample batch to the `LearnerManager`. +/// +/// [`AsyncTrainer`]: crate::AsyncTrainer pub struct ActorManager where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, - P: StepProcessorBase, - R: ReplayBufferBase, + P: StepProcessor, + R: ExperienceBufferBase + ReplayBufferBase, { - /// Configurations of [Agent]s. + /// Configurations of [`Agent`]s. agent_configs: Vec, - /// Configuration of [Env]. + /// Configuration of [`Env`]. env_config: E::Config, /// Configuration of a `StepProcessor`. @@ -44,10 +48,10 @@ where stop: Arc>, /// Receiver of [PushedItemMessage]s from [Actor]. - batch_message_receiver: Option>>, + batch_message_receiver: Option>>, /// Sender of [PushedItemMessage]s to [AsyncTrainer](crate::AsyncTrainer). - pushed_item_message_sender: Sender>, + pushed_item_message_sender: Sender>, /// Information of the model. /// @@ -65,23 +69,23 @@ where impl ActorManager where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, - P: StepProcessorBase, - R: ReplayBufferBase + Send + 'static, + P: StepProcessor, + R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, A::Config: Send + 'static, E::Config: Send + 'static, P::Config: Send + 'static, - R::PushedItem: Send + 'static, + R::Item: Send + 'static, A::ModelInfo: Send + 'static, { - /// Builds a [ActorManager]. + /// Builds a [`ActorManager`]. pub fn build( config: &ActorManagerConfig, agent_configs: &Vec, env_config: &E::Config, step_proc_config: &P::Config, - pushed_item_message_sender: Sender>, + pushed_item_message_sender: Sender>, model_info_receiver: Receiver<(usize, A::ModelInfo)>, stop: Arc>, ) -> Self { @@ -101,10 +105,10 @@ where } } - /// Runs threads for [Actor]s and a thread for sending samples into the replay buffer. + /// Runs threads for [`Actor`]s and a thread for sending samples into the replay buffer. /// - /// A thread will wait for the initial [SyncModel::ModelInfo] from [AsyncTrainer](crate::AsyncTrainer), - /// which blocks execution of [Actor] threads. + /// Each thread is blocked until receiving the initial [`SyncModel::ModelInfo`] + /// from [`AsyncTrainer`](crate::AsyncTrainer). pub fn run(&mut self, guard_init_env: Arc>) { // Guard for sync of the initial model let guard_init_model = Arc::new(Mutex::new(true)); @@ -207,9 +211,9 @@ where /// Loop waiting [PushedItemMessage] from [Actor]s. fn handle_message( - receiver: Receiver>, + receiver: Receiver>, stop: Arc>, - sender: Sender>, + sender: Sender>, ) { let mut _n_samples = 0; @@ -218,10 +222,11 @@ where // TODO: error handling, timeout // TODO: caching // TODO: stats - let msg = receiver.recv().unwrap(); - _n_samples += 1; - sender.try_send(msg).unwrap(); - // println!("{:?}", (_msg.id, n_samples)); + let msg = receiver.recv(); + if msg.is_ok() { + _n_samples += 1; + sender.try_send(msg.unwrap()).unwrap(); + } // Stop the loop if *stop.lock().unwrap() { diff --git a/border-async-trainer/src/actor_manager/config.rs b/border-async-trainer/src/actor_manager/config.rs index 2a5552fe..ec5ad90c 100644 --- a/border-async-trainer/src/actor_manager/config.rs +++ b/border-async-trainer/src/actor_manager/config.rs @@ -10,9 +10,7 @@ pub struct ActorManagerConfig { } impl Default for ActorManagerConfig { - fn default() -> Self { - Self { - n_buffer: 100, - } + fn default() -> Self { + Self { n_buffer: 100 } } } diff --git a/border-async-trainer/src/async_trainer/base.rs b/border-async-trainer/src/async_trainer/base.rs index a4f82c87..ffcc001d 100644 --- a/border-async-trainer/src/async_trainer/base.rs +++ b/border-async-trainer/src/async_trainer/base.rs @@ -1,14 +1,14 @@ use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel}; use border_core::{ - record::{Record, RecordValue::Scalar, Recorder}, - Agent, Env, Evaluator, ReplayBufferBase, + record::{AggregateRecorder, Record, RecordValue::Scalar}, + Agent, Configurable, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, }; use crossbeam_channel::{Receiver, Sender}; use log::info; use std::{ marker::PhantomData, sync::{Arc, Mutex}, - time::SystemTime, + time::{Duration, SystemTime}, }; #[cfg_attr(doc, aquamarine::aquamarine)] @@ -33,79 +33,92 @@ use std::{ /// end /// ``` /// -/// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type -/// [`ReplayBufferBase::PushedItem`], in parallel and push the transitions into -/// [`ReplayBufferProxy`]. It should be noted that [`ReplayBufferProxy`] has a -/// type parameter of [`ReplayBufferBase`] and the proxy accepts -/// [`ReplayBufferBase::PushedItem`]. -/// * The proxy sends the transitions into the replay buffer, implementing -/// [`ReplayBufferBase`], in the [`AsyncTrainer`]. -/// * The [`Agent`] in [`AsyncTrainer`] trains its model parameters by using batches +/// * The [`Agent`] in [`AsyncTrainer`] (left) is trained with batches /// of type [`ReplayBufferBase::Batch`], which are taken from the replay buffer. /// * The model parameters of the [`Agent`] in [`AsyncTrainer`] are wrapped in /// [`SyncModel::ModelInfo`] and periodically sent to the [`Agent`]s in [`Actor`]s. -/// [`Agent`] must implement [`SyncModel`] to synchronize its model. +/// [`Agent`] must implement [`SyncModel`] to synchronize the model parameters. +/// * In [`ActorManager`] (right), [`Actor`]s sample transitions, which have type +/// [`ReplayBufferBase::Item`], and push the transitions into +/// [`ReplayBufferProxy`]. +/// * [`ReplayBufferProxy`] has a type parameter of [`ReplayBufferBase`] and the proxy accepts +/// [`ReplayBufferBase::Item`]. +/// * The proxy sends the transitions into the replay buffer in the [`AsyncTrainer`]. /// /// [`ActorManager`]: crate::ActorManager /// [`Actor`]: crate::Actor -/// [`ReplayBufferBase::PushedItem`]: border_core::ReplayBufferBase::PushedItem +/// [`ReplayBufferBase::Item`]: border_core::ReplayBufferBase::PushedItem +/// [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::PushedBatch /// [`ReplayBufferProxy`]: crate::ReplayBufferProxy /// [`ReplayBufferBase`]: border_core::ReplayBufferBase /// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo +/// [`Agent`]: border_core::Agent pub struct AsyncTrainer where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, // R: ReplayBufferBase + Sync + Send + 'static, - R: ReplayBufferBase, - R::PushedItem: Send + 'static, + R: ExperienceBufferBase + ReplayBufferBase, + R::Item: Send + 'static, { + /// Configuration of [`Env`]. Note that it is used only for evaluation, not for training. + env_config: E::Config, + + /// Configuration of the replay buffer. + replay_buffer_config: R::Config, + /// Where to save the trained model. model_dir: Option, - /// Interval of recording in training steps. - record_interval: usize, + /// Interval of recording computational cost in optimization steps. + record_compute_cost_interval: usize, + + /// Interval of flushing records in optimization steps. + flush_records_interval: usize, /// Interval of evaluation in training steps. eval_interval: usize, - /// The maximal number of training steps. - max_train_steps: usize, - /// Interval of saving the model in optimization steps. save_interval: usize, + /// The maximal number of optimization steps. + max_opts: usize, + + /// Optimization steps for computing optimization steps per second. + opt_steps_for_ops: usize, + + /// Timer for computing for optimization steps per second. + timer_for_ops: Duration, + + /// Warmup period, for filling replay buffer, in environment steps + warmup_period: usize, + /// Interval of synchronizing model parameters in training steps. sync_interval: usize, /// Receiver of pushed items. - r_bulk_pushed_item: Receiver>, + r_bulk_pushed_item: Receiver>, /// If `false`, stops the actor threads. stop: Arc>, - /// Configuration of [Agent]. + /// Configuration of [`Agent`]. agent_config: A::Config, - /// Configuration of [Env]. Note that it is used only for evaluation, not for training. - env_config: E::Config, - /// Sender of model info. model_info_sender: Sender<(usize, A::ModelInfo)>, - /// Configuration of replay buffer. - replay_buffer_config: R::Config, - phantom: PhantomData<(A, E, R)>, } impl AsyncTrainer where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, // R: ReplayBufferBase + Sync + Send + 'static, - R: ReplayBufferBase, - R::PushedItem: Send + 'static, + R: ExperienceBufferBase + ReplayBufferBase, + R::Item: Send + 'static, { /// Creates [`AsyncTrainer`]. pub fn build( @@ -113,29 +126,33 @@ where agent_config: &A::Config, env_config: &E::Config, replay_buffer_config: &R::Config, - r_bulk_pushed_item: Receiver>, + r_bulk_pushed_item: Receiver>, model_info_sender: Sender<(usize, A::ModelInfo)>, stop: Arc>, ) -> Self { Self { model_dir: config.model_dir.clone(), - record_interval: config.record_interval, eval_interval: config.eval_interval, - max_train_steps: config.max_train_steps, + max_opts: config.max_opts, + record_compute_cost_interval: config.record_compute_cost_interval, + flush_records_interval: config.flush_record_interval, save_interval: config.save_interval, sync_interval: config.sync_interval, + warmup_period: config.warmup_period, agent_config: agent_config.clone(), env_config: env_config.clone(), replay_buffer_config: replay_buffer_config.clone(), r_bulk_pushed_item, model_info_sender, stop, + opt_steps_for_ops: 0, + timer_for_ops: Duration::new(0, 0), phantom: PhantomData, } } fn save_model(agent: &A, model_dir: String) { - match agent.save(&model_dir) { + match agent.save_params(&model_dir) { Ok(()) => info!("Saved the model in {:?}.", &model_dir), Err(_) => info!("Failed to save model in {:?}.", &model_dir), } @@ -146,48 +163,56 @@ where Self::save_model(agent, model_dir); } - /// Record. - #[inline] - fn record( - &mut self, - record: &mut Record, - opt_steps_: &mut usize, - samples: &mut usize, - time: &mut SystemTime, - samples_total: usize, - ) { - let duration = time.elapsed().unwrap().as_secs_f32(); - let ops = (*opt_steps_ as f32) / duration; - let sps = (*samples as f32) / duration; - let spo = (*samples as f32) / (*opt_steps_ as f32); - record.insert("samples_total", Scalar(samples_total as _)); - record.insert("opt_steps_per_sec", Scalar(ops)); - record.insert("samples_per_sec", Scalar(sps)); - record.insert("samples_per_opt_steps", Scalar(spo)); - // info!("Collected samples per optimization step = {}", spo); - - // Reset counter - *opt_steps_ = 0; - *samples = 0; - *time = SystemTime::now(); + /// Save model. + fn save_model_with_steps(agent: &A, model_dir: String, steps: usize) { + let model_dir = model_dir + format!("/{}", steps).as_str(); + Self::save_model(agent, model_dir); } - /// Flush record. - #[inline] - fn flush(&mut self, opt_steps: usize, mut record: Record, recorder: &mut impl Recorder) { - record.insert("opt_steps", Scalar(opt_steps as _)); - recorder.write(record); + /// Returns optimization steps per second, then reset the internal counter. + fn opt_steps_per_sec(&mut self) -> f32 { + let osps = 1000. * self.opt_steps_for_ops as f32 / (self.timer_for_ops.as_millis() as f32); + self.opt_steps_for_ops = 0; + self.timer_for_ops = Duration::new(0, 0); + osps } - /// Save model. + // /// Record. + // #[inline] + // fn record( + // &mut self, + // record: &mut Record, + // opt_steps_: &mut usize, + // samples: &mut usize, + // time: &mut SystemTime, + // samples_total: usize, + // ) { + // let duration = time.elapsed().unwrap().as_secs_f32(); + // let ops = (*opt_steps_ as f32) / duration; + // let sps = (*samples as f32) / duration; + // let spo = (*samples as f32) / (*opt_steps_ as f32); + // record.insert("samples_total", Scalar(samples_total as _)); + // record.insert("opt_steps_per_sec", Scalar(ops)); + // record.insert("samples_per_sec", Scalar(sps)); + // record.insert("samples_per_opt_steps", Scalar(spo)); + // // info!("Collected samples per optimization step = {}", spo); + + // // Reset counter + // *opt_steps_ = 0; + // *samples = 0; + // *time = SystemTime::now(); + // } + #[inline] - fn save(&mut self, opt_steps: usize, agent: &A) { - let model_dir = - self.model_dir.as_ref().unwrap().clone() + format!("/{}", opt_steps).as_str(); - Self::save_model(agent, model_dir); + fn train_step(&mut self, agent: &mut A, buffer: &mut R) -> Record { + let timer = SystemTime::now(); + let record = agent.opt_with_record(buffer); + self.opt_steps_for_ops += 1; + self.timer_for_ops += timer.elapsed().unwrap(); + record } - /// Sync model. + /// Synchronize model. #[inline] fn sync(&mut self, agent: &A) { let model_info = agent.model_info(); @@ -195,40 +220,39 @@ where self.model_info_sender.send(model_info).unwrap(); } - // /// Run a thread for replay buffer. - // fn run_replay_buffer_thread(&self, buffer: Arc>) { - // let r = self.r_bulk_pushed_item.clone(); - // let stop = self.stop.clone(); - - // std::thread::spawn(move || loop { - // let msg = r.recv().unwrap(); - // { - // let mut buffer = buffer.lock().unwrap(); - // buffer.push(msg.pushed_item); - // } - // if *stop.lock().unwrap() { - // break; - // } - // std::thread::sleep(std::time::Duration::from_millis(100)); - // }); - // } + #[inline] + fn update_replay_buffer( + &mut self, + buffer: &mut R, + samples: &mut usize, + samples_total: &mut usize, + ) { + let msgs: Vec<_> = self.r_bulk_pushed_item.try_iter().collect(); + msgs.into_iter().for_each(|msg| { + *samples += msg.pushed_items.len(); + *samples_total += msg.pushed_items.len(); + msg.pushed_items + .into_iter() + .for_each(|pushed_item| buffer.push(pushed_item).unwrap()) + }); + } /// Runs training loop. /// /// In the training loop, the following values will be pushed into the given recorder: /// /// * `samples_total` - Total number of samples pushed into the replay buffer. - /// Here, a "sample" is an item in [`ExperienceBufferBase::PushedItem`]. + /// Here, a "sample" is an item in [`ExperienceBufferBase::Item`]. /// * `opt_steps_per_sec` - The number of optimization steps per second. /// * `samples_per_sec` - The number of samples per second. /// * `samples_per_opt_steps` - The number of samples per optimization step. /// /// These values will typically be monitored with tensorboard. /// - /// [`ExperienceBufferBase::PushedItem`]: border_core::ExperienceBufferBase::PushedItem + /// [`ExperienceBufferBase::Item`]: border_core::ExperienceBufferBase::Item pub fn train( &mut self, - recorder: &mut impl Recorder, + recorder: &mut Box, evaluator: &mut D, guard_init_env: Arc>, ) -> AsyncTrainStat @@ -243,100 +267,79 @@ where }; let mut agent = A::build(self.agent_config.clone()); let mut buffer = R::build(&self.replay_buffer_config); - // let buffer = Arc::new(Mutex::new(R::build(&self.replay_buffer_config))); agent.train(); - // self.run_replay_buffer_thread(buffer.clone()); - let mut max_eval_reward = f32::MIN; let mut opt_steps = 0; - let mut opt_steps_ = 0; let mut samples = 0; let time_total = SystemTime::now(); let mut samples_total = 0; - let mut time = SystemTime::now(); info!("Send model info first in AsyncTrainer"); self.sync(&mut agent); info!("Starts training loop"); loop { - // Update replay buffer - let msgs: Vec<_> = self.r_bulk_pushed_item.try_iter().collect(); - msgs.into_iter().for_each(|msg| { - samples += msg.pushed_items.len(); - samples_total += msg.pushed_items.len(); - msg.pushed_items - .into_iter() - .for_each(|pushed_item| buffer.push(pushed_item).unwrap()) - }); - - let record = agent.opt(&mut buffer); - - if let Some(mut record) = record { - opt_steps += 1; - opt_steps_ += 1; - - let do_eval = opt_steps % self.eval_interval == 0; - let do_record = opt_steps % self.record_interval == 0; - let do_flush = do_eval || do_record; - let do_save = opt_steps % self.save_interval == 0; - let do_sync = opt_steps % self.sync_interval == 0; - - // Do evaluation - if do_eval { - info!("Starts evaluation of the trained model"); - agent.eval(); - let eval_reward = evaluator.evaluate(&mut agent).unwrap(); - agent.train(); - record.insert("eval_reward", Scalar(eval_reward)); - - // Save the best model up to the current iteration - if eval_reward > max_eval_reward { - max_eval_reward = eval_reward; - let model_dir = self.model_dir.as_ref().unwrap().clone(); - Self::save_best_model(&agent, model_dir) - } - } + self.update_replay_buffer(&mut buffer, &mut samples, &mut samples_total); - // Record - if do_record { - info!("Records training logs"); - self.record( - &mut record, - &mut opt_steps_, - &mut samples, - &mut time, - samples_total, - ); - } + if buffer.len() < self.warmup_period { + std::thread::sleep(Duration::from_millis(100)); + continue; + } - // Flush record to the recorder - if do_flush { - info!("Flushes records"); - self.flush(opt_steps, record, recorder); - } + let mut record = self.train_step(&mut agent, &mut buffer); + opt_steps += 1; - // Save the current model - if do_save { - info!("Saves the trained model"); - self.save(opt_steps, &mut agent); - } + // Add stats wrt computation cost + if opt_steps % self.record_compute_cost_interval == 0 { + record.insert("opt_steps_per_sec", Scalar(self.opt_steps_per_sec())); + } - // Finish the training loop - if opt_steps == self.max_train_steps { - // Flush channels - *self.stop.lock().unwrap() = true; - let _: Vec<_> = self.r_bulk_pushed_item.try_iter().collect(); - self.sync(&agent); - break; + // Evaluation + if opt_steps % self.eval_interval == 0 { + info!("Starts evaluation of the trained model"); + agent.eval(); + let eval_reward = evaluator.evaluate(&mut agent).unwrap(); + agent.train(); + record.insert("eval_reward", Scalar(eval_reward)); + + // Save the best model up to the current iteration + if eval_reward > max_eval_reward { + max_eval_reward = eval_reward; + let model_dir = self.model_dir.as_ref().unwrap().clone(); + Self::save_best_model(&agent, model_dir) } + } - // Sync the current model - if do_sync { - info!("Sends the trained model info to ActorManager"); - self.sync(&agent); - } + // Save the current model + if (self.save_interval > 0) && (opt_steps % self.save_interval == 0) { + let model_dir = self.model_dir.as_ref().unwrap().clone(); + Self::save_model_with_steps(&agent, model_dir, opt_steps); + } + + // Finish the training loop + if opt_steps == self.max_opts { + // Flush channels + *self.stop.lock().unwrap() = true; + let _: Vec<_> = self.r_bulk_pushed_item.try_iter().collect(); + self.sync(&agent); + break; + } + + // Sync the current model + if opt_steps % self.sync_interval == 0 { + info!("Sends the trained model info to ActorManager"); + self.sync(&agent); + } + + // Store record to the recorder + if !record.is_empty() { + recorder.store(record); + } + + // Flush records + if opt_steps % self.flush_records_interval == 0 { + recorder.flush(opt_steps as _); } } info!("Stopped training loop"); @@ -344,7 +347,7 @@ where let duration = time_total.elapsed().unwrap(); let time_total = duration.as_secs_f32(); let samples_per_sec = samples_total as f32 / time_total; - let opt_per_sec = self.max_train_steps as f32 / time_total; + let opt_per_sec = self.max_opts as f32 / time_total; AsyncTrainStat { samples_per_sec, duration, diff --git a/border-async-trainer/src/async_trainer/config.rs b/border-async-trainer/src/async_trainer/config.rs index e192b323..2f01ac21 100644 --- a/border-async-trainer/src/async_trainer/config.rs +++ b/border-async-trainer/src/async_trainer/config.rs @@ -1,34 +1,91 @@ -use serde::{Deserialize, Serialize}; use anyhow::Result; +use serde::{Deserialize, Serialize}; use std::{ fs::File, io::{BufReader, Write}, path::Path, }; -/// Configuration of [AsyncTrainer](crate::AsyncTrainer) +/// Configuration of [`AsyncTrainer`](crate::AsyncTrainer). #[derive(Clone, Debug, Deserialize, Serialize)] pub struct AsyncTrainerConfig { + /// The maximum number of optimization steps. + pub max_opts: usize, + /// Where to save the trained model. pub model_dir: Option, - /// Interval of recording in training steps. - pub record_interval: usize, - /// Interval of evaluation in training steps. pub eval_interval: usize, - /// The maximal number of training steps. - pub max_train_steps: usize, + /// Interval of flushing records in optimization steps. + pub flush_record_interval: usize, + + /// Interval of recording agent information in optimization steps. + pub record_compute_cost_interval: usize, /// Interval of saving the model in optimization steps. pub save_interval: usize, /// Interval of synchronizing model parameters in training steps. pub sync_interval: usize, + + /// Warmup period, for filling replay buffer, in environment steps + pub warmup_period: usize, } impl AsyncTrainerConfig { + /// Sets the number of optimization steps. + pub fn max_opts(mut self, v: usize) -> Result { + self.max_opts = v; + Ok(self) + } + + /// Sets the interval of evaluation in optimization steps. + pub fn eval_interval(mut self, v: usize) -> Result { + self.eval_interval = v; + Ok(self) + } + + /// Sets the directory the trained model being saved. + pub fn model_dir>(mut self, model_dir: T) -> Result { + self.model_dir = Some(model_dir.into()); + Ok(self) + } + + /// Sets the interval of computation cost in optimization steps. + pub fn record_compute_cost_interval( + mut self, + record_compute_cost_interval: usize, + ) -> Result { + self.record_compute_cost_interval = record_compute_cost_interval; + Ok(self) + } + + /// Sets the interval of flushing recordd in optimization steps. + pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Result { + self.flush_record_interval = flush_record_interval; + Ok(self) + } + + /// Sets warmup period in environment steps. + pub fn warmup_period(mut self, warmup_period: usize) -> Result { + self.warmup_period = warmup_period; + Ok(self) + } + + /// Sets the interval of saving in optimization steps. + pub fn save_interval(mut self, save_interval: usize) -> Result { + self.save_interval = save_interval; + Ok(self) + } + + /// Sets the interval of synchronizing model parameters in training steps. + pub fn sync_interval(mut self, sync_interval: usize) -> Result { + self.sync_interval = sync_interval; + Ok(self) + } + /// Constructs [AsyncTrainerConfig] from YAML file. pub fn load(path: impl AsRef) -> Result { let file = File::open(path)?; @@ -44,3 +101,19 @@ impl AsyncTrainerConfig { Ok(()) } } + +impl Default for AsyncTrainerConfig { + /// There is no special intention behind these initial values. + fn default() -> Self { + Self { + max_opts: 10, //000, + model_dir: None, + eval_interval: 5000, + flush_record_interval: 5000, + record_compute_cost_interval: 5000, + save_interval: 50000, + sync_interval: 100, + warmup_period: 10000, + } + } +} diff --git a/border-async-trainer/src/error.rs b/border-async-trainer/src/error.rs index dde38909..0a55237e 100644 --- a/border-async-trainer/src/error.rs +++ b/border-async-trainer/src/error.rs @@ -3,5 +3,5 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum BorderAsyncTrainerError { #[error("Error")] - SendMsgForPush + SendMsgForPush, } diff --git a/border-async-trainer/src/lib.rs b/border-async-trainer/src/lib.rs index d6747291..ce775418 100644 --- a/border-async-trainer/src/lib.rs +++ b/border-async-trainer/src/lib.rs @@ -2,61 +2,127 @@ //! //! The code might look like below. //! -//! ```ignore -//! fn train() { -//! let agent_configs: Vec<_> = vec![agent_config()]; -//! let env_config_train = env_config(name); -//! let env_config_eval = env_config(name).eval(); -//! let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?; -//! let step_proc_config = SimpleStepProcessorConfig::default(); -//! let actor_man_config = ActorManagerConfig::default(); -//! let async_trainer_config = load_async_trainer_config(model_dir.as_str())?; -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; -//! -//! // Shared flag to stop actor threads -//! let stop = Arc::new(Mutex::new(false)); +//! ``` +//! # use serde::{Deserialize, Serialize}; +//! # use border_core::test::{ +//! # TestAgent, TestAgentConfig, TestEnv, TestObs, TestObsBatch, +//! # TestAct, TestActBatch +//! # }; +//! # use border_async_trainer::{ +//! # //test::{TestAgent, TestAgentConfig, TestEnv}, +//! # ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, +//! # }; +//! # use border_core::{ +//! # generic_replay_buffer::{ +//! # SimpleReplayBuffer, SimpleReplayBufferConfig, +//! # SimpleStepProcessorConfig, SimpleStepProcessor +//! # }, +//! # record::{AggregateRecorder, NullRecorder}, DefaultEvaluator, +//! # }; +//! # +//! # fn agent_config() -> TestAgentConfig { +//! # TestAgentConfig +//! # } +//! # +//! # fn env_config() -> usize { +//! # 0 +//! # } +//! +//! type Env = TestEnv; +//! type ObsBatch = TestObsBatch; +//! type ActBatch = TestActBatch; +//! type ReplayBuffer = SimpleReplayBuffer; +//! type StepProcessor = SimpleStepProcessor; +//! +//! // Create a new agent by wrapping the existing agent in order to implement SyncModel. +//! struct TestAgent2(TestAgent); +//! +//! impl border_core::Configurable for TestAgent2 { +//! type Config = TestAgentConfig; +//! +//! fn build(config: Self::Config) -> Self { +//! Self(TestAgent::build(config)) +//! } +//! } +//! +//! impl border_core::Agent for TestAgent2 { +//! // Boilerplate code to delegate the method calls to the inner agent. +//! fn train(&mut self) { +//! self.0.train(); +//! } //! -//! // Creates channels -//! let (item_s, item_r) = unbounded(); // items pushed to replay buffer -//! let (model_s, model_r) = unbounded(); // model_info +//! // For other methods ... +//! # fn is_train(&self) -> bool { +//! # self.0.is_train() +//! # } +//! # +//! # fn eval(&mut self) { +//! # self.0.eval(); +//! # } +//! # +//! # fn opt_with_record(&mut self, buffer: &mut ReplayBuffer) -> border_core::record::Record { +//! # self.0.opt_with_record(buffer) +//! # } +//! # +//! # fn save_params>(&self, path: T) -> anyhow::Result<()> { +//! # self.0.save_params(path) +//! # } +//! # +//! # fn load_params>(&mut self, path: T) -> anyhow::Result<()> { +//! # self.0.load_params(path) +//! # } +//! # +//! # fn opt(&mut self, buffer: &mut ReplayBuffer) { +//! # self.0.opt_with_record(buffer); +//! # } +//! } //! -//! // guard for initialization of envs in multiple threads -//! let guard_init_env = Arc::new(Mutex::new(true)); +//! impl border_core::Policy for TestAgent2 { +//! // Boilerplate code to delegate the method calls to the inner agent. +//! // ... +//! # fn sample(&mut self, obs: &TestObs) -> TestAct { +//! # self.0.sample(obs) +//! # } +//! } +//! +//! impl border_async_trainer::SyncModel for TestAgent2{ +//! // Self::ModelInfo shold include the model parameters. +//! type ModelInfo = usize; +//! //! -//! // Actor manager and async trainer -//! let mut actors = ActorManager::build( -//! &actor_man_config, -//! &agent_configs, -//! &env_config_train, -//! &step_proc_config, -//! item_s, -//! model_r, -//! stop.clone(), -//! ); -//! let mut trainer = AsyncTrainer::build( -//! &async_trainer_config, -//! &agent_config, -//! &env_config_eval, -//! &replay_buffer_config, -//! item_r, -//! model_s, -//! stop.clone(), -//! ); +//! fn model_info(&self) -> (usize, Self::ModelInfo) { +//! // Extracts the model parameters and returns them as Self::ModelInfo. +//! // The first element of the tuple is the number of optimization steps. +//! (0, 0) +//! } //! -//! // Set the number of threads -//! tch::set_num_threads(1); +//! fn sync_model(&mut self, _model_info: &Self::ModelInfo) { +//! // implements synchronization of the model based on the _model_info +//! } +//! } //! -//! // Starts sampling and training -//! actors.run(guard_init_env.clone()); -//! let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env); -//! println!("Stats of async trainer"); -//! println!("{}", stats.fmt()); +//! let agent_configs: Vec<_> = vec![agent_config()]; +//! let env_config_train = env_config(); +//! let env_config_eval = env_config(); +//! let replay_buffer_config = SimpleReplayBufferConfig::default(); +//! let step_proc_config = SimpleStepProcessorConfig::default(); +//! let actor_man_config = ActorManagerConfig::default(); +//! let async_trainer_config = AsyncTrainerConfig::default(); +//! let mut recorder: Box = Box::new(NullRecorder {}); +//! let mut evaluator = DefaultEvaluator::::new(&env_config_eval, 0, 1).unwrap(); //! -//! let stats = actors.stop_and_join(); -//! println!("Stats of generated samples in actors"); -//! println!("{}", actor_stats_fmt(&stats)); -//! } +//! border_async_trainer::util::train_async::<_, _, _, StepProcessor>( +//! &agent_config(), +//! &agent_configs, +//! &env_config_train, +//! &env_config_eval, +//! &step_proc_config, +//! &replay_buffer_config, +//! &actor_man_config, +//! &async_trainer_config, +//! &mut recorder, +//! &mut evaluator, +//! ); //! ``` //! //! Training process consists of the following two components: @@ -89,6 +155,7 @@ mod messages; mod replay_buffer_proxy; mod sync_model; pub mod util; + pub use actor::{actor_stats_fmt, Actor, ActorStat}; pub use actor_manager::{ActorManager, ActorManagerConfig}; pub use async_trainer::{AsyncTrainStat, AsyncTrainer, AsyncTrainerConfig}; @@ -96,3 +163,226 @@ pub use error::BorderAsyncTrainerError; pub use messages::PushedItemMessage; pub use replay_buffer_proxy::{ReplayBufferProxy, ReplayBufferProxyConfig}; pub use sync_model::SyncModel; + +/// Agent and Env for testing. +#[cfg(test)] +pub mod test { + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl border_core::Obs for TestObs { + fn dummy(_n: usize) -> Self { + Self { obs: 0 } + } + + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl border_core::generic_replay_buffer::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl border_core::Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl border_core::generic_replay_buffer::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl border_core::Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl border_core::Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset( + &mut self, + a: &Self::Act, + ) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = border_core::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, border_core::record::Record::empty()); + } + + fn step(&mut self, a: &Self::Act) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = border_core::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, border_core::record::Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = + border_core::generic_replay_buffer::SimpleReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl border_core::Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> border_core::record::Record { + border_core::record::Record::empty() + } + + fn save_params>(&self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + + fn load_params>(&mut self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + } + + impl border_core::Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl border_core::Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } + + impl crate::SyncModel for TestAgent { + type ModelInfo = usize; + + fn model_info(&self) -> (usize, Self::ModelInfo) { + (0, 0) + } + + fn sync_model(&mut self, _model_info: &Self::ModelInfo) { + // nothing to do + } + } +} diff --git a/border-async-trainer/src/messages.rs b/border-async-trainer/src/messages.rs index ff65aeef..32070a91 100644 --- a/border-async-trainer/src/messages.rs +++ b/border-async-trainer/src/messages.rs @@ -1,4 +1,4 @@ -/// Message containing a [`ReplayBufferBase`](border_core::ReplayBufferBase)`::PushedItem`. +/// Message containing a [`ReplayBufferBase`](border_core::ReplayBufferBase)`::Item`. /// /// It will be sent from [`Actor`](crate::Actor) to [`ActorManager`](crate::ActorManager). pub struct PushedItemMessage { diff --git a/border-async-trainer/src/replay_buffer_proxy.rs b/border-async-trainer/src/replay_buffer_proxy.rs index dbe09e6e..263c5beb 100644 --- a/border-async-trainer/src/replay_buffer_proxy.rs +++ b/border-async-trainer/src/replay_buffer_proxy.rs @@ -9,31 +9,31 @@ use std::marker::PhantomData; pub struct ReplayBufferProxyConfig { /// Number of samples buffered until sent to the trainer. /// - /// Here, a sample corresponds to a `R::PushedItem` for [`ReplayBufferProxy`]``. + /// A sample is a `R::Item` for [`ReplayBufferProxy`]``. pub n_buffer: usize, } /// A wrapper of replay buffer for asynchronous trainer. -pub struct ReplayBufferProxy { +pub struct ReplayBufferProxy { id: usize, /// Sender of [PushedItemMessage]. - sender: Sender>, + sender: Sender>, /// Number of samples buffered until sent to the trainer. n_buffer: usize, - /// Buffer of `R::PushedItem`s. - buffer: Vec, + /// Buffer of `R::Item`s. + buffer: Vec, phantom: PhantomData, } -impl ReplayBufferProxy { +impl ReplayBufferProxy { pub fn build_with_sender( id: usize, config: &ReplayBufferProxyConfig, - sender: Sender>, + sender: Sender>, ) -> Self { let n_buffer = config.n_buffer; Self { @@ -46,10 +46,10 @@ impl ReplayBufferProxy { } } -impl ExperienceBufferBase for ReplayBufferProxy { - type PushedItem = R::PushedItem; +impl ExperienceBufferBase for ReplayBufferProxy { + type Item = R::Item; - fn push(&mut self, tr: Self::PushedItem) -> Result<()> { + fn push(&mut self, tr: Self::Item) -> Result<()> { self.buffer.push(tr); if self.buffer.len() == self.n_buffer { let mut buffer = Vec::with_capacity(self.n_buffer); @@ -76,7 +76,7 @@ impl ExperienceBufferBase for ReplayBufferProxy { } } -impl ReplayBufferBase for ReplayBufferProxy { +impl ReplayBufferBase for ReplayBufferProxy { type Config = ReplayBufferProxyConfig; type Batch = R::Batch; diff --git a/border-async-trainer/src/util.rs b/border-async-trainer/src/util.rs index d57b9708..3e97800d 100644 --- a/border-async-trainer/src/util.rs +++ b/border-async-trainer/src/util.rs @@ -3,16 +3,12 @@ use crate::{ actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel, }; use border_core::{ - Agent, DefaultEvaluator, Env, ReplayBufferBase, - StepProcessorBase, + record::AggregateRecorder, Agent, Configurable, Env, Evaluator, ExperienceBufferBase, + ReplayBufferBase, StepProcessor, }; -use border_tensorboard::TensorboardRecorder; use crossbeam_channel::unbounded; use log::info; -use std::{ - path::Path, - sync::{Arc, Mutex}, -}; +use std::sync::{Arc, Mutex}; /// Runs asynchronous training. /// @@ -32,8 +28,7 @@ use std::{ /// * `replay_buffer_config` - Configuration of the replay buffer. /// * `actor_man_config` - Configuration of [`ActorManager`]. /// * `async_trainer_config` - Configuration of [`AsyncTrainer`]. -pub fn train_async( - model_dir: &P, +pub fn train_async( agent_config: &A::Config, agent_configs: &Vec, env_config_train: &E::Config, @@ -42,21 +37,19 @@ pub fn train_async( replay_buffer_config: &R::Config, actor_man_config: &ActorManagerConfig, async_trainer_config: &AsyncTrainerConfig, + recorder: &mut Box, + evaluator: &mut impl Evaluator, ) where - A: Agent + SyncModel, + A: Agent + Configurable + SyncModel, E: Env, - R: ReplayBufferBase + Send + 'static, - S: StepProcessorBase, + R: ExperienceBufferBase + Send + 'static + ReplayBufferBase, + S: StepProcessor, A::Config: Send + 'static, E::Config: Send + 'static, S::Config: Send + 'static, - R::PushedItem: Send + 'static, + R::Item: Send + 'static, A::ModelInfo: Send + 'static, - P: AsRef, { - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = DefaultEvaluator::new(env_config_eval, 0, 1).unwrap(); - // Shared flag to stop actor threads let stop = Arc::new(Mutex::new(false)); @@ -89,7 +82,7 @@ pub fn train_async( // Starts sampling and training actors.run(guard_init_env.clone()); - let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env); + let stats = trainer.train(recorder, evaluator, guard_init_env); info!("Stats of async trainer"); info!("{}", stats.fmt()); diff --git a/border-atari-env/Cargo.toml b/border-atari-env/Cargo.toml index e7abc683..a29fbc6b 100644 --- a/border-atari-env/Cargo.toml +++ b/border-atari-env/Cargo.toml @@ -1,26 +1,23 @@ [package] name = "border-atari-env" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Atari environment based on gym-rs" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "GPL-2.0-or-later" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +package.license = "GPL-2.0-or-later" readme = "README.md" -autoexamples = false [dependencies] anyhow = { workspace = true } pixels = { version = "0.2.0", optional = true } winit = { version = "0.24.0", optional = true } dirs = { workspace = true } -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } image = { workspace = true } tch = { workspace = true, optional = true } +candle-core = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } itertools = "0.10.1" fastrand = { workspace = true } @@ -46,7 +43,7 @@ default = [ "winit_input_helper", "minifb", "pixels", - "tch", + # "tch", ] sdl = ["atari-env-sys/sdl"] diff --git a/border-atari-env/examples/random_pong.rs b/border-atari-env/examples/random_pong.rs index 5bd859e5..07fcb88d 100644 --- a/border-atari-env/examples/random_pong.rs +++ b/border-atari-env/examples/random_pong.rs @@ -3,7 +3,8 @@ use border_atari_env::{ BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, BorderAtariObsRawFilter, }; -use border_core::{DefaultEvaluator, Env as _, Evaluator, Policy}; +use border_core::{Configurable, DefaultEvaluator, Env as _, Evaluator, Policy}; +use serde::Deserialize; type Obs = BorderAtariObs; type Act = BorderAtariAct; @@ -12,7 +13,7 @@ type ActFilter = BorderAtariActRawFilter; type EnvConfig = BorderAtariEnvConfig; type Env = BorderAtariEnv; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig { pub n_acts: usize, } @@ -22,6 +23,12 @@ struct RandomPolicy { } impl Policy for RandomPolicy { + fn sample(&mut self, _: &Obs) -> Act { + fastrand::u8(..self.n_acts as u8).into() + } +} + +impl Configurable for RandomPolicy { type Config = RandomPolicyConfig; fn build(config: Self::Config) -> Self { @@ -29,10 +36,6 @@ impl Policy for RandomPolicy { n_acts: config.n_acts, } } - - fn sample(&mut self, _: &Obs) -> Act { - fastrand::u8(..self.n_acts as u8).into() - } } fn env_config(name: String) -> EnvConfig { diff --git a/border-atari-env/src/act.rs b/border-atari-env/src/act.rs index d21ff36f..6feae6bf 100644 --- a/border-atari-env/src/act.rs +++ b/border-atari-env/src/act.rs @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize}; use std::{default::Default, marker::PhantomData}; #[derive(Debug, Clone)] -/// Action for [BorderAtariEnv](crate::BorderAtariEnv) +/// Action for [`BorderAtariEnv`](crate::BorderAtariEnv). +/// +/// This action is a discrete action and denotes pushing a button. pub struct BorderAtariAct { pub act: u8, } @@ -28,7 +30,7 @@ impl From for BorderAtariAct { } } -/// Converts `A` to [BorderAtariAct]. +/// Converts action of type `A` to [`BorderAtariAct`]. pub trait BorderAtariActFilter { /// Configuration of the filter. type Config: Clone + Default; @@ -46,7 +48,7 @@ pub trait BorderAtariActFilter { } #[derive(Debug, Deserialize, Serialize)] -/// Configuration of [BorderAtariActRawFilter]. +/// Configuration of [`BorderAtariActRawFilter`]. #[derive(Clone)] pub struct BorderAtariActRawFilterConfig; @@ -56,7 +58,7 @@ impl Default for BorderAtariActRawFilterConfig { } } -/// A filter without any processing. +/// A filter that performs no processing. pub struct BorderAtariActRawFilter { phantom: PhantomData, } diff --git a/border-atari-env/src/atari_env.rs b/border-atari-env/src/atari_env.rs index ec4d0758..b2731cd2 100644 --- a/border-atari-env/src/atari_env.rs +++ b/border-atari-env/src/atari_env.rs @@ -1,3 +1,4 @@ +//! Atari environment for reinforcement learning. pub mod ale; use std::path::Path; diff --git a/border-atari-env/src/env.rs b/border-atari-env/src/env.rs index b16b8a06..5ed7ba20 100644 --- a/border-atari-env/src/env.rs +++ b/border-atari-env/src/env.rs @@ -1,21 +1,21 @@ mod config; mod window; use super::BorderAtariAct; -use anyhow::Result; +use super::{BorderAtariActFilter, BorderAtariObsFilter}; use crate::atari_env::{AtariAction, AtariEnv, EmulatorConfig}; -use border_core::{record::Record, Env, Info, Obs, Act, Step}; +use anyhow::Result; +use border_core::{record::Record, Act, Env, Info, Obs, Step}; pub use config::BorderAtariEnvConfig; use image::{ imageops::{/*grayscale,*/ resize, FilterType::Triangle}, ImageBuffer, /*Luma,*/ Rgb, }; -use std::{default::Default, marker::PhantomData}; +use itertools::izip; use std::ptr::copy; +use std::{default::Default, marker::PhantomData}; use window::AtariWindow; #[cfg(feature = "atari-env-sys")] use winit::{event_loop::ControlFlow, platform::run_return::EventLoopExtRunReturn}; -use super::{BorderAtariObsFilter, BorderAtariActFilter}; -use itertools::izip; /// Empty struct. pub struct NullInfo; @@ -74,7 +74,7 @@ where phantom: PhantomData<(O, A)>, } -impl BorderAtariEnv +impl BorderAtariEnv where O: Obs, A: Act, @@ -103,37 +103,40 @@ where let actions = self.env.minimal_actions(); let ix = a.act; let reward = self.env.step(actions[ix as usize]) as f32; - let mut done = self.env.is_game_over(); - self.was_real_done = done; + + let is_terminated = match self.env.is_game_over() { + true => 1, + false => 0, + }; + self.was_real_done = is_terminated == 1; let lives = self.env.lives(); - if self.train && lives < self.lives && lives > 0 { - done = true; - } + // if self.train && lives < self.lives && lives > 0 { + // done = true; + // } self.lives = lives; - let done = if done { 1 } else { 0 }; let (w, h) = (self.env.width(), self.env.height()); let mut obs = vec![0u8; w * h * 3]; self.env.render_rgb24(&mut obs); - (obs, reward, done) + (obs, reward, is_terminated) } fn skip_and_max(&mut self, a: &BorderAtariAct) -> (Vec, f32, Vec) { let mut total_reward = 0f32; - let mut done = 0; + let mut is_terminated = 0; for i in 0..4 { - let (obs, reward, done_) = self.episodic_life_env_step(a); + let (obs, reward, is_terminated_) = self.episodic_life_env_step(a); total_reward += reward; - done = done_; + is_terminated = is_terminated_; if i == 2 { self.obs_buffer[0] = obs; } else if i == 3 { self.obs_buffer[1] = obs; } - if done_ == 1 { + if is_terminated_ == 1 { break; } } @@ -145,7 +148,7 @@ where .map(|(&a, &b)| a.max(b)) .collect::>(); - (obs, total_reward, vec![done]) + (obs, total_reward, vec![is_terminated]) } fn clip_reward(&self, r: f32) -> Vec { @@ -169,9 +172,11 @@ where let i1 = buf.iter().step_by(3); let i2 = buf.iter().skip(1).step_by(3); let i3 = buf.iter().skip(2).step_by(3); - izip![i1, i2, i3].map(|(&b, &g, &r)| - ((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8 - ).collect::>() + izip![i1, i2, i3] + .map(|(&b, &g, &r)| { + ((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8 + }) + .collect::>() }; // let buf = { // let img: ImageBuffer, _> = grayscale(&img); @@ -307,14 +312,15 @@ where Self: Sized, { let (step, record) = self.step(a); - assert_eq!(step.is_done.len(), 1); - let step = if step.is_done[0] == 1 { + assert_eq!(step.is_terminated.len(), 1); + let step = if step.is_done() { let init_obs = self.reset(None).unwrap(); Step { act: step.act, obs: step.obs, reward: step.reward, - is_done: step.is_done, + is_terminated: step.is_terminated, + is_truncated: step.is_truncated, info: step.info, init_obs, } @@ -333,7 +339,8 @@ where { let act_org = act.clone(); let (act, _record) = self.act_filter.filt(act_org.clone()); - let (obs, reward, is_done) = self.skip_and_max(&act); + let (obs, reward, is_terminated) = self.skip_and_max(&act); + let is_truncated = vec![0]; // not compatible with the official implementation let (w, h) = (self.env.width() as u32, self.env.height() as u32); let obs = Self::warp_and_grayscale(w, h, obs); let reward = self.clip_reward(reward); // in training @@ -343,7 +350,8 @@ where obs, act_org, reward, - is_done, + is_terminated, + is_truncated, NullInfo, Self::Obs::dummy(1), ); diff --git a/border-atari-env/src/env/config.rs b/border-atari-env/src/env/config.rs index 338f0656..9abb742b 100644 --- a/border-atari-env/src/env/config.rs +++ b/border-atari-env/src/env/config.rs @@ -2,13 +2,13 @@ //! //! If environment variable `ATARI_ROM_DIR` exists, it is used as the directory //! from which ROM images of the Atari games is loaded. -use std::{env, default::Default}; -use border_core::{Obs, Act}; -use super::{BorderAtariObsFilter, BorderAtariActFilter}; +use super::{BorderAtariActFilter, BorderAtariObsFilter}; +use border_core::{Act, Obs}; use serde::{Deserialize, Serialize}; +use std::{default::Default, env}; #[derive(Serialize, Deserialize, Debug)] -/// Configurations of [`BorderAtariEnv`](super::BorderAtariEnv). +/// Configuration of [`BorderAtariEnv`](super::BorderAtariEnv). pub struct BorderAtariEnvConfig where O: Obs, @@ -16,12 +16,12 @@ where OF: BorderAtariObsFilter, AF: BorderAtariActFilter, { - pub(super) rom_dir: String, - pub(super) name: String, - pub(super) obs_filter_config: OF::Config, - pub(super) act_filter_config: AF::Config, - pub(super) train: bool, - pub(super) render: bool, + pub rom_dir: String, + pub name: String, + pub obs_filter_config: OF::Config, + pub act_filter_config: AF::Config, + pub train: bool, + pub render: bool, } impl Clone for BorderAtariEnvConfig @@ -43,7 +43,6 @@ where } } - impl Default for BorderAtariEnvConfig where O: Obs, diff --git a/border-atari-env/src/env/window.rs b/border-atari-env/src/env/window.rs index 9e140108..ce98765c 100644 --- a/border-atari-env/src/env/window.rs +++ b/border-atari-env/src/env/window.rs @@ -1,5 +1,5 @@ -use anyhow::Result; use crate::atari_env::AtariEnv; +use anyhow::Result; #[cfg(feature = "atari-env-sys")] use { pixels::{Pixels, SurfaceTexture}, @@ -7,7 +7,7 @@ use { event_loop::EventLoop, // platform::run_return::EventLoopExtRunReturn, window::{Window, WindowBuilder}, - } + }, }; pub(super) struct AtariWindow { @@ -29,12 +29,8 @@ impl AtariWindow { .with_inner_size(winit::dpi::LogicalSize::new(128.0, 128.0)) .build(&event_loop)?; let surface_texture = SurfaceTexture::new(128, 128, &window); - let pixels = Pixels::new( - env.width() as u32, - env.height() as u32, - surface_texture, - ) - .unwrap(); + let pixels = + Pixels::new(env.width() as u32, env.height() as u32, surface_texture).unwrap(); // event_loop.run_return(move |_event, _, _control_flow| {}); Ok(Self { diff --git a/border-atari-env/src/lib.rs b/border-atari-env/src/lib.rs index 2a8412e5..7d47560e 100644 --- a/border-atari-env/src/lib.rs +++ b/border-atari-env/src/lib.rs @@ -1,18 +1,18 @@ -//! A thin wrapper of [atari-env](https://crates.io/crates/atari-env) for [Border](https://crates.io/crates/border). +//! A thin wrapper of [`atari-env`](https://crates.io/crates/atari-env) for [`Border`](https://crates.io/crates/border). //! //! The code under [atari_env] is adapted from the -//! [atari-env](https://crates.io/crates/atari-env) crate +//! [`atari-env`](https://crates.io/crates/atari-env) crate //! (rev = `0ef0422f953d79e96b32ad14284c9600bd34f335`), //! because the crate registered in crates.io does not implement //! [`atari_env::AtariEnv::lives()`] method, which is required for episodic life environments. //! //! This environment applies some preprocessing to observation as in -//! [atari_wrapper.py](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). +//! [`atari_wrapper.py`](https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py). //! //! You need to place Atari Rom directories under the directory specified by environment variable //! `ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/) //! Python package. -//! +//! //! ```bash //! pip install autorom //! mkdir $HOME/atari_rom @@ -21,62 +21,57 @@ //! ``` //! //! Here is an example of running Pong environment with a random policy. -//! +//! //! ```no_run //! use anyhow::Result; //! use border_atari_env::{ //! BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, //! BorderAtariObs, BorderAtariObsRawFilter, //! }; -//! use border_core::{util, Env as _, Policy, DefaultEvaluator, Evaluator as _}; -//! -//! type Obs = BorderAtariObs; -//! type Act = BorderAtariAct; -//! type ObsFilter = BorderAtariObsRawFilter; -//! type ActFilter = BorderAtariActRawFilter; -//! type EnvConfig = BorderAtariEnvConfig; -//! type Env = BorderAtariEnv; -//! -//! #[derive(Clone)] -//! struct RandomPolicyConfig { -//! pub n_acts: usize, -//! } -//! -//! struct RandomPolicy { -//! n_acts: usize, -//! } -//! -//! impl Policy for RandomPolicy { -//! type Config = RandomPolicyConfig; -//! -//! fn build(config: Self::Config) -> Self { -//! Self { -//! n_acts: config.n_acts, -//! } -//! } -//! -//! fn sample(&mut self, _: &Obs) -> Act { -//! fastrand::u8(..self.n_acts as u8).into() -//! } -//! } -//! -//! fn env_config(name: String) -> EnvConfig { -//! EnvConfig::default().name(name) -//! } -//! +//! use border_core::{Env as _, Policy, DefaultEvaluator, Evaluator as _}; +//! +//! # type Obs = BorderAtariObs; +//! # type Act = BorderAtariAct; +//! # type ObsFilter = BorderAtariObsRawFilter; +//! # type ActFilter = BorderAtariActRawFilter; +//! # type EnvConfig = BorderAtariEnvConfig; +//! # type Env = BorderAtariEnv; +//! # +//! # #[derive(Clone)] +//! # struct RandomPolicyConfig { +//! # pub n_acts: usize, +//! # } +//! # +//! # struct RandomPolicy { +//! # n_acts: usize, +//! # } +//! # +//! # impl RandomPolicy { +//! # pub fn build(n_acts: usize) -> Self { +//! # Self { n_acts } +//! # } +//! # } +//! # +//! # impl Policy for RandomPolicy { +//! # fn sample(&mut self, _: &Obs) -> Act { +//! # fastrand::u8(..self.n_acts as u8).into() +//! # } +//! # } +//! # +//! # fn env_config(name: String) -> EnvConfig { +//! # EnvConfig::default().name(name) +//! # } +//! # //! fn main() -> Result<()> { -//! env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); -//! fastrand::seed(42); -//! +//! # env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); +//! # fastrand::seed(42); +//! # //! // Creates Pong environment //! let env_config = env_config("pong".to_string()); //! //! // Creates a random policy -//! let n_acts = 4; // number of actions; -//! let policy_config = RandomPolicyConfig { -//! n_acts: n_acts as _, -//! }; -//! let mut policy = RandomPolicy::build(policy_config); +//! let n_acts = 4; +//! let mut policy = RandomPolicy::build(n_acts); //! //! // Runs evaluation //! let env_config = env_config.render(true); @@ -87,10 +82,10 @@ //! ``` //! [`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives mod act; -mod obs; +pub mod atari_env; mod env; +mod obs; pub mod util; -pub mod atari_env; pub use act::{BorderAtariAct, BorderAtariActFilter, BorderAtariActRawFilter}; -pub use obs::{BorderAtariObs, BorderAtariObsFilter, BorderAtariObsRawFilter}; pub use env::{BorderAtariEnv, BorderAtariEnvConfig}; +pub use obs::{BorderAtariObs, BorderAtariObsFilter, BorderAtariObsRawFilter}; diff --git a/border-atari-env/src/obs.rs b/border-atari-env/src/obs.rs index 05ca77a3..2a5e4b59 100644 --- a/border-atari-env/src/obs.rs +++ b/border-atari-env/src/obs.rs @@ -14,6 +14,8 @@ //! Instead, the scaling is applied in CNN model. use anyhow::Result; use border_core::{record::Record, Obs}; +#[cfg(feature = "candle-core")] +use candle_core::{Device::Cpu, Tensor}; use serde::{Deserialize, Serialize}; use std::{default::Default, marker::PhantomData}; #[cfg(feature = "tch")] @@ -53,7 +55,19 @@ impl From for Tensor { } } -/// Converts [`BorderAtariObs`] to `O` with an arbitrary processing. +#[cfg(feature = "candle-core")] +impl From for Tensor { + fn from(obs: BorderAtariObs) -> Tensor { + let tmp = obs.frames; + // Assumes the batch size is 1, implying non-vectorized environment + Tensor::from_vec(tmp, &[1 * 4 * 1 * 84 * 84], &Cpu) + .unwrap() + .reshape(&[1, 4, 1, 84, 84]) + .unwrap() + } +} + +/// Converts [`BorderAtariObs`] to observation of type `O` with an arbitrary processing. pub trait BorderAtariObsFilter { /// Configuration of the filter. type Config: Clone + Default; @@ -84,7 +98,7 @@ impl Default for BorderAtariObsRawFilterConfig { } } -/// A filter without any processing. +/// A filter that performs no processing. pub struct BorderAtariObsRawFilter { phantom: PhantomData, } diff --git a/border-atari-env/src/util.rs b/border-atari-env/src/util.rs index 3f408e81..8788d89b 100644 --- a/border-atari-env/src/util.rs +++ b/border-atari-env/src/util.rs @@ -1 +1,2 @@ -pub mod test; \ No newline at end of file +//! Utility functions for testing. +pub mod test; diff --git a/border-atari-env/src/util/test.rs b/border-atari-env/src/util/test.rs index 3e481f27..5c7c8064 100644 --- a/border-atari-env/src/util/test.rs +++ b/border-atari-env/src/util/test.rs @@ -5,10 +5,11 @@ use crate::{ }; use anyhow::Result; use border_core::{ + generic_replay_buffer::{BatchBase, SimpleReplayBuffer}, record::Record, - replay_buffer::{SimpleReplayBuffer, SubBatch}, - Agent as Agent_, Policy, ReplayBufferBase, + Agent as Agent_, Configurable, Policy, ReplayBufferBase, }; +use serde::Deserialize; use std::ptr::copy; pub type Obs = BorderAtariObs; @@ -34,7 +35,7 @@ pub struct ObsBatch { pub buf: Vec, } -impl SubBatch for ObsBatch { +impl BatchBase for ObsBatch { fn new(capacity: usize) -> Self { let m = 4 * FRAME_IN_BYTES; Self { @@ -45,7 +46,7 @@ impl SubBatch for ObsBatch { } #[inline] - fn push(&mut self, i: usize, data: &Self) { + fn push(&mut self, i: usize, data: Self) { unsafe { let src: *const u8 = &data.buf[0]; let dst: *mut u8 = &mut self.buf[i * self.m]; @@ -89,7 +90,7 @@ pub struct ActBatch { pub buf: Vec, } -impl SubBatch for ActBatch { +impl BatchBase for ActBatch { fn new(capacity: usize) -> Self { let m = 1; Self { @@ -100,7 +101,7 @@ impl SubBatch for ActBatch { } #[inline] - fn push(&mut self, i: usize, data: &Self) { + fn push(&mut self, i: usize, data: Self) { unsafe { let src: *const u8 = &data.buf[0]; let dst: *mut u8 = &mut self.buf[i * self.m]; @@ -132,8 +133,8 @@ impl From for ActBatch { } } -#[derive(Clone)] -/// Configuration of [RandomAgent]. +#[derive(Clone, Deserialize)] +/// Configuration of [`RandomAgent``]. pub struct RandomAgentConfig { pub n_acts: usize, } @@ -146,6 +147,12 @@ pub struct RandomAgent { } impl Policy for RandomAgent { + fn sample(&mut self, _: &Obs) -> Act { + fastrand::u8(..self.n_acts as u8).into() + } +} + +impl Configurable for RandomAgent { type Config = RandomAgentConfig; fn build(config: Self::Config) -> Self { @@ -155,10 +162,6 @@ impl Policy for RandomAgent { train: true, } } - - fn sample(&mut self, _: &Obs) -> Act { - fastrand::u8(..self.n_acts as u8).into() - } } impl Agent_ for RandomAgent { @@ -174,23 +177,18 @@ impl Agent_ for RandomAgent { self.train } - fn opt(&mut self, buffer: &mut R) -> Option { - // Check warmup period - if buffer.len() <= 100 { - None - } else { + fn opt_with_record(&mut self, _buffer: &mut R) -> border_core::record::Record { // Do nothing - self.n_opts_steps += 1; - Some(Record::empty()) - } + self.n_opts_steps += 1; + Record::empty() } - fn save>(&self, _path: T) -> Result<()> { + fn save_params>(&self, _path: T) -> Result<()> { println!("save() was invoked"); Ok(()) } - fn load>(&mut self, _path: T) -> Result<()> { + fn load_params>(&mut self, _path: T) -> Result<()> { println!("load() was invoked"); Ok(()) } diff --git a/border-candle-agent/Cargo.toml b/border-candle-agent/Cargo.toml new file mode 100644 index 00000000..e06c2a16 --- /dev/null +++ b/border-candle-agent/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "border-candle-agent" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.7", path = "../border-core" } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } +serde = { workspace = true, features = ["derive"] } +serde_yaml = { workspace = true } +tensorboard-rs = { workspace = true } +log = { workspace = true } +thiserror = { workspace = true } +anyhow = { workspace = true } +chrono = { workspace = true } +aquamarine = { workspace = true } +candle-core = { workspace = true } +candle-nn = { workspace = true } +fastrand = { workspace = true } +segment-tree = { workspace = true } +rand = { workspace = true } +itertools = { workspace = true } +ordered-float = { workspace = true } +candle-optimisers = { workspace = true } + +[dev-dependencies] +tempdir = { workspace = true } + +# [package.metadata.docs.rs] +# features = ["doc-only"] + +# [features] +# doc-only = ["tch/doc-only"] diff --git a/border-candle-agent/README.md b/border-candle-agent/README.md new file mode 100644 index 00000000..e69de29b diff --git a/border-candle-agent/src/cnn.rs b/border-candle-agent/src/cnn.rs new file mode 100644 index 00000000..2d0bff3e --- /dev/null +++ b/border-candle-agent/src/cnn.rs @@ -0,0 +1,9 @@ +//! Convolutional neural network. +//! +//! The architecture is the same in the DQN Nature paper. +//! It should be noted that the input array will be scaled by 1 / 255 for normalizing +//! pixel intensities with casting from `u8` to `f32`. +mod base; +mod config; +pub use base::Cnn; +pub use config::CnnConfig; diff --git a/border-candle-agent/src/cnn/base.rs b/border-candle-agent/src/cnn/base.rs new file mode 100644 index 00000000..96093a6b --- /dev/null +++ b/border-candle-agent/src/cnn/base.rs @@ -0,0 +1,92 @@ +use super::CnnConfig; +use crate::model::SubModel1; +use anyhow::Result; +use candle_core::{DType::F32, Device, Tensor}; +use candle_nn::{ + conv::Conv2dConfig, + conv2d, linear, + sequential::{seq, Sequential}, + Module, VarBuilder, +}; + +#[allow(clippy::upper_case_acronyms)] +#[allow(dead_code)] +/// Convolutional neural network, which has the same architecture of the DQN paper. +pub struct Cnn { + n_stack: i64, + out_dim: i64, + device: Device, + seq: Sequential, + skip_linear: bool, +} + +impl Cnn { + fn stride(s: i64) -> Conv2dConfig { + Conv2dConfig { + stride: s as _, + ..Default::default() + } + } + + fn create_net(vb: &VarBuilder, n_stack: i64, out_dim: i64) -> Result { + let seq = seq() + .add_fn(|xs| xs.squeeze(2)?.to_dtype(F32)? / 255.0) + .add(conv2d(n_stack as _, 32, 8, Self::stride(4), vb.pp("c1"))?) + .add_fn(|xs| xs.relu()) + .add(conv2d(32, 64, 4, Self::stride(2), vb.pp("c2"))?) + .add_fn(|xs| xs.relu()) + .add(conv2d(64, 64, 3, Self::stride(1), vb.pp("c3"))?) + .add_fn(|xs| xs.relu()?.flatten_from(1)) + .add(linear(3136, 512, vb.pp("l1"))?) + .add_fn(|xs| xs.relu()) + .add(linear(512, out_dim as _, vb.pp("l2"))?); + + Ok(seq) + } + + fn create_net_wo_linear(vb: &VarBuilder, n_stack: i64) -> Result { + let seq = seq() + .add_fn(|xs| xs.squeeze(2)?.to_dtype(F32)? / 255.0) + .add(conv2d(n_stack as _, 32, 8, Self::stride(4), vb.pp("c1"))?) + .add_fn(|xs| xs.relu()) + .add(conv2d(32, 64, 4, Self::stride(2), vb.pp("c2"))?) + .add_fn(|xs| xs.relu()) + .add(conv2d(64, 64, 3, Self::stride(1), vb.pp("c3"))?) + .add_fn(|xs| xs.relu()?.flatten_from(1)); + + Ok(seq) + } +} + +impl SubModel1 for Cnn { + type Config = CnnConfig; + type Input = Tensor; + type Output = Tensor; + + fn forward(&self, x: &Self::Input) -> Tensor { + self.seq + .forward(&x.to_device(&self.device).unwrap()) + .unwrap() + } + + fn build(vb: VarBuilder, config: Self::Config) -> Self { + let n_stack = config.n_stack; + let out_dim = config.out_dim; + let device = vb.device().clone(); + let skip_linear = config.skip_linear; + let seq = if config.skip_linear { + Self::create_net_wo_linear(&vb, n_stack) + } else { + Self::create_net(&vb, n_stack, out_dim) + } + .unwrap(); + + Self { + n_stack, + out_dim, + device, + seq, + skip_linear, + } + } +} diff --git a/border-candle-agent/src/cnn/config.rs b/border-candle-agent/src/cnn/config.rs new file mode 100644 index 00000000..6cb5982f --- /dev/null +++ b/border-candle-agent/src/cnn/config.rs @@ -0,0 +1,46 @@ +use crate::util::OutDim; +use serde::{Deserialize, Serialize}; + +fn default_skip_linear() -> bool { + false +} + +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +/// Configuration of [`Cnn`](super::Cnn). +/// +/// If `skip_linear` is `true`, `out_dim` is not used. +pub struct CnnConfig { + pub n_stack: i64, + pub out_dim: i64, + #[serde(default = "default_skip_linear")] + pub skip_linear: bool, +} + +impl CnnConfig { + /// Constructs [`CnnConfig`] + pub fn new(n_stack: i64, out_dim: i64) -> Self { + Self { + n_stack, + out_dim, + skip_linear: false, + } + } + + pub fn skip_linear(mut self, skip_linear: bool) -> Self { + self.skip_linear = skip_linear; + self + } +} + +impl OutDim for CnnConfig { + /// Gets output dimension. + fn get_out_dim(&self) -> i64 { + self.out_dim + } + + /// Sets output dimension. + fn set_out_dim(&mut self, v: i64) { + self.out_dim = v; + } +} diff --git a/border-candle-agent/src/dqn.rs b/border-candle-agent/src/dqn.rs new file mode 100644 index 00000000..2929aeb4 --- /dev/null +++ b/border-candle-agent/src/dqn.rs @@ -0,0 +1,9 @@ +//! DQN agent. +mod base; +mod config; +mod explorer; +mod model; +pub use base::Dqn; +pub use config::DqnConfig; +pub use explorer::{DqnExplorer, EpsilonGreedy, Softmax}; +pub use model::{DqnModel, DqnModelConfig}; diff --git a/border-candle-agent/src/dqn/base.rs b/border-candle-agent/src/dqn/base.rs new file mode 100644 index 00000000..f67055b0 --- /dev/null +++ b/border-candle-agent/src/dqn/base.rs @@ -0,0 +1,385 @@ +//! DQN agent implemented with candle. +use super::{config::DqnConfig, explorer::DqnExplorer, model::DqnModel}; +use crate::{ + model::SubModel1, + util::{smooth_l1_loss, track, CriticLoss, OutDim}, +}; +use anyhow::Result; +use border_core::{ + record::{Record, RecordValue}, + Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, +}; +use candle_core::{shape::D, DType, Device, Tensor}; +use candle_nn::loss::mse; +use rand::{rngs::SmallRng, Rng, SeedableRng}; +use serde::{de::DeserializeOwned, Serialize}; +use std::convert::TryFrom; +use std::{fs, marker::PhantomData, path::Path}; + +#[allow(clippy::upper_case_acronyms, dead_code)] +/// DQN agent implemented with candle. +pub struct Dqn +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + pub(in crate::dqn) soft_update_interval: usize, + pub(in crate::dqn) soft_update_counter: usize, + pub(in crate::dqn) n_updates_per_opt: usize, + pub(in crate::dqn) batch_size: usize, + pub(in crate::dqn) qnet: DqnModel, + pub(in crate::dqn) qnet_tgt: DqnModel, + pub(in crate::dqn) train: bool, + pub(in crate::dqn) phantom: PhantomData<(E, R)>, + pub(in crate::dqn) discount_factor: f64, + pub(in crate::dqn) tau: f64, + pub(in crate::dqn) explorer: DqnExplorer, + pub(in crate::dqn) device: Device, + pub(in crate::dqn) n_opts: usize, + pub(in crate::dqn) double_dqn: bool, + pub(in crate::dqn) _clip_reward: Option, + pub(in crate::dqn) clip_td_err: Option<(f64, f64)>, + pub(in crate::dqn) critic_loss: CriticLoss, + n_samples_act: usize, + n_samples_best_act: usize, + record_verbose_level: usize, + rng: SmallRng, +} + +impl Dqn +where + E: Env, + Q: SubModel1, + R: ReplayBufferBase, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, +{ + fn update_critic(&mut self, buffer: &mut R) -> Record { + let mut record = Record::empty(); + let batch = buffer.batch(self.batch_size).unwrap(); + let (obs, act, next_obs, reward, is_terminated, _is_truncated, _ixs, weight) = + batch.unpack(); + let obs = obs.into(); + let act = act.into().to_device(&self.device).unwrap(); + let next_obs = next_obs.into(); + let reward = Tensor::from_slice(&reward[..], &[reward.len()], &self.device).unwrap(); + let is_not_terminated = { + let is_not_terminated = is_terminated + .into_iter() + .map(|v| (1 - v) as f32) + .collect::>(); + Tensor::from_slice( + &is_not_terminated[..], + &[is_not_terminated.len()], + &self.device, + ) + .unwrap() + }; + let pred = { + let x = self.qnet.forward(&obs); + x.gather(&act, D::Minus1) + .unwrap() + .squeeze(D::Minus1) + .unwrap() + }; + + if self.record_verbose_level >= 2 { + record.insert( + "pred_mean", + RecordValue::Scalar(pred.mean_all().unwrap().to_vec0::().unwrap()), + ); + } + + if self.record_verbose_level >= 2 { + let reward_mean: f32 = reward.mean_all().unwrap().to_vec0().unwrap(); + record.insert("reward_mean", RecordValue::Scalar(reward_mean)); + } + + let tgt = { + let q = if self.double_dqn { + let x = self.qnet.forward(&next_obs); + let y = x.argmax(D::Minus1).unwrap(); + let tgt = self.qnet_tgt.forward(&next_obs); + tgt.gather(&y, D::Minus1).unwrap() + } else { + let x = self.qnet_tgt.forward(&next_obs); + let y = x.argmax(D::Minus1).unwrap(); + x.gather(&y.unsqueeze(D::Minus1).unwrap(), D::Minus1) + .unwrap() + }; + + reward + is_not_terminated * self.discount_factor * q.squeeze(D::Minus1).unwrap() + } + .unwrap() + .detach(); + + if self.record_verbose_level >= 2 { + record.insert( + "tgt_mean", + RecordValue::Scalar(tgt.mean_all().unwrap().to_vec0::().unwrap()), + ); + let tgt_minus_pred_mean: f32 = (&tgt - &pred) + .unwrap() + .mean_all() + .unwrap() + .to_vec0() + .unwrap(); + record.insert( + "tgt_minus_pred_mean", + RecordValue::Scalar(tgt_minus_pred_mean), + ); + } + + let loss = if let Some(_ws) = weight { + // Prioritized weighting loss, will be implemented later + panic!(); + // let n = ws.len() as i64; + // let td_errs = match self.clip_td_err { + // None => (&pred - &tgt).abs(), + // Some((min, max)) => (&pred - &tgt).abs().clip(min, max), + // }; + // let loss = Tensor::of_slice(&ws[..]).to(self.device) * &td_errs; + // let loss = loss.smooth_l1_loss( + // &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device), + // tch::Reduction::Mean, + // 1.0, + // ); + // self.qnet.backward_step(&loss); + // let td_errs = Vec::::from(td_errs); + // buffer.update_priority(&ixs, &Some(td_errs)); + // loss + } else { + match self.critic_loss { + CriticLoss::Mse => mse(&pred, &tgt).unwrap(), + CriticLoss::SmoothL1 => smooth_l1_loss(&pred, &tgt).unwrap(), + } + }; + + // Backprop + self.qnet.backward_step(&loss).unwrap(); + + record.insert( + "loss", + RecordValue::Scalar(loss.to_scalar::().unwrap()), + ); + + record + // f32::from(loss.to_scalar::().unwrap()) + } + + fn opt_(&mut self, buffer: &mut R) -> Record { + let mut record_ = Record::empty(); + + for _ in 0..self.n_updates_per_opt { + let record = self.update_critic(buffer); + record_ = record_.merge(record); + } + + self.soft_update_counter += 1; + if self.soft_update_counter == self.soft_update_interval { + self.soft_update_counter = 0; + let _ = track(self.qnet_tgt.get_varmap(), self.qnet.get_varmap(), self.tau); + } + + self.n_opts += 1; + + record_ + // Record::from_slice(&[("loss", RecordValue::Scalar(loss_critic))]) + } +} + +impl Policy for Dqn +where + E: Env, + Q: SubModel1, + E::Obs: Into, + E::Act: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + /// In evaluation mode, take a random action with probability 0.01. + fn sample(&mut self, obs: &E::Obs) -> E::Act { + let a = self.qnet.forward(&obs.clone().into()).detach(); + let a = if self.train { + self.n_samples_act += 1; + match &mut self.explorer { + DqnExplorer::Softmax(softmax) => softmax.action(&a, &mut self.rng), + DqnExplorer::EpsilonGreedy(egreedy) => { + if self.record_verbose_level >= 2 { + let (act, best) = egreedy.action_with_best(&a, &mut self.rng); + if best { + self.n_samples_best_act += 1; + } + act + } else { + egreedy.action(&a, &mut self.rng) + } + } + } + } else { + if self.rng.gen::() < 0.01 { + let n_actions = a.dims()[1] as i64; + let a: i64 = self.rng.gen_range(0..n_actions); + Tensor::try_from(vec![a]).unwrap() + } else { + a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap() + } + }; + a.into() + } +} + +impl Configurable for Dqn +where + E: Env, + Q: SubModel1, + E::Obs: Into, + E::Act: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + type Config = DqnConfig; + + /// Constructs DQN agent. + fn build(config: Self::Config) -> Self { + let device: Device = config + .device + .expect("No device is given for DQN agent") + .into(); + let qnet = DqnModel::build(config.model_config.clone(), device.clone()).unwrap(); + let qnet_tgt = DqnModel::build(config.model_config.clone(), device.clone()).unwrap(); + let _ = track(qnet_tgt.get_varmap(), qnet.get_varmap(), 1.0); + + Dqn { + qnet, + qnet_tgt, + soft_update_interval: config.soft_update_interval, + soft_update_counter: 0, + n_updates_per_opt: config.n_updates_per_opt, + batch_size: config.batch_size, + discount_factor: config.discount_factor, + tau: config.tau, + train: config.train, + explorer: config.explorer, + device, + n_opts: 0, + _clip_reward: config.clip_reward, + double_dqn: config.double_dqn, + clip_td_err: config.clip_td_err, + critic_loss: config.critic_loss, + phantom: PhantomData, + n_samples_act: 0, + n_samples_best_act: 0, + record_verbose_level: config.record_verbose_level, + rng: SmallRng::seed_from_u64(42), + } + } +} + +impl Agent for Dqn +where + E: Env, + Q: SubModel1, + R: ReplayBufferBase, + E::Obs: Into, + E::Act: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, +{ + fn train(&mut self) { + self.train = true; + } + + fn eval(&mut self) { + self.train = false; + } + + fn is_train(&self) -> bool { + self.train + } + + fn opt(&mut self, buffer: &mut R) { + self.opt_(buffer); + } + + fn opt_with_record(&mut self, buffer: &mut R) -> Record { + let mut record = { + let record = self.opt_(buffer); + + match self.record_verbose_level >= 2 { + true => { + let record_weights = self.qnet.param_stats(); + let record = record.merge(record_weights); + record + } + false => record, + } + }; + + // Best action ratio for epsilon greedy + let ratio = match self.n_samples_act == 0 { + true => 0f32, + false => self.n_samples_best_act as f32 / self.n_samples_act as f32, + }; + record.insert("ratio_best_act", RecordValue::Scalar(ratio)); + self.n_samples_act = 0; + self.n_samples_best_act = 0; + + record + } + + /// Save model parameters in the given directory. + /// + /// The parameters of the model are saved as `qnet.pt`. + /// The parameters of the target model are saved as `qnet_tgt.pt`. + fn save_params>(&self, path: T) -> Result<()> { + // TODO: consider to rename the path if it already exists + fs::create_dir_all(&path)?; + self.qnet.save(&path.as_ref().join("qnet.pt").as_path())?; + self.qnet_tgt + .save(&path.as_ref().join("qnet_tgt.pt").as_path())?; + Ok(()) + } + + fn load_params>(&mut self, path: T) -> Result<()> { + self.qnet.load(&path.as_ref().join("qnet.pt").as_path())?; + self.qnet_tgt + .load(&path.as_ref().join("qnet_tgt.pt").as_path())?; + Ok(()) + } +} + +#[cfg(feature = "border-async-trainer")] +use {crate::util::NamedTensors, border_async_trainer::SyncModel}; + +#[cfg(feature = "border-async-trainer")] +impl SyncModel for Dqn +where + E: Env, + Q: SubModel1, + R: ReplayBufferBase, + E::Obs: Into, + E::Act: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, +{ + type ModelInfo = NamedTensors; + + fn model_info(&self) -> (usize, Self::ModelInfo) { + unimplemented!(); + // ( + // self.n_opts, + // NamedTensors::copy_from(self.qnet.get_var_store()), + // ) + } + + fn sync_model(&mut self, _model_info: &Self::ModelInfo) { + unimplemented!(); + // let vs = self.qnet.get_var_store_mut(); + // model_info.copy_to(vs); + } +} diff --git a/border-candle-agent/src/dqn/config.rs b/border-candle-agent/src/dqn/config.rs new file mode 100644 index 00000000..aaa694e2 --- /dev/null +++ b/border-candle-agent/src/dqn/config.rs @@ -0,0 +1,212 @@ +//! Configuration of DQN agent. +use super::{ + explorer::{DqnExplorer, Softmax}, + DqnModelConfig, +}; +use crate::{ + model::SubModel1, + util::{CriticLoss, OutDim}, + Device, +}; +use anyhow::Result; +use candle_core::Tensor; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + default::Default, + fs::File, + io::{BufReader, Write}, + marker::PhantomData, + path::Path, +}; + +/// Configuration of [`Dqn`](super::Dqn) agent. +#[derive(Debug, Deserialize, Serialize, PartialEq)] +pub struct DqnConfig +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + pub model_config: DqnModelConfig, + pub soft_update_interval: usize, + pub n_updates_per_opt: usize, + pub batch_size: usize, + pub discount_factor: f64, + pub tau: f64, + pub train: bool, + pub explorer: DqnExplorer, + #[serde(default)] + pub clip_reward: Option, + #[serde(default)] + pub double_dqn: bool, + pub clip_td_err: Option<(f64, f64)>, + pub device: Option, + pub critic_loss: CriticLoss, + pub record_verbose_level: usize, + pub phantom: PhantomData, +} + +impl Clone for DqnConfig +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + fn clone(&self) -> Self { + Self { + model_config: self.model_config.clone(), + soft_update_interval: self.soft_update_interval, + n_updates_per_opt: self.n_updates_per_opt, + batch_size: self.batch_size, + discount_factor: self.discount_factor, + tau: self.tau, + train: self.train, + explorer: self.explorer.clone(), + clip_reward: self.clip_reward, + double_dqn: self.double_dqn, + clip_td_err: self.clip_td_err, + device: self.device.clone(), + critic_loss: self.critic_loss.clone(), + record_verbose_level: self.record_verbose_level, + phantom: PhantomData, + } + } +} + +impl Default for DqnConfig +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + /// Constructs DQN builder with default parameters. + fn default() -> Self { + Self { + model_config: Default::default(), + soft_update_interval: 1, + n_updates_per_opt: 1, + batch_size: 1, + discount_factor: 0.99, + tau: 0.005, + train: false, + // replay_burffer_capacity: 100, + explorer: DqnExplorer::Softmax(Softmax::new()), + // expr_sampling: ExperienceSampling::Uniform, + clip_reward: None, + double_dqn: false, + clip_td_err: None, + device: None, + critic_loss: CriticLoss::Mse, + record_verbose_level: 0, + phantom: PhantomData, + } + } +} + +impl DqnConfig +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + /// Sets soft update interval. + pub fn soft_update_interval(mut self, v: usize) -> Self { + self.soft_update_interval = v; + self + } + + /// Sets the numper of parameter update steps per optimization step. + pub fn n_updates_per_opt(mut self, v: usize) -> Self { + self.n_updates_per_opt = v; + self + } + + /// Batch size. + pub fn batch_size(mut self, v: usize) -> Self { + self.batch_size = v; + self + } + + /// Discount factor. + pub fn discount_factor(mut self, v: f64) -> Self { + self.discount_factor = v; + self + } + + /// Soft update coefficient. + pub fn tau(mut self, v: f64) -> Self { + self.tau = v; + self + } + + /// Explorer. + pub fn explorer(mut self, v: DqnExplorer) -> Self { + self.explorer = v; + self + } + + /// Sets the configuration of the model. + pub fn model_config(mut self, model_config: DqnModelConfig) -> Self { + self.model_config = model_config; + self + } + + /// Sets the output dimention of the dqn model of the DQN agent. + pub fn out_dim(mut self, out_dim: i64) -> Self { + let model_config = self.model_config.clone(); + self.model_config = model_config.out_dim(out_dim as _); + self + } + + /// Reward clipping. + pub fn clip_reward(mut self, clip_reward: Option) -> Self { + self.clip_reward = clip_reward; + self + } + + /// Double DQN + pub fn double_dqn(mut self, double_dqn: bool) -> Self { + self.double_dqn = double_dqn; + self + } + + /// TD-error clipping. + pub fn clip_td_err(mut self, clip_td_err: Option<(f64, f64)>) -> Self { + self.clip_td_err = clip_td_err; + self + } + + /// Device. + pub fn device(mut self, device: candle_core::Device) -> Self { + self.device = Some(device.into()); + self + } + + /// Sets critic loss. + pub fn critic_loss(mut self, v: CriticLoss) -> Self { + self.critic_loss = v; + self + } + + /// Sets verbose level. + pub fn record_verbose_level(mut self, v: usize) -> Self { + self.record_verbose_level = v; + self + } + + /// Loads [`DqnConfig`] from YAML file. + pub fn load(path: impl AsRef) -> Result { + let path_ = path.as_ref().to_owned(); + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + info!("Load config of DQN agent from {}", path_.to_str().unwrap()); + Ok(b) + } + + /// Saves [`DqnConfig`]. + pub fn save(&self, path: impl AsRef) -> Result<()> { + let path_ = path.as_ref().to_owned(); + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + info!("Save config of DQN agent into {}", path_.to_str().unwrap()); + Ok(()) + } +} diff --git a/border-candle-agent/src/dqn/explorer.rs b/border-candle-agent/src/dqn/explorer.rs new file mode 100644 index 00000000..98a9f655 --- /dev/null +++ b/border-candle-agent/src/dqn/explorer.rs @@ -0,0 +1,148 @@ +//! Exploration strategies of DQN. +use candle_core::{shape::D, DType, Tensor}; +use candle_nn::ops::softmax; +use rand::{distributions::WeightedIndex, Rng}; +use serde::{Deserialize, Serialize}; + +/// Explorers for DQN. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub enum DqnExplorer { + /// Softmax action selection. + Softmax(Softmax), + + /// Epsilon-greedy action selection. + EpsilonGreedy(EpsilonGreedy), +} + +/// Softmax explorer for DQN. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct Softmax {} + +#[allow(clippy::new_without_default)] +impl Softmax { + /// Constructs softmax explorer. + pub fn new() -> Self { + Self {} + } + + /// Takes an action based on action values, returns i64 tensor. + /// + /// * `a` - action values. + pub fn action(&mut self, a: &Tensor, rng: &mut impl Rng) -> Tensor { + let device = a.device(); + let probs = softmax(a, 1).unwrap().to_vec2::().unwrap(); + let n_samples = probs.len(); + let data = probs + .into_iter() + .map(|p| rng.sample(WeightedIndex::new(&p).unwrap()) as i64) + .collect::>(); + Tensor::from_vec(data, &[n_samples], device).unwrap() + } +} + +/// Epsilon-greedy explorer for DQN. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct EpsilonGreedy { + pub n_opts: usize, + pub eps_start: f64, + pub eps_final: f64, + pub final_step: usize, +} + +#[allow(clippy::new_without_default)] +impl EpsilonGreedy { + /// Constructs epsilon-greedy explorer. + pub fn new() -> Self { + Self { + n_opts: 0, + eps_start: 1.0, + eps_final: 0.02, + final_step: 100_000, + } + } + + /// Constructs epsilon-greedy explorer. + /// + /// TODO: improve interface. + pub fn with_final_step(final_step: usize) -> DqnExplorer { + DqnExplorer::EpsilonGreedy(Self { + n_opts: 0, + eps_start: 1.0, + eps_final: 0.02, + final_step, + }) + } + + /// Takes an action based on action values, returns i64 tensor. + /// + /// * `a` - action values. + pub fn action(&mut self, a: &Tensor, rng: &mut impl Rng) -> Tensor { + let d = (self.eps_start - self.eps_final) / (self.final_step as f64); + let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final); + let r = rng.gen::(); + let is_random = r < eps as f32; + self.n_opts += 1; + + if is_random { + let n_samples = a.dims()[0]; + let n_actions = a.dims()[1] as u64; + Tensor::from_slice( + (0..n_samples) + .map(|_| (rng.gen::() % n_actions) as i64) + .collect::>() + .as_slice(), + &[n_samples], + a.device(), + ) + .unwrap() + } else { + a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap() + } + } + + /// Takes an action based on action values, returns i64 tensor. + /// + /// * `a` - action values. + pub fn action_with_best(&mut self, a: &Tensor, rng: &mut impl Rng) -> (Tensor, bool) { + let d = (self.eps_start - self.eps_final) / (self.final_step as f64); + let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final); + let r = rng.gen::(); + let is_random = r < eps as f32; + self.n_opts += 1; + + let best = a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap(); + + if is_random { + let n_samples = a.dims()[0]; + let n_actions = a.dims()[1] as u64; + let act = Tensor::from_slice( + (0..n_samples) + .map(|_| (rng.gen::() % n_actions) as i64) + .collect::>() + .as_slice(), + &[n_samples], + a.device(), + ) + .unwrap(); + let act_: Vec = act.to_vec1().unwrap(); + let best_: Vec = best.to_vec1().unwrap(); + (act, act_ == best_) + } else { + (best, true) + } + } + + /// Set the epsilon value at the final step. + pub fn eps_final(self, v: f64) -> Self { + let mut s = self; + s.eps_final = v; + s + } + + /// Set the epsilon value at the start. + pub fn eps_start(self, v: f64) -> Self { + let mut s = self; + s.eps_start = v; + s + } +} diff --git a/border-candle-agent/src/dqn/model.rs b/border-candle-agent/src/dqn/model.rs new file mode 100644 index 00000000..90efee8b --- /dev/null +++ b/border-candle-agent/src/dqn/model.rs @@ -0,0 +1,228 @@ +use crate::{ + model::SubModel1, + opt::{Optimizer, OptimizerConfig}, + util::OutDim, +}; +use anyhow::{Context, Result}; +use border_core::record::Record; +use candle_core::{DType, Device, Tensor}; +use candle_nn::{VarBuilder, VarMap}; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +/// Configuration of [`DqnModel`]. +pub struct DqnModelConfig +where + Q: OutDim, +{ + pub q_config: Option, + #[serde(default)] + pub opt_config: OptimizerConfig, +} + +impl Default for DqnModelConfig +where + Q: OutDim, +{ + fn default() -> Self { + Self { + q_config: None, + opt_config: OptimizerConfig::default(), + } + } +} + +impl DqnModelConfig +where + Q: DeserializeOwned + Serialize + OutDim, +{ + /// Sets configurations for action-value function. + pub fn q_config(mut self, v: Q) -> Self { + self.q_config = Some(v); + self + } + + /// Sets output dimension of the model. + pub fn out_dim(mut self, v: i64) -> Self { + match &mut self.q_config { + None => {} + Some(q_config) => q_config.set_out_dim(v), + }; + self + } + + /// Sets optimizer configuration. + pub fn opt_config(mut self, v: OptimizerConfig) -> Self { + self.opt_config = v; + self + } + + /// Constructs [`DqnModelConfig`] from YAML file. + pub fn load(path: impl AsRef) -> Result { + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + Ok(b) + } + + /// Saves [`DqnModelConfig`] to as a YAML file. + pub fn save(&self, path: impl AsRef) -> Result<()> { + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + Ok(()) + } +} + +/// Action value function model for DQN. +/// +/// The architecture of the model is defined by the type parameter `Q`, +/// which should implement [`SubModel1`]. +/// This takes [`SubModel1::Input`] as input and outputs a tensor. +/// The output tensor should have the same dimension as the number of actions. +pub struct DqnModel +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim, +{ + device: Device, + varmap: VarMap, + + // Dimension of the output vector (equal to the number of actions). + pub(super) out_dim: i64, + + // Action-value function + q: Q, + + // Optimizer + opt_config: OptimizerConfig, + q_config: Q::Config, + opt: Optimizer, +} + +impl DqnModel +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + Clone, +{ + /// Constructs [`DqnModel`]. + pub fn build(config: DqnModelConfig, device: Device) -> Result { + let out_dim = config.q_config.as_ref().unwrap().get_out_dim(); + let q_config = config.q_config.context("q_config is not set.")?; + let opt_config = config.opt_config; + let varmap = VarMap::new(); + let q = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + Q::build(vb, q_config.clone()) + }; + + Ok(Self::_build( + device, + out_dim as _, + opt_config, + q_config, + q, + varmap, + None, + )) + } + + fn _build( + device: Device, + out_dim: i64, + opt_config: OptimizerConfig, + q_config: Q::Config, + q: Q, + mut varmap: VarMap, + varmap_src: Option<&VarMap>, + ) -> Self { + // Optimizer + let opt = opt_config.build(varmap.all_vars()).unwrap(); + + // Copy varmap + if let Some(varmap_src) = varmap_src { + varmap.clone_from(varmap_src); + } + + Self { + device, + out_dim, + opt_config, + varmap, + opt, + q, + q_config, + } + } + + /// Outputs the action-value given observation(s). + pub fn forward(&self, obs: &Q::Input) -> Tensor { + self.q.forward(obs) + } + + pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + // Consider to use gradient clipping, below code + // let mut grads = loss.backward()?; + // for (_, var) in self.varmap.data().lock().unwrap().iter() { + // let g1 = grads.get(var).unwrap(); + // let g2 = g1.clamp(-1.0, 1.0)?; + // let _ = grads.remove(&var).unwrap(); + // let _ = grads.insert(&var, g2); + // } + // self.opt.step(&grads) + self.opt.backward_step(loss) + } + + pub fn get_varmap(&self) -> &VarMap { + &self.varmap + } + + pub fn save>(&self, path: T) -> Result<()> { + self.varmap.save(&path)?; + info!("Save dqnmodel to {:?}", path.as_ref()); + Ok(()) + } + + pub fn load>(&mut self, path: T) -> Result<()> { + self.varmap.load(&path)?; + info!("Load dqnmodel from {:?}", path.as_ref()); + Ok(()) + } + + pub fn param_stats(&self) -> Record { + crate::util::param_stats(&self.varmap) + } +} + +impl Clone for DqnModel +where + Q: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + Clone, +{ + fn clone(&self) -> Self { + let device = self.device.clone(); + let out_dim = self.out_dim; + let opt_config = self.opt_config.clone(); + let q_config = self.q_config.clone(); + let varmap = VarMap::new(); + let q = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + Q::build(vb, self.q_config.clone()) + }; + + Self::_build( + device, + out_dim, + opt_config, + q_config, + q, + varmap, + Some(&self.varmap), + ) + } +} diff --git a/border-candle-agent/src/lib.rs b/border-candle-agent/src/lib.rs new file mode 100644 index 00000000..9809a488 --- /dev/null +++ b/border-candle-agent/src/lib.rs @@ -0,0 +1,52 @@ +//! RL agents implemented with [candle](https://crates.io/crates/candle-core). +pub mod cnn; +pub mod dqn; +// pub mod iqn; +pub mod mlp; +pub mod model; +pub mod opt; +pub mod sac; +mod tensor_batch; +pub mod util; +use candle_core::{backend::BackendDevice, DeviceLocation}; +use serde::{Deserialize, Serialize}; +pub use tensor_batch::{TensorBatch, ZeroTensor}; + +#[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)] +/// Device for using candle. +/// +/// This enum is added because [`candle_core::Device`] does not support serialization. +/// +/// [`candle_core::Device`]: https://docs.rs/candle-core/0.4.1/candle_core/enum.Device.html +pub enum Device { + /// The main CPU device. + Cpu, + + /// The main GPU device. + Cuda(usize), +} + +impl From for Device { + fn from(device: candle_core::Device) -> Self { + match device { + candle_core::Device::Cpu => Self::Cpu, + candle_core::Device::Cuda(cuda_device) => { + let loc = cuda_device.location(); + match loc { + DeviceLocation::Cuda { gpu_id } => Self::Cuda(gpu_id), + _ => panic!(), + } + } + _ => unimplemented!(), + } + } +} + +impl Into for Device { + fn into(self) -> candle_core::Device { + match self { + Self::Cpu => candle_core::Device::Cpu, + Self::Cuda(n) => candle_core::Device::new_cuda(n).unwrap(), + } + } +} diff --git a/border-candle-agent/src/mlp.rs b/border-candle-agent/src/mlp.rs new file mode 100644 index 00000000..7881edbe --- /dev/null +++ b/border-candle-agent/src/mlp.rs @@ -0,0 +1,52 @@ +//! Multilayer perceptron. +mod base; +mod config; +mod mlp2; +pub use base::Mlp; +use candle_core::Tensor; +use candle_nn::{Linear, Module}; +pub use config::MlpConfig; +pub use mlp2::Mlp2; + +fn mlp_forward(xs: Tensor, layers: &Vec) -> Tensor { + let n_layers = layers.len(); + let mut xs = xs; + + for i in 0..=n_layers - 2 { + xs = layers[i].forward(&xs).unwrap().relu().unwrap(); + } + + layers[n_layers - 1].forward(&xs).unwrap() +} + +// for (i, &n) in config.units.iter().enumerate() { +// seq = seq.add(nn::linear( +// p / format!("{}{}", prefix, i + 1), +// in_dim, +// n, +// Default::default(), +// )); +// seq = seq.add_fn(|x| x.relu()); +// in_dim = n; +// } + +// } + +// fn mlp(prefix: &str, var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential { +// let mut seq = nn::seq(); +// let mut in_dim = config.in_dim; +// let p = &var_store.root(); + +// for (i, &n) in config.units.iter().enumerate() { +// seq = seq.add(nn::linear( +// p / format!("{}{}", prefix, i + 1), +// in_dim, +// n, +// Default::default(), +// )); +// seq = seq.add_fn(|x| x.relu()); +// in_dim = n; +// } + +// seq +// } diff --git a/border-candle-agent/src/mlp/base.rs b/border-candle-agent/src/mlp/base.rs new file mode 100644 index 00000000..f8f50fca --- /dev/null +++ b/border-candle-agent/src/mlp/base.rs @@ -0,0 +1,87 @@ +use super::{mlp_forward, MlpConfig}; +use crate::model::{SubModel1, SubModel2}; +use anyhow::Result; +use candle_core::{Device, Tensor, D}; +use candle_nn::{linear, Linear, VarBuilder}; + +/// Returns vector of linear modules from [`MlpConfig`]. +fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result> { + let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1) + .map(|i| (config.units[i], config.units[i + 1])) + .collect(); + in_out_pairs.insert(0, (config.in_dim, config.units[0])); + in_out_pairs.push((*config.units.last().unwrap(), config.out_dim)); + let vs = vs.pp(prefix); + + Ok(in_out_pairs + .iter() + .enumerate() + .map(|(i, &(in_dim, out_dim))| { + linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap() + }) + .collect()) +} + +/// Multilayer perceptron with ReLU activation function. +pub struct Mlp { + config: MlpConfig, + device: Device, + layers: Vec, +} + +fn _build(vs: VarBuilder, config: MlpConfig) -> Mlp { + let device = vs.device().clone(); + let layers = create_linear_layers("mlp", vs, &config).unwrap(); + + Mlp { + config, + device, + layers, + } +} + +impl SubModel1 for Mlp { + type Config = MlpConfig; + type Input = Tensor; + type Output = Tensor; + + fn forward(&self, xs: &Self::Input) -> Tensor { + let xs = xs.to_device(&self.device).unwrap(); + let xs = mlp_forward(xs, &self.layers); + + match self.config.activation_out { + false => xs, + true => xs.relu().unwrap(), + } + } + + fn build(vs: VarBuilder, config: Self::Config) -> Self { + _build(vs, config) + } +} + +impl SubModel2 for Mlp { + type Config = MlpConfig; + type Input1 = Tensor; + type Input2 = Tensor; + type Output = Tensor; + + fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output { + let input1: Tensor = input1.to_device(&self.device).unwrap(); + let input2: Tensor = input2.to_device(&self.device).unwrap(); + let input = Tensor::cat(&[input1, input2], D::Minus1) + .unwrap() + .to_device(&self.device) + .unwrap(); + let xs = mlp_forward(input, &self.layers); + + match self.config.activation_out { + false => xs, + true => xs.relu().unwrap(), + } + } + + fn build(vs: VarBuilder, config: Self::Config) -> Self { + _build(vs, config) + } +} diff --git a/border-candle-agent/src/mlp/config.rs b/border-candle-agent/src/mlp/config.rs new file mode 100644 index 00000000..8ebaace0 --- /dev/null +++ b/border-candle-agent/src/mlp/config.rs @@ -0,0 +1,35 @@ +use crate::util::OutDim; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +/// Configuration of [`Mlp`](super::Mlp). +pub struct MlpConfig { + pub in_dim: i64, + pub units: Vec, + pub out_dim: i64, + pub activation_out: bool, +} + +impl MlpConfig { + /// Creates configuration of MLP. + /// + /// * `activation_out` - If `true`, activation function is added in the final layer. + pub fn new(in_dim: i64, units: Vec, out_dim: i64, activation_out: bool) -> Self { + Self { + in_dim, + units, + out_dim, + activation_out, + } + } +} + +impl OutDim for MlpConfig { + fn get_out_dim(&self) -> i64 { + self.out_dim + } + + fn set_out_dim(&mut self, out_dim: i64) { + self.out_dim = out_dim; + } +} diff --git a/border-candle-agent/src/mlp/mlp2.rs b/border-candle-agent/src/mlp/mlp2.rs new file mode 100644 index 00000000..e2aa36d4 --- /dev/null +++ b/border-candle-agent/src/mlp/mlp2.rs @@ -0,0 +1,65 @@ +use super::{mlp_forward, MlpConfig}; +use crate::model::SubModel1; +use anyhow::Result; +use candle_core::{Device, Module, Tensor}; +use candle_nn::{linear, Linear, VarBuilder}; + +/// Returns vector of linear modules from [`MlpConfig`]. +fn create_linear_layers(prefix: &str, vs: VarBuilder, config: &MlpConfig) -> Result> { + let mut in_out_pairs: Vec<(i64, i64)> = (0..config.units.len() - 1) + .map(|i| (config.units[i], config.units[i + 1])) + .collect(); + in_out_pairs.insert(0, (config.in_dim, config.units[0])); + let vs = vs.pp(prefix); + + Ok(in_out_pairs + .iter() + .enumerate() + .map(|(i, &(in_dim, out_dim))| { + linear(in_dim as _, out_dim as _, vs.pp(format!("ln{}", i))).unwrap() + }) + .collect()) +} + +/// Multilayer perceptron that outputs two tensors of the same size. +pub struct Mlp2 { + _config: MlpConfig, + device: Device, + head1: Linear, + head2: Linear, + layers: Vec, +} + +impl SubModel1 for Mlp2 { + type Config = MlpConfig; + type Input = Tensor; + type Output = (Tensor, Tensor); + + fn forward(&self, xs: &Self::Input) -> Self::Output { + let xs = xs.to_device(&self.device).unwrap(); + let xs = mlp_forward(xs, &self.layers).relu().unwrap(); + let mean = self.head1.forward(&xs).unwrap(); + let std = self.head2.forward(&xs).unwrap().exp().unwrap(); + (mean, std) + } + + fn build(vs: VarBuilder, config: Self::Config) -> Self { + let device = vs.device().clone(); + let layers = create_linear_layers("mlp", vs.clone(), &config).unwrap(); + let (head1, head2) = { + let in_dim = *config.units.last().unwrap(); + let out_dim = config.out_dim; + let head1 = linear(in_dim as _, out_dim as _, vs.pp(format!("mean"))).unwrap(); + let head2 = linear(in_dim as _, out_dim as _, vs.pp(format!("std"))).unwrap(); + (head1, head2) + }; + + Self { + _config: config, + device, + head1, + head2, + layers, + } + } +} diff --git a/border-candle-agent/src/model.rs b/border-candle-agent/src/model.rs new file mode 100644 index 00000000..6fbacc6a --- /dev/null +++ b/border-candle-agent/src/model.rs @@ -0,0 +1,156 @@ +//! Interface of neural networks used in RL agents. +// use anyhow::Result; +// use candle_core::Tensor; +use candle_nn::VarBuilder; +// use std::path::Path; +// use tch::{nn, nn::VarStore, Tensor}; + +/// Neural network model not owing its [`VarMap`] internally. +/// +/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html +pub trait SubModel1 { + /// Configuration from which [`SubModel1`] is constructed. + type Config; + + /// Input of the [`SubModel1`]. + type Input; + + /// Output of the [`SubModel1`]. + type Output; + + /// Builds [`SubModel1`] with [`VarBuilder`] and [`SubModel1::Config`]. + /// + /// [`VarBuilder`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_builder/type.VarBuilder.html + fn build(vb: VarBuilder, config: Self::Config) -> Self; + + /// A generalized forward function. + fn forward(&self, input: &Self::Input) -> Self::Output; +} + +/// Neural network model not owing its [`VarMap`] internally. +/// +/// The difference from [`SubModel1`] is that this trait takes two inputs. +/// +/// [`VarMap`]: https://docs.rs/candle-nn/0.4.1/candle_nn/var_map/struct.VarMap.html +pub trait SubModel2 { + /// Configuration from which [`SubModel2`] is constructed. + type Config; + + /// Input of the [`SubModel2`]. + type Input1; + + /// Input of the [`SubModel2`]. + type Input2; + + /// Output of the [`SubModel2`]. + type Output; + + /// Builds [`SubModel2`]. + fn build(vb: VarBuilder, config: Self::Config) -> Self; + + /// A generalized forward function. + fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output; +} + +// /// Base interface of a neural nrtwork model owing its [`VarMap`]. +// /// +// /// [`VarMap`]: candle_nn::VarMap +// pub trait ModelBase { +// /// Trains the network given a loss. +// fn backward_step(&mut self, loss: &Tensor); + +// /// Returns `var_store` as mutable reference. +// fn get_var_store_mut(&mut self) -> &mut nn::VarStore; + +// /// Returns `var_store`. +// fn get_var_store(&self) -> &nn::VarStore; + +// /// Save parameters of the neural network. +// fn save>(&self, path: T) -> Result<()>; + +// /// Load parameters of the neural network. +// fn load>(&mut self, path: T) -> Result<()>; +// } + +// /// Neural networks with a single input and a single output. +// pub trait Model1: ModelBase { +// /// The input of the neural network. +// type Input; +// /// The output of the neural network. +// type Output; + +// /// Performs forward computation given an input. +// fn forward(&self, xs: &Self::Input) -> Self::Output; + +// // /// TODO: check places this method is used in code. +// // fn in_shape(&self) -> &[usize]; + +// // /// TODO: check places this method is used in code. +// // fn out_dim(&self) -> usize; +// } + +// /// Neural networks with double inputs and a single output. +// pub trait Model2: ModelBase { +// /// An input of the neural network. +// type Input1; +// /// The other input of the neural network. +// type Input2; +// /// The output of the neural network. +// type Output; + +// /// Performs forward computation given a pair of inputs. +// fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output; +// } + +// /// Neural network model that can be initialized with [VarStore] and configuration. +// /// +// /// The purpose of this trait is for modularity of neural network models. +// /// Modules, which consists a neural network, should share [VarStore]. +// /// To do this, structs implementing this trait can be initialized with a given [VarStore]. +// /// This trait also provide the ability to clone with a given [VarStore]. +// /// The ability is useful when creating a target network, used in recent deep learning algorithms in common. +// pub trait SubModel { +// /// Configuration from which [SubModel] is constructed. +// type Config; + +// /// Input of the [SubModel]. +// type Input; + +// /// Output of the [SubModel]. +// type Output; + +// /// Builds [SubModel] with [VarStore] and [SubModel::Config]. +// fn build(var_store: &VarStore, config: Self::Config) -> Self; + +// /// Clones [SubModel] with [VarStore]. +// fn clone_with_var_store(&self, var_store: &VarStore) -> Self; + +// /// A generalized forward function. +// fn forward(&self, input: &Self::Input) -> Self::Output; +// } + +// /// Neural network model that can be initialized with [VarStore] and configuration. +// /// +// /// The difference from [SubModel] is that this trait takes two inputs. +// pub trait SubModel2 { +// /// Configuration from which [SubModel2] is constructed. +// type Config; + +// /// Input of the [SubModel2]. +// type Input1; + +// /// Input of the [SubModel2]. +// type Input2; + +// /// Output of the [SubModel2]. +// type Output; + +// /// Builds [SubModel2] with [VarStore] and [SubModel2::Config]. +// fn build(var_store: &VarStore, config: Self::Config) -> Self; + +// /// Clones [SubModel2] with [VarStore]. +// fn clone_with_var_store(&self, var_store: &VarStore) -> Self; + +// /// A generalized forward function. +// fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output; +// } diff --git a/border-candle-agent/src/opt.rs b/border-candle-agent/src/opt.rs new file mode 100644 index 00000000..0dff9522 --- /dev/null +++ b/border-candle-agent/src/opt.rs @@ -0,0 +1,140 @@ +//! Optimizers. +use anyhow::Result; +use candle_core::{Tensor, Var}; +use candle_nn::{AdamW, Optimizer as _, ParamsAdamW}; +use candle_optimisers::adam::{Adam, ParamsAdam}; +use serde::{Deserialize, Serialize}; + +/// Configuration of optimizer for training neural networks in an RL agent. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub enum OptimizerConfig { + /// AdamW optimizer. + AdamW { + lr: f64, + #[serde(default = "default_beta1")] + beta1: f64, + #[serde(default = "default_beta2")] + beta2: f64, + #[serde(default = "default_eps")] + eps: f64, + #[serde(default = "default_weight_decay")] + weight_decay: f64, + }, + + /// Adam optimizer. + Adam { + /// Learning rate. + lr: f64, + }, +} + +fn default_beta1() -> f64 { + ParamsAdamW::default().beta1 +} + +fn default_beta2() -> f64 { + ParamsAdamW::default().beta2 +} + +fn default_eps() -> f64 { + ParamsAdamW::default().eps +} + +fn default_weight_decay() -> f64 { + ParamsAdamW::default().weight_decay +} + +impl OptimizerConfig { + /// Constructs [`AdamW`] optimizer. + pub fn build(&self, vars: Vec) -> Result { + match &self { + OptimizerConfig::AdamW { + lr, + beta1, + beta2, + eps, + weight_decay, + } => { + let params = ParamsAdamW { + lr: *lr, + beta1: *beta1, + beta2: *beta2, + eps: *eps, + weight_decay: *weight_decay, + }; + let opt = AdamW::new(vars, params)?; + Ok(Optimizer::AdamW(opt)) + } + OptimizerConfig::Adam { lr } => { + let params = ParamsAdam { + lr: *lr, + ..ParamsAdam::default() + }; + let opt = Adam::new(vars, params)?; + Ok(Optimizer::Adam(opt)) + } + } + } + + /// Override learning rate. + pub fn learning_rate(self, lr: f64) -> Self { + match self { + Self::AdamW { + lr: _, + beta1, + beta2, + eps, + weight_decay, + } => Self::AdamW { + lr, + beta1, + beta2, + eps, + weight_decay, + }, + Self::Adam { lr: _ } => Self::Adam { lr }, + } + } +} + +impl Default for OptimizerConfig { + fn default() -> Self { + let params = ParamsAdamW::default(); + Self::AdamW { + lr: params.lr, + beta1: params.beta1, + beta2: params.beta2, + eps: params.eps, + weight_decay: params.weight_decay, + } + } +} + +/// Optimizers. +/// +/// This is a thin wrapper of [`candle_nn::optim::Optimizer`]. +/// +/// [`candle_nn::optim::Optimizer`]: https://docs.rs/candle-nn/0.4.1/candle_nn/optim/trait.Optimizer.html +pub enum Optimizer { + /// Adam optimizer. + AdamW(AdamW), + + Adam(Adam), +} + +impl Optimizer { + /// Applies a backward step pass. + pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + match self { + Self::AdamW(opt) => Ok(opt.backward_step(loss)?), + Self::Adam(opt) => Ok(opt.backward_step(loss)?), + } + } + + pub fn step(&mut self, grads: &candle_core::backprop::GradStore) -> Result<()> { + match self { + Self::AdamW(opt) => Ok(opt.step(grads)?), + Self::Adam(opt) => Ok(opt.step(grads)?), + } + } +} diff --git a/border-candle-agent/src/sac.rs b/border-candle-agent/src/sac.rs new file mode 100644 index 00000000..89164dfd --- /dev/null +++ b/border-candle-agent/src/sac.rs @@ -0,0 +1,178 @@ +//! SAC agent. +//! +//! Here is an example of creating SAC agent: +//! +//! ```no_run +//! # use anyhow::Result; +//! use border_core::{ +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }, +//! # record::Record, +//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, +//! Configurable, +//! }; +//! use border_candle_agent::{ +//! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, +//! mlp::{Mlp, Mlp2, MlpConfig}, +//! opt::OptimizerConfig +//! }; +//! +//! # struct TestEnv(TestEnv_); +//! # #[derive(Clone, Debug)] +//! # struct TestObs(TestObs_); +//! # #[derive(Clone, Debug)] +//! # struct TestAct(TestAct_); +//! # struct TestObsBatch(TestObsBatch_); +//! # struct TestActBatch(TestActBatch_); +//! # +//! # impl Obs_ for TestObs { +//! # fn dummy(n: usize) -> Self { +//! # Self(TestObs_::dummy(n)) +//! # } +//! # +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl Into for TestObs { +//! # fn into(self) -> candle_core::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl BatchBase for TestObsBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestObsBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl BatchBase for TestActBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestActBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl Act_ for TestAct { +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl From for TestAct { +//! # fn from(t: candle_core::Tensor) -> Self { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Into for TestAct { +//! # fn into(self) -> candle_core::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Env_ for TestEnv { +//! # type Config = ::Config; +//! # type Obs = TestObs; +//! # type Act = TestAct; +//! # type Info = ::Info; +//! # +//! # fn build(config: &Self::Config, seed: i64) -> Result { +//! # Ok(Self(TestEnv_::build(&config, seed).unwrap())) +//! # } +//! # +//! # fn step(&mut self, act: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step(&act.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset(&mut self, is_done: Option<&Vec>) -> Result { +//! # Ok(TestObs(self.0.reset(is_done).unwrap())) +//! # } +//! # +//! # fn step_with_reset(&mut self, a: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step_with_reset(&a.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset_with_index(&mut self, ix: usize) -> Result { +//! # Ok(TestObs(self.0.reset_with_index(ix).unwrap())) +//! # } +//! # } +//! # +//! # type Env = TestEnv; +//! # type ObsBatch = TestObsBatch; +//! # type ActBatch = TestActBatch; +//! # type ReplayBuffer = SimpleReplayBuffer; +//! # +//! const DIM_OBS: i64 = 3; +//! const DIM_ACT: i64 = 1; +//! const LR_ACTOR: f64 = 1e-3; +//! const LR_CRITIC: f64 = 1e-3; +//! const BATCH_SIZE: usize = 256; +//! +//! fn create_agent(in_dim: i64, out_dim: i64) -> Sac { +//! let device = candle_core::Device::cuda_if_available(0).unwrap(); +//! +//! let actor_config = ActorConfig::default() +//! .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) +//! .out_dim(out_dim) +//! .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, true)); +//! let critic_config = CriticConfig::default() +//! .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) +//! .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); +//! let sac_config = SacConfig::::default() +//! .batch_size(BATCH_SIZE) +//! .actor_config(actor_config) +//! .critic_config(critic_config) +//! .device(device); +//! Sac::build(sac_config) +//! } +//! ``` +mod actor; +mod base; +mod config; +mod critic; +mod ent_coef; +pub use actor::{Actor, ActorConfig}; +pub use base::Sac; +pub use config::SacConfig; +pub use critic::{Critic, CriticConfig}; +pub use ent_coef::{EntCoef, EntCoefMode}; diff --git a/border-candle-agent/src/sac/actor.rs b/border-candle-agent/src/sac/actor.rs new file mode 100644 index 00000000..c967f385 --- /dev/null +++ b/border-candle-agent/src/sac/actor.rs @@ -0,0 +1,227 @@ +//! Actor of SAC agent. +use crate::util::OutDim; +use crate::{ + model::SubModel1, + opt::{Optimizer, OptimizerConfig}, +}; +use anyhow::{Context, Result}; +use candle_core::{DType, Device, Tensor}; +use candle_nn::{VarBuilder, VarMap}; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +/// Configuration of [`Actor`]. +pub struct ActorConfig { + pi_config: Option

, + opt_config: OptimizerConfig, +} + +impl Default for ActorConfig

{ + fn default() -> Self { + Self { + pi_config: None, + opt_config: OptimizerConfig::default(), + } + } +} + +impl

ActorConfig

+where + P: DeserializeOwned + Serialize + OutDim, +{ + /// Sets configurations for action-value function. + pub fn pi_config(mut self, v: P) -> Self { + self.pi_config = Some(v); + self + } + + /// Sets output dimension of the model. + pub fn out_dim(mut self, v: i64) -> Self { + match &mut self.pi_config { + None => {} + Some(pi_config) => pi_config.set_out_dim(v), + }; + self + } + + /// Sets optimizer configuration. + pub fn opt_config(mut self, v: OptimizerConfig) -> Self { + self.opt_config = v; + self + } + + /// Constructs [`ActorConfig`] from YAML file. + pub fn load(path: impl AsRef) -> Result { + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + Ok(b) + } + + /// Saves [`ActorConfig`] as YAML file. + pub fn save(&self, path: impl AsRef) -> Result<()> { + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + Ok(()) + } +} + +/// Stochastic policy for SAC agents. +pub struct Actor

+where + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Clone, +{ + device: Device, + varmap: VarMap, + + // Dimension of the action vector. + out_dim: i64, + + // Action-value function + pi_config: P::Config, + pi: P, + + // Optimizer + opt_config: OptimizerConfig, + opt: Optimizer, +} + +impl

Actor

+where + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Clone, +{ + /// Constructs [`Actor`]. + pub fn build(config: ActorConfig, device: Device) -> Result> { + let pi_config = config.pi_config.context("pi_config is not set.")?; + let out_dim = pi_config.get_out_dim(); + let varmap = VarMap::new(); + let pi = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + P::build(vb, pi_config.clone()) + }; + let opt_config = config.opt_config; + + Ok(Actor::_build( + device, out_dim, opt_config, pi_config, pi, varmap, None, + )) + } + + fn _build( + device: Device, + out_dim: i64, + opt_config: OptimizerConfig, + pi_config: P::Config, + pi: P, + mut varmap: VarMap, + varmap_src: Option<&VarMap>, + ) -> Self { + // Optimizer + let opt = opt_config.build(varmap.all_vars()).unwrap(); + + // Copy varmap + if let Some(varmap_src) = varmap_src { + varmap.clone_from(varmap_src); + } + + Self { + device, + out_dim, + opt_config, + varmap, + opt, + pi, + pi_config, + } + } + + /// Outputs the parameters of Gaussian distribution given an observation. + pub fn forward(&self, x: &P::Input) -> (Tensor, Tensor) { + let (mean, std) = self.pi.forward(&x); + debug_assert_eq!(mean.dims()[1], self.out_dim as usize); + debug_assert_eq!(std.dims()[1], self.out_dim as usize); + (mean, std) + } + + pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + self.opt.backward_step(loss)?; + Ok(()) + } + + pub fn save>(&self, path: T) -> Result<()> { + self.varmap.save(&path)?; + info!("Save actor to {:?}", path.as_ref()); + Ok(()) + } + + pub fn load>(&mut self, path: T) -> Result<()> { + self.varmap.load(&path)?; + info!("Load actor from {:?}", path.as_ref()); + Ok(()) + } +} + +impl

Clone for Actor

+where + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Clone, +{ + fn clone(&self) -> Self { + let device = self.device.clone(); + let opt_config = self.opt_config.clone(); + let varmap = VarMap::new(); + let pi_config = self.pi_config.clone(); + let pi = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + P::build(vb, pi_config.clone()) + }; + let out_dim = self.out_dim; + + Self::_build( + device, + out_dim, + opt_config, + pi_config, + pi, + varmap, + Some(&self.varmap), + ) + } +} + +// impl

ModelBase for Actor

+// where +// P: SubModel1, +// P::Config: DeserializeOwned + Serialize + OutDim + Clone, +// { +// fn backward_step(&mut self, loss: &Tensor) { +// self.opt.backward_step(loss); +// } + +// // fn get_var_store_mut(&mut self) -> &mut nn::VarStore { +// // &mut self.var_store +// // } + +// // fn get_var_store(&self) -> &nn::VarStore { +// // &self.var_store +// // } + +// fn save>(&self, path: T) -> Result<()> { +// self.varmap.save(&path)?; +// info!("Save actor to {:?}", path.as_ref()); +// Ok(()) +// } + +// fn load>(&mut self, path: T) -> Result<()> { +// self.varmap.load(&path)?; +// info!("Load actor from {:?}", path.as_ref()); +// Ok(()) +// } +// } diff --git a/border-candle-agent/src/sac/base.rs b/border-candle-agent/src/sac/base.rs new file mode 100644 index 00000000..05e4ea9e --- /dev/null +++ b/border-candle-agent/src/sac/base.rs @@ -0,0 +1,381 @@ +use super::{Actor, Critic, EntCoef, SacConfig}; +use crate::{ + model::{SubModel1, SubModel2}, + util::{smooth_l1_loss, track, CriticLoss, OutDim}, +}; +use anyhow::Result; +use border_core::{ + record::{Record, RecordValue}, + Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, +}; +use candle_core::{Device, Tensor, D}; +use candle_nn::loss::mse; +use log::trace; +use serde::{de::DeserializeOwned, Serialize}; +use std::{fs, marker::PhantomData, path::Path}; + +type ActionValue = Tensor; +type ActMean = Tensor; +type ActStd = Tensor; + +fn normal_logp(x: &Tensor) -> Result { + let tmp: Tensor = + ((-0.5 * (2.0 * std::f32::consts::PI).ln() as f64) - (0.5 * x.powf(2.0)?)?)?; + Ok(tmp.sum(D::Minus1)?) +} + +/// Soft actor critic (SAC) agent. +pub struct Sac +where + Q: SubModel2, + P: SubModel1, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + pub(super) qnets: Vec>, + pub(super) qnets_tgt: Vec>, + pub(super) pi: Actor

, + pub(super) gamma: f64, + pub(super) tau: f64, + pub(super) ent_coef: EntCoef, + pub(super) epsilon: f64, + pub(super) min_lstd: f64, + pub(super) max_lstd: f64, + pub(super) n_updates_per_opt: usize, + pub(super) batch_size: usize, + pub(super) train: bool, + pub(super) reward_scale: f32, + pub(super) n_opts: usize, + pub(super) critic_loss: CriticLoss, + pub(super) phantom: PhantomData<(E, R)>, + pub(super) device: Device, +} + +impl Sac +where + E: Env, + Q: SubModel2, + P: SubModel1, + R: ReplayBufferBase, + E::Obs: Into + Into, + E::Act: Into, + Q::Input2: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + R::Batch: TransitionBatch, + ::ObsBatch: Into + Into + Clone, + ::ActBatch: Into + Into, +{ + /// Returns action and its log probability under the Normal distributioni. + fn action_logp(&self, o: &P::Input) -> Result<(Tensor, Tensor)> { + let (mean, lstd) = self.pi.forward(o); + let std = lstd.clamp(self.min_lstd, self.max_lstd)?.exp()?; + let z = Tensor::randn(0f32, 1f32, mean.dims(), &self.device)?; + let a = (&std * &z + &mean)?.tanh()?; + let log_p = (normal_logp(&z)? + - 1f64 + * ((1f64 - a.powf(2.0)?)? + self.epsilon)? + .log()? + .sum(D::Minus1)?)? + .squeeze(D::Minus1)?; + + debug_assert_eq!(a.dims()[0], self.batch_size); + debug_assert_eq!(log_p.dims(), [self.batch_size]); + + Ok((a, log_p)) + } + + fn qvals(&self, qnets: &[Critic], obs: &Q::Input1, act: &Q::Input2) -> Vec { + qnets + .iter() + .map(|qnet| qnet.forward(obs, act).squeeze(D::Minus1).unwrap()) + .collect() + } + + /// Returns the minimum values of q values over critics + fn qvals_min(&self, qnets: &[Critic], obs: &Q::Input1, act: &Q::Input2) -> Result { + let qvals = self.qvals(qnets, obs, act); + let qvals = Tensor::stack(&qvals, 0)?; + let qvals_min = qvals.min(0)?.squeeze(D::Minus1)?; + + debug_assert_eq!(qvals_min.dims(), [self.batch_size]); + + Ok(qvals_min) + } + + fn update_critic(&mut self, batch: R::Batch) -> Result { + let losses = { + let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack(); + let batch_size = reward.len(); + let reward = Tensor::from_slice(&reward[..], (batch_size,), &self.device)?; + let is_terminated = { + let is_terminated = is_terminated.iter().map(|e| *e as f32).collect::>(); + Tensor::from_slice(&is_terminated[..], (batch_size,), &self.device)? + }; + + let preds = self.qvals(&self.qnets, &obs.into(), &act.into()); + let tgt = { + let next_q = { + let (next_a, next_log_p) = self.action_logp(&next_obs.clone().into())?; + let next_q = + self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into())?; + (next_q - self.ent_coef.alpha()?.broadcast_mul(&next_log_p))? + }; + ((self.reward_scale as f64) * reward)? + + (1f64 - &is_terminated)? * self.gamma * next_q + }? + .detach(); + + debug_assert_eq!(tgt.dims(), [self.batch_size]); + + let losses: Vec<_> = match self.critic_loss { + CriticLoss::Mse => preds + .iter() + .map(|pred| mse(&pred.squeeze(D::Minus1).unwrap(), &tgt).unwrap()) + .collect(), + CriticLoss::SmoothL1 => preds + .iter() + .map(|pred| smooth_l1_loss(&pred, &tgt).unwrap()) + .collect(), + }; + losses + }; + + for (qnet, loss) in self.qnets.iter_mut().zip(&losses) { + qnet.backward_step(&loss).unwrap(); + } + + Ok(losses + .iter() + .map(|loss| loss.to_scalar::().unwrap()) + .sum::() + / (self.qnets.len() as f32)) + } + + fn update_actor(&mut self, batch: &R::Batch) -> Result { + let loss = { + let o = batch.obs().clone(); + let (a, log_p) = self.action_logp(&o.into())?; + + // Update the entropy coefficient + self.ent_coef.update(&log_p.detach())?; + + let o = batch.obs().clone(); + let qval = self.qvals_min(&self.qnets, &o.into(), &a.into())?; + ((self.ent_coef.alpha()?.detach().broadcast_mul(&log_p))? - &qval)?.mean_all()? + }; + + self.pi.backward_step(&loss)?; + + Ok(loss.to_scalar::()?) + } + + fn soft_update(&mut self) -> Result<()> { + for (qnet_tgt, qnet) in self.qnets_tgt.iter().zip(&mut self.qnets) { + track(qnet_tgt.get_varmap(), qnet.get_varmap(), self.tau)?; + } + Ok(()) + } + + fn opt_(&mut self, buffer: &mut R) -> Result { + let mut loss_critic = 0f32; + let mut loss_actor = 0f32; + + for _ in 0..self.n_updates_per_opt { + trace!("batch()"); + let batch = buffer.batch(self.batch_size).unwrap(); + + trace!("update_actor()"); + loss_actor += self.update_actor(&batch)?; + + trace!("update_critic()"); + loss_critic += self.update_critic(batch)?; + + trace!("soft_update()"); + self.soft_update()?; + + self.n_opts += 1; + } + + loss_critic /= self.n_updates_per_opt as f32; + loss_actor /= self.n_updates_per_opt as f32; + + let record = Record::from_slice(&[ + ("loss_critic", RecordValue::Scalar(loss_critic)), + ("loss_actor", RecordValue::Scalar(loss_actor)), + ( + "ent_coef", + RecordValue::Scalar(self.ent_coef.alpha()?.to_vec1::()?[0]), + ), + ]); + + Ok(record) + } +} + +impl Policy for Sac +where + E: Env, + Q: SubModel2, + P: SubModel1, + E::Obs: Into + Into, + E::Act: Into + From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + fn sample(&mut self, obs: &E::Obs) -> E::Act { + let obs = obs.clone().into(); + let (mean, lstd) = self.pi.forward(&obs); + let std = lstd + .clamp(self.min_lstd, self.max_lstd) + .unwrap() + .exp() + .unwrap(); + let act = if self.train { + ((std * mean.randn_like(0., 1.).unwrap()).unwrap() + mean).unwrap() + } else { + mean + }; + act.tanh().unwrap().into() + } +} + +impl Configurable for Sac +where + E: Env, + Q: SubModel2, + P: SubModel1, + E::Obs: Into + Into, + E::Act: Into + From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +{ + type Config = SacConfig; + + /// Constructs [`Sac`] agent. + fn build(config: Self::Config) -> Self { + let device = config.device.expect("No device is given for SAC agent"); + let n_critics = config.n_critics; + let pi = Actor::build(config.actor_config, device.clone().into()).unwrap(); + let mut qnets = vec![]; + let mut qnets_tgt = vec![]; + for _ in 0..n_critics { + qnets.push(Critic::build(config.critic_config.clone(), device.clone().into()).unwrap()); + qnets_tgt + .push(Critic::build(config.critic_config.clone(), device.clone().into()).unwrap()); + } + + // if let Some(seed) = config.seed.as_ref() { + // tch::manual_seed(*seed); + // } + + Sac { + qnets, + qnets_tgt, + pi, + gamma: config.gamma, + tau: config.tau, + ent_coef: EntCoef::new(config.ent_coef_mode, device.into()).unwrap(), + epsilon: config.epsilon, + min_lstd: config.min_lstd, + max_lstd: config.max_lstd, + n_updates_per_opt: config.n_updates_per_opt, + batch_size: config.batch_size, + train: config.train, + reward_scale: config.reward_scale, + critic_loss: config.critic_loss, + n_opts: 0, + device: device.into(), + phantom: PhantomData, + } + } +} + +impl Agent for Sac +where + E: Env, + Q: SubModel2, + P: SubModel1, + R: ReplayBufferBase, + E::Obs: Into + Into, + E::Act: Into + From, + Q::Input2: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + R::Batch: TransitionBatch, + ::ObsBatch: Into + Into + Clone, + ::ActBatch: Into + Into, +{ + fn train(&mut self) { + self.train = true; + } + + fn eval(&mut self) { + self.train = false; + } + + fn is_train(&self) -> bool { + self.train + } + + fn opt_with_record(&mut self, buffer: &mut R) -> Record { + self.opt_(buffer).expect("Failed in Sac::opt_()") + } + + fn save_params>(&self, path: T) -> Result<()> { + // TODO: consider to rename the path if it already exists + fs::create_dir_all(&path)?; + for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() { + qnet.save(&path.as_ref().join(format!("qnet_{}.pt", i)).as_path())?; + qnet_tgt.save(&path.as_ref().join(format!("qnet_tgt_{}.pt", i)).as_path())?; + } + self.pi.save(&path.as_ref().join("pi.pt").as_path())?; + self.ent_coef + .save(&path.as_ref().join("ent_coef.pt").as_path())?; + Ok(()) + } + + fn load_params>(&mut self, path: T) -> Result<()> { + for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() { + qnet.load(&path.as_ref().join(format!("qnet_{}.pt", i)).as_path())?; + qnet_tgt.load(&path.as_ref().join(format!("qnet_tgt_{}.pt", i)).as_path())?; + } + self.pi.load(&path.as_ref().join("pi.pt").as_path())?; + self.ent_coef + .load(&path.as_ref().join("ent_coef.pt").as_path())?; + Ok(()) + } +} + +// #[cfg(feature = "border-async-trainer")] +// use {crate::util::NamedTensors, border_async_trainer::SyncModel}; + +// #[cfg(feature = "border-async-trainer")] +// impl SyncModel for Sac +// where +// E: Env, +// Q: SubModel2, +// P: SubModel, +// R: ReplayBufferBase, +// E::Obs: Into + Into, +// E::Act: Into + From, +// Q::Input2: From, +// Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +// P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, +// R::Batch: TransitionBatch, +// ::ObsBatch: Into + Into + Clone, +// ::ActBatch: Into + Into, +// { +// type ModelInfo = NamedTensors; + +// fn model_info(&self) -> (usize, Self::ModelInfo) { +// ( +// self.n_opts, +// NamedTensors::copy_from(self.pi.get_var_store()), +// ) +// } + +// fn sync_model(&mut self, model_info: &Self::ModelInfo) { +// model_info.copy_to(self.pi.get_var_store_mut()); +// } +// } diff --git a/border-candle-agent/src/sac/config.rs b/border-candle-agent/src/sac/config.rs new file mode 100644 index 00000000..3256740f --- /dev/null +++ b/border-candle-agent/src/sac/config.rs @@ -0,0 +1,206 @@ +//! Configuration of SAC agent. +use super::{ActorConfig, CriticConfig}; +use crate::{ + model::{SubModel1, SubModel2}, + sac::ent_coef::EntCoefMode, + util::CriticLoss, + util::OutDim, + Device, +}; +use anyhow::Result; +use candle_core::Tensor; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + fmt::Debug, + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +/// Configuration of [`Sac`](super::Sac). +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Deserialize, Serialize, PartialEq)] +pub struct SacConfig +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone, + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone, +{ + pub actor_config: ActorConfig, + pub critic_config: CriticConfig, + pub gamma: f64, + pub tau: f64, + pub ent_coef_mode: EntCoefMode, + pub epsilon: f64, + pub min_lstd: f64, + pub max_lstd: f64, + pub n_updates_per_opt: usize, + pub batch_size: usize, + pub train: bool, + pub critic_loss: CriticLoss, + pub reward_scale: f32, + pub n_critics: usize, + pub seed: Option, + pub device: Option, +} + +impl Clone for SacConfig +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone, + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone, +{ + fn clone(&self) -> Self { + Self { + actor_config: self.actor_config.clone(), + critic_config: self.critic_config.clone(), + gamma: self.gamma.clone(), + tau: self.tau.clone(), + ent_coef_mode: self.ent_coef_mode.clone(), + epsilon: self.epsilon.clone(), + min_lstd: self.min_lstd.clone(), + max_lstd: self.max_lstd.clone(), + n_updates_per_opt: self.n_updates_per_opt.clone(), + batch_size: self.batch_size.clone(), + train: self.train.clone(), + critic_loss: self.critic_loss.clone(), + reward_scale: self.reward_scale.clone(), + n_critics: self.n_critics.clone(), + seed: self.seed.clone(), + device: self.device.clone(), + } + } +} + +impl Default for SacConfig +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone, + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone, +{ + fn default() -> Self { + Self { + actor_config: Default::default(), + critic_config: Default::default(), + gamma: 0.99, + tau: 0.005, + ent_coef_mode: EntCoefMode::Fix(1.0), + epsilon: 1e-4, + min_lstd: -20.0, + max_lstd: 2.0, + n_updates_per_opt: 1, + batch_size: 1, + train: false, + critic_loss: CriticLoss::Mse, + reward_scale: 1.0, + n_critics: 1, + seed: None, + device: None, + } + } +} + +impl SacConfig +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Debug + PartialEq + Clone, + P: SubModel1, + P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone, +{ + /// Sets the numper of parameter update steps per optimization step. + pub fn n_updates_per_opt(mut self, v: usize) -> Self { + self.n_updates_per_opt = v; + self + } + + /// Batch size. + pub fn batch_size(mut self, v: usize) -> Self { + self.batch_size = v; + self + } + + /// Discount factor. + pub fn discount_factor(mut self, v: f64) -> Self { + self.gamma = v; + self + } + + /// Sets soft update coefficient. + pub fn tau(mut self, v: f64) -> Self { + self.tau = v; + self + } + + /// SAC-alpha. + pub fn ent_coef_mode(mut self, v: EntCoefMode) -> Self { + self.ent_coef_mode = v; + self + } + + /// Reward scale. + /// + /// It works for obtaining target values, not the values in logs. + pub fn reward_scale(mut self, v: f32) -> Self { + self.reward_scale = v; + self + } + + /// Critic loss. + pub fn critic_loss(mut self, v: CriticLoss) -> Self { + self.critic_loss = v; + self + } + + /// Configuration of actor. + pub fn actor_config(mut self, actor_config: ActorConfig) -> Self { + self.actor_config = actor_config; + self + } + + /// Configuration of critic. + pub fn critic_config(mut self, critic_config: CriticConfig) -> Self { + self.critic_config = critic_config; + self + } + + /// The number of critics. + pub fn n_critics(mut self, n_critics: usize) -> Self { + self.n_critics = n_critics; + self + } + + /// Random seed. + pub fn seed(mut self, seed: i64) -> Self { + self.seed = Some(seed); + self + } + + /// Device. + pub fn device(mut self, device: candle_core::Device) -> Self { + self.device = Some(device.into()); + self + } + + /// Constructs [`SacConfig`] from YAML file. + pub fn load(path: impl AsRef) -> Result { + let path_ = path.as_ref().to_owned(); + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + info!("Load config of SAC agent from {}", path_.to_str().unwrap()); + Ok(b) + } + + /// Saves [`SacConfig`]. + pub fn save(&self, path: impl AsRef) -> Result<()> { + let path_ = path.as_ref().to_owned(); + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + info!("Save config of SAC agent into {}", path_.to_str().unwrap()); + Ok(()) + } +} diff --git a/border-candle-agent/src/sac/critic.rs b/border-candle-agent/src/sac/critic.rs new file mode 100644 index 00000000..f9922619 --- /dev/null +++ b/border-candle-agent/src/sac/critic.rs @@ -0,0 +1,183 @@ +//! Critic of SAC agent. +use crate::{ + model::SubModel2, + opt::{Optimizer, OptimizerConfig}, +}; +use anyhow::{Context, Result}; +use candle_core::{DType, Device, Tensor}; +use candle_nn::{VarBuilder, VarMap}; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use std::{ + fs::File, + io::{BufReader, Write}, + path::Path, +}; + +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +/// Configuration of [`Critic`]. +pub struct CriticConfig { + pub q_config: Option, + pub opt_config: OptimizerConfig, +} + +impl Default for CriticConfig { + fn default() -> Self { + Self { + q_config: None, + opt_config: OptimizerConfig::default(), + } + } +} + +impl CriticConfig +where + Q: DeserializeOwned + Serialize, +{ + /// Sets configurations for action-value function. + pub fn q_config(mut self, v: Q) -> Self { + self.q_config = Some(v); + self + } + + /// Sets optimizer configuration. + pub fn opt_config(mut self, v: OptimizerConfig) -> Self { + self.opt_config = v; + self + } + + /// Constructs [CriticConfig] from YAML file. + pub fn load(path: impl AsRef) -> Result { + let file = File::open(path)?; + let rdr = BufReader::new(file); + let b = serde_yaml::from_reader(rdr)?; + Ok(b) + } + + /// Saves [CriticConfig]. + pub fn save(&self, path: impl AsRef) -> Result<()> { + let mut file = File::create(path)?; + file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; + Ok(()) + } +} + +/// Represents soft critic for SAC agents. +/// +/// It takes observations and actions as inputs and outputs action values. +pub struct Critic +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize, +{ + device: Device, + varmap: VarMap, + + /// Action-value function + q: Q, + q_config: Q::Config, + + opt_config: OptimizerConfig, + opt: Optimizer, +} + +impl Critic +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Clone, +{ + /// Constructs [`Critic`]. + pub fn build(config: CriticConfig, device: Device) -> Result> { + let q_config = config.q_config.context("q_config is not set.")?; + let opt_config = config.opt_config; + let varmap = VarMap::new(); + let q = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + Q::build(vb, q_config.clone()) + }; + + Ok(Critic::_build( + device, opt_config, q_config, q, varmap, None, + )) + } + + fn _build( + device: Device, + opt_config: OptimizerConfig, + q_config: Q::Config, + q: Q, + mut varmap: VarMap, + varmap_src: Option<&VarMap>, + ) -> Self { + // Optimizer + let opt = opt_config.build(varmap.all_vars()).unwrap(); + + // Copy varmap + if let Some(varmap_src) = varmap_src { + varmap.clone_from(varmap_src); + } + + Self { + device, + opt_config, + varmap, + opt, + q, + q_config, + } + } + + /// Outputs the action-value given observations and actions. + pub fn forward(&self, obs: &Q::Input1, act: &Q::Input2) -> Tensor { + self.q.forward(obs, act) + } +} + +impl Clone for Critic +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize + Clone, +{ + fn clone(&self) -> Self { + let device = self.device.clone(); + let opt_config = self.opt_config.clone(); + let varmap = VarMap::new(); + let q_config = self.q_config.clone(); + let q = { + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + Q::build(vb, q_config.clone()) + }; + + Self::_build(device, opt_config, q_config, q, varmap, Some(&self.varmap)) + } +} + +impl Critic +where + Q: SubModel2, + Q::Config: DeserializeOwned + Serialize, +{ + pub fn backward_step(&mut self, loss: &Tensor) -> Result<()> { + self.opt.backward_step(loss) + } + + // fn get_var_store_mut(&mut self) -> &mut nn::VarStore { + // &mut self.var_store + // } + + pub fn get_varmap(&self) -> &VarMap { + &self.varmap + } + + pub fn save>(&self, path: T) -> Result<()> { + self.varmap.save(&path)?; + info!("Save critic to {:?}", path.as_ref()); + Ok(()) + } + + pub fn load>(&mut self, path: T) -> Result<()> { + self.varmap.load(&path)?; + info!("Load critic from {:?}", path.as_ref()); + Ok(()) + } +} diff --git a/border-candle-agent/src/sac/ent_coef.rs b/border-candle-agent/src/sac/ent_coef.rs new file mode 100644 index 00000000..57c773d2 --- /dev/null +++ b/border-candle-agent/src/sac/ent_coef.rs @@ -0,0 +1,99 @@ +//! Entropy coefficient of SAC. +use std::convert::TryFrom; + +use crate::opt::{Optimizer, OptimizerConfig}; +use anyhow::Result; +use candle_core::{DType, Device, Tensor}; +use candle_nn::{init::Init, VarBuilder, VarMap}; +use log::info; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// Mode of the entropy coefficient of SAC. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub enum EntCoefMode { + /// Use a constant as alpha. + Fix(f64), + /// Automatic tuning given `(target_entropy, learning_rate)`. + Auto(f64, f64), +} + +/// The entropy coefficient of SAC. +pub struct EntCoef { + varmap: VarMap, + log_alpha: Tensor, + target_entropy: Option, + opt: Option, +} + +impl EntCoef { + /// Constructs an instance of `EntCoef`. + pub fn new(mode: EntCoefMode, device: Device) -> Result { + let varmap = VarMap::new(); + let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device); + let (log_alpha, target_entropy, opt) = match mode { + EntCoefMode::Fix(alpha) => { + let init = Init::Const(alpha.ln()); + let log_alpha = vb.get_with_hints(1, "log_alpha", init)?; + (log_alpha, None, None) + } + EntCoefMode::Auto(target_entropy, learning_rate) => { + let init = Init::Const(0.0); + let log_alpha = vb.get_with_hints(1, "log_alpha", init)?; + let opt = OptimizerConfig::default() + .learning_rate(learning_rate) + .build(varmap.all_vars())?; + (log_alpha, Some(target_entropy), Some(opt)) + } + }; + + Ok(Self { + varmap, + log_alpha, + opt, + target_entropy, + }) + } + + /// Returns the entropy coefficient. + pub fn alpha(&self) -> Result { + Ok(self.log_alpha.detach().exp()?) + } + + /// Does an optimization step given a loss. + pub fn backward_step(&mut self, loss: &Tensor) { + if let Some(opt) = &mut self.opt { + opt.backward_step(loss).unwrap(); + } + } + + /// Update the parameter given an action probability vector. + pub fn update(&mut self, logp: &Tensor) -> Result<()> { + if let Some(target_entropy) = &self.target_entropy { + let target_entropy = + Tensor::try_from(*target_entropy as f32)?.to_device(logp.device())?; + let loss = { + // let tmp = ((&self.log_alpha * (logp + target_entropy)?.detach())? * -1f64)?; + let tmp = (&self.log_alpha * -1f64)? + .broadcast_mul(&logp.broadcast_add(&target_entropy)?.detach())?; + tmp.mean(0)? + }; + self.backward_step(&loss); + } + Ok(()) + } + + /// Save the parameter into a file. + pub fn save>(&self, path: T) -> Result<()> { + self.varmap.save(&path)?; + info!("Save entropy coefficient to {:?}", path.as_ref()); + Ok(()) + } + + /// Save the parameter from a file. + pub fn load>(&mut self, path: T) -> Result<()> { + self.varmap.load(&path)?; + info!("Load entropy coefficient from {:?}", path.as_ref()); + Ok(()) + } +} diff --git a/border-candle-agent/src/tensor_batch.rs b/border-candle-agent/src/tensor_batch.rs new file mode 100644 index 00000000..410ac023 --- /dev/null +++ b/border-candle-agent/src/tensor_batch.rs @@ -0,0 +1,96 @@ +use border_core::generic_replay_buffer::BatchBase; +use candle_core::{error::Result, DType, Device, Tensor}; + +/// Adds capability of constructing [`Tensor`] with a static method. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html +pub trait ZeroTensor { + /// Constructs zero tensor. + fn zeros(shape: &[usize]) -> Result; +} + +impl ZeroTensor for u8 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::U8, &Device::Cpu) + } +} + +impl ZeroTensor for f32 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::F32, &Device::Cpu) + } +} + +impl ZeroTensor for i64 { + fn zeros(shape: &[usize]) -> Result { + Tensor::zeros(shape, DType::I64, &Device::Cpu) + } +} + +/// A buffer consisting of a [`Tensor`]. +/// +/// The internal buffer is `Vec`. +/// +/// [`Tensor`]: https://docs.rs/candle-core/0.4.1/candle_core/struct.Tensor.html +#[derive(Clone, Debug)] +pub struct TensorBatch { + buf: Vec, + capacity: usize, +} + +impl TensorBatch { + pub fn from_tensor(t: Tensor) -> Self { + let capacity = t.dims()[0] as _; + Self { + buf: vec![t], + capacity, + } + } +} + +impl BatchBase for TensorBatch { + fn new(capacity: usize) -> Self { + Self { + buf: Vec::with_capacity(capacity), + capacity: capacity, + } + } + + /// Pushes given data. + /// + /// if ix + data.buf.len() exceeds the self.capacity, + /// the tail samples in data is placed in the head of the buffer of self. + fn push(&mut self, ix: usize, data: Self) { + if self.buf.len() == self.capacity { + for (i, sample) in data.buf.into_iter().enumerate() { + let ix_ = (ix + i) % self.capacity; + self.buf[ix_] = sample; + } + } else if self.buf.len() < self.capacity { + for (i, sample) in data.buf.into_iter().enumerate() { + if self.buf.len() < self.capacity { + self.buf.push(sample); + } else { + let ix_ = (ix + i) % self.capacity; + self.buf[ix_] = sample; + } + } + } else { + panic!("The length of the buffer is SubBatch is larger than its capacity."); + } + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = ixs.iter().map(|&ix| self.buf[ix].clone()).collect(); + Self { + buf, + capacity: ixs.len(), + } + } +} + +impl From for Tensor { + fn from(b: TensorBatch) -> Self { + Tensor::cat(&b.buf[..], 0).unwrap() + } +} diff --git a/border-candle-agent/src/util.rs b/border-candle-agent/src/util.rs new file mode 100644 index 00000000..1e967546 --- /dev/null +++ b/border-candle-agent/src/util.rs @@ -0,0 +1,157 @@ +//! Utilities. +// use crate::model::ModelBase; +use anyhow::Result; +use candle_core::{DType, Tensor}; +use candle_nn::VarMap; +use log::trace; +use serde::{Deserialize, Serialize}; +mod named_tensors; +mod quantile_loss; +use border_core::record::{Record, RecordValue}; +pub use named_tensors::NamedTensors; +pub use quantile_loss::quantile_huber_loss; +use std::convert::TryFrom; + +/// Critic loss type. +#[allow(clippy::upper_case_acronyms)] +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub enum CriticLoss { + /// Mean squared error. + Mse, + + /// Smooth L1 loss. + SmoothL1, +} + +/// Apply soft update on variables. +/// +/// Variables are identified by their names. +/// +/// dest = tau * src + (1.0 - tau) * dest +pub fn track(dest: &VarMap, src: &VarMap, tau: f64) -> Result<()> { + trace!("dest"); + let dest = dest.data().lock().unwrap(); + trace!("src"); + let src = src.data().lock().unwrap(); + + dest.iter().for_each(|(k_dest, v_dest)| { + let v_src = src.get(k_dest).unwrap(); + let t_src = v_src.as_tensor(); + let t_dest = v_dest.as_tensor(); + let t_dest = ((tau * t_src).unwrap() + (1.0 - tau) * t_dest).unwrap(); + v_dest.set(&t_dest).unwrap(); + }); + + Ok(()) +} + +// /// Concatenates slices. +// pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec { +// let mut v = Vec::from(s1); +// v.append(&mut Vec::from(s2)); +// v +// } + +/// Interface for handling output dimensions. +pub trait OutDim { + /// Returns the output dimension. + fn get_out_dim(&self) -> i64; + + /// Sets the output dimension. + fn set_out_dim(&mut self, v: i64); +} + +#[test] +fn test_track() -> Result<()> { + use candle_core::{DType, Device, Tensor}; + use candle_nn::Init; + + let tau = 0.7; + let t_src = Tensor::from_slice(&[1.0f32, 2.0, 3.0], (3,), &Device::Cpu)?; + let t_dest = Tensor::from_slice(&[4.0f32, 5.0, 6.0], (3,), &Device::Cpu)?; + let t = ((tau * &t_src).unwrap() + (1.0 - tau) * &t_dest).unwrap(); + + let vm_src = { + let vm = VarMap::new(); + let init = Init::Randn { + mean: 0.0, + stdev: 1.0, + }; + vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?; + vm.data().lock().unwrap().get("var1").unwrap().set(&t_src)?; + vm + }; + let vm_dest = { + let vm = VarMap::new(); + let init = Init::Randn { + mean: 0.0, + stdev: 1.0, + }; + vm.get((3,), "var1", init, DType::F32, &Device::Cpu)?; + vm.data() + .lock() + .unwrap() + .get("var1") + .unwrap() + .set(&t_dest)?; + vm + }; + track(&vm_dest, &vm_src, tau)?; + + let t_ = vm_dest + .data() + .lock() + .unwrap() + .get("var1") + .unwrap() + .as_tensor() + .clone(); + + println!("{:?}", t); + println!("{:?}", t_); + assert!((t - t_)?.abs()?.sum(0)?.to_scalar::()? < 1e-32); + + Ok(()) +} + +/// See . +pub fn smooth_l1_loss(x: &Tensor, y: &Tensor) -> Result { + let device = x.device(); + let d = (x - y)?.abs()?; + let m1 = d.lt(1.0)?.to_dtype(DType::F32)?.to_device(&device)?; + let m2 = Tensor::try_from(1f32)? + .to_device(&device)? + .broadcast_sub(&m1)?; + (((0.5 * m1)? * d.powf(2.0))? + m2 * (d - 0.5))?.mean_all() +} + +/// Returns the standard deviation of a tensor. +pub fn std(t: &Tensor) -> f32 { + t.broadcast_sub(&t.mean_all().unwrap()) + .unwrap() + .powf(2f64) + .unwrap() + .mean_all() + .unwrap() + .sqrt() + .unwrap() + .to_vec0::() + .unwrap() +} + +/// Returns the mean and standard deviation of the parameters. +pub fn param_stats(varmap: &VarMap) -> Record { + let mut record = Record::empty(); + + for (k, v) in varmap.data().lock().unwrap().iter() { + let m: f32 = v.mean_all().unwrap().to_vec0().unwrap(); + let k_mean = format!("{}_mean", &k); + record.insert(k_mean, RecordValue::Scalar(m)); + + let m: f32 = std(v.as_tensor()); + let k_std = format!("{}_std", &k); + record.insert(k_std, RecordValue::Scalar(m)); + } + + record +} diff --git a/border-candle-agent/src/util/named_tensors.rs b/border-candle-agent/src/util/named_tensors.rs new file mode 100644 index 00000000..a6a5eada --- /dev/null +++ b/border-candle-agent/src/util/named_tensors.rs @@ -0,0 +1,104 @@ +use candle_core::Tensor; +use candle_nn::VarMap; +use std::collections::HashMap; +// use tch::{nn::VarStore, Device::Cpu, Tensor}; + +/// Named tensors to send model parameters using a channel. +pub struct NamedTensors { + pub named_tensors: HashMap, +} + +impl NamedTensors { + /// Copy data of [`VarMap`] to CPU. + pub fn copy_from(_vs: &VarMap) -> Self { + unimplemented!(); + + // let src = vs.variables(); + + // tch::no_grad(|| NamedTensors { + // named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| { + // let v = v.detach().to(Cpu).data(); + // (k.clone(), v) + // })), + // }) + } + + /// Copy named tensors to [`VarMap`]. + pub fn copy_to(&self, _vs: &mut VarMap) { + unimplemented!(); + + // let src = &self.named_tensors; + // let dest = &mut vs.variables(); + // // let device = vs.device(); + // debug_assert_eq!(src.len(), dest.len()); + + // tch::no_grad(|| { + // for (name, src) in src.iter() { + // let dest = dest.get_mut(name).unwrap(); + // dest.copy_(src); + // } + // }); + } +} + +impl Clone for NamedTensors { + fn clone(&self) -> Self { + unimplemented!(); + + // let src = &self.named_tensors; + + // tch::no_grad(|| NamedTensors { + // named_tensors: HashMap::from_iter(src.iter().map(|(k, v)| { + // let v = v.detach().to(Cpu).data(); + // (k.clone(), v) + // })), + // }) + } +} + +#[cfg(test)] +mod test { + // use super::NamedTensors; + // use std::convert::{TryFrom, TryInto}; + // use tch::{ + // nn::{self, Module}, + // Device::Cpu, + // Tensor, + // }; + + #[test] + fn test_named_tensors() { + // tch::manual_seed(42); + + // let tensor1 = Tensor::try_from(vec![1., 2., 3.]) + // .unwrap() + // .internal_cast_float(false); + + // let vs1 = nn::VarStore::new(Cpu); + // let model1 = nn::seq() + // .add(nn::linear(&vs1.root() / "layer1", 3, 8, Default::default())) + // .add(nn::linear(&vs1.root() / "layer2", 8, 2, Default::default())); + + // let mut vs2 = nn::VarStore::new(tch::Device::cuda_if_available()); + // let model2 = nn::seq() + // .add(nn::linear(&vs2.root() / "layer1", 3, 8, Default::default())) + // .add(nn::linear(&vs2.root() / "layer2", 8, 2, Default::default())); + // let device = vs2.device(); + + // let t1: Vec = model1.forward(&tensor1).try_into().unwrap(); + // let t2: Vec = model2.forward(&tensor1.to(device)).try_into().unwrap(); + + // let nt = NamedTensors::copy_from(&vs1); + // nt.copy_to(&mut vs2); + + // let t3: Vec = model2.forward(&tensor1.to(device)).try_into().unwrap(); + + // for i in 0..2 { + // assert!((t1[i] - t2[i]).abs() >= t1[i].abs() * 0.001); + // assert!((t1[i] - t3[i]).abs() < t1[i].abs() * 0.001); + // } + // // println!("{:?}", t1); + // // println!("{:?}", t2); + // // println!("{:?}", t3); + } +} diff --git a/border-candle-agent/src/util/quantile_loss.rs b/border-candle-agent/src/util/quantile_loss.rs new file mode 100644 index 00000000..112712e3 --- /dev/null +++ b/border-candle-agent/src/util/quantile_loss.rs @@ -0,0 +1,15 @@ +//! Quantile loss. +use candle_core::Tensor; + +/// Returns the quantile huber loss. +/// +/// `x` and `tau` has the same shape. +pub fn quantile_huber_loss(_x: &Tensor, _tau: &Tensor) -> Tensor { + // TODO: implements this + panic!(); + // debug_assert_eq!(x.size().as_slice(), tau.size().as_slice()); + + // let lt_0 = &x.lt(0.0).detach(); + // let loss = x.smooth_l1_loss(&Tensor::zeros_like(x), tch::Reduction::None, 1.0); + // (tau - Tensor::where_scalar(lt_0, 1., 0.)).abs() * loss +} diff --git a/border-core/Cargo.toml b/border-core/Cargo.toml index f8028f9f..b6273e7b 100644 --- a/border-core/Cargo.toml +++ b/border-core/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-core" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] serde = { workspace = true, features = ["derive"] } @@ -23,6 +19,7 @@ chrono = { workspace = true } aquamarine = { workspace = true } fastrand = { workspace = true } segment-tree = { workspace = true } +xxhash-rust = { workspace = true } # Consider to replace with fastrand rand = "0.8.4" diff --git a/border-core/src/base.rs b/border-core/src/base.rs index 91a552ee..5d739f26 100644 --- a/border-core/src/base.rs +++ b/border-core/src/base.rs @@ -6,12 +6,12 @@ mod policy; mod replay_buffer; mod step; pub use agent::Agent; -pub use batch::StdBatchBase; +pub use batch::TransitionBatch; pub use env::Env; -pub use policy::Policy; +pub use policy::{Configurable, Policy}; pub use replay_buffer::{ExperienceBufferBase, ReplayBufferBase}; use std::fmt::Debug; -pub use step::{Info, Step, StepProcessorBase}; +pub use step::{Info, Step, StepProcessor}; /// A set of observations of an environment. /// diff --git a/border-core/src/base/agent.rs b/border-core/src/base/agent.rs index ae880a9d..c6ab3ac0 100644 --- a/border-core/src/base/agent.rs +++ b/border-core/src/base/agent.rs @@ -15,15 +15,23 @@ pub trait Agent: Policy { /// Return if it is in training mode. fn is_train(&self) -> bool; - /// Do an optimization step. - fn opt(&mut self, buffer: &mut R) -> Option; + /// Performs an optimization step. + /// + /// `buffer` is a replay buffer from which transitions will be taken + /// for updating model parameters. + fn opt(&mut self, buffer: &mut R) { + let _ = self.opt_with_record(buffer); + } - /// Save the agent in the given directory. + /// Performs an optimization step and returns some information. + fn opt_with_record(&mut self, buffer: &mut R) -> Record; + + /// Save the parameters of the agent in the given directory. /// This method commonly creates a number of files consisting the agent /// in the directory. For example, the DQN agent in `border_tch_agent` crate saves /// two Q-networks corresponding to the original and target networks. - fn save>(&self, path: T) -> Result<()>; + fn save_params>(&self, path: T) -> Result<()>; - /// Load the agent from the given directory. - fn load>(&mut self, path: T) -> Result<()>; + /// Load the parameters of the agent from the given directory. + fn load_params>(&mut self, path: T) -> Result<()>; } diff --git a/border-core/src/base/batch.rs b/border-core/src/base/batch.rs index 188a76dd..597c0c33 100644 --- a/border-core/src/base/batch.rs +++ b/border-core/src/base/batch.rs @@ -10,14 +10,14 @@ /// /// The type of `o` and `o'` is the associated type `ObsBatch`. /// The type of `a` is the associated type `ActBatch`. -pub trait StdBatchBase { +pub trait TransitionBatch { /// A set of observation in a batch. type ObsBatch; /// A set of observation in a batch. type ActBatch; - /// Unpack the data `(o_t, a_t, o_t+n, r_t, is_done_t)`. + /// Unpack the data `(o_t, a_t, o_t+n, r_t, is_terminated_t, is_truncated_t)`. /// /// Optionally, the return value has sample indices in the replay buffer and /// thier weights. Those are used for prioritized experience replay (PER). @@ -29,6 +29,7 @@ pub trait StdBatchBase { Self::ObsBatch, Vec, Vec, + Vec, Option>, Option>, ); @@ -38,25 +39,4 @@ pub trait StdBatchBase { /// Returns `o_t`. fn obs(&self) -> &Self::ObsBatch; - - /// Returns `a_t`. - fn act(&self) -> &Self::ActBatch; - - /// Returns `o_t+1`. - fn next_obs(&self) -> &Self::ObsBatch; - - /// Returns `r_t`. - fn reward(&self) -> &Vec; - - /// Returns `is_done_t`. - fn is_done(&self) -> &Vec; - - /// Returns `weight`. It is used for PER. - fn weight(&self) -> &Option>; - - /// Returns `ix_sample`. It is used for PER. - fn ix_sample(&self) -> &Option>; - - /// Creates an empty batch. - fn empty() -> Self; } diff --git a/border-core/src/base/policy.rs b/border-core/src/base/policy.rs index 2fdb97fa..090cc0ba 100644 --- a/border-core/src/base/policy.rs +++ b/border-core/src/base/policy.rs @@ -1,18 +1,34 @@ //! Policy. use super::Env; -// use anyhow::Result; +use anyhow::Result; +use serde::de::DeserializeOwned; +use std::path::Path; /// A policy on an environment. /// /// Policy is a mapping from an observation to an action. /// The mapping can be either of deterministic or stochastic. pub trait Policy { - /// Configuration of the policy. - type Config: Clone; + /// Sample an action given an observation. + fn sample(&mut self, obs: &E::Obs) -> E::Act; +} - /// Builds the policy. +/// A configurable object, having type parameter. +pub trait Configurable { + /// Configuration. + type Config: Clone + DeserializeOwned; + + /// Builds the object. fn build(config: Self::Config) -> Self; - /// Sample an action given an observation. - fn sample(&mut self, obs: &E::Obs) -> E::Act; + /// Build the object with the configuration in the yaml file of the given path. + fn build_from_path(path: impl AsRef) -> Result + where + Self: Sized, + { + let file = std::fs::File::open(path)?; + let rdr = std::io::BufReader::new(file); + let config = serde_yaml::from_reader(rdr)?; + Ok(Self::build(config)) + } } diff --git a/border-core/src/base/replay_buffer.rs b/border-core/src/base/replay_buffer.rs index c328fbb3..6725df99 100644 --- a/border-core/src/base/replay_buffer.rs +++ b/border-core/src/base/replay_buffer.rs @@ -1,5 +1,4 @@ //! Replay buffer. -//use super::StdBatchBase; use anyhow::Result; /// Interface of buffers of experiences from environments. @@ -8,10 +7,10 @@ use anyhow::Result; /// This trait is usually required by processes sampling experiences. pub trait ExperienceBufferBase { /// Items pushed into the buffer. - type PushedItem; + type Item; /// Pushes a transition into the buffer. - fn push(&mut self, tr: Self::PushedItem) -> Result<()>; + fn push(&mut self, tr: Self::Item) -> Result<()>; /// The number of samples in the buffer. fn len(&self) -> usize; @@ -19,9 +18,9 @@ pub trait ExperienceBufferBase { /// Interface of replay buffers. /// -/// Ones implementing this trait generates a [ReplayBufferBase::Batch], +/// Ones implementing this trait generates a [`ReplayBufferBase::Batch`], /// which is used to train agents. -pub trait ReplayBufferBase: ExperienceBufferBase { +pub trait ReplayBufferBase { /// Configuration of the replay buffer. type Config: Clone; diff --git a/border-core/src/base/step.rs b/border-core/src/base/step.rs index ea0cca88..bd759500 100644 --- a/border-core/src/base/step.rs +++ b/border-core/src/base/step.rs @@ -9,10 +9,6 @@ pub trait Info {} /// /// An environment emits [`Step`] object at every interaction steps. /// This object might be used to create transitions `(o_t, a_t, o_t+1, r_t)`. -/// -/// Old versions of the library support veectorized environments, which requires -/// elements in [`Step`] to be able to handle multiple values. -/// This is why `reward` and `is_done` are vector. pub struct Step { /// Action. pub act: E::Act, @@ -23,8 +19,11 @@ pub struct Step { /// Reward. pub reward: Vec, - /// Flag denoting if episode is done. - pub is_done: Vec, + /// Flag denoting if episode is terminated. + pub is_terminated: Vec, + + /// Flag denoting if episode is truncated. + pub is_truncated: Vec, /// Information defined by user. pub info: E::Info, @@ -39,7 +38,8 @@ impl Step { obs: E::Obs, act: E::Act, reward: Vec, - is_done: Vec, + is_terminated: Vec, + is_truncated: Vec, info: E::Info, init_obs: E::Obs, ) -> Self { @@ -47,11 +47,18 @@ impl Step { act, obs, reward, - is_done, + is_terminated, + is_truncated, info, init_obs, } } + + #[inline] + /// Terminated or truncated. + pub fn is_done(&self) -> bool { + self.is_terminated[0] == 1 || self.is_truncated[0] == 1 + } } /// Process [`Step`] and output an item [`Self::Output`]. @@ -59,11 +66,11 @@ impl Step { /// This trait is used in [`Trainer`](crate::Trainer). [`Step`] object is transformed to /// [`Self::Output`], which will be pushed into a replay buffer implementing /// [`ExperienceBufferBase`](crate::ExperienceBufferBase). -/// The type [`Self::Output`] should be the same with [`ExperienceBufferBase::PushedItem`]. +/// The type [`Self::Output`] should be the same with [`ExperienceBufferBase::Item`]. /// -/// [`Self::Output`]: StepProcessorBase::Output -/// [`ExperienceBufferBase::PushedItem`]: crate::ExperienceBufferBase::PushedItem -pub trait StepProcessorBase { +/// [`Self::Output`]: StepProcessor::Output +/// [`ExperienceBufferBase::Item`]: crate::ExperienceBufferBase::Item +pub trait StepProcessor { /// Configuration. type Config: Clone; diff --git a/border-core/src/evaluator.rs b/border-core/src/evaluator.rs index 20ee5af1..88b3c531 100644 --- a/border-core/src/evaluator.rs +++ b/border-core/src/evaluator.rs @@ -1,12 +1,12 @@ -//! Evaluate [`Policy`](crate::Policy). +//! Evaluate [`Policy`]. use crate::{Env, Policy}; use anyhow::Result; mod default_evaluator; pub use default_evaluator::DefaultEvaluator; -/// Evaluate [`Policy`](crate::Policy). +/// Evaluate [`Policy`]. pub trait Evaluator> { - /// Evaluate [`Policy`](crate::Policy). + /// Evaluate [`Policy`]. /// /// The caller of this method needs to handle the internal state of `policy`, /// like training/evaluation mode. diff --git a/border-core/src/evaluator/default_evaluator.rs b/border-core/src/evaluator/default_evaluator.rs index 9a0086ea..d6937336 100644 --- a/border-core/src/evaluator/default_evaluator.rs +++ b/border-core/src/evaluator/default_evaluator.rs @@ -27,7 +27,7 @@ where let act = policy.sample(&prev_obs); let (step, _) = self.env.step(&act); r_total += step.reward[0]; - if step.is_done[0] == 1 { + if step.is_done() { break; } prev_obs = step.obs; diff --git a/border-core/src/generic_replay_buffer.rs b/border-core/src/generic_replay_buffer.rs new file mode 100644 index 00000000..2724364d --- /dev/null +++ b/border-core/src/generic_replay_buffer.rs @@ -0,0 +1,9 @@ +//! A generic implementation of replay buffer. +mod base; +mod batch; +mod config; +mod step_proc; +pub use base::{IwScheduler, SimpleReplayBuffer, WeightNormalizer}; +pub use batch::{BatchBase, GenericTransitionBatch}; +pub use config::{PerConfig, SimpleReplayBufferConfig}; +pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; diff --git a/border-core/src/replay_buffer/base.rs b/border-core/src/generic_replay_buffer/base.rs similarity index 73% rename from border-core/src/replay_buffer/base.rs rename to border-core/src/generic_replay_buffer/base.rs index 36c55c8b..c79ece1a 100644 --- a/border-core/src/replay_buffer/base.rs +++ b/border-core/src/generic_replay_buffer/base.rs @@ -1,8 +1,8 @@ //! Simple generic replay buffer. mod iw_scheduler; mod sum_tree; -use super::{config::PerConfig, StdBatch, SimpleReplayBufferConfig, SubBatch}; -use crate::{StdBatchBase, ExperienceBufferBase, ReplayBufferBase}; +use super::{config::PerConfig, BatchBase, GenericTransitionBatch, SimpleReplayBufferConfig}; +use crate::{ExperienceBufferBase, ReplayBufferBase, TransitionBatch}; use anyhow::Result; pub use iw_scheduler::IwScheduler; use rand::{rngs::StdRng, RngCore, SeedableRng}; @@ -30,8 +30,8 @@ impl PerState { /// A simple generic replay buffer. pub struct SimpleReplayBuffer where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { capacity: usize, i: usize, @@ -40,15 +40,16 @@ where act: A, next_obs: O, reward: Vec, - is_done: Vec, + is_terminated: Vec, + is_truncated: Vec, rng: StdRng, per_state: Option, } impl SimpleReplayBuffer where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { #[inline] fn push_reward(&mut self, i: usize, b: &Vec) { @@ -63,10 +64,21 @@ where } #[inline] - fn push_is_done(&mut self, i: usize, b: &Vec) { + fn push_is_terminated(&mut self, i: usize, b: &Vec) { let mut j = i; for d in b.iter() { - self.is_done[j] = *d; + self.is_terminated[j] = *d; + j += 1; + if j == self.capacity { + j = 0; + } + } + } + + fn push_is_truncated(&mut self, i: usize, b: &Vec) { + let mut j = i; + for d in b.iter() { + self.is_truncated[j] = *d; j += 1; if j == self.capacity { j = 0; @@ -78,8 +90,12 @@ where ixs.iter().map(|ix| self.reward[*ix]).collect() } - fn sample_is_done(&self, ixs: &Vec) -> Vec { - ixs.iter().map(|ix| self.is_done[*ix]).collect() + fn sample_is_terminated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_terminated[*ix]).collect() + } + + fn sample_is_truncated(&self, ixs: &Vec) -> Vec { + ixs.iter().map(|ix| self.is_truncated[*ix]).collect() } /// Sets priorities for the added samples. @@ -96,23 +112,24 @@ where impl ExperienceBufferBase for SimpleReplayBuffer where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { - type PushedItem = StdBatch; + type Item = GenericTransitionBatch; fn len(&self) -> usize { self.size } - fn push(&mut self, tr: Self::PushedItem) -> Result<()> { + fn push(&mut self, tr: Self::Item) -> Result<()> { let len = tr.len(); // batch size - let (obs, act, next_obs, reward, is_done, _, _) = tr.unpack(); - self.obs.push(self.i, &obs); - self.act.push(self.i, &act); - self.next_obs.push(self.i, &next_obs); + let (obs, act, next_obs, reward, is_terminated, is_truncated, _, _) = tr.unpack(); + self.obs.push(self.i, obs); + self.act.push(self.i, act); + self.next_obs.push(self.i, next_obs); self.push_reward(self.i, &reward); - self.push_is_done(self.i, &is_done); + self.push_is_terminated(self.i, &is_terminated); + self.push_is_truncated(self.i, &is_truncated); if self.per_state.is_some() { self.set_priority(len) @@ -128,14 +145,13 @@ where } } - impl ReplayBufferBase for SimpleReplayBuffer where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { type Config = SimpleReplayBufferConfig; - type Batch = StdBatch; + type Batch = GenericTransitionBatch; fn build(config: &Self::Config) -> Self { let capacity = config.capacity; @@ -152,7 +168,8 @@ where act: A::new(capacity), next_obs: O::new(capacity), reward: vec![0.; capacity], - is_done: vec![0; capacity], + is_terminated: vec![0; capacity], + is_truncated: vec![0; capacity], // rng: Rng::with_seed(config.seed), rng: StdRng::seed_from_u64(config.seed as _), per_state, @@ -180,7 +197,8 @@ where act: self.act.sample(&ixs), next_obs: self.next_obs.sample(&ixs), reward: self.sample_reward(&ixs), - is_done: self.sample_is_done(&ixs), + is_terminated: self.sample_is_terminated(&ixs), + is_truncated: self.sample_is_truncated(&ixs), ix_sample: Some(ixs), weight, }) diff --git a/border-core/src/replay_buffer/base/iw_scheduler.rs b/border-core/src/generic_replay_buffer/base/iw_scheduler.rs similarity index 90% rename from border-core/src/replay_buffer/base/iw_scheduler.rs rename to border-core/src/generic_replay_buffer/base/iw_scheduler.rs index fe27331b..13c02ce0 100644 --- a/border-core/src/replay_buffer/base/iw_scheduler.rs +++ b/border-core/src/generic_replay_buffer/base/iw_scheduler.rs @@ -20,7 +20,12 @@ pub struct IwScheduler { impl IwScheduler { /// Creates a scheduler. pub fn new(beta_0: f32, beta_final: f32, n_opts_final: usize) -> Self { - Self { beta_0, beta_final, n_opts_final, n_opts: 0 } + Self { + beta_0, + beta_final, + n_opts_final, + n_opts: 0, + } } /// Gets the exponents of importance sampling weight. diff --git a/border-core/src/replay_buffer/base/sum_tree.rs b/border-core/src/generic_replay_buffer/base/sum_tree.rs similarity index 100% rename from border-core/src/replay_buffer/base/sum_tree.rs rename to border-core/src/generic_replay_buffer/base/sum_tree.rs diff --git a/border-core/src/replay_buffer/batch.rs b/border-core/src/generic_replay_buffer/batch.rs similarity index 51% rename from border-core/src/replay_buffer/batch.rs rename to border-core/src/generic_replay_buffer/batch.rs index 3d0a9476..e97a8adb 100644 --- a/border-core/src/replay_buffer/batch.rs +++ b/border-core/src/generic_replay_buffer/batch.rs @@ -1,12 +1,23 @@ -//! A generic implementation of [`StdBatchBase`](crate::StdBatchBase). -use super::SubBatch; -use crate::StdBatchBase; +//! A generic implementation of [`TransitionBatch`]. +use crate::TransitionBatch; -/// A generic implementation of [`StdBatchBase`](`crate::StdBatchBase`). -pub struct StdBatch +/// A generic implementation of a batch of items. +pub trait BatchBase { + /// Builds a subbatch with a capacity. + fn new(capacity: usize) -> Self; + + /// Pushes the samples in `data`. + fn push(&mut self, ix: usize, data: Self); + + /// Takes samples in the batch. + fn sample(&self, ixs: &Vec) -> Self; +} + +/// A generic implementation of [`TransitionBatch`](`crate::TransitionBatch`). +pub struct GenericTransitionBatch where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { /// Observations. pub obs: O, @@ -20,8 +31,11 @@ where /// Rewards. pub reward: Vec, - /// Done flags. - pub is_done: Vec, + /// Termination flags. + pub is_terminated: Vec, + + /// Truncation flags. + pub is_truncated: Vec, /// Priority weights. pub weight: Option>, @@ -30,10 +44,10 @@ where pub ix_sample: Option>, } -impl StdBatchBase for StdBatch +impl TransitionBatch for GenericTransitionBatch where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { type ObsBatch = O; type ActBatch = A; @@ -46,6 +60,7 @@ where Self::ObsBatch, Vec, Vec, + Vec, Option>, Option>, ) { @@ -54,7 +69,8 @@ where self.act, self.next_obs, self.reward, - self.is_done, + self.is_terminated, + self.is_truncated, self.ix_sample, self.weight, ) @@ -67,48 +83,12 @@ where fn obs(&self) -> &Self::ObsBatch { &self.obs } - - fn act(&self) -> &Self::ActBatch { - &self.act - } - - fn next_obs(&self) -> &Self::ObsBatch { - &self.next_obs - } - - fn reward(&self) -> &Vec { - &self.reward - } - - fn is_done(&self) -> &Vec { - &self.is_done - } - - fn weight(&self) -> &Option> { - &self.weight - } - - fn ix_sample(&self) -> &Option> { - &self.ix_sample - } - - fn empty() -> Self { - Self { - obs: O::new(0), - act: A::new(0), - next_obs: O::new(0), - reward: vec![], - is_done: vec![], - ix_sample: None, - weight: None, - } - } } -impl StdBatch +impl GenericTransitionBatch where - O: SubBatch, - A: SubBatch, + O: BatchBase, + A: BatchBase, { /// Creates new batch with the given capacity. pub fn with_capacity(capacity: usize) -> Self { @@ -117,9 +97,10 @@ where act: A::new(capacity), next_obs: O::new(capacity), reward: vec![0.0; capacity], - is_done: vec![0; capacity], + is_terminated: vec![0; capacity], + is_truncated: vec![0; capacity], ix_sample: None, weight: None, } } -} \ No newline at end of file +} diff --git a/border-core/src/replay_buffer/config.rs b/border-core/src/generic_replay_buffer/config.rs similarity index 89% rename from border-core/src/replay_buffer/config.rs rename to border-core/src/generic_replay_buffer/config.rs index 25c97774..cc32a190 100644 --- a/border-core/src/replay_buffer/config.rs +++ b/border-core/src/generic_replay_buffer/config.rs @@ -1,4 +1,5 @@ //! Configuration of [SimpleReplayBuffer](super::SimpleReplayBuffer). +use super::{WeightNormalizer, WeightNormalizer::All}; use anyhow::Result; use serde::{Deserialize, Serialize}; use std::{ @@ -7,20 +8,20 @@ use std::{ io::{BufReader, Write}, path::Path, }; -use super::{WeightNormalizer, WeightNormalizer::All}; /// Configuration for prioritized experience replay. #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct PerConfig { - pub(super) alpha: f32, + /// Exponent for prioritization. + pub alpha: f32, /// Initial value of $\beta$. - pub(super) beta_0: f32, + pub beta_0: f32, /// Final value of $\beta$. - pub(super) beta_final: f32, + pub beta_final: f32, /// Optimization step when beta reaches its final value. - pub(super) n_opts_final: usize, + pub n_opts_final: usize, /// How to normalize the weights. - pub(super) normalize: WeightNormalizer, + pub normalize: WeightNormalizer, } impl Default for PerConfig { @@ -70,9 +71,14 @@ impl PerConfig { /// Configuration of [SimpleReplayBuffer](super::SimpleReplayBuffer). #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct SimpleReplayBufferConfig { - pub(super) capacity: usize, - pub(super) seed: u64, - pub(super) per_config: Option, + /// Capacity of the buffer. + pub capacity: usize, + + /// Random seed for sampling. + pub seed: u64, + + /// Config for prioritized sampling. + pub per_config: Option, } impl Default for SimpleReplayBufferConfig { diff --git a/border-core/src/replay_buffer/step_proc.rs b/border-core/src/generic_replay_buffer/step_proc.rs similarity index 56% rename from border-core/src/replay_buffer/step_proc.rs rename to border-core/src/generic_replay_buffer/step_proc.rs index 70d78397..d7095c67 100644 --- a/border-core/src/replay_buffer/step_proc.rs +++ b/border-core/src/generic_replay_buffer/step_proc.rs @@ -1,12 +1,11 @@ -//! A generic implementation of [StepProcessorBase](crate::StepProcessorBase). +//! A generic implementation of [`StepProcessor`]. +use super::{BatchBase, GenericTransitionBatch}; +use crate::{Env, Obs, StepProcessor}; use std::{default::Default, marker::PhantomData}; -use crate::{Obs, Env, StepProcessorBase}; -use super::{StdBatch, SubBatch}; -/// Configuration of [SimpleStepProcessor]. +/// Configuration of [`SimpleStepProcessor`]. #[derive(Clone, Debug)] -pub struct SimpleStepProcessorConfig { -} +pub struct SimpleStepProcessorConfig {} impl Default for SimpleStepProcessorConfig { fn default() -> Self { @@ -14,28 +13,28 @@ impl Default for SimpleStepProcessorConfig { } } -/// A generic implementation of [StepProcessorBase](crate::StepProcessorBase). +/// A generic implementation of [`StepProcessor`]. /// /// It supports 1-step TD backup for non-vectorized environment: /// `E::Obs.len()` must be 1. pub struct SimpleStepProcessor { prev_obs: Option, - phantom: PhantomData<(E, A)> + phantom: PhantomData<(E, A)>, } -impl StepProcessorBase for SimpleStepProcessor +impl StepProcessor for SimpleStepProcessor where E: Env, - O: SubBatch + From, - A: SubBatch + From, + O: BatchBase + From, + A: BatchBase + From, { type Config = SimpleStepProcessorConfig; - type Output = StdBatch; + type Output = GenericTransitionBatch; fn build(_config: &Self::Config) -> Self { Self { prev_obs: None, - phantom: PhantomData + phantom: PhantomData, } } @@ -49,19 +48,30 @@ where let batch = if self.prev_obs.is_none() { panic!("prev_obs is not set. Forgot to call reset()?"); } else { + let is_done = step.is_done(); let next_obs = step.obs.clone().into(); let obs = self.prev_obs.replace(step.obs.into()).unwrap(); let act = step.act.into(); let reward = step.reward; - let is_done = step.is_done; + let is_terminated = step.is_terminated; + let is_truncated = step.is_truncated; let ix_sample = None; let weight = None; - if is_done[0] == 1 { + if is_done { self.prev_obs.replace(step.init_obs.into()); } - StdBatch {obs, act, next_obs, reward, is_done, ix_sample, weight} + GenericTransitionBatch { + obs, + act, + next_obs, + reward, + is_terminated, + is_truncated, + ix_sample, + weight, + } }; batch diff --git a/border-core/src/lib.rs b/border-core/src/lib.rs index 705666ab..8490fc14 100644 --- a/border-core/src/lib.rs +++ b/border-core/src/lib.rs @@ -4,26 +4,25 @@ //! # Observation and action //! //! [`Obs`] and [`Act`] traits are abstractions of observation and action in environments. -//! These traits can handle two or more samples for implementing vectorized environments. +//! These traits can handle two or more samples for implementing vectorized environments, +//! although there is currently no implementation of vectorized environment. //! //! # Environment //! //! [`Env`] trait is an abstraction of environments. It has four associated types: //! `Config`, `Obs`, `Act` and `Info`. `Obs` and `Act` are concrete types of //! observation and action of the environment. -//! These must implement [`Obs`] and [`Act`] traits, respectively. +//! These types must implement [`Obs`] and [`Act`] traits, respectively. //! The environment that implements [`Env`] generates [`Step`] object //! at every environment interaction step with [`Env::step()`] method. -//! -//! `Info` stores some information at every step of interactions of an agent and +//! [`Info`] stores some information at every step of interactions of an agent and //! the environment. It could be empty (zero-sized struct). `Config` represents //! configurations of the environment and is used to build. //! //! # Policy //! -//! [`Policy`] represents a policy, from which actions are sampled for -//! environment `E`. [`Policy::sample()`] takes `E::Obs` and emits `E::Act`. -//! It could be probabilistic or deterministic. +//! [`Policy`] represents a policy. [`Policy::sample()`] takes `E::Obs` and +//! generates `E::Act`. It could be probabilistic or deterministic. //! //! # Agent //! @@ -32,65 +31,293 @@ //! the agent's policy might be probabilistic for exploration, while in evaluation mode, //! the policy might be deterministic. //! -//! [`Agent::opt()`] method does a single optimization step. The definition of an -//! optimization step depends on each agent. It might be multiple stochastic gradient +//! The [`Agent::opt()`] method performs a single optimization step. The definition of an +//! optimization step varies for each agent. It might be multiple stochastic gradient //! steps in an optimization step. Samples for training are taken from //! [`R: ReplayBufferBase`][`ReplayBufferBase`]. //! -//! This trait also has methods for saving/loading the trained policy -//! in the given directory. +//! This trait also has methods for saving/loading parameters of the trained policy +//! in a directory. +//! +//! # Batch +//! +//! [`TransitionBatch`] is a trait of a batch of transitions `(o_t, r_t, a_t, o_t+1)`. +//! This trait is used to train [`Agent`]s using an RL algorithm. //! -//! # Replay buffer +//! # Replay buffer and experience buffer //! -//! [`ReplayBufferBase`] trait is an abstraction of replay buffers. For handling samples, -//! there are two associated types: `PushedItem` and `Batch`. `PushedItem` is a type -//! representing samples pushed to the buffer. These samples might be generated from -//! [`Step`]. [`StepProcessorBase`] trait provides the interface -//! for converting [`Step`] into `PushedItem`. +//! [`ReplayBufferBase`] trait is an abstraction of replay buffers. +//! One of the associated type [`ReplayBufferBase::Batch`] represents samples taken from +//! the buffer for training [`Agent`]s. Agents must implements [`Agent::opt()`] method, +//! where [`ReplayBufferBase::Batch`] has an appropriate type or trait bound(s) to train +//! the agent. //! -//! `Batch` is a type of samples taken from the buffer for training [`Agent`]s. -//! The user implements [`Agent::opt()`] method such that it handles `Batch` objects -//! for doing an optimization step. +//! As explained above, [`ReplayBufferBase`] trait has an ability to generates batches +//! of samples with which agents are trained. On the other hand, [`ExperienceBufferBase`] +//! trait has an ability to store samples. [`ExperienceBufferBase::push()`] is used to push +//! samples of type [`ExperienceBufferBase::Item`], which might be obtained via interaction +//! steps with an environment. //! //! ## A reference implementation //! -//! [`SimpleReplayBuffer`] implementats [`ReplayBufferBase`]. +//! [`SimpleReplayBuffer`] implementats both [`ReplayBufferBase`] and [`ExperienceBufferBase`]. //! This type has two parameters `O` and `A`, which are representation of //! observation and action in the replay buffer. `O` and `A` must implement -//! [`SubBatch`], which has the functionality of storing samples, like `Vec`, -//! for observation and action. The associated types `PushedItem` and `Batch` -//! are the same type, [`StdBatch`], representing sets of `(o_t, r_t, a_t, o_t+1)`. +//! [`BatchBase`], which has the functionality of storing samples, like `Vec`, +//! for observation and action. The associated types `Item` and `Batch` +//! are the same type, [`GenericTransitionBatch`], representing sets of `(o_t, r_t, a_t, o_t+1)`. //! //! [`SimpleStepProcessor`] might be used with [`SimpleReplayBuffer`]. -//! It converts `E::Obs` and `E::Act` into [`SubBatch`]s of respective types -//! and generates [`StdBatch`]. The conversion process relies on trait bounds, +//! It converts `E::Obs` and `E::Act` into [`BatchBase`]s of respective types +//! and generates [`GenericTransitionBatch`]. The conversion process relies on trait bounds, //! `O: From` and `A: From`. //! //! # Trainer //! //! [`Trainer`] manages training loop and related objects. The [`Trainer`] object is -//! built with configurations of [`Env`], [`ReplayBufferBase`], [`StepProcessorBase`] -//! and some training parameters. Then, [`Trainer::train`] method starts training loop with -//! given [`Agent`] and [`Recorder`](crate::record::Recorder). -//! +//! built with configurations of training parameters such as the maximum number of +//! optimization steps, model directory to save parameters of the agent during training, etc. +//! [`Trainer::train`] method executes online training of an agent on an environment. +//! In the training loop of this method, the agent interacts with the environment to +//! take samples and perform optimization steps. Some metrices are recorded at the same time. +//! +//! # Evaluator +//! +//! [`Evaluator`] is used to evaluate the policy's (`P`) performance in the environment (`E`). +//! The object of this type is given to the [`Trainer`] object to evaluate the policy during training. +//! [`DefaultEvaluator`] is a default implementation of [`Evaluator`]. +//! This evaluator runs the policy in the environment for a certain number of episodes. +//! At the start of each episode, the environment is reset using [`Env::reset_with_index()`] +//! to control specific conditions for evaluation. +//! //! [`SimpleReplayBuffer`]: replay_buffer::SimpleReplayBuffer -//! [`SimpleReplayBuffer`]: replay_buffer::SimpleReplayBuffer -//! [`SubBatch`]: replay_buffer::SubBatch -//! [`StdBatch`]: replay_buffer::StdBatch +//! [`SimpleReplayBuffer`]: generic_replay_buffer::SimpleReplayBuffer +//! [`BatchBase`]: generic_replay_buffer::BatchBase +//! [`GenericTransitionBatch`]: generic_replay_buffer::GenericTransitionBatch //! [`SimpleStepProcessor`]: replay_buffer::SimpleStepProcessor -//! [`SimpleStepProcessor`]: replay_buffer::SimpleStepProcessor +//! [`SimpleStepProcessor`]: generic_replay_buffer::SimpleStepProcessor pub mod error; mod evaluator; +pub mod generic_replay_buffer; pub mod record; -pub mod replay_buffer; -pub mod util; mod base; pub use base::{ - Act, Agent, Env, ExperienceBufferBase, Info, Obs, Policy, ReplayBufferBase, StdBatchBase, Step, - StepProcessorBase, + Act, Agent, Configurable, Env, ExperienceBufferBase, Info, Obs, Policy, ReplayBufferBase, Step, + StepProcessor, TransitionBatch, }; mod trainer; -pub use evaluator::{Evaluator, DefaultEvaluator}; -pub use trainer::{SyncSampler, Trainer, TrainerConfig}; +pub use evaluator::{DefaultEvaluator, Evaluator}; +pub use trainer::{Sampler, Trainer, TrainerConfig}; + +// TODO: Consider to compile this module only for tests. +/// Agent and Env for testing. +pub mod test { + use serde::{Deserialize, Serialize}; + + /// Obs for testing. + #[derive(Clone, Debug)] + pub struct TestObs { + obs: usize, + } + + impl crate::Obs for TestObs { + fn dummy(_n: usize) -> Self { + Self { obs: 0 } + } + + fn len(&self) -> usize { + 1 + } + } + + /// Batch of obs for testing. + pub struct TestObsBatch { + obs: Vec, + } + + impl crate::generic_replay_buffer::BatchBase for TestObsBatch { + fn new(capacity: usize) -> Self { + Self { + obs: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.obs[i] = data.obs[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let obs = ixs.iter().map(|ix| self.obs[*ix]).collect(); + Self { obs } + } + } + + impl From for TestObsBatch { + fn from(obs: TestObs) -> Self { + Self { obs: vec![obs.obs] } + } + } + + /// Act for testing. + #[derive(Clone, Debug)] + pub struct TestAct { + act: usize, + } + + impl crate::Act for TestAct {} + + /// Batch of act for testing. + pub struct TestActBatch { + act: Vec, + } + + impl From for TestActBatch { + fn from(act: TestAct) -> Self { + Self { act: vec![act.act] } + } + } + + impl crate::generic_replay_buffer::BatchBase for TestActBatch { + fn new(capacity: usize) -> Self { + Self { + act: vec![0; capacity], + } + } + + fn push(&mut self, i: usize, data: Self) { + self.act[i] = data.act[0]; + } + + fn sample(&self, ixs: &Vec) -> Self { + let act = ixs.iter().map(|ix| self.act[*ix]).collect(); + Self { act } + } + } + + /// Info for testing. + pub struct TestInfo {} + + impl crate::Info for TestInfo {} + + /// Environment for testing. + pub struct TestEnv { + state_init: usize, + state: usize, + } + + impl crate::Env for TestEnv { + type Config = usize; + type Obs = TestObs; + type Act = TestAct; + type Info = TestInfo; + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + self.state = self.state_init; + Ok(TestObs { obs: self.state }) + } + + fn step_with_reset( + &mut self, + a: &Self::Act, + ) -> (crate::Step, crate::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = crate::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, crate::record::Record::empty()); + } + + fn step(&mut self, a: &Self::Act) -> (crate::Step, crate::record::Record) + where + Self: Sized, + { + self.state = self.state + a.act; + let step = crate::Step { + obs: TestObs { obs: self.state }, + act: a.clone(), + reward: vec![0.0], + is_terminated: vec![0], + is_truncated: vec![0], + info: TestInfo {}, + init_obs: TestObs { + obs: self.state_init, + }, + }; + return (step, crate::record::Record::empty()); + } + + fn build(config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + Ok(Self { + state_init: *config, + state: 0, + }) + } + } + + type ReplayBuffer = + crate::generic_replay_buffer::SimpleReplayBuffer; + + /// Agent for testing. + pub struct TestAgent {} + + #[derive(Clone, Deserialize, Serialize)] + /// Config of agent for testing. + pub struct TestAgentConfig; + + impl crate::Agent for TestAgent { + fn train(&mut self) {} + + fn is_train(&self) -> bool { + false + } + + fn eval(&mut self) {} + + fn opt_with_record(&mut self, _buffer: &mut ReplayBuffer) -> crate::record::Record { + crate::record::Record::empty() + } + + fn save_params>(&self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + + fn load_params>(&mut self, _path: T) -> anyhow::Result<()> { + Ok(()) + } + } + + impl crate::Policy for TestAgent { + fn sample(&mut self, _obs: &TestObs) -> TestAct { + TestAct { act: 1 } + } + } + + impl crate::Configurable for TestAgent { + type Config = TestAgentConfig; + + fn build(_config: Self::Config) -> Self { + Self {} + } + } +} diff --git a/border-core/src/record.rs b/border-core/src/record.rs index 8e941953..d479cacd 100644 --- a/border-core/src/record.rs +++ b/border-core/src/record.rs @@ -1,6 +1,6 @@ //! Types for recording various values obtained during training and evaluation. //! -//! [Record] is a [HashMap], where its key and values represents various values obtained during training and +//! [`Record`] is a [`HashMap`], where its key and values represents various values obtained during training and //! evaluation. A record may contains multiple types of values. //! //! ```no_run @@ -19,194 +19,17 @@ //! //! A typical usecase is to record internal values obtained in training processes. //! [Trainer::train](crate::Trainer::train), which executes a training loop, writes a record -//! in a [Recorder] given as an input argument. +//! in a [`Recorder`] given as an input argument. //! -use chrono::prelude::{DateTime, Local}; -use std::{ - collections::{ - hash_map::{IntoIter, Iter, Keys}, - HashMap, - }, - convert::Into, - iter::IntoIterator, -}; - -use crate::error::LrrError; - -#[derive(Debug, Clone)] -/// Represents possible types of values in a [`Record`]. -pub enum RecordValue { - /// Represents a scalar, e.g., optimization steps and loss value. - Scalar(f32), - - /// Represents a datetime. - DateTime(DateTime), - - /// A 1-dimensional array - Array1(Vec), - - /// A 2-dimensional array - Array2(Vec, [usize; 2]), - - /// A 3-dimensional array - Array3(Vec, [usize; 3]), - - /// String - String(String), -} - -#[derive(Debug)] -/// Represents a record. -pub struct Record(HashMap); - -impl Record { - /// Construct empty record. - pub fn empty() -> Self { - Self { 0: HashMap::new() } - } - - /// Create `Record` from slice of `(Into, RecordValue)`. - pub fn from_slice + Clone>(s: &[(K, RecordValue)]) -> Self { - Self( - s.iter() - .map(|(k, v)| (k.clone().into(), v.clone())) - .collect(), - ) - } - - /// Get keys. - pub fn keys(&self) -> Keys { - self.0.keys() - } - - /// Insert a key-value pair into the record. - pub fn insert(&mut self, k: impl Into, v: RecordValue) { - self.0.insert(k.into(), v); - } - - /// Return an iterator over key-value pairs in the record. - pub fn iter(&self) -> Iter<'_, String, RecordValue> { - self.0.iter() - } - - /// Return an iterator over key-value pairs in the record. - pub fn into_iter_in_record(self) -> IntoIter { - self.0.into_iter() - } - - /// Get the value of the given key. - pub fn get(&self, k: &str) -> Option<&RecordValue> { - self.0.get(k) - } - - /// Merge records. - pub fn merge(self, record: Record) -> Self { - Record(self.0.into_iter().chain(record.0).collect()) - } - - /// Get scalar value. - /// - /// * `key` - The key of an entry in the record. - pub fn get_scalar(&self, k: &str) -> Result { - if let Some(v) = self.0.get(k) { - match v { - RecordValue::Scalar(v) => Ok(*v as _), - _ => Err(LrrError::RecordValueTypeError("Scalar".to_string())), - } - } else { - Err(LrrError::RecordKeyError(k.to_string())) - } - } - - /// Get Array1 value. - pub fn get_array1(&self, k: &str) -> Result, LrrError> { - if let Some(v) = self.0.get(k) { - match v { - RecordValue::Array1(v) => Ok(v.clone()), - _ => Err(LrrError::RecordValueTypeError("Array1".to_string())), - } - } else { - Err(LrrError::RecordKeyError(k.to_string())) - } - } - - /// Get Array2 value. - pub fn get_array2(&self, k: &str) -> Result<(Vec, [usize; 2]), LrrError> { - if let Some(v) = self.0.get(k) { - match v { - RecordValue::Array2(v, s) => Ok((v.clone(), s.clone())), - _ => Err(LrrError::RecordValueTypeError("Array2".to_string())), - } - } else { - Err(LrrError::RecordKeyError(k.to_string())) - } - } - - /// Get Array3 value. - pub fn get_array3(&self, k: &str) -> Result<(Vec, [usize; 3]), LrrError> { - if let Some(v) = self.0.get(k) { - match v { - RecordValue::Array3(v, s) => Ok((v.clone(), s.clone())), - _ => Err(LrrError::RecordValueTypeError("Array3".to_string())), - } - } else { - Err(LrrError::RecordKeyError(k.to_string())) - } - } - - /// Get String value. - pub fn get_string(&self, k: &str) -> Result { - if let Some(v) = self.0.get(k) { - match v { - RecordValue::String(s) => Ok(s.clone()), - _ => Err(LrrError::RecordValueTypeError("String".to_string())), - } - } else { - Err(LrrError::RecordKeyError(k.to_string())) - } - } -} - -/// Process records provided with [`Recorder::write`] -pub trait Recorder { - /// Write a record to the [`Recorder`]. - fn write(&mut self, record: Record); -} - -/// A recorder that ignores any record. This struct is used just for debugging. -pub struct NullRecorder {} - -impl NullRecorder {} - -impl Recorder for NullRecorder { - /// Discard the given record. - fn write(&mut self, _record: Record) {} -} - -/// Buffered recorder. -/// -/// This is used for recording sequences of observation and action -/// during evaluation runs in [`crate::util::eval_with_recorder`]. -#[derive(Default)] -pub struct BufferedRecorder(Vec); - -impl BufferedRecorder { - /// Construct the recorder. - pub fn new() -> Self { - Self(Vec::default()) - } - - /// Returns an iterator over the records. - pub fn iter(&self) -> std::slice::Iter { - self.0.iter() - } -} - -impl Recorder for BufferedRecorder { - /// Write a [`Record`] to the buffer. - /// - /// TODO: Consider if it is worth making the method public. - fn write(&mut self, record: Record) { - self.0.push(record); - } -} +//! [`HashMap`]: std::collections::HashMap +mod base; +mod buffered_recorder; +mod null_recorder; +mod recorder; +mod storage; + +pub use base::{Record, RecordValue}; +pub use buffered_recorder::BufferedRecorder; +pub use null_recorder::NullRecorder; +pub use recorder::{AggregateRecorder, Recorder}; +pub use storage::RecordStorage; diff --git a/border-core/src/record/base.rs b/border-core/src/record/base.rs new file mode 100644 index 00000000..78744576 --- /dev/null +++ b/border-core/src/record/base.rs @@ -0,0 +1,149 @@ +use crate::error::LrrError; +use chrono::prelude::{DateTime, Local}; +use std::{ + collections::{ + hash_map::{IntoIter, Iter, Keys}, + HashMap, + }, + convert::Into, + iter::IntoIterator, +}; + +#[derive(Debug, Clone)] +/// Represents possible types of values in a [`Record`]. +pub enum RecordValue { + /// Represents a scalar, e.g., optimization steps and loss value. + Scalar(f32), + + /// Represents a datetime. + DateTime(DateTime), + + /// A 1-dimensional array + Array1(Vec), + + /// A 2-dimensional array + Array2(Vec, [usize; 2]), + + /// A 3-dimensional array + Array3(Vec, [usize; 3]), + + /// String + String(String), +} + +#[derive(Debug)] +/// Represents a record. +pub struct Record(HashMap); + +impl Record { + /// Construct empty record. + pub fn empty() -> Self { + Self { 0: HashMap::new() } + } + + /// Create `Record` from slice of `(Into, RecordValue)`. + pub fn from_slice + Clone>(s: &[(K, RecordValue)]) -> Self { + Self( + s.iter() + .map(|(k, v)| (k.clone().into(), v.clone())) + .collect(), + ) + } + + /// Get keys. + pub fn keys(&self) -> Keys { + self.0.keys() + } + + /// Insert a key-value pair into the record. + pub fn insert(&mut self, k: impl Into, v: RecordValue) { + self.0.insert(k.into(), v); + } + + /// Return an iterator over key-value pairs in the record. + pub fn iter(&self) -> Iter<'_, String, RecordValue> { + self.0.iter() + } + + /// Return an iterator over key-value pairs in the record. + pub fn into_iter_in_record(self) -> IntoIter { + self.0.into_iter() + } + + /// Get the value of the given key. + pub fn get(&self, k: &str) -> Option<&RecordValue> { + self.0.get(k) + } + + /// Merge records. + pub fn merge(self, record: Record) -> Self { + Record(self.0.into_iter().chain(record.0).collect()) + } + + /// Get scalar value. + /// + /// * `key` - The key of an entry in the record. + pub fn get_scalar(&self, k: &str) -> Result { + if let Some(v) = self.0.get(k) { + match v { + RecordValue::Scalar(v) => Ok(*v as _), + _ => Err(LrrError::RecordValueTypeError("Scalar".to_string())), + } + } else { + Err(LrrError::RecordKeyError(k.to_string())) + } + } + + /// Get Array1 value. + pub fn get_array1(&self, k: &str) -> Result, LrrError> { + if let Some(v) = self.0.get(k) { + match v { + RecordValue::Array1(v) => Ok(v.clone()), + _ => Err(LrrError::RecordValueTypeError("Array1".to_string())), + } + } else { + Err(LrrError::RecordKeyError(k.to_string())) + } + } + + /// Get Array2 value. + pub fn get_array2(&self, k: &str) -> Result<(Vec, [usize; 2]), LrrError> { + if let Some(v) = self.0.get(k) { + match v { + RecordValue::Array2(v, s) => Ok((v.clone(), s.clone())), + _ => Err(LrrError::RecordValueTypeError("Array2".to_string())), + } + } else { + Err(LrrError::RecordKeyError(k.to_string())) + } + } + + /// Get Array3 value. + pub fn get_array3(&self, k: &str) -> Result<(Vec, [usize; 3]), LrrError> { + if let Some(v) = self.0.get(k) { + match v { + RecordValue::Array3(v, s) => Ok((v.clone(), s.clone())), + _ => Err(LrrError::RecordValueTypeError("Array3".to_string())), + } + } else { + Err(LrrError::RecordKeyError(k.to_string())) + } + } + + /// Get String value. + pub fn get_string(&self, k: &str) -> Result { + if let Some(v) = self.0.get(k) { + match v { + RecordValue::String(s) => Ok(s.clone()), + _ => Err(LrrError::RecordValueTypeError("String".to_string())), + } + } else { + Err(LrrError::RecordKeyError(k.to_string())) + } + } + + /// Returns true if the record is empty. + pub fn is_empty(&self) -> bool { + self.0.len() == 0 + } +} diff --git a/border-core/src/record/buffered_recorder.rs b/border-core/src/record/buffered_recorder.rs new file mode 100644 index 00000000..cfe10486 --- /dev/null +++ b/border-core/src/record/buffered_recorder.rs @@ -0,0 +1,29 @@ +use super::{Record, Recorder}; + +/// Buffered recorder. +/// +/// This is used for recording sequences of observation and action +/// during evaluation runs. +#[derive(Default)] +pub struct BufferedRecorder(Vec); + +impl BufferedRecorder { + /// Construct the recorder. + pub fn new() -> Self { + Self(Vec::default()) + } + + /// Returns an iterator over the records. + pub fn iter(&self) -> std::slice::Iter { + self.0.iter() + } +} + +impl Recorder for BufferedRecorder { + /// Write a [`Record`] to the buffer. + /// + /// TODO: Consider if it is worth making the method public. + fn write(&mut self, record: Record) { + self.0.push(record); + } +} diff --git a/border-core/src/record/null_recorder.rs b/border-core/src/record/null_recorder.rs new file mode 100644 index 00000000..dabe4ed8 --- /dev/null +++ b/border-core/src/record/null_recorder.rs @@ -0,0 +1,17 @@ +use super::{AggregateRecorder, Record, Recorder}; + +/// A recorder that ignores any record. This struct is used just for debugging. +pub struct NullRecorder {} + +impl NullRecorder {} + +impl Recorder for NullRecorder { + /// Discard the given record. + fn write(&mut self, _record: Record) {} +} + +impl AggregateRecorder for NullRecorder { + fn store(&mut self, _record: Record) {} + + fn flush(&mut self, _step: i64) {} +} diff --git a/border-core/src/record/recorder.rs b/border-core/src/record/recorder.rs new file mode 100644 index 00000000..13ab4e57 --- /dev/null +++ b/border-core/src/record/recorder.rs @@ -0,0 +1,16 @@ +use super::Record; + +/// Writes a record to an output destination with [`Recorder::write`]. +pub trait Recorder { + /// Write a record to the [`Recorder`]. + fn write(&mut self, record: Record); +} + +/// Stores records, then aggregates them and writes to an output destination. +pub trait AggregateRecorder { + /// Store the record. + fn store(&mut self, record: Record); + + /// Writes values aggregated from the stored records. + fn flush(&mut self, step: i64); +} diff --git a/border-core/src/record/storage.rs b/border-core/src/record/storage.rs new file mode 100644 index 00000000..039fa66f --- /dev/null +++ b/border-core/src/record/storage.rs @@ -0,0 +1,177 @@ +use super::{Record, RecordValue}; +use std::collections::HashSet; +use xxhash_rust::xxh3::Xxh3Builder; + +/// Store records and aggregates them. +pub struct RecordStorage { + data: Vec, +} + +fn min(vs: &Vec) -> RecordValue { + RecordValue::Scalar(*vs.iter().min_by(|x, y| x.total_cmp(y)).unwrap()) +} + +fn max(vs: &Vec) -> RecordValue { + RecordValue::Scalar(*vs.iter().min_by(|x, y| y.total_cmp(x)).unwrap()) +} + +fn mean(vs: &Vec) -> RecordValue { + RecordValue::Scalar(vs.iter().map(|v| *v).sum::() / vs.len() as f32) +} + +fn median(mut vs: Vec) -> RecordValue { + vs.sort_by(|x, y| x.partial_cmp(y).unwrap()); + RecordValue::Scalar(vs[vs.len() / 2]) +} + +impl RecordStorage { + fn get_keys(&self) -> HashSet { + let mut keys = HashSet::::default(); + for record in self.data.iter() { + for k in record.keys() { + keys.insert(k.clone()); + } + } + keys + } + + /// Returns a reference to a value having given key. + fn find(&self, key: &String) -> &RecordValue { + for record in self.data.iter() { + if let Some(value) = record.get(key) { + return value; + } + } + panic!("Key '{}' was not found. ", key); + } + + // Takes value from the last record. + fn datetime(&self, key: &String) -> Record { + for record in self.data.iter().rev() { + if let Some(value) = record.get(key) { + match value { + RecordValue::DateTime(..) => { + return Record::from_slice(&[(key, value.clone())]); + } + _ => panic!("Expect RecordValue::DateTime for {}", key), + } + } + } + panic!("Unexpected"); + } + + // Takes value from the last record. + fn array1(&self, key: &String) -> Record { + for record in self.data.iter().rev() { + if let Some(value) = record.get(key) { + match value { + RecordValue::Array1(..) => { + return Record::from_slice(&[(key, value.clone())]); + } + _ => panic!("Expect RecordValue::Array1 for {}", key), + } + } + } + panic!("Unexpected"); + } + + // Takes value from the last record. + fn array2(&self, key: &String) -> Record { + for record in self.data.iter().rev() { + if let Some(value) = record.get(key) { + match value { + RecordValue::Array2(..) => { + return Record::from_slice(&[(key, value.clone())]); + } + _ => panic!("Expect RecordValue::Array2 for {}", key), + } + } + } + panic!("Unexpected"); + } + + // Takes value from the last record. + fn array3(&self, key: &String) -> Record { + for record in self.data.iter().rev() { + if let Some(value) = record.get(key) { + match value { + RecordValue::Array3(..) => { + return Record::from_slice(&[(key, value.clone())]); + } + _ => panic!("Expect RecordValue::Array3 for {}", key), + } + } + } + panic!("Unexpected"); + } + + // Takes value from the last record. + fn string(&self, key: &String) -> Record { + for record in self.data.iter().rev() { + if let Some(value) = record.get(key) { + match value { + RecordValue::String(..) => { + return Record::from_slice(&[(key, value.clone())]); + } + _ => panic!("Expect RecordValue::String for {}", key), + } + } + } + panic!("Unexpected"); + } + + // Mean, Median, Min, Max + fn scalar(&self, key: &String) -> Record { + let vs: Vec = self + .data + .iter() + .filter_map(|record| match record.get(key) { + Some(v) => match v { + RecordValue::Scalar(v) => Some(*v), + _ => panic!("Expect RecordValue::Scalar for {}", key), + }, + None => None, + }) + .collect(); + + Record::from_slice(&[ + (format!("{}_min", key), min(&vs)), + (format!("{}_max", key), max(&vs)), + (format!("{}_mean", key), mean(&vs)), + (format!("{}_median", key), median(vs)), + ]) + } + + /// Creates the storage. + pub fn new() -> Self { + Self { data: vec![] } + } + + /// Store the given record. + pub fn store(&mut self, record: Record) { + self.data.push(record); + } + + /// Returns aggregated record and clear the storage. + pub fn aggregate(&mut self) -> Record { + let mut record = Record::empty(); + + for key in self.get_keys().iter() { + let value = self.find(key); + let r = match value { + RecordValue::DateTime(..) => self.datetime(key), + RecordValue::Array1(..) => self.array1(key), + RecordValue::Array2(..) => self.array2(key), + RecordValue::Array3(..) => self.array3(key), + RecordValue::String(..) => self.string(key), + RecordValue::Scalar(..) => self.scalar(key), + }; + // record = record.merge(r); + record = record.merge(r); + } + + self.data = vec![]; + + record + } +} diff --git a/border-core/src/replay_buffer.rs b/border-core/src/replay_buffer.rs deleted file mode 100644 index e73db710..00000000 --- a/border-core/src/replay_buffer.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! A generic implementation of replay buffer. -mod base; -mod batch; -mod config; -mod subbatch; -mod step_proc; -pub use base::{SimpleReplayBuffer, WeightNormalizer, IwScheduler}; -pub use batch::StdBatch; -pub use config::{SimpleReplayBufferConfig, PerConfig}; -pub use subbatch::SubBatch; -pub use step_proc::{SimpleStepProcessor, SimpleStepProcessorConfig}; diff --git a/border-core/src/replay_buffer/subbatch.rs b/border-core/src/replay_buffer/subbatch.rs deleted file mode 100644 index 0fcfd321..00000000 --- a/border-core/src/replay_buffer/subbatch.rs +++ /dev/null @@ -1,13 +0,0 @@ -//! SubBatch, which consists [`StdBatchBase`](`crate::StdBatchBase`). - -/// Represents a SubBatch, which consists [`StdBatchBase`](`crate::StdBatchBase`). -pub trait SubBatch { - /// Builds a subbatch with a capacity. - fn new(capacity: usize) -> Self; - - /// Pushes the samples in `data`. - fn push(&mut self, i: usize, data: &Self); - - /// Takes samples in the batch. - fn sample(&self, ixs: &Vec) -> Self; -} diff --git a/border-core/src/trainer.rs b/border-core/src/trainer.rs index eb3ae14d..8d7ffdda 100644 --- a/border-core/src/trainer.rs +++ b/border-core/src/trainer.rs @@ -1,14 +1,16 @@ -//! Train [`Agent`](crate::Agent). +//! Train [`Agent`]. mod config; mod sampler; +use std::time::{Duration, SystemTime}; + use crate::{ - record::{Record, Recorder}, - Agent, Env, ReplayBufferBase, StepProcessorBase, Evaluator, + record::{AggregateRecorder, Record, RecordValue::Scalar}, + Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor, }; use anyhow::Result; pub use config::TrainerConfig; use log::info; -pub use sampler::SyncSampler; +pub use sampler::Sampler; #[cfg_attr(doc, aquamarine::aquamarine)] /// Manages training loop and related objects. @@ -19,7 +21,7 @@ pub use sampler::SyncSampler; /// /// 0. Given an agent implementing [`Agent`] and a recorder implementing [`Recorder`]. /// 1. Initialize the objects used in the training loop, involving instances of [`Env`], -/// [`StepProcessorBase`], [`SyncSampler`]. +/// [`StepProcessor`], [`Sampler`]. /// * Reset a counter of the environment steps: `env_steps = 0` /// * Reset a counter of the optimization steps: `opt_steps = 0` /// * Reset objects for computing optimization steps per sec (OSPS): @@ -69,10 +71,10 @@ pub use sampler::SyncSampler; /// referred to as an *environment step*. /// * Next, [`Step`] will be created with the next observation `o_t+1`, /// reward `r_t`, and `a_t`. -/// * The [`Step`] object will be processed by [`StepProcessorBase`] and -/// creates [`ReplayBufferBase::PushedItem`], typically representing a transition +/// * The [`Step`] object will be processed by [`StepProcessor`] and +/// creates [`ReplayBufferBase::Item`], typically representing a transition /// `(o_t, a_t, o_t+1, r_t)`, where `o_t` is kept in the -/// [`StepProcessorBase`], while other items in the given [`Step`]. +/// [`StepProcessor`], while other items in the given [`Step`]. /// * Finally, the transitions pushed to the [`ReplayBufferBase`] will be used to create /// batches, each of which implementing [`BatchBase`]. These batches will be used in /// *optimization step*s, where the agent updates its parameters using sampled @@ -82,29 +84,22 @@ pub use sampler::SyncSampler; /// [`Act`]: crate::Act /// [`BatchBase`]: crate::BatchBase /// [`Step`]: crate::Step -pub struct Trainer -where - E: Env, - P: StepProcessorBase, - R: ReplayBufferBase, -{ - /// Configuration of the environment for training. - env_config_train: E::Config, - - /// Configuration of the transition producer. - step_proc_config: P::Config, - - /// Configuration of the replay buffer. - replay_buffer_config: R::Config, - +pub struct Trainer { /// Where to save the trained model. model_dir: Option, /// Interval of optimization in environment steps. + /// This is ignored for offline training. opt_interval: usize, - /// Interval of recording in optimization steps. - record_interval: usize, + /// Interval of recording computational cost in optimization steps. + record_compute_cost_interval: usize, + + /// Interval of recording agent information in optimization steps. + record_agent_info_interval: usize, + + /// Interval of flushing records in optimization steps. + flush_records_interval: usize, /// Interval of evaluation in optimization steps. eval_interval: usize, @@ -114,187 +109,273 @@ where /// The maximal number of optimization steps. max_opts: usize, + + /// Optimization steps for computing optimization steps per second. + opt_steps_for_ops: usize, + + /// Timer for computing for optimization steps per second. + timer_for_ops: Duration, + + /// Warmup period, for filling replay buffer, in environment steps. + /// This is ignored for offline training. + warmup_period: usize, + + /// Max value of evaluation reward. + max_eval_reward: f32, + + /// Environment steps during online training. + env_steps: usize, + + /// Optimization steps during training. + opt_steps: usize, } -impl Trainer -where - E: Env, - P: StepProcessorBase, - R: ReplayBufferBase, -{ +impl Trainer { /// Constructs a trainer. - pub fn build( - config: TrainerConfig, - env_config_train: E::Config, - step_proc_config: P::Config, - replay_buffer_config: R::Config, - ) -> Self { + pub fn build(config: TrainerConfig) -> Self { Self { - env_config_train, - step_proc_config, - replay_buffer_config, model_dir: config.model_dir, opt_interval: config.opt_interval, - record_interval: config.record_interval, + record_compute_cost_interval: config.record_compute_cost_interval, + record_agent_info_interval: config.record_agent_info_interval, + flush_records_interval: config.flush_record_interval, eval_interval: config.eval_interval, save_interval: config.save_interval, max_opts: config.max_opts, + warmup_period: config.warmup_period, + opt_steps_for_ops: 0, + timer_for_ops: Duration::new(0, 0), + max_eval_reward: f32::MIN, + env_steps: 0, + opt_steps: 0, } } - fn save_model>(agent: &A, model_dir: String) { - match agent.save(&model_dir) { + fn save_model(agent: &A, model_dir: String) + where + E: Env, + A: Agent, + R: ReplayBufferBase, + { + match agent.save_params(&model_dir) { Ok(()) => info!("Saved the model in {:?}.", &model_dir), Err(_) => info!("Failed to save model in {:?}.", &model_dir), } } - fn save_best_model>(agent: &A, model_dir: String) { + fn save_best_model(agent: &A, model_dir: String) + where + E: Env, + A: Agent, + R: ReplayBufferBase, + { let model_dir = model_dir + "/best"; Self::save_model(agent, model_dir); } - fn save_model_with_steps>(agent: &A, model_dir: String, steps: usize) { + fn save_model_with_steps(agent: &A, model_dir: String, steps: usize) + where + E: Env, + A: Agent, + R: ReplayBufferBase, + { let model_dir = model_dir + format!("/{}", steps).as_str(); Self::save_model(agent, model_dir); } - // /// Run episodes with the given agent and returns the average of cumulative reward. - // fn evaluate(&mut self, agent: &mut A) -> Result - // where - // A: Agent, - // { - // agent.eval(); - - // let env_config = if self.env_config_eval.is_none() { - // &self.env_config_train - // } else { - // &self.env_config_eval.as_ref().unwrap() - // }; - - // let mut env = E::build(env_config, 0)?; // TODO use eval_env_config - // let mut r_total = 0f32; - - // for ix in 0..self.eval_episodes { - // let mut prev_obs = env.reset_with_index(ix)?; - // assert_eq!(prev_obs.len(), 1); // env must be non-vectorized - - // loop { - // let act = agent.sample(&prev_obs); - // let (step, _) = env.step(&act); - // r_total += step.reward[0]; - // if step.is_done[0] == 1 { - // break; - // } - // prev_obs = step.obs; - // } - // } - - // agent.train(); - - // Ok(r_total / self.eval_episodes as f32) - // } + /// Returns optimization steps per second, then reset the internal counter. + fn opt_steps_per_sec(&mut self) -> f32 { + let osps = 1000. * self.opt_steps_for_ops as f32 / (self.timer_for_ops.as_millis() as f32); + self.opt_steps_for_ops = 0; + self.timer_for_ops = Duration::new(0, 0); + osps + } /// Performs a training step. - pub fn train_step>( - &self, + /// + /// First, it performes an environment step once and pushes a transition + /// into the given buffer with [`Sampler`]. Then, if the number of environment steps + /// reaches the optimization interval `opt_interval`, performes an optimization + /// step. + /// + /// The second return value in the tuple is if an optimization step is done (`true`). + // pub fn train_step( + pub fn train_step(&mut self, agent: &mut A, buffer: &mut R) -> Result<(Record, bool)> + where + E: Env, + A: Agent, + R: ReplayBufferBase, + { + if self.env_steps < self.warmup_period { + Ok((Record::empty(), false)) + } else if self.env_steps % self.opt_interval != 0 { + // skip optimization step + Ok((Record::empty(), false)) + } else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 { + // Do optimization step with record + let timer = SystemTime::now(); + let record_agent = agent.opt_with_record(buffer); + self.opt_steps += 1; + self.timer_for_ops += timer.elapsed()?; + self.opt_steps_for_ops += 1; + Ok((record_agent, true)) + } else { + // Do optimization step without record + let timer = SystemTime::now(); + agent.opt(buffer); + self.opt_steps += 1; + self.timer_for_ops += timer.elapsed()?; + self.opt_steps_for_ops += 1; + Ok((Record::empty(), true)) + } + } + + fn post_process( + &mut self, + agent: &mut A, + evaluator: &mut D, + record: &mut Record, + fps: f32, + ) -> Result<()> + where + E: Env, + A: Agent, + R: ReplayBufferBase, + D: Evaluator, + { + // Add stats wrt computation cost + if self.opt_steps % self.record_compute_cost_interval == 0 { + record.insert("fps", Scalar(fps)); + record.insert("opt_steps_per_sec", Scalar(self.opt_steps_per_sec())); + } + + // Evaluation + if self.opt_steps % self.eval_interval == 0 { + info!("Starts evaluation of the trained model"); + agent.eval(); + let eval_reward = evaluator.evaluate(agent)?; + agent.train(); + record.insert("eval_reward", Scalar(eval_reward)); + + // Save the best model up to the current iteration + if eval_reward > self.max_eval_reward { + self.max_eval_reward = eval_reward; + let model_dir = self.model_dir.as_ref().unwrap().clone(); + Self::save_best_model(agent, model_dir) + } + }; + + // Save the current model + if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) { + let model_dir = self.model_dir.as_ref().unwrap().clone(); + Self::save_model_with_steps(agent, model_dir, self.opt_steps); + } + + Ok(()) + } + + fn loop_step( + &mut self, agent: &mut A, buffer: &mut R, - sampler: &mut SyncSampler, - env_steps: &mut usize, - ) -> Result> + recorder: &mut Box, + evaluator: &mut D, + record: Record, + fps: f32, + ) -> Result where + E: Env, A: Agent, + R: ReplayBufferBase, + D: Evaluator, { - // Sample transition(s) and push it into the replay buffer - let record_ = sampler.sample_and_push(agent, buffer)?; + // Performe optimization step(s) + let (mut record, is_opt) = { + let (r, is_opt) = self.train_step(agent, buffer)?; + (record.merge(r), is_opt) + }; + + // Postprocessing after each training step + if is_opt { + self.post_process(agent, evaluator, &mut record, fps)?; + + // End loop + if self.opt_steps == self.max_opts { + return Ok(true); + } + } - // Do optimization step - *env_steps += 1; + // Store record to the recorder + if !record.is_empty() { + recorder.store(record); + } - if *env_steps % self.opt_interval == 0 { - let record = agent.opt(buffer).map_or(None, |r| Some(record_.merge(r))); - Ok(record) - } else { - Ok(None) + // Flush records + if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) { + recorder.flush(self.opt_steps as _); } + + Ok(false) } - /// Train the agent. - pub fn train(&mut self, agent: &mut A, recorder: &mut S, evaluator: &mut D) -> Result<()> + /// Train the agent online. + pub fn train( + &mut self, + env: E, + step_proc: P, + agent: &mut A, + buffer: &mut R, + recorder: &mut Box, + evaluator: &mut D, + ) -> Result<()> where + E: Env, A: Agent, - S: Recorder, + P: StepProcessor, + R: ExperienceBufferBase + ReplayBufferBase, D: Evaluator, { - let env = E::build(&self.env_config_train, 0)?; - let producer = P::build(&self.step_proc_config); - let mut buffer = R::build(&self.replay_buffer_config); - let mut sampler = SyncSampler::new(env, producer); - let mut max_eval_reward = f32::MIN; - let mut env_steps: usize = 0; - let mut opt_steps: usize = 0; - let mut opt_steps_ops: usize = 0; // optimizations per second - let mut timer = std::time::SystemTime::now(); - sampler.reset(); + let mut sampler = Sampler::new(env, step_proc); + sampler.reset_fps_counter(); agent.train(); loop { - let record = self.train_step(agent, &mut buffer, &mut sampler, &mut env_steps)?; - - // Postprocessing after each training step - if let Some(mut record) = record { - use crate::record::RecordValue::Scalar; - - opt_steps += 1; - opt_steps_ops += 1; - let do_eval = opt_steps % self.eval_interval == 0; - let do_rec = opt_steps % self.record_interval == 0; - - // Do evaluation - if do_eval { - let eval_reward = evaluator.evaluate(agent)?; - record.insert("eval_reward", Scalar(eval_reward)); - - // Save the best model up to the current iteration - if eval_reward > max_eval_reward { - max_eval_reward = eval_reward; - let model_dir = self.model_dir.as_ref().unwrap().clone(); - Self::save_best_model(agent, model_dir) - } - }; - - // Record - if do_rec { - record.insert("env_steps", Scalar(env_steps as f32)); - record.insert("fps", Scalar(sampler.fps())); - sampler.reset(); - let time = timer.elapsed()?.as_secs_f32(); - let osps = opt_steps_ops as f32 / time; - record.insert("opt_steps_per_sec", Scalar(osps)); - opt_steps_ops = 0; - timer = std::time::SystemTime::now(); - } - - // Flush record to the recorder - if do_eval || do_rec { - record.insert("opt_steps", Scalar(opt_steps as _)); - recorder.write(record); - } - - // Save the current model - if (self.save_interval > 0) && (opt_steps % self.save_interval == 0) { - let model_dir = self.model_dir.as_ref().unwrap().clone(); - Self::save_model_with_steps(agent, model_dir, opt_steps); - } - - // End loop - if opt_steps == self.max_opts { - break; - } + let record = sampler.sample_and_push(agent, buffer)?; + let fps = sampler.fps(); + self.env_steps += 1; + + if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? { + return Ok(()); } } + } - Ok(()) + /// Train the agent offline. + pub fn train_offline( + &mut self, + agent: &mut A, + buffer: &mut R, + recorder: &mut Box, + evaluator: &mut D, + ) -> Result<()> + where + E: Env, + A: Agent, + R: ReplayBufferBase, + D: Evaluator, + { + // Return empty record + self.warmup_period = 0; + self.opt_interval = 1; + agent.train(); + let fps = 0f32; + + loop { + let record = Record::empty(); + + if self.loop_step(agent, buffer, recorder, evaluator, record, fps)? { + return Ok(()); + } + } } } diff --git a/border-core/src/trainer/config.rs b/border-core/src/trainer/config.rs index ea1742fa..26b4d3b4 100644 --- a/border-core/src/trainer/config.rs +++ b/border-core/src/trainer/config.rs @@ -10,13 +10,32 @@ use std::{ /// Configuration of [`Trainer`](super::Trainer). #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct TrainerConfig { - pub(super) max_opts: usize, - pub(super) eval_threshold: Option, - pub(super) model_dir: Option, - pub(super) opt_interval: usize, - pub(super) eval_interval: usize, - pub(super) record_interval: usize, - pub(super) save_interval: usize, + /// The maximum number of optimization steps. + pub max_opts: usize, + + /// Directory where model parameters will be saved. + pub model_dir: Option, + + /// Interval of optimization steps in environment steps. + pub opt_interval: usize, + + /// Interval of evaluation in optimization steps. + pub eval_interval: usize, + + /// Interval of flushing records in optimization steps. + pub flush_record_interval: usize, + + /// Interval of recording agent information in optimization steps. + pub record_compute_cost_interval: usize, + + /// Interval of recording agent information in optimization steps. + pub record_agent_info_interval: usize, + + /// Warmup period, for filling replay buffer, in environment steps + pub warmup_period: usize, + + /// Intercal of saving model parameters in optimization steps. + pub save_interval: usize, } impl Default for TrainerConfig { @@ -24,10 +43,13 @@ impl Default for TrainerConfig { Self { max_opts: 0, eval_interval: 0, - eval_threshold: None, + // eval_threshold: None, model_dir: None, opt_interval: 1, - record_interval: usize::MAX, + flush_record_interval: usize::MAX, + record_compute_cost_interval: usize::MAX, + record_agent_info_interval: usize::MAX, + warmup_period: 0, save_interval: usize::MAX, } } @@ -46,10 +68,11 @@ impl TrainerConfig { self } - /// Sets the evaluation threshold. - pub fn eval_threshold(mut self, v: f32) -> Self { - self.eval_threshold = Some(v); - self + /// (Deprecated) Sets the evaluation threshold. + pub fn eval_threshold(/*mut */ self, _v: f32) -> Self { + unimplemented!(); + // self.eval_threshold = Some(v); + // self } /// Sets the directory the trained model being saved. @@ -64,9 +87,27 @@ impl TrainerConfig { self } - /// Sets the interval of recording in optimization steps. - pub fn record_interval(mut self, record_interval: usize) -> Self { - self.record_interval = record_interval; + /// Sets the interval of flushing recordd in optimization steps. + pub fn flush_record_interval(mut self, flush_record_interval: usize) -> Self { + self.flush_record_interval = flush_record_interval; + self + } + + /// Sets the interval of computation cost in optimization steps. + pub fn record_compute_cost_interval(mut self, record_compute_cost_interval: usize) -> Self { + self.record_compute_cost_interval = record_compute_cost_interval; + self + } + + /// Sets the interval of recording agent information in optimization steps. + pub fn record_agent_info_interval(mut self, record_agent_info_interval: usize) -> Self { + self.record_agent_info_interval = record_agent_info_interval; + self + } + + /// Sets warmup period in environment steps. + pub fn warmup_period(mut self, warmup_period: usize) -> Self { + self.warmup_period = warmup_period; self } @@ -76,7 +117,7 @@ impl TrainerConfig { self } - /// Constructs [TrainerConfig] from YAML file. + /// Constructs [`TrainerConfig`] from YAML file. pub fn load(path: impl AsRef) -> Result { let file = File::open(path)?; let rdr = BufReader::new(file); @@ -84,7 +125,7 @@ impl TrainerConfig { Ok(b) } - /// Saves [TrainerConfig]. + /// Saves [`TrainerConfig`]. pub fn save(&self, path: impl AsRef) -> Result<()> { let mut file = File::create(path)?; file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; diff --git a/border-core/src/trainer/sampler.rs b/border-core/src/trainer/sampler.rs index e67af5b6..6298fe29 100644 --- a/border-core/src/trainer/sampler.rs +++ b/border-core/src/trainer/sampler.rs @@ -1,97 +1,150 @@ //! Samples transitions and pushes them into a replay buffer. -use crate::{Env, Agent, ReplayBufferBase, StepProcessorBase, record::Record}; +use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor}; use anyhow::Result; -/// Gets an [`Agent`] interacts with an [`Env`] and takes samples. +/// Encapsulates sampling steps. Specifically it does the followint steps: /// -/// TODO: Rename to `Sampler`. -pub struct SyncSampler +/// 1. Samples an action from the [`Agent`], apply to the [`Env`] and takes [`Step`]. +/// 2. Convert [`Step`] into a transition (typically a batch) with [`StepProcessor`]. +/// 3. Pushes the trainsition to [`ReplayBufferBase`]. +/// 4. Count episode length and pushes to [`Record`]. +/// +/// TODO: being able to set `interval_env_record` +/// +/// [`Step`]: crate::Step +/// [`StepProcessor`]: crate::StepProcessor +pub struct Sampler where E: Env, - P: StepProcessorBase, + P: StepProcessor, { env: E, prev_obs: Option, - producer: P, - n_frames: usize, + step_processor: P, + /// Number of environment steps for counting frames per second. + n_env_steps_for_fps: usize, + + /// Total time of takes n_frames. time: f32, + + /// Number of environment steps in an episode. + n_env_steps_in_episode: usize, + + /// Total number of environment steps. + n_env_steps_total: usize, + + /// Interval of recording from the environment in environment steps. + /// + /// Default to None (record from environment discarded) + interval_env_record: Option, } -impl SyncSampler +impl Sampler where E: Env, - P: StepProcessorBase, + P: StepProcessor, { /// Creates a sampler. - pub fn new(env: E, producer: P) -> Self { + pub fn new(env: E, step_processor: P) -> Self { Self { env, prev_obs: None, - producer, - n_frames: 0, + step_processor, + n_env_steps_for_fps: 0, time: 0f32, + n_env_steps_in_episode: 0, + n_env_steps_total: 0, + interval_env_record: None, } } /// Samples transitions and pushes them into the replay buffer. /// /// The replay buffer `R_`, to which samples will be pushed, has to accept - /// `PushedItem` that are the same with `Agent::R`. + /// `Item` that are the same with `Agent::R`. pub fn sample_and_push(&mut self, agent: &mut A, buffer: &mut R_) -> Result where A: Agent, - R: ReplayBufferBase, - R_: ReplayBufferBase, + R: ExperienceBufferBase + ReplayBufferBase, + R_: ExperienceBufferBase, { let now = std::time::SystemTime::now(); - + // Reset environment(s) if required if self.prev_obs.is_none() { // For a vectorized environments, reset all environments in `env` // by giving `None` to reset() method self.prev_obs = Some(self.env.reset(None)?); - self.producer.reset(self.prev_obs.as_ref().unwrap().clone()); + self.step_processor + .reset(self.prev_obs.as_ref().unwrap().clone()); } - // Sample action(s) and apply it to environment(s) - let act = agent.sample(self.prev_obs.as_ref().unwrap()); - let (step, record) = self.env.step_with_reset(&act); - let terminate_episode = step.is_done[0] == 1; // not support vectorized env + // Sample an action and apply it to the environment + let (step, mut record, is_done) = { + let act = agent.sample(self.prev_obs.as_ref().unwrap()); + let (step, mut record) = self.env.step_with_reset(&act); + self.n_env_steps_in_episode += 1; + self.n_env_steps_total += 1; + let is_done = step.is_done(); // not support vectorized env + if let Some(interval) = &self.interval_env_record { + if self.n_env_steps_total % interval != 0 { + record = Record::empty(); + } + } else { + record = Record::empty(); + } + (step, record, is_done) + }; // Update previouos observation - self.prev_obs = if terminate_episode { - Some(step.init_obs.clone()) - } else { - Some(step.obs.clone()) + self.prev_obs = match is_done { + true => Some(step.init_obs.clone()), + false => Some(step.obs.clone()), }; - // Create and push transition(s) - let transition = self.producer.process(step); + // Produce transition + let transition = self.step_processor.process(step); + + // Push transition buffer.push(transition)?; - // Reset producer - if terminate_episode { - self.producer.reset(self.prev_obs.as_ref().unwrap().clone()); + // Reset step processor + if is_done { + self.step_processor + .reset(self.prev_obs.as_ref().unwrap().clone()); + record.insert( + "episode_length", + crate::record::RecordValue::Scalar(self.n_env_steps_in_episode as _), + ); + self.n_env_steps_in_episode = 0; } - // For counting FPS + // Count environment steps if let Ok(time) = now.elapsed() { - self.n_frames += 1; + self.n_env_steps_for_fps += 1; self.time += time.as_millis() as f32; } Ok(record) } - /// Returns frames per second, including taking action, applying it to the environment, + /// Returns frames (environment steps) per second, then resets the internal counter. + /// + /// A frame involves taking action, applying it to the environment, /// producing transition, and pushing it into the replay buffer. - pub fn fps(&self) -> f32 { - self.n_frames as f32 / self.time * 1000f32 + pub fn fps(&mut self) -> f32 { + if self.time == 0f32 { + 0f32 + } else { + let fps = self.n_env_steps_for_fps as f32 / self.time * 1000f32; + self.reset_fps_counter(); + fps + } } /// Reset stats for computing FPS. - pub fn reset(&mut self) { - self.n_frames = 0; + pub fn reset_fps_counter(&mut self) { + self.n_env_steps_for_fps = 0; self.time = 0f32; } } diff --git a/border-core/src/util.rs b/border-core/src/util.rs index 404a6917..016fb361 100644 --- a/border-core/src/util.rs +++ b/border-core/src/util.rs @@ -2,7 +2,7 @@ //! Utilities for interaction of agents and environments. use crate::{ record::{RecordValue, Recorder}, - Env, Policy + Env, Policy, }; use anyhow::Result; @@ -55,12 +55,12 @@ where /// Runs environment steps with a given policy and recorder. /// /// This function does not support vectorized environments. -/// +/// /// * `n_steps` - The maximum number of environment steps. /// The interaction loop is terminated when is_done is true before reaching `n_steps` environment steps. /// * `prev_obs` - The observation, applied to the policy at the first step of interaction. /// If `None`, `env.reset_with_index(0)` is invoked. -/// +/// /// This function returns the sum of rewards during interaction. #[deprecated] pub fn eval_with_recorder2( diff --git a/border-derive/Cargo.toml b/border-derive/Cargo.toml index 3c7ad2ef..70a4a40f 100644 --- a/border-derive/Cargo.toml +++ b/border-derive/Cargo.toml @@ -1,17 +1,13 @@ [package] name = "border-derive" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Derive macros for observation and action in RL environments of border" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" -# readme = "README.md" -autoexamples = false +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" [lib] proc-macro = true @@ -22,19 +18,17 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } tch = { workspace = true, optional = true } +candle-core = { workspace = true, optional = true } [dev-dependencies] -border-tch-agent = { version = "0.0.6", path = "../border-tch-agent" } -border-py-gym-env = { version = "0.0.6", path = "../border-py-gym-env" } -border-core = { version = "0.0.6", path = "../border-core" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } +border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } +border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } +border-core = { version = "0.0.7", path = "../border-core" } +border-atari-env = { version = "0.0.7", path = "../border-atari-env" } ndarray = { workspace = true } tch = { workspace = true } - -[features] -default = ["tch"] - -[[example]] -name = "test1" +candle-core = { workspace = true } [package.metadata.docs.rs] features = ["doc-only"] diff --git a/border-derive/README.md b/border-derive/README.md new file mode 100644 index 00000000..e69de29b diff --git a/border-derive/examples/border_atari_act.rs b/border-derive/examples/border_atari_act.rs new file mode 100644 index 00000000..4f7b002a --- /dev/null +++ b/border-derive/examples/border_atari_act.rs @@ -0,0 +1,8 @@ +use border_atari_env::BorderAtariAct; +use border_derive::Act; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(BorderAtariAct); + +fn main() {} diff --git a/border-derive/examples/border_gym_cont_act.rs b/border-derive/examples/border_gym_cont_act.rs new file mode 100644 index 00000000..9015aca0 --- /dev/null +++ b/border-derive/examples/border_gym_cont_act.rs @@ -0,0 +1,8 @@ +use border_derive::Act; +use border_py_gym_env::GymContinuousAct; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(GymContinuousAct); + +fn main() {} diff --git a/border-derive/examples/border_gym_disc_act.rs b/border-derive/examples/border_gym_disc_act.rs new file mode 100644 index 00000000..05d0ea07 --- /dev/null +++ b/border-derive/examples/border_gym_disc_act.rs @@ -0,0 +1,8 @@ +use border_derive::Act; +use border_py_gym_env::GymDiscreteAct; + +#[allow(dead_code)] +#[derive(Clone, Debug, Act)] +struct MyAct(GymDiscreteAct); + +fn main() {} diff --git a/border-derive/examples/border_tensor_batch.rs b/border-derive/examples/border_tensor_batch.rs new file mode 100644 index 00000000..697e7b32 --- /dev/null +++ b/border-derive/examples/border_tensor_batch.rs @@ -0,0 +1,8 @@ +use border_derive::BatchBase; +use border_tch_agent::TensorBatch; + +#[allow(dead_code)] +#[derive(Clone, BatchBase)] +pub struct ObsBatch(TensorBatch); + +fn main() {} diff --git a/border-derive/examples/test1.rs b/border-derive/examples/test1.rs deleted file mode 100644 index 8c475aa7..00000000 --- a/border-derive/examples/test1.rs +++ /dev/null @@ -1,40 +0,0 @@ -use border_derive::{SubBatch, Act}; -use border_py_gym_env::GymDiscreteAct; -use std::convert::TryFrom; -use border_tch_agent::TensorSubBatch; -use ndarray::ArrayD; -use tch::Tensor; - -#[derive(Debug, Clone)] -struct Obs(ArrayD); - -#[derive(SubBatch)] -struct ObsBatch(TensorSubBatch); - -impl From for Tensor { - fn from(value: Obs) -> Self { - Tensor::try_from(&value.0).unwrap() - } -} - -impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } -} - -#[derive(Clone, Debug, Act)] -struct Act(GymDiscreteAct); - -#[derive(SubBatch)] -struct ActBatch(TensorSubBatch); - -impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } -} - -fn main() {} diff --git a/border-derive/src/act.rs b/border-derive/src/act.rs index 0492bcc9..1ddf9ea5 100644 --- a/border-derive/src/act.rs +++ b/border-derive/src/act.rs @@ -56,7 +56,39 @@ fn py_gym_env_cont_act( .iter() .map(|x| *x as usize) .collect::>(); - let act: Vec = t.into(); + use std::convert::TryInto; + let act: Vec = t.try_into().unwrap(); + + let act = ndarray::Array1::::from(act).into_shape(ndarray::IxDyn(&shape)).unwrap(); + + #ident(GymContinuousAct::new(act)) + } + } + }.into_iter()); + + #[cfg(feature = "candle-core")] + output.extend(quote! { + impl From<#ident> for candle_core::Tensor { + fn from(act: #ident) -> candle_core::Tensor { + let v = act.0.act.iter().map(|e| *e as f32).collect::>(); + let n = v.len(); + let t = candle_core::Tensor::from_vec(v, &[n], &candle_core::Device::Cpu).unwrap(); + + // The first dimension of the action tensor is the number of processes, + // which is 1 for the non-vectorized environment. + t.unsqueeze(0).unwrap() + } + } + + impl From for #ident { + /// `t` must be a 1-dimentional tensor of `f32`. + fn from(t: candle_core::Tensor) -> Self { + // In non-vectorized environment, the batch dimension is not required, thus dropped. + let shape = t.size()[1..] + .iter() + .map(|x| *x as usize) + .collect::>(); + let act: Vec = t.to_vec1().unwrap(); let act = ndarray::Array1::::from(act).into_shape(ndarray::IxDyn(&shape)).unwrap(); @@ -76,57 +108,116 @@ fn py_gym_env_disc_act( let mut output = common(ident.clone(), field_type.clone()); #[cfg(feature = "tch")] - output.extend(quote! { - impl From<#ident> for tch::Tensor { - fn from(act: #ident) -> tch::Tensor { - let v = act.0.act.iter().map(|e| *e as i64).collect::>(); - let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); + output.extend( + quote! { + impl From<#ident> for tch::Tensor { + fn from(act: #ident) -> tch::Tensor { + let v = act.0.act.iter().map(|e| *e as i64).collect::>(); + let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); + + // The first dimension is for the batch + t.unsqueeze(0) + } + } - // The first dimension is for the batch - t.unsqueeze(0) + impl From for #ident { + fn from(t: tch::Tensor) -> Self { + use std::convert::TryInto; + let data: Vec = t.try_into().unwrap(); + let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); + #ident(GymDiscreteAct::new(data)) + } } } + .into_iter(), + ); - impl From for #ident { - fn from(t: tch::Tensor) -> Self { - let data: Vec = t.into(); - let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); - #ident(GymDiscreteAct::new(data)) + #[cfg(feature = "candle-core")] + output.extend( + quote! { + impl From<#ident> for candle_core::Tensor { + fn from(act: #ident) -> candle_core::Tensor { + let v = act.0.act.iter().map(|e| *e as i64).collect::>(); + let n = v.len(); + let t = candle_core::Tensor::from_vec(v, &[n], &candle_core::Device::Cpu).unwrap(); + + // The first dimension is for the batch + t.unsqueeze(0).unwrap() + } + } + + impl From for #ident { + fn from(t: candle_core::Tensor) -> Self { + let data: Vec = t.to_vec1().unwrap(); + let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); + #ident(GymDiscreteAct::new(data)) + } } } - }.into_iter()); + .into_iter(), + ); output } -fn atari_env_act( - ident: proc_macro2::Ident, - field_type: syn::Type, -) -> proc_macro2::TokenStream { +fn atari_env_act(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro2::TokenStream { #[allow(unused_mut)] let mut output = common(ident.clone(), field_type.clone()); #[cfg(feature = "tch")] - output.extend(quote! { - impl From<#ident> for tch::Tensor { - fn from(act: #ident) -> tch::Tensor { - // let v = act.0.act.iter().map(|e| *e as i64).collect::>(); - let v = vec![act.0.act as i64]; - let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); + output.extend( + quote! { + impl From<#ident> for tch::Tensor { + fn from(act: #ident) -> tch::Tensor { + let v = vec![act.0.act as i64]; + let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); + + // The first dimension is for the batch + t.unsqueeze(0) + } + } - // The first dimension is for the batch - t.unsqueeze(0) + impl From for #ident { + fn from(t: tch::Tensor) -> Self { + let data: Vec = { + let t = t.to_dtype(tch::Kind::Int64, false, true); + let n = t.numel(); + let mut data = vec![0i64; n]; + t.f_copy_data(&mut data, n).unwrap(); + data + }; + // Non-vectorized environment + #ident(BorderAtariAct::new(data[0] as u8)) + } } } + .into_iter(), + ); - impl From for #ident { - fn from(t: tch::Tensor) -> Self { - let data: Vec = t.into(); - // Non-vectorized environment - #ident(BorderAtariAct::new(data[0] as u8)) + #[cfg(feature = "candle-core")] + output.extend( + quote! { + impl From<#ident> for candle_core::Tensor { + fn from(act: #ident) -> candle_core::Tensor { + let v = vec![act.0.act as i64]; + let n = v.len(); + let t = candle_core::Tensor::from_vec(v, &[n], &candle_core::Device::Cpu).unwrap(); + + // The first dimension is for the batch + t.unsqueeze(0).unwrap() + } + } + + impl From for #ident { + fn from(t: candle_core::Tensor) -> Self { + let data: Vec = t.to_vec1().unwrap(); + // Non-vectorized environment + #ident(BorderAtariAct::new(data[0] as u8)) + } } } - }.into_iter()); + .into_iter(), + ); output } @@ -145,4 +236,4 @@ fn common(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro2::Toke } } } -} \ No newline at end of file +} diff --git a/border-derive/src/lib.rs b/border-derive/src/lib.rs index d02afce2..e5874c76 100644 --- a/border-derive/src/lib.rs +++ b/border-derive/src/lib.rs @@ -1,28 +1,227 @@ -//! Derive macros for making newtypes of types that implements -//! `border_core::Obs`, `border_core::Act` and -//! `order_core::replay_buffer::SubBatch`. +//! Derive macros for implementing [`border_core::Act`] and +//! [`border_core::generic_replay_buffer::BatchBase`]. //! -//! These macros will implements some conversion traits for combining -//! interfaces of an environment and an agent. +//! # Examples +//! +//! ## Newtype for [`BorderAtariAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_atari_env::BorderAtariAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(BorderAtariAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_atari_env::BorderAtariAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(BorderAtariAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> BorderAtariAct { +//! self.0 +//! } +//! } +//! /// The following code is generated when features="tch" is enabled. +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = vec![act.0.act as i64]; +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! fn from(t: tch::Tensor) -> Self { +//! let data: Vec = { +//! let t = t.to_dtype(tch::Kind::Int64, false, true); +//! let n = t.numel(); +//! let mut data = vec![0i64; n]; +//! t.f_copy_data(&mut data, n).unwrap(); +//! data +//! }; +//! MyAct(BorderAtariAct::new(data[0] as u8)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`GymContinuousAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymContinuousAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(GymContinuousAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymContinuousAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(GymContinuousAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> GymContinuousAct { +//! self.0 +//! } +//! } +//! /// The following code is generated when features="tch" is enabled. +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = act.0.act.iter().map(|e| *e as f32).collect::>(); +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! /// `t` must be a 1-dimentional tensor of `f32`. +//! fn from(t: tch::Tensor) -> Self { +//! let shape = t.size()[1..].iter().map(|x| *x as usize).collect::>(); +//! use std::convert::TryInto; +//! let act: Vec = t.try_into().unwrap(); +//! let act = ndarray::Array1::::from(act) +//! .into_shape(ndarray::IxDyn(&shape)) +//! .unwrap(); +//! MyAct(GymContinuousAct::new(act)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`GymDiscreteAct`] +//! +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymDiscreteAct; +//! # +//! #[derive(Clone, Debug, Act)] +//! struct MyAct(GymDiscreteAct); +//! ``` +//! +//! The above code will generate the following implementation: +//! ``` +//! # use border_core::Act; +//! # use border_derive::Act; +//! # use border_py_gym_env::GymDiscreteAct; +//! # +//! #[derive(Clone, Debug)] +//! struct MyAct(GymDiscreteAct); +//! impl border_core::Act for MyAct { +//! fn len(&self) -> usize { +//! self.0.len() +//! } +//! } +//! impl Into for MyAct { +//! fn into(self) -> GymDiscreteAct { +//! self.0 +//! } +//! } +//! impl From for tch::Tensor { +//! fn from(act: MyAct) -> tch::Tensor { +//! let v = act.0.act.iter().map(|e| *e as i64).collect::>(); +//! let t: tch::Tensor = std::convert::TryFrom::>::try_from(v).unwrap(); +//! t.unsqueeze(0) +//! } +//! } +//! impl From for MyAct { +//! fn from(t: tch::Tensor) -> Self { +//! use std::convert::TryInto; +//! let data: Vec = t.try_into().unwrap(); +//! let data: Vec<_> = data.iter().map(|e| *e as i32).collect(); +//! MyAct(GymDiscreteAct::new(data)) +//! } +//! } +//! ``` +//! +//! ## Newtype for [`TensorBatch`] +//! +//! ``` +//! # use border_derive::BatchBase; +//! # use border_tch_agent::TensorBatch; +//! # +//! #[derive(Clone, BatchBase)] +//! struct MyBatch(TensorBatch); +//! ``` +//! +//! The above code will generate the following implementation: +//! +//! ``` +//! # use border_derive::BatchBase; +//! # use border_tch_agent::TensorBatch; +//! # +//! #[derive(Clone)] +//! struct ObsBatch(TensorBatch); +//! impl border_core::generic_replay_buffer::BatchBase for ObsBatch { +//! fn new(capacity: usize) -> Self { +//! Self(TensorBatch::new(capacity)) +//! } +//! fn push(&mut self, i: usize, data: Self) { +//! self.0.push(i, data.0) +//! } +//! fn sample(&self, ixs: &Vec) -> Self { +//! let buf = self.0.sample(ixs); +//! Self(buf) +//! } +//! } +//! impl From for ObsBatch { +//! fn from(obs: TensorBatch) -> Self { +//! ObsBatch(obs) +//! } +//! } +//! impl From for tch::Tensor { +//! fn from(b: ObsBatch) -> Self { +//! b.0.into() +//! } +//! } +//! ``` +//! +//! [`border_core::Obs`]: border_core::Obs +//! [`border_core::Act`]: border_core::Act +//! [`border_core::generic_replay_buffer::BatchBase`]: border_core::generic_replay_buffer::BatchBase +//! [`BorderAtariAct`]: border_atari_env::BorderAtariAct +//! [`GymContinuousAct`]: border_py_gym_env::GymContinuousAct +//! [`GymDiscreteAct`]: border_py_gym_env::GymDiscreteAct +//! [`TensorBatch`]: border_tch_agent::TensorBatch + +mod act; mod obs; mod subbatch; -mod act; use proc_macro::{self, TokenStream}; /// Implements `border_core::Obs` for the newtype that wraps /// PyGymEnvObs or BorderAtariObs. +#[deprecated] #[proc_macro_derive(Obs, attributes(my_trait))] pub fn derive1(input: TokenStream) -> TokenStream { obs::derive(input) } -/// Implements `border_core::replay_buffer::SubBatch` for the newtype. -#[proc_macro_derive(SubBatch, attributes(my_trait))] +/// Implements [`border_core::generic_replay_buffer::BatchBase`] for the newtype. +#[proc_macro_derive(BatchBase, attributes(my_trait))] pub fn derive2(input: TokenStream) -> TokenStream { subbatch::derive(input) } -/// Implements `border_core::Act` for the newtype. +/// Implements [`border_core::Act`] for the newtype. #[proc_macro_derive(Act, attributes(my_trait))] pub fn derive3(input: TokenStream) -> TokenStream { act::derive(input) diff --git a/border-derive/src/obs.rs b/border-derive/src/obs.rs index f37e64af..a63c4d7c 100644 --- a/border-derive/src/obs.rs +++ b/border-derive/src/obs.rs @@ -5,12 +5,11 @@ use syn::{parse_macro_input, DeriveInput}; pub fn derive(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input); - // let opts = Opts::from_derive_input(&input).expect("Wrong options"); let DeriveInput { ident, data, .. } = input; let field_type = get_field_type(data); let field_type_str = get_type_str( field_type.clone(), - "The item for deriving Obs must be a new type like MyObs(PyGymEnvObs)", + "The item for deriving Obs must be a new type like MyObs(BorderAtariObs)", ); // let output = if field_type_str == "PyGymEnvObs" { @@ -18,7 +17,10 @@ pub fn derive(input: TokenStream) -> TokenStream { let output = if field_type_str == "BorderAtariObs" { atari_env_obs(ident, field_type) } else { - panic!("Deriving Obs supports PyGymEnvObs or BorderAtariObs, given {:?}", field_type_str); + panic!( + "Deriving Obs supports PyGymEnvObs or BorderAtariObs, given {:?}", + field_type_str + ); }; output.into() @@ -66,16 +68,19 @@ fn atari_env_obs(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro }; #[cfg(feature = "tch")] - output.extend(quote! { - use std::convert::TryFrom as _; - - impl From<#ident> for tch::Tensor { - fn from(obs: #ident) -> tch::Tensor { - // `BorderAtariObs` implements Into when feature = "tch" - tch::Tensor::try_from(obs.0).unwrap() + output.extend( + quote! { + use std::convert::TryFrom as _; + + impl From<#ident> for tch::Tensor { + fn from(obs: #ident) -> tch::Tensor { + // `BorderAtariObs` implements Into when feature = "tch" + tch::Tensor::try_from(obs.0).unwrap() + } } } - }.into_iter()); + .into_iter(), + ); output } diff --git a/border-derive/src/subbatch.rs b/border-derive/src/subbatch.rs index 89fed9b8..56a307cb 100644 --- a/border-derive/src/subbatch.rs +++ b/border-derive/src/subbatch.rs @@ -9,14 +9,14 @@ pub fn derive(input: TokenStream) -> TokenStream { let field_type = get_field_type(data); let field_type_str = get_type_str( field_type.clone(), - "The item for deriving SubBatch must be a new type like SubBatch(TensorSubBatch)", + "The item for deriving BatchBase must be a new type like Batch(TensorSubBatch)", ); - let output = if field_type_str == "TensorSubBatch" { - tensor_sub_batch(ident, field_type) + let output = if field_type_str == "TensorBatch" { + tensor_batch(ident, field_type) } else { panic!( - "Deriving ObsBatch support TensorSubBatch, given {:?}", + "Deriving ObsBatch support TensorBatch, given {:?}", field_type_str ); }; @@ -24,16 +24,16 @@ pub fn derive(input: TokenStream) -> TokenStream { output.into() } -fn tensor_sub_batch(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro2::TokenStream { +fn tensor_batch(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_macro2::TokenStream { #[allow(unused_mut)] let mut output = quote! { - impl border_core::replay_buffer::SubBatch for #ident { + impl border_core::generic_replay_buffer::BatchBase for #ident { fn new(capacity: usize) -> Self { - Self(TensorSubBatch::new(capacity)) + Self(TensorBatch::new(capacity)) } - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) } fn sample(&self, ixs: &Vec) -> Self { @@ -50,13 +50,28 @@ fn tensor_sub_batch(ident: proc_macro2::Ident, field_type: syn::Type) -> proc_ma }; #[cfg(feature = "tch")] - output.extend(quote! { - impl From<#ident> for tch::Tensor { - fn from(b: #ident) -> Self { - b.0.into() + output.extend( + quote! { + impl From<#ident> for tch::Tensor { + fn from(b: #ident) -> Self { + b.0.into() + } } } - }.into_iter()); + .into_iter(), + ); + + #[cfg(feature = "candle-core")] + output.extend( + quote! { + impl From<#ident> for candle_core::Tensor { + fn from(b: #ident) -> Self { + b.0.into() + } + } + } + .into_iter(), + ); output } diff --git a/border-mlflow-tracking/Cargo.toml b/border-mlflow-tracking/Cargo.toml new file mode 100644 index 00000000..08016031 --- /dev/null +++ b/border-mlflow-tracking/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "border-mlflow-tracking" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.7", path = "../border-core" } +reqwest = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true, features = ["derive"] } +log = { workspace = true } +serde_json = { workspace = true } +flatten-serde-json = "0.1.0" +chrono = { workspace = true } + +[dev-dependencies] +env_logger = { workspace = true } + +[[example]] +name = "tracking_basic" +# test = true diff --git a/border-mlflow-tracking/README.md b/border-mlflow-tracking/README.md new file mode 100644 index 00000000..dab7cc1d --- /dev/null +++ b/border-mlflow-tracking/README.md @@ -0,0 +1,105 @@ +Support [MLflow](https://mlflow.org) tracking to manage experiments. + +Before running the program using this crate, run a tracking server with the following command: + +```bash +mlflow server --host 127.0.0.1 --port 8080 +``` + +Then, training configurations and metrices can be logged to the tracking server. +The following code provides an example. Nested configuration parameters will be flattened, +logged like `hyper_params.param1`, `hyper_params.param2`. + +```rust +use anyhow::Result; +use border_core::record::{Record, RecordValue, Recorder}; +use border_mlflow_tracking::MlflowTrackingClient; +use serde::Serialize; + +// Nested Configuration struct +#[derive(Debug, Serialize)] +struct Config { + env_params: String, + hyper_params: HyperParameters, +} + +#[derive(Debug, Serialize)] +struct HyperParameters { + param1: i64, + param2: Param2, + param3: Param3, +} + +#[derive(Debug, Serialize)] +enum Param2 { + Variant1, + Variant2(f32), +} + +#[derive(Debug, Serialize)] +struct Param3 { + dataset_name: String, +} + +fn main() -> Result<()> { + env_logger::init(); + + let config1 = Config { + env_params: "env1".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant1, + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + let config2 = Config { + env_params: "env2".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant2(3.0), + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + + // Set experiment for runs + let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + + // Create recorders for logging + let mut recorder_run1 = client.create_recorder("")?; + let mut recorder_run2 = client.create_recorder("")?; + recorder_run1.log_params(&config1)?; + recorder_run2.log_params(&config2)?; + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run1.write(record); + } + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run2.write(record); + } + + Ok(()) +} +``` diff --git a/border-mlflow-tracking/examples/tracking_basic.rs b/border-mlflow-tracking/examples/tracking_basic.rs new file mode 100644 index 00000000..98d22f9d --- /dev/null +++ b/border-mlflow-tracking/examples/tracking_basic.rs @@ -0,0 +1,91 @@ +use anyhow::Result; +use border_core::record::{Record, RecordValue, Recorder}; +use border_mlflow_tracking::MlflowTrackingClient; +use serde::Serialize; + +// Nested Configuration struct +#[derive(Debug, Serialize)] +struct Config { + env_params: String, + hyper_params: HyperParameters, +} + +#[derive(Debug, Serialize)] +struct HyperParameters { + param1: i64, + param2: Param2, + param3: Param3, +} + +#[derive(Debug, Serialize)] +enum Param2 { + Variant1, + Variant2(f32), +} + +#[derive(Debug, Serialize)] +struct Param3 { + dataset_name: String, +} + +fn main() -> Result<()> { + env_logger::init(); + + let config1 = Config { + env_params: "env1".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant1, + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + let config2 = Config { + env_params: "env2".to_string(), + hyper_params: HyperParameters { + param1: 0, + param2: Param2::Variant2(3.0), + param3: Param3 { + dataset_name: "a".to_string(), + }, + }, + }; + + // Set experiment for runs + let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; + + // Create recorders for logging + let mut recorder_run1 = client.create_recorder("")?; + let mut recorder_run2 = client.create_recorder("")?; + recorder_run1.log_params(&config1)?; + recorder_run2.log_params(&config2)?; + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run1.write(record); + } + + // Logging while training + for opt_steps in 0..100 { + let opt_steps = opt_steps as f32; + + // Create a record + let mut record = Record::empty(); + record.insert("opt_steps", RecordValue::Scalar(opt_steps)); + record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); + + // Log metrices in the record + recorder_run2.write(record); + } + + Ok(()) +} diff --git a/border-mlflow-tracking/src/client.rs b/border-mlflow-tracking/src/client.rs new file mode 100644 index 00000000..db0f129b --- /dev/null +++ b/border-mlflow-tracking/src/client.rs @@ -0,0 +1,224 @@ +// use anyhow::Result; +use crate::{system_time_as_millis, Experiment, MlflowTrackingRecorder, Run}; +use anyhow::Result; +use log::info; +use reqwest::blocking::Client; +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt::Display; + +#[derive(Debug, Deserialize)] +/// Internally used. +struct Experiment_ { + pub(crate) experiment: Experiment, +} + +#[derive(Debug, Deserialize)] +/// Internally used. +struct Run_ { + run: Run, +} + +#[derive(Debug, Clone)] +pub struct GetExperimentIdError; + +impl Display for GetExperimentIdError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to get experiment ID") + } +} + +impl Error for GetExperimentIdError {} + +#[derive(Debug, Serialize)] +/// Parameters adapted from . +/// +/// TODO: Support parameters in API, if required. +struct CreateRunParams { + experiment_id: String, + start_time: i64, + run_name: String, +} + +#[derive(Debug, Serialize)] +struct CreateExperimentParams { + name: String, +} + +/// Provides access to a MLflow tracking server via REST API. +/// +/// Support Mlflow API version 2.0. +pub struct MlflowTrackingClient { + client: Client, + + /// Base URL. + base_url: String, + + /// Current experiment ID. + experiment_id: Option, + + /// User name of the tracking server. + user_name: String, + + /// Password. + password: String, +} + +impl MlflowTrackingClient { + pub fn new(base_url: impl AsRef) -> Self { + Self { + client: Client::new(), + base_url: base_url.as_ref().to_string(), + experiment_id: None, + user_name: "".to_string(), + password: "".to_string(), + } + } + + /// Set user name and password for basic authentication of the tracking server. + pub fn basic_auth(self, user_name: impl AsRef, password: impl AsRef) -> Self { + Self { + client: self.client, + base_url: self.base_url, + experiment_id: self.experiment_id, + user_name: user_name.as_ref().to_string(), + password: password.as_ref().to_string(), + } + } + + /// Set the experiment ID to this struct. + /// + /// Get the ID from the tracking server by `name`. + pub fn set_experiment_id(self, name: impl AsRef) -> Result { + let experiment_id = { + self.get_experiment(name.as_ref()) + .expect(format!("Failed to get experiment: {:?}", name.as_ref()).as_str()) + .experiment_id + }; + + info!( + "For experiment '{}', id={} is set in MlflowTrackingClient", + name.as_ref(), + experiment_id + ); + + Ok(Self { + client: self.client, + base_url: self.base_url, + experiment_id: Some(experiment_id), + user_name: self.user_name, + password: self.password, + }) + + // let experiment_id = self.get_experiment_id(&name); + // match experiment_id { + // None => Err(GetExperimentIdError), + // Some(experiment_id) => Ok(Self { + // client: self.client, + // base_url: self.base_url, + // experiment_id: Some(experiment_id), + // }), + // } + } + + /// Get [`Experiment`] by name from the tracking server. + /// + /// If the experiment with given name does not exist in the trackingserver, + /// it will be created. + /// + /// TODO: Better error handling + pub fn get_experiment(&self, name: impl AsRef) -> Option { + let resp = match self.get( + self.url("experiments/get-by-name"), + &[("experiment_name", name.as_ref())], + ) { + Ok(resp) => { + if resp.status().is_success() { + resp + } else { + // if the experiment does not exist, create it + self.post( + self.url("experiments/create"), + &CreateExperimentParams { + name: name.as_ref().into(), + }, + ) + .unwrap(); + self.get( + self.url("experiments/get-by-name"), + &[("experiment_name", name.as_ref())], + ) + .unwrap() + } + } + Err(_) => { + panic!(); + } + }; + let experiment: Experiment_ = serde_json::from_str(resp.text().unwrap().as_str()).unwrap(); + + Some(experiment.experiment) + } + + fn url(&self, api: impl AsRef) -> String { + format!("{}/api/2.0/mlflow/{}", self.base_url, api.as_ref()) + } + + fn get( + &self, + url: String, + query: &impl Serialize, + ) -> reqwest::Result { + self.client + .get(url) + .basic_auth(&self.user_name, Some(&self.password)) + .query(query) + .send() + } + + fn post( + &self, + url: String, + params: &impl Serialize, + ) -> reqwest::Result { + self.client + .post(url) + .basic_auth(&self.user_name, Some(&self.password)) + .json(¶ms) // auto serialize + .send() + } + + /// Create [`MlflowTrackingRecorder`] corresponding to a run. + /// + /// If `name` is empty (`""`), a run name is generated by the tracking server. + /// + /// Need to call [`MlflowTrackingClient::set_experiment_id()`] before calling this method. + pub fn create_recorder(&self, run_name: impl AsRef) -> Result { + let not_given_name = run_name.as_ref().len() == 0; + let experiment_id = self.experiment_id.as_ref().expect("Needs experiment_id"); + let resp = self + .post( + self.url("runs/create"), + &CreateRunParams { + experiment_id: experiment_id.to_string(), + start_time: system_time_as_millis() as i64, + run_name: run_name.as_ref().to_string(), + }, + ) + .unwrap(); + // TODO: Check the response from the tracking server here + + let run = { + let run: Run_ = + serde_json::from_str(&resp.text().unwrap()).expect("Failed to deserialize Run"); + run.run + }; + if not_given_name { + info!( + "Run name '{}' has been automatically generated", + run.info.run_name + ); + } + MlflowTrackingRecorder::new(&self.base_url, &experiment_id, &run) + } +} diff --git a/border-mlflow-tracking/src/experiment.rs b/border-mlflow-tracking/src/experiment.rs new file mode 100644 index 00000000..75cf7425 --- /dev/null +++ b/border-mlflow-tracking/src/experiment.rs @@ -0,0 +1,19 @@ +use serde::Deserialize; + +#[derive(Debug, Deserialize)] +pub struct ExperimentTag { + pub key: String, + pub value: String, +} + +#[derive(Debug, Deserialize)] +/// all fields taken from . +pub struct Experiment { + pub experiment_id: String, + pub name: String, + pub artifact_location: String, + pub lifecycle_stage: String, + pub last_update_time: i64, + pub creation_time: i64, + pub tags: Option>, +} diff --git a/border-mlflow-tracking/src/lib.rs b/border-mlflow-tracking/src/lib.rs new file mode 100644 index 00000000..8d9a0684 --- /dev/null +++ b/border-mlflow-tracking/src/lib.rs @@ -0,0 +1,122 @@ +//! Support [MLflow](https://mlflow.org) tracking to manage experiments. +//! +//! Before running the program using this crate, run a tracking server with the following command: +//! +//! ```bash +//! mlflow server --host 127.0.0.1 --port 8080 +//! ``` +//! +//! Then, training configurations and metrices can be logged to the tracking server. +//! The following code is an example. Nested configuration parameters will be flattened, +//! logged like `hyper_params.param1`, `hyper_params.param2`. +//! +//! ```no_run +//! use anyhow::Result; +//! use border_core::record::{Record, RecordValue, Recorder}; +//! use border_mlflow_tracking::MlflowTrackingClient; +//! use serde::Serialize; +//! +//! // Nested Configuration struct +//! #[derive(Debug, Serialize)] +//! struct Config { +//! env_params: String, +//! hyper_params: HyperParameters, +//! } +//! +//! #[derive(Debug, Serialize)] +//! struct HyperParameters { +//! param1: i64, +//! param2: Param2, +//! param3: Param3, +//! } +//! +//! #[derive(Debug, Serialize)] +//! enum Param2 { +//! Variant1, +//! Variant2(f32), +//! } +//! +//! #[derive(Debug, Serialize)] +//! struct Param3 { +//! dataset_name: String, +//! } +//! +//! fn main() -> Result<()> { +//! env_logger::init(); +//! +//! let config1 = Config { +//! env_params: "env1".to_string(), +//! hyper_params: HyperParameters { +//! param1: 0, +//! param2: Param2::Variant1, +//! param3: Param3 { +//! dataset_name: "a".to_string(), +//! }, +//! }, +//! }; +//! let config2 = Config { +//! env_params: "env2".to_string(), +//! hyper_params: HyperParameters { +//! param1: 0, +//! param2: Param2::Variant2(3.0), +//! param3: Param3 { +//! dataset_name: "a".to_string(), +//! }, +//! }, +//! }; +//! +//! // Set experiment for runs +//! let client = MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Default")?; +//! +//! // Create recorders for logging +//! let mut recorder_run1 = client.create_recorder("")?; +//! let mut recorder_run2 = client.create_recorder("")?; +//! recorder_run1.log_params(&config1)?; +//! recorder_run2.log_params(&config2)?; +//! +//! // Logging while training +//! for opt_steps in 0..100 { +//! let opt_steps = opt_steps as f32; +//! +//! // Create a record +//! let mut record = Record::empty(); +//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); +//! record.insert("Loss", RecordValue::Scalar((-1f32 * opt_steps).exp())); +//! +//! // Log metrices in the record +//! recorder_run1.write(record); +//! } +//! +//! // Logging while training +//! for opt_steps in 0..100 { +//! let opt_steps = opt_steps as f32; +//! +//! // Create a record +//! let mut record = Record::empty(); +//! record.insert("opt_steps", RecordValue::Scalar(opt_steps)); +//! record.insert("Loss", RecordValue::Scalar((-0.5f32 * opt_steps).exp())); +//! +//! // Log metrices in the record +//! recorder_run2.write(record); +//! } +//! +//! Ok(()) +//! } +//! ``` +mod client; +mod experiment; +mod run; +mod writer; +pub use client::{GetExperimentIdError, MlflowTrackingClient}; +use experiment::Experiment; +pub use run::Run; +use std::time::{SystemTime, UNIX_EPOCH}; +pub use writer::MlflowTrackingRecorder; + +/// Code adapted from . +fn system_time_as_millis() -> u128 { + let time = SystemTime::now(); + time.duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis() +} diff --git a/border-mlflow-tracking/src/run.rs b/border-mlflow-tracking/src/run.rs new file mode 100644 index 00000000..79acaac5 --- /dev/null +++ b/border-mlflow-tracking/src/run.rs @@ -0,0 +1,58 @@ +use serde::Deserialize; + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct Run { + pub info: RunInfo, + data: Option, + inputs: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +pub struct RunInfo { + pub run_id: String, + pub run_name: String, + experiment_id: String, + status: Option, + start_time: i64, + end_time: Option, + artifact_uri: Option, + lifecycle_stage: Option, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct RunData { + metrics: Option>, + params: Option>, + tags: Option>, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +/// TODO: implement +struct RunInputs {} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct RunTag { + key: String, + value: String, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Param { + key: String, + value: String, +} + +#[allow(dead_code)] +#[derive(Debug, Deserialize)] +struct Metric { + key: String, + value: String, + timestamp: i64, + step: i64, +} diff --git a/border-mlflow-tracking/src/writer.rs b/border-mlflow-tracking/src/writer.rs new file mode 100644 index 00000000..361e4cd6 --- /dev/null +++ b/border-mlflow-tracking/src/writer.rs @@ -0,0 +1,232 @@ +use crate::{system_time_as_millis, Run}; +use anyhow::Result; +use border_core::record::{AggregateRecorder, RecordStorage, RecordValue, Recorder}; +use chrono::{DateTime, Local, SecondsFormat, Duration}; +use reqwest::blocking::Client; +use serde::Serialize; +use serde_json::Value; + +#[derive(Debug, Serialize)] +struct LogParamParams<'a> { + run_id: &'a String, + key: &'a String, + value: String, +} + +#[derive(Debug, Serialize)] +struct LogMetricParams<'a> { + run_id: &'a String, + key: &'a String, + value: f64, + timestamp: i64, + step: i64, +} + +#[derive(Debug, Serialize)] +struct UpdateRunParams<'a> { + run_id: &'a String, + status: String, + end_time: i64, + run_name: &'a String, +} + +#[derive(Debug, Serialize)] +struct SetTagParams<'a> { + run_id: &'a String, + key: &'a String, + value: &'a String, +} + +#[allow(dead_code)] +/// Record metrics to the MLflow tracking server during training. +/// +/// Before training, you can use [`MlflowTrackingRecorder::log_params()`] to log parameters +/// of the run like hyperparameters of the algorithm, the name of environment on which the +/// agent is trained, etc. +/// +/// [`MlflowTrackingRecorder::write()`] method logs [`RecordValue::Scalar`] values in the record +/// as metrics. As an exception, `opt_steps` is treated as the `step` field of Mlflow's metric data +/// (). +/// +/// Other types of values like [`RecordValue::Array1`] will be ignored. +/// +/// When dropped, this struct updates run's status to "FINISHED" +/// (). +/// +/// [`RecordValue::Scalar`]: border_core::record::RecordValue::Scalar +/// [`RecordValue::Array1`]: border_core::record::RecordValue::Array1 +pub struct MlflowTrackingRecorder { + client: Client, + base_url: String, + experiment_id: String, + run_id: String, + run_name: String, + user_name: String, + storage: RecordStorage, + password: String, + start_time: DateTime, +} + +impl MlflowTrackingRecorder { + /// Create a new instance of `MlflowTrackingRecorder`. + /// + /// This method adds a tag "host_start_time" with the current time. + /// This tag is useful when using mlflow-export-import: it losts the original time. + /// See https://github.com/mlflow/mlflow-export-import/issues/72 + pub fn new(base_url: &String, experiment_id: &String, run: &Run) -> Result { + let client = Client::new(); + let start_time = Local::now(); + let recorder = Self { + client, + base_url: base_url.clone(), + experiment_id: experiment_id.to_string(), + run_id: run.info.run_id.clone(), + run_name: run.info.run_name.clone(), + user_name: "".to_string(), + password: "".to_string(), + storage: RecordStorage::new(), + start_time: start_time.clone(), + }; + + // Record current time as tag "host_start_time" + recorder.set_tag( + "host_start_time", + start_time.to_rfc3339_opts(SecondsFormat::Secs, true), + )?; + + Ok(recorder) + } + + pub fn log_params(&self, params: impl Serialize) -> Result<()> { + let url = format!("{}/api/2.0/mlflow/runs/log-parameter", self.base_url); + let flatten_map = { + let map = match serde_json::to_value(params).unwrap() { + Value::Object(map) => map, + _ => panic!("Failed to parse object"), + }; + flatten_serde_json::flatten(&map) + }; + for (key, value) in flatten_map.iter() { + let params = LogParamParams { + run_id: &self.run_id, + key, + value: value.to_string(), + }; + let _resp = self + .client + .post(&url) + .basic_auth(&self.user_name, Some(&self.password)) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } + + Ok(()) + } + + pub fn set_tag(&self, key: impl AsRef, value: impl AsRef) -> Result<()> { + let url = format!("{}/api/2.0/mlflow/runs/set-tag", self.base_url); + let params = SetTagParams { + run_id: &self.run_id, + key: &key.as_ref().to_string(), + value: &value.as_ref().to_string(), + }; + let _resp = self + .client + .post(&url) + .basic_auth(&self.user_name, Some(&self.password)) + .json(¶ms) + .send() + .unwrap(); + + Ok(()) + } +} + +impl Recorder for MlflowTrackingRecorder { + fn write(&mut self, record: border_core::record::Record) { + let url = format!("{}/api/2.0/mlflow/runs/log-metric", self.base_url); + let timestamp = system_time_as_millis() as i64; + let step = record.get_scalar("opt_steps").unwrap() as i64; + + for (key, value) in record.iter() { + if *key != "opt_steps" { + match value { + RecordValue::Scalar(v) => { + let value = *v as f64; + let params = LogMetricParams { + run_id: &self.run_id, + key, + value, + timestamp, + step, + }; + let _resp = self + .client + .post(&url) + .basic_auth(&self.user_name, Some(&self.password)) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } + _ => {} // ignore record value + } + } + } + } +} + +impl Drop for MlflowTrackingRecorder { + /// Update run's status to "FINISHED" when dropped. + /// + /// It also adds tags "host_end_time" and "host_duration" with the current time and duration. + fn drop(&mut self) { + let end_time = Local::now(); + let duration = end_time.signed_duration_since(self.start_time); + self.set_tag( + "host_end_time", + end_time.to_rfc3339_opts(SecondsFormat::Secs, true), + ) + .unwrap(); + self.set_tag("host_duration", format_duration(&duration)).unwrap(); + + let url = format!("{}/api/2.0/mlflow/runs/update", self.base_url); + let params = UpdateRunParams { + run_id: &self.run_id, + status: "FINISHED".to_string(), + end_time: end_time.timestamp_millis(), + run_name: &self.run_name, + }; + let _resp = self + .client + .post(&url) + .basic_auth(&self.user_name, Some(&self.password)) + .json(¶ms) // auto serialize + .send() + .unwrap(); + // TODO: error handling caused by API call + } +} + +impl AggregateRecorder for MlflowTrackingRecorder { + fn flush(&mut self, step: i64) { + let mut record = self.storage.aggregate(); + record.insert("opt_steps", RecordValue::Scalar(step as _)); + self.write(record); + } + + fn store(&mut self, record: border_core::record::Record) { + self.storage.store(record); + } +} + +fn format_duration(dt: &Duration) -> String { + let mut seconds = dt.num_seconds(); + let mut minutes = seconds / 60; + seconds %= 60; + let hours = minutes / 60; + minutes %= 60; + format!("{:02}:{:02}:{:02}", hours, minutes, seconds) +} diff --git a/border-policy-no-backend/Cargo.toml b/border-policy-no-backend/Cargo.toml new file mode 100644 index 00000000..53bf671d --- /dev/null +++ b/border-policy-no-backend/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "border-policy-no-backend" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true +readme = "README.md" + +[dependencies] +border-core = { version = "0.0.7", path = "../border-core" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent", optional = true } +serde = { workspace = true, features = ["derive"] } +log = { workspace = true } +anyhow = { workspace = true } +tch = { workspace = true, optional = true } + +[dev-dependencies] +tempdir = { workspace = true } +tch = { workspace = true } + + +[features] +border-tch-agent = ["dep:border-tch-agent", "dep:tch"] diff --git a/border-policy-no-backend/src/lib.rs b/border-policy-no-backend/src/lib.rs new file mode 100644 index 00000000..93053528 --- /dev/null +++ b/border-policy-no-backend/src/lib.rs @@ -0,0 +1,6 @@ +//! Policy with no backend. +mod mat; +mod mlp; + +pub use mat::Mat; +pub use mlp::Mlp; diff --git a/border-policy-no-backend/src/mat.rs b/border-policy-no-backend/src/mat.rs new file mode 100644 index 00000000..5a429cd8 --- /dev/null +++ b/border-policy-no-backend/src/mat.rs @@ -0,0 +1,107 @@ +//! A matrix object. +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct Mat { + pub data: Vec, + pub shape: Vec, +} + +#[cfg(feature = "border-tch-agent")] +impl From for Mat { + fn from(x: tch::Tensor) -> Self { + let shape: Vec = x.size().iter().map(|e| *e as i32).collect(); + let (n, shape) = match shape.len() { + 1 => (shape[0] as usize, vec![shape[0], 1]), + 2 => ((shape[0] * shape[1]) as usize, shape), + _ => panic!("Invalid matrix size: {:?}", shape), + }; + let mut data: Vec = vec![0f32; n]; + x.f_copy_data(&mut data, n).unwrap(); + Self { data, shape } + } +} + +impl Mat { + pub fn matmul(&self, x: &Mat) -> Self { + let (m, l, n) = ( + self.shape[0] as usize, + self.shape[1] as usize, + x.shape[1] as usize, + ); + // println!("{:?}", (m, l, x.shape[0], n)); + let mut data = vec![0.0f32; (m * n) as usize]; + for i in 0..m as usize { + for j in 0..n as usize { + let kk = i * n as usize + j; + for k in 0..l as usize { + data[kk] += self.data[i * l + k] * x.data[k * n + j]; + } + } + } + + Self { + shape: vec![m as _, n as _], + data, + } + } + + pub fn add(&self, x: &Mat) -> Self { + if self.shape[0] != x.shape[0] || self.shape[1] != x.shape[1] { + panic!( + "Trying to add matrices of different sizes: {:?}", + (&self.shape, &x.shape) + ); + } + + let data = self + .data + .iter() + .zip(x.data.iter()) + .map(|(a, b)| *a + *b) + .collect(); + + Mat { + data, + shape: self.shape.clone(), + } + } + + pub fn relu(&self) -> Self { + let data = self + .data + .iter() + .map(|a| match *a < 0. { + true => 0., + false => *a, + }) + .collect(); + + Self { + data, + shape: self.shape.clone(), + } + } + + pub fn empty() -> Self { + Self { + data: vec![], + shape: vec![0, 0], + } + } + + pub fn shape(&self) -> &Vec { + &self.shape + } + + pub fn new(data: Vec, shape: Vec) -> Self { + Self { data, shape } + } +} + +impl From> for Mat { + fn from(x: Vec) -> Self { + let shape = vec![x.len() as i32, 1]; + Self { shape, data: x } + } +} diff --git a/border-policy-no-backend/src/mlp.rs b/border-policy-no-backend/src/mlp.rs new file mode 100644 index 00000000..8805a50d --- /dev/null +++ b/border-policy-no-backend/src/mlp.rs @@ -0,0 +1,44 @@ +use crate::Mat; +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "border-tch-agent")] +use tch::nn::VarStore; + +#[derive(Clone, Debug, Deserialize, Serialize)] +/// Multilayer perceptron with ReLU activation function. +pub struct Mlp { + /// Weights of layers. + ws: Vec, + + /// Biases of layers. + bs: Vec, +} + +impl Mlp { + pub fn forward(&self, x: &Mat) -> Mat { + let n_layers = self.ws.len(); + let mut x = x.clone(); + for i in 0..n_layers { + x = self.ws[i].matmul(&x).add(&self.bs[i]); + if i != n_layers - 1 { + x = x.relu(); + } + } + x + } + + #[cfg(feature = "border-tch-agent")] + pub fn from_varstore(vs: &VarStore, w_names: &[&str], b_names: &[&str]) -> Self { + let vars = vs.variables(); + let ws: Vec = w_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + let bs: Vec = b_names + .iter() + .map(|name| vars[&name.to_string()].copy().into()) + .collect(); + + Self { ws, bs } + } +} diff --git a/border-policy-no-backend/tests/test.rs b/border-policy-no-backend/tests/test.rs new file mode 100644 index 00000000..f1b66b07 --- /dev/null +++ b/border-policy-no-backend/tests/test.rs @@ -0,0 +1,24 @@ +use border_policy_no_backend::Mat; +use tch::Tensor; + +#[test] +fn test_matmul() { + let x1 = Tensor::from_slice2(&[&[1.0f32, 2., 3.], &[4., 5., 6.]]); + let y1 = Tensor::from_slice(&[7.0f32, 8., 9.]); + let z1 = x1.matmul(&y1); + + let x2: Mat = x1.into(); + let y2: Mat = y1.into(); + let z2 = x2.matmul(&y2); + + let z3 = { + let mut data = vec![0.0f32; 2]; + z1.f_copy_data(&mut data, 2).unwrap(); + Mat { + shape: vec![2 as _, 1 as _], + data, + } + }; + + assert_eq!(z2, z3) +} diff --git a/border-py-gym-env/Cargo.toml b/border-py-gym-env/Cargo.toml index b1592ee2..a996eccb 100644 --- a/border-py-gym-env/Cargo.toml +++ b/border-py-gym-env/Cargo.toml @@ -1,20 +1,16 @@ [package] name = "border-py-gym-env" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } numpy = { workspace = true } pyo3 = { workspace = true, default-features = false, features = [ "auto-initialize", @@ -26,14 +22,15 @@ ndarray = { workspace = true, features = ["serde"] } anyhow = { workspace = true } tch = { workspace = true, optional = true } image = { workspace = true } +candle-core = { workspace = true, optional = true } [dev-dependencies] fastrand = { workspace = true } env_logger = { workspace = true } csv = { workspace = true } -[features] -default = ["tch"] +# [features] +# default = ["tch"] [[example]] name = "random_cartpole" @@ -41,7 +38,7 @@ test = true [[example]] name = "random_lunarlander_cont" -test = true +test = false # due to box2d installation issue in github action [[example]] name = "random_fetch_reach" diff --git a/border-py-gym-env/examples/atari_wrappers.py b/border-py-gym-env/examples/atari_wrappers.py index e7fff735..04b13c93 100644 --- a/border-py-gym-env/examples/atari_wrappers.py +++ b/border-py-gym-env/examples/atari_wrappers.py @@ -215,7 +215,7 @@ def __init__(self, env=None): def observation(self, observation): return observation.transpose(2, 0, 1) - + # vecenv.py class VecEnv(object): """ diff --git a/border-py-gym-env/examples/random_ant_pybullet.rs b/border-py-gym-env/examples/backup/random_ant_pybullet.rs similarity index 100% rename from border-py-gym-env/examples/random_ant_pybullet.rs rename to border-py-gym-env/examples/backup/random_ant_pybullet.rs diff --git a/border-py-gym-env/examples/pybullet_pyo3.rs b/border-py-gym-env/examples/pybullet_pyo3.rs index 4e9fb69e..af96932e 100644 --- a/border-py-gym-env/examples/pybullet_pyo3.rs +++ b/border-py-gym-env/examples/pybullet_pyo3.rs @@ -1,7 +1,7 @@ //! This program is used to quickly check pybullet works properly with pyo3. use anyhow::Result; -use pyo3::{Python, types::IntoPyDict}; +use pyo3::{types::IntoPyDict, Python}; // With a version of tch, commenting out the following line causes segmentation fault. // use tch::Tensor; diff --git a/border-py-gym-env/examples/random_ant.rs b/border-py-gym-env/examples/random_ant.rs index 7dd43bc3..12b27bbd 100644 --- a/border-py-gym-env/examples/random_ant.rs +++ b/border-py-gym-env/examples/random_ant.rs @@ -1,9 +1,10 @@ use anyhow::Result; -use border_core::{DefaultEvaluator, Evaluator as _, Policy}; +use border_core::{Configurable, DefaultEvaluator, Evaluator as _, Policy}; use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; use ndarray::{Array, ArrayD, IxDyn}; +use serde::Deserialize; use std::default::Default; mod obs { @@ -58,18 +59,12 @@ type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; impl Policy for RandomPolicy { - type Config = RandomPolicyConfig; - - fn build(_config: Self::Config) -> Self { - Self - } - fn sample(&mut self, _: &Obs) -> Act { Act::new( Array::from( @@ -82,6 +77,14 @@ impl Policy for RandomPolicy { } } +impl Configurable for RandomPolicy { + type Config = RandomPolicyConfig; + + fn build(_config: Self::Config) -> Self { + Self + } +} + fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); fastrand::seed(42); diff --git a/border-py-gym-env/examples/random_cartpole.rs b/border-py-gym-env/examples/random_cartpole.rs index 28b0cf10..11ac54d5 100644 --- a/border-py-gym-env/examples/random_cartpole.rs +++ b/border-py-gym-env/examples/random_cartpole.rs @@ -1,9 +1,9 @@ use anyhow::Result; -use border_core::{record::Record, DefaultEvaluator, Evaluator as _, Policy}; +use border_core::{record::Record, Configurable, DefaultEvaluator, Evaluator as _, Policy}; use border_py_gym_env::{ ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::convert::TryFrom; type PyObsDtype = f32; @@ -60,22 +60,24 @@ type ActFilter = DiscreteActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; impl Policy for RandomPolicy { + fn sample(&mut self, _: &Obs) -> Act { + let v = fastrand::u32(..=1); + Act::new(vec![v as i32]) + } +} + +impl Configurable for RandomPolicy { type Config = RandomPolicyConfig; fn build(_config: Self::Config) -> Self { Self } - - fn sample(&mut self, _: &Obs) -> Act { - let v = fastrand::u32(..=1); - Act::new(vec![v as i32]) - } } #[derive(Debug, Serialize)] diff --git a/border-py-gym-env/examples/random_fetch_reach.rs b/border-py-gym-env/examples/random_fetch_reach.rs index fd114292..c7f8b5f4 100644 --- a/border-py-gym-env/examples/random_fetch_reach.rs +++ b/border-py-gym-env/examples/random_fetch_reach.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use border_core::{DefaultEvaluator, Evaluator as _, Policy}; +use border_core::{Configurable, DefaultEvaluator, Evaluator as _, Policy}; use border_py_gym_env::{ util::ArrayType, ArrayDictObsFilter, ArrayDictObsFilterConfig, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, @@ -55,24 +55,19 @@ mod act { use act::Act; use obs::Obs; +use serde::Deserialize; type ObsFilter = ArrayDictObsFilter; type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; impl Policy for RandomPolicy { - type Config = RandomPolicyConfig; - - fn build(_config: Self::Config) -> Self { - Self - } - fn sample(&mut self, _: &Obs) -> Act { let x = 2. * fastrand::f32() - 1.; // let y = 2. * fastrand::f32() - 1.; @@ -80,6 +75,14 @@ impl Policy for RandomPolicy { } } +impl Configurable for RandomPolicy { + type Config = RandomPolicyConfig; + + fn build(_config: Self::Config) -> Self { + Self + } +} + fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); fastrand::seed(42); diff --git a/border-py-gym-env/examples/random_lunarlander_cont.rs b/border-py-gym-env/examples/random_lunarlander_cont.rs index c30252c1..8af6c5e7 100644 --- a/border-py-gym-env/examples/random_lunarlander_cont.rs +++ b/border-py-gym-env/examples/random_lunarlander_cont.rs @@ -1,5 +1,5 @@ use anyhow::Result; -use border_core::{DefaultEvaluator, Evaluator as _, Policy}; +use border_core::{Configurable, DefaultEvaluator, Evaluator as _, Policy}; use border_py_gym_env::{ ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, }; @@ -51,24 +51,19 @@ mod act { use act::Act; use obs::Obs; +use serde::Deserialize; type ObsFilter = ArrayObsFilter; type ActFilter = ContinuousActFilter; type Env = GymEnv; type Evaluator = DefaultEvaluator; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig; struct RandomPolicy; impl Policy for RandomPolicy { - type Config = RandomPolicyConfig; - - fn build(_config: Self::Config) -> Self { - Self - } - fn sample(&mut self, _: &Obs) -> Act { let x = 2. * fastrand::f32() - 1.; let y = 2. * fastrand::f32() - 1.; @@ -76,6 +71,14 @@ impl Policy for RandomPolicy { } } +impl Configurable for RandomPolicy { + type Config = RandomPolicyConfig; + + fn build(_config: Self::Config) -> Self { + Self + } +} + fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); fastrand::seed(42); diff --git a/border-py-gym-env/src/act/continuous_filter.rs b/border-py-gym-env/src/act/continuous_filter.rs index eb0350cf..68251c92 100644 --- a/border-py-gym-env/src/act/continuous_filter.rs +++ b/border-py-gym-env/src/act/continuous_filter.rs @@ -22,7 +22,7 @@ impl Default for ContinuousActFilterConfig { /// Raw filter for continuous actions. /// -/// Type `A` must implements `Into>` +/// Type `A` must implements `Into>`. #[derive(Clone, Debug)] pub struct ContinuousActFilter { // `true` indicates that this filter is used in a vectorized environment. diff --git a/border-py-gym-env/src/act_c.rs b/border-py-gym-env/src/act_c.rs deleted file mode 100644 index 574f1c13..00000000 --- a/border-py-gym-env/src/act_c.rs +++ /dev/null @@ -1,17 +0,0 @@ -//! Continuous action for [`GymEnv`](crate::GymEnv). -mod base; -pub use base::GymContinuousAct; -use ndarray::ArrayD; -use numpy::PyArrayDyn; -use pyo3::{IntoPy, PyObject}; - -/// Convert [`ArrayD`] to [`PyObject`]. -/// -/// This function does not support batch action. -pub fn to_pyobj(act: ArrayD) -> PyObject { - // let act = act.remove_axis(ndarray::Axis(0)); - pyo3::Python::with_gil(|py| { - let act = PyArrayDyn::::from_array(py, &act); - act.into_py(py) - }) -} diff --git a/border-py-gym-env/src/act_c/base.rs b/border-py-gym-env/src/act_c/base.rs deleted file mode 100644 index 0901634d..00000000 --- a/border-py-gym-env/src/act_c/base.rs +++ /dev/null @@ -1,30 +0,0 @@ -use border_core::Act; -use ndarray::ArrayD; -use std::fmt::Debug; - -/// Represents an action. -#[derive(Clone, Debug)] -pub struct GymContinuousAct { - /// Stores an action. - pub act: ArrayD, -} - -impl GymContinuousAct { - /// Constructs an action. - pub fn new(act: ArrayD) -> Self { - Self { - act, - } - } -} - -impl Act for GymContinuousAct { - fn len(&self) -> usize { - let shape = self.act.shape(); - if shape.len() == 1 { - 1 - } else { - shape[0] - } - } -} diff --git a/border-py-gym-env/src/act_d.rs b/border-py-gym-env/src/act_d.rs deleted file mode 100644 index 7cac4219..00000000 --- a/border-py-gym-env/src/act_d.rs +++ /dev/null @@ -1,3 +0,0 @@ -//! Discrete action for [`GymEnv`](crate::GymEnv). -mod base; -pub use base::GymDiscreteAct; diff --git a/border-py-gym-env/src/act_d/base.rs b/border-py-gym-env/src/act_d/base.rs deleted file mode 100644 index 5afb829b..00000000 --- a/border-py-gym-env/src/act_d/base.rs +++ /dev/null @@ -1,21 +0,0 @@ -use border_core::Act; -use std::fmt::Debug; - -/// Represents action. -#[derive(Clone, Debug)] -pub struct GymDiscreteAct { - pub act: Vec, -} - -impl GymDiscreteAct { - /// Constructs a discrete action. - pub fn new(act: Vec) -> Self { - Self { act } - } -} - -impl Act for GymDiscreteAct { - fn len(&self) -> usize { - self.act.len() - } -} diff --git a/border-py-gym-env/src/atari.rs b/border-py-gym-env/src/atari.rs index ffc833a9..cd3b130e 100644 --- a/border-py-gym-env/src/atari.rs +++ b/border-py-gym-env/src/atari.rs @@ -1,9 +1,10 @@ //! Parameters of atari environments -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] /// Specifies training or evaluation mode. #[derive(Clone)] +// TODO: consider to remove this enum pub enum AtariWrapper { /// Training mode Train, diff --git a/border-py-gym-env/src/base.rs b/border-py-gym-env/src/base.rs index debdd2fe..e0d09954 100644 --- a/border-py-gym-env/src/base.rs +++ b/border-py-gym-env/src/base.rs @@ -2,7 +2,10 @@ #![allow(clippy::float_cmp)] use crate::{AtariWrapper, GymEnvConfig}; use anyhow::Result; -use border_core::{record::Record, Act, Env, Info, Obs, Step}; +use border_core::{ + record::{Record, RecordValue::Scalar}, + Act, Env, Info, Obs, Step, +}; use log::{info, trace}; // use pyo3::IntoPy; use pyo3::types::{IntoPyDict, PyTuple}; @@ -19,6 +22,8 @@ pub struct GymInfo {} impl Info for GymInfo {} /// Convert [`PyObject`] to [`GymEnv`]::Obs with a preprocessing. +/// +/// [`PyObject`]: https://docs.rs/pyo3/0.14.5/pyo3/type.PyObject.html pub trait GymObsFilter { /// Configuration. type Config: Clone + Default + Serialize + DeserializeOwned; @@ -47,7 +52,7 @@ pub trait GymObsFilter { /// Convert [`GymEnv`]::Act to [`PyObject`] with a preprocessing. /// -/// This trait should support vectorized environments. +/// [`PyObject`]: https://docs.rs/pyo3/0.14.5/pyo3/type.PyObject.html pub trait GymActFilter { /// Configuration. type Config: Clone + Default + Serialize + DeserializeOwned; @@ -76,7 +81,7 @@ pub trait GymActFilter { } } -/// An environment in [OpenAI gym](https://github.com/openai/gym). +/// An wrapper of [Gymnasium](https://gymnasium.farama.org). #[derive(Debug)] pub struct GymEnv where @@ -171,14 +176,15 @@ where Self: Sized, { let (step, record) = self.step(a); - assert_eq!(step.is_done.len(), 1); - let step = if step.is_done[0] == 1 { + assert_eq!(step.is_terminated.len(), 1); + let step = if step.is_done() { let init_obs = self.reset(None).unwrap(); Step { act: step.act, obs: step.obs, reward: step.reward, - is_done: step.is_done, + is_terminated: step.is_terminated, + is_truncated: step.is_truncated, info: step.info, init_obs, } @@ -266,17 +272,19 @@ where /// It returns [`Step`] and [`Record`] objects. /// The [`Record`] is composed of [`Record`]s constructed in [`GymObsFilter`] and /// [`GymActFilter`]. - fn step(&mut self, a: &A) -> (Step, Record) { - fn is_done(step: &PyTuple) -> i8 { + fn step(&mut self, act: &A) -> (Step, Record) { + fn is_done(step: &PyTuple) -> (i8, i8) { // terminated or truncated - let terminated: bool = step.get_item(2).extract().unwrap(); - let truncated: bool = step.get_item(3).extract().unwrap(); + let is_terminated = match step.get_item(2).extract().unwrap() { + true => 1, + false => 0, + }; + let is_truncated = match step.get_item(3).extract().unwrap() { + true => 1, + false => 0, + }; - if terminated | truncated { - 1 - } else { - 0 - } + (is_terminated, is_truncated) } trace!("PyGymEnv::step()"); @@ -296,26 +304,68 @@ where std::thread::sleep(self.wait); } - let (a_py, record_a) = self.act_filter.filt(a.clone()); - let ret = self.env.call_method(py, "step", (a_py,), None).unwrap(); - let step: &PyTuple = ret.extract(py).unwrap(); - let obs = step.get_item(0).to_owned(); - let (obs, record_o) = self.obs_filter.filt(obs.to_object(py)); - let reward: Vec = vec![step.get_item(1).extract().unwrap()]; - let mut is_done: Vec = vec![is_done(step)]; + // State transition + let ( + act, + next_obs, + reward, + is_terminated, + mut is_truncated, + mut record, + info, + init_obs, + ) = { + let (a_py, record_a) = self.act_filter.filt(act.clone()); + let ret = self.env.call_method(py, "step", (a_py,), None).unwrap(); + let step: &PyTuple = ret.extract(py).unwrap(); + let next_obs = step.get_item(0).to_owned(); + let (next_obs, record_o) = self.obs_filter.filt(next_obs.to_object(py)); + let reward: Vec = vec![step.get_item(1).extract().unwrap()]; + let (is_terminated, is_truncated) = is_done(step); + let is_terminated = vec![is_terminated]; + let is_truncated = vec![is_truncated]; + let record = record_o.merge(record_a); + let info = GymInfo {}; + let init_obs = O::dummy(1); + let act = act.clone(); + + ( + act, + next_obs, + reward, + is_terminated, + is_truncated, + record, + info, + init_obs, + ) + }; - // let c = *self.count_steps.borrow(); self.count_steps += 1; //.replace(c + 1); + + // Terminated or truncated if let Some(max_steps) = self.max_steps { if self.count_steps >= max_steps { - is_done[0] = 1; - self.count_steps = 0; + is_truncated[0] = 1; } }; + if (is_terminated[0] | is_truncated[0]) == 1 { + record.insert("episode_length", Scalar(self.count_steps as _)); + self.count_steps = 0; + } + ( - Step::::new(obs, a.clone(), reward, is_done, GymInfo {}, O::dummy(1)), - record_o.merge(record_a), + Step::new( + next_obs, + act, + reward, + is_terminated, + is_truncated, + info, + init_obs, + ), + record, ) }) } diff --git a/border-py-gym-env/src/config.rs b/border-py-gym-env/src/config.rs index 7ad50f40..3f5a6aab 100644 --- a/border-py-gym-env/src/config.rs +++ b/border-py-gym-env/src/config.rs @@ -39,7 +39,7 @@ where pub act_filter_config: Option, /// Wait time at every interaction steps. - pub wait: Duration + pub wait: Duration, } impl Clone for GymEnvConfig diff --git a/border-py-gym-env/src/lib.rs b/border-py-gym-env/src/lib.rs index 2b2560ed..f4ee5451 100644 --- a/border-py-gym-env/src/lib.rs +++ b/border-py-gym-env/src/lib.rs @@ -4,34 +4,43 @@ //! It has been tested on some of [classic control](https://gymnasium.farama.org/environments/classic_control/) and //! [Gymnasium-Robotics](https://robotics.farama.org) environments. //! -//! ```note -//! In a past, [`Atari`](https://gym.openai.com/envs/#atari), and -//! [`PyBullet`](https://github.com/benelot/pybullet-gym) environments were supported. -//! However, currently they are not tested. -//! ``` -//! -//! This wrapper accepts array-like observation and action -//! ([`Box`](https://github.com/openai/gym/blob/master/gym/spaces/box.py) spaces), and -//! discrete action. In order to interact with Python interpreter where gym is running, -//! [`GymObsFilter`] and [`GymActFilter`] provides interfaces for converting Python object -//! (numpy array) to/from ndarrays in Rust. [`GymObsFilter`], -//! [`ContinuousActFilter`] and [`DiscreteActFilter`] do the conversion for environments -//! where observation and action are arrays. In addition to the data conversion between Python and Rust, -//! we can implements arbitrary preprocessing in these filters. For example, [`FrameStackFilter`] keeps -//! four consevutive observation frames (images) and outputs a stack of these frames. -//! -//! For Atari environments, a tweaked version of -//! [`atari_wrapper.py`](https://github.com/taku-y/border/blob/main/examples/atari_wrappers.py) -//! is required to be in `PYTHONPATH`. The frame stacking preprocessing is implemented in -//! [`FrameStackFilter`] as an [`GymObsFilter`]. -//! -//! Examples with a random controller ([`Policy`](border_core::Policy)) are in -//! [`examples`](https://github.com/taku-y/border/blob/main/border-py-gym-env/examples) directory. -//! Examples with `border-tch-agents`, which are collections of RL agents implemented with tch-rs, -//! are in [here](https://github.com/taku-y/border/blob/main/border/examples). +//! In order to bridge Python and Rust, we need to convert Python objects to Rust objects and vice versa. +//! +//! ## Observation +//! +//! Observation is created in Python and passed to Rust as a Python object. In order to convert +//! Python object to Rust object, this crate provides [`GymObsFilter`] trait. This trait has +//! [`GymObsFilter::filt`] method which converts Python object to Rust object. +//! The type of the Rust object after conversion corresponds to the type parameter `O` of the trait +//! and this is also the type of the observation in the environment, i.e., [`GymEnv`]`::Obs`. +//! +//! There are two built-in implementations of [`GymObsFilter`]: [`ArrayObsFilter`] and [`ArrayDictObsFilter`]. +//! [`ArrayObsFilter`] is for environments where observation is an array (e.g., CartPole). +//! Internally, the array is converted to [`ndarray::ArrayD`] from Python object. +//! Then, the array is converted to the type parameter `O` of the filter. +//! Since `O` must implement [`From`] by trait bound, the conversion is done +//! by calling `array.into()`. +//! +//! [`ArrayDictObsFilter`] is for environments where observation is a dictionary of arrays (e.g., FetchPickAndPlace). +//! Internally, the dictionary is converted to `Vec<(String, border_py_gym_env:util::Array)>` from Python object. +//! Then, `Vec<(String, border_py_gym_env:util::Array)>` is converted to `O` by calling `into()`. +//! +//! ## Action +//! +//! Action is created in [`Policy`] and passed to Python as a Python object. In order to convert +//! Rust object to Python object, this crate provides [`GymActFilter`] trait. This trait has +//! [`GymActFilter::filt`] method which converts Rust object of type `A`, which is the type parameter of +//! the trait, to Python object. +//! +//! There are two built-in implementations of [`GymActFilter`]: [`DiscreteActFilter`] and [`ContinuousActFilter`]. +//! [`DiscreteActFilter`] is for environments where action is discrete (e.g., CartPole). +//! This filter converts `A` to [`Vec`] and then to Python object. +//! [`ContinuousActFilter`] is for environments where action is continuous (e.g., Pendulum). +//! This filter converts `A` to [`ArrayD`] and then to Python object. +//! +//! [`Policy`]: border_core::Policy +//! [`ArrayD`]: https://docs.rs/ndarray/0.15.1/ndarray/type.ArrayD.html mod act; -mod act_c; -mod act_d; mod atari; mod base; mod config; @@ -41,13 +50,10 @@ mod vec; pub use act::{ ContinuousActFilter, ContinuousActFilterConfig, DiscreteActFilter, DiscreteActFilterConfig, }; -pub use act_c::{to_pyobj, GymContinuousAct}; -pub use act_d::GymDiscreteAct; -pub use atari::AtariWrapper; +use atari::AtariWrapper; pub use base::{GymActFilter, GymEnv, GymInfo, GymObsFilter}; pub use config::GymEnvConfig; #[allow(deprecated)] pub use obs::{ - ArrayDictObsFilter, ArrayDictObsFilterConfig, ArrayObsFilter, FrameStackFilter, GymObs, + ArrayDictObsFilter, ArrayDictObsFilterConfig, ArrayObsFilter, ArrayObsFilterConfig, }; -// pub use vec::{PyVecGymEnv, PyVecGymEnvConfig}; diff --git a/border-py-gym-env/src/obs.rs b/border-py-gym-env/src/obs.rs index 40a6b2a7..f40a5d3e 100644 --- a/border-py-gym-env/src/obs.rs +++ b/border-py-gym-env/src/obs.rs @@ -1,10 +1,5 @@ //! Observation for [`GymEnv`](crate::GymEnv). -mod base; -mod frame_stack_filter; -mod array_filter; mod array_dict_filter; -#[allow(deprecated)] -pub use base::GymObs; -pub use frame_stack_filter::{FrameStackFilter, FrameStackFilterConfig}; -pub use array_filter::{ArrayObsFilter, ArrayObsFilterConfig}; +mod array_filter; pub use array_dict_filter::{ArrayDictObsFilter, ArrayDictObsFilterConfig}; +pub use array_filter::{ArrayObsFilter, ArrayObsFilterConfig}; diff --git a/border-py-gym-env/src/obs/array_dict_filter.rs b/border-py-gym-env/src/obs/array_dict_filter.rs index b68226af..6f5a2445 100644 --- a/border-py-gym-env/src/obs/array_dict_filter.rs +++ b/border-py-gym-env/src/obs/array_dict_filter.rs @@ -28,9 +28,10 @@ impl Default for ArrayDictObsFilterConfig { impl ArrayDictObsFilterConfig { pub fn add_key_and_types(self, key_and_types: Vec<(impl Into, ArrayType)>) -> Self { - let key_and_types = key_and_types.into_iter().map(|(k, t)| { - (k.into(), t) - }).collect::>(); + let key_and_types = key_and_types + .into_iter() + .map(|(k, t)| (k.into(), t)) + .collect::>(); let mut config = self; config.key_and_types.extend(key_and_types); config @@ -84,7 +85,7 @@ where /// observation, for either of single and vectorized environments. fn filt(&mut self, obs: PyObject) -> (O, Record) where - O: From> + O: From>, { let obs = pyo3::Python::with_gil(|py| { self.config @@ -103,10 +104,10 @@ where obs.iter().for_each(|(key, arr)| { if keys.contains(key) { let v = arr.to_flat_vec::(); - record.insert(key, RecordValue::Array1(v)) + record.insert(key, RecordValue::Array1(v)) } }); - record + record } }; (obs.into(), record) diff --git a/border-py-gym-env/src/obs/array_filter.rs b/border-py-gym-env/src/obs/array_filter.rs index 1717a50f..0cb4023a 100644 --- a/border-py-gym-env/src/obs/array_filter.rs +++ b/border-py-gym-env/src/obs/array_filter.rs @@ -1,4 +1,4 @@ -use crate::{GymObsFilter, util::pyobj_to_arrayd}; +use crate::{util::pyobj_to_arrayd, GymObsFilter}; use border_core::{ record::{Record, RecordValue}, Obs, @@ -24,6 +24,8 @@ impl Default for ArrayObsFilterConfig { /// An observation filter that convertes PyObject of an numpy array. /// /// Type parameter `O` must implements [`From`]`` and [`border_core::Obs`]. +/// +/// [`border_core::Obs`]: border_core::Obs pub struct ArrayObsFilter { /// Marker. pub phantom: PhantomData<(T1, T2, O)>, diff --git a/border-py-gym-env/src/obs/base.rs b/border-py-gym-env/src/obs/base.rs deleted file mode 100644 index b72fb680..00000000 --- a/border-py-gym-env/src/obs/base.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::util::pyobj_to_arrayd; -use border_core::Obs; -use ndarray::{ArrayD, IxDyn}; -use num_traits::cast::AsPrimitive; -use numpy::Element; -use pyo3::PyObject; -use std::fmt::Debug; -use std::marker::PhantomData; -#[cfg(feature = "tch")] -use {std::convert::TryFrom, tch::Tensor}; - -/// Observation represented by an [ndarray::ArrayD]. -/// -/// `S` is the shape of an observation, except for batch and process dimensions. -/// `T` is the dtype of ndarray in the Python gym environment. -/// For some reason, the dtype of observations in Python gym environments seems to -/// vary, f32 or f64. To get observations in Rust side, the dtype is specified as a -/// type parameter, instead of checking the dtype of Python array at runtime. -#[deprecated] -#[derive(Clone, Debug)] -pub struct GymObs -where - T1: Element + Debug, - T2: 'static + Copy, -{ - pub obs: ArrayD, - pub(crate) phantom: PhantomData, -} - -#[allow(deprecated)] -impl From> for GymObs -where - T1: Element + Debug, - T2: 'static + Copy, -{ - fn from(obs: ArrayD) -> Self { - Self { - obs, - phantom: PhantomData, - } - } -} - -#[allow(deprecated)] -impl Obs for GymObs -where - T1: Debug + Element, - T2: 'static + Copy + Debug + num_traits::Zero, -{ - fn dummy(_n_procs: usize) -> Self { - // let shape = &mut S::shape().to_vec(); - // shape.insert(0, n_procs as _); - // trace!("Shape of TchPyGymEnvObs: {:?}", shape); - let shape = vec![0]; - Self { - obs: ArrayD::zeros(IxDyn(&shape[..])), - phantom: PhantomData, - } - } - - fn len(&self) -> usize { - self.obs.shape()[0] - } -} - -/// Convert numpy array of Python into [`GymObs`]. -#[allow(deprecated)] -impl From for GymObs -where - T1: Element + AsPrimitive + std::fmt::Debug, - T2: 'static + Copy, -{ - fn from(obs: PyObject) -> Self { - Self { - obs: pyobj_to_arrayd::(obs), - phantom: PhantomData, - } - } -} - -// #[cfg(feature = "tch")] -// impl From> for Tensor -// where -// S: Shape, -// T1: Element + Debug, -// T2: 'static + Copy, -// { -// fn from(obs: PyGymEnvObs) -> Tensor { -// let tmp = &obs.obs; -// Tensor::try_from(tmp).unwrap() -// // Tensor::try_from(&obs.obs).unwrap() -// } -// } - -#[allow(deprecated)] -#[cfg(feature = "tch")] -impl From> for Tensor -where - T1: Element + Debug, -{ - fn from(obs: GymObs) -> Tensor { - let tmp = &obs.obs; - Tensor::try_from(tmp).unwrap() - // Tensor::try_from(&obs.obs).unwrap() - } -} - -#[allow(deprecated)] -#[cfg(feature = "tch")] -impl From> for Tensor -where - T1: Element + Debug, -{ - fn from(obs: GymObs) -> Tensor { - let tmp = &obs.obs; - Tensor::try_from(tmp).unwrap() - // Tensor::try_from(&obs.obs).unwrap() - } -} diff --git a/border-py-gym-env/src/obs/frame_stack_filter.rs b/border-py-gym-env/src/obs/frame_stack_filter.rs deleted file mode 100644 index 4f42f0de..00000000 --- a/border-py-gym-env/src/obs/frame_stack_filter.rs +++ /dev/null @@ -1,250 +0,0 @@ -//! An observation filter with stacking observations (frames). -#[allow(deprecated)] -use super::GymObs; -use crate::GymObsFilter; -use border_core::{ - record::{Record, RecordValue}, - Obs, -}; -use ndarray::{ArrayD, Axis, SliceInfoElem}; //, SliceOrIndex}; - // use ndarray::{stack, ArrayD, Axis, IxDyn, SliceInfo, SliceInfoElem}; -use num_traits::cast::AsPrimitive; -use numpy::{Element, PyArrayDyn}; -use pyo3::{PyAny, PyObject}; -// use pyo3::{types::PyList, Py, PyAny, PyObject}; -use serde::{Deserialize, Serialize}; -use std::{fmt::Debug, marker::PhantomData}; -// use std::{convert::TryFrom, fmt::Debug, marker::PhantomData}; - -#[allow(deprecated)] -#[derive(Debug, Serialize, Deserialize)] -/// Configuration of [FrameStackFilter]. -#[derive(Clone)] -pub struct FrameStackFilterConfig { - n_procs: i64, - n_stack: i64, - vectorized: bool, -} - -impl Default for FrameStackFilterConfig { - fn default() -> Self { - Self { - n_procs: 1, - n_stack: 4, - vectorized: false, - } - } -} - -/// An observation filter with stacking sequence of original observations. -/// -/// The first element of the shape `S` denotes the number of stacks (`n_stack`) and the following elements -/// denote the shape of the partial observation, which is the observation of each environment -/// in the vectorized environment. -#[allow(deprecated)] -#[derive(Debug)] -pub struct FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero, - U: Obs + From>, -{ - // Each element in the vector corresponds to a process. - buffers: Vec>>, - - #[allow(dead_code)] - n_procs: i64, - - n_stack: i64, - - shape: Option>, - - // Verctorized environment is not supported - vectorized: bool, - - phantom: PhantomData<(T1, U)>, -} - -#[allow(deprecated)] -impl FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero, - U: Obs + From>, -{ - /// Returns the default configuration. - pub fn default_config() -> FrameStackFilterConfig { - FrameStackFilterConfig::default() - } - - /// Create slice for a dynamic array: equivalent to arr[j:(j+1), ::] in numpy. - /// - /// See - fn s(shape: &Option>, j: usize) -> Vec { - // The first index of self.shape corresponds to stacking dimension, - // specific index. - let mut slicer = vec![SliceInfoElem::Index(j as isize)]; - - // For remaining dimensions, all elements will be taken. - let n = shape.as_ref().unwrap().len() - 1; - let (start, end, step) = (0, None, 1); - - slicer.extend(vec![SliceInfoElem::Slice { start, end, step }; n]); - slicer - } - - /// Update the buffer of the stacked observations. - /// - /// * `i` - Index of process. - fn update_buffer(&mut self, i: i64, obs: &ArrayD) { - let arr = if let Some(arr) = &mut self.buffers[i as usize] { - arr - } else { - let mut shape = obs.shape().to_vec(); - self.shape = Some(shape.clone()); - shape.insert(0, self.n_stack as _); - self.buffers[i as usize] = Some(ArrayD::zeros(shape)); - self.buffers[i as usize].as_mut().unwrap() - }; - - // Shift stacks frame(j) <- frame(j - 1) for j=1,..,(n_stack - 1) - for j in (1..self.n_stack as usize).rev() { - let dst_slice = Self::s(&self.shape, j); - let src_slice = Self::s(&self.shape, j - 1); - let (mut dst, src) = arr.multi_slice_mut((dst_slice.as_slice(), src_slice.as_slice())); - dst.assign(&src); - } - arr.slice_mut(Self::s(&self.shape, 0).as_slice()) - .assign(obs) - } - - /// Fill the buffer, invoked when resetting - fn fill_buffer(&mut self, i: i64, obs: &ArrayD) { - if let Some(arr) = &mut self.buffers[i as usize] { - for j in (0..self.n_stack as usize).rev() { - let mut dst = arr.slice_mut(Self::s(&self.shape, j).as_slice()); - dst.assign(&obs); - } - } else { - unimplemented!("fill_buffer() was called before receiving the first sample."); - } - } - - /// Get ndarray from pyobj - fn get_ndarray(o: &PyAny) -> ArrayD { - debug_assert_eq!(o.get_type().name().unwrap(), "ndarray"); - let o: &PyArrayDyn = o.extract().unwrap(); - let o = o.to_owned_array(); - let o = o.mapv(|elem| elem.as_()); - o - } -} - -#[allow(deprecated)] -impl GymObsFilter for FrameStackFilter -where - T1: Element + Debug + num_traits::identities::Zero + AsPrimitive, - T2: 'static + Copy + num_traits::Zero + Into, - U: Obs + From>, -{ - type Config = FrameStackFilterConfig; - - fn build(config: &Self::Config) -> anyhow::Result - where - Self: Sized, - { - Ok(FrameStackFilter { - buffers: vec![None; config.n_procs as usize], - n_procs: config.n_procs, - n_stack: config.n_stack, - shape: None, - vectorized: config.vectorized, - phantom: PhantomData, - }) - } - - fn filt(&mut self, obs: PyObject) -> (U, Record) { - if self.vectorized { - unimplemented!(); - // // Processes the input observation to update `self.buffer` - // pyo3::Python::with_gil(|py| { - // debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "list"); - - // let obs: Py = obs.extract(py).unwrap(); - - // for (i, o) in (0..self.n_procs).zip(obs.as_ref(py).iter()) { - // let o = Self::get_ndarray(o); - // self.update_buffer(i, &o); - // } - // }); - - // // Returned values - // let array_views: Vec<_> = self.buffer.iter().map(|a| a.view()).collect(); - // let obs = PyGymEnvObs::from(stack(Axis(0), array_views.as_slice()).unwrap()); - // let obs = U::from(obs); - - // // TODO: add contents in the record - // let record = Record::empty(); - - // (obs, record) - } else { - // Update the buffer with obs - pyo3::Python::with_gil(|py| { - debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "ndarray"); - let o = Self::get_ndarray(obs.as_ref(py)); - self.update_buffer(0, &o); - }); - - // Returns stacked observation in the buffer - // img.shape() = [1, 4, 1, 84, 84] - // [batch_size, n_stack, color_ch, width, height] - let img = self.buffers[0].clone().unwrap().insert_axis(Axis(0)); - let data = img.iter().map(|&e| e.into()).collect::>(); - let shape = [img.shape()[3] * self.n_stack as usize, img.shape()[4]]; - - let obs = GymObs::from(img); - let obs = U::from(obs); - - // TODO: add contents in the record - let mut record = Record::empty(); - record.insert("frame_stack_filter_out", RecordValue::Array2(data, shape)); - - (obs, record) - } - } - - fn reset(&mut self, obs: PyObject) -> U { - if self.vectorized { - unimplemented!(); - // pyo3::Python::with_gil(|py| { - // debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "list"); - - // let obs: Py = obs.extract(py).unwrap(); - - // for (i, o) in (0..self.n_procs).zip(obs.as_ref(py).iter()) { - // if o.get_type().name().unwrap() != "NoneType" { - // let o = Self::get_ndarray(o); - // self.fill_buffer(i, &o); - // } - // } - // }); - - // // Returned values - // let array_views: Vec<_> = self.buffer.iter().map(|a| a.view()).collect(); - // O::from(stack(Axis(0), array_views.as_slice()).unwrap()) - } else { - // Update the buffer if obs is not None, otherwise do nothing - pyo3::Python::with_gil(|py| { - if obs.as_ref(py).get_type().name().unwrap() != "NoneType" { - debug_assert_eq!(obs.as_ref(py).get_type().name().unwrap(), "ndarray"); - let o = Self::get_ndarray(obs.as_ref(py)); - self.fill_buffer(0, &o); - } - }); - - // Returns stacked observation in the buffer - let frames = self.buffers[0].clone().unwrap().insert_axis(Axis(0)); - U::from(GymObs::from(frames)) - } - } -} diff --git a/border-py-gym-env/src/util.rs b/border-py-gym-env/src/util.rs index 31ffdc08..117eed69 100644 --- a/border-py-gym-env/src/util.rs +++ b/border-py-gym-env/src/util.rs @@ -1,3 +1,4 @@ +//! Utility functions mainly for data conversion between Python and Rust. use ndarray::{concatenate, ArrayD, Axis}; use num_traits::cast::AsPrimitive; use numpy::{Element, PyArrayDyn}; @@ -34,57 +35,118 @@ pub fn arrayd_to_pyobj(act: ArrayD) -> PyObject { } #[cfg(feature = "tch")] -use {std::convert::TryFrom, tch::Tensor}; +mod _tch { + use super::*; + use {std::convert::TryFrom, tch::Tensor}; -#[cfg(feature = "tch")] -pub fn vec_to_tensor(v: Vec, add_batch_dim: bool) -> Tensor -where - T1: num_traits::AsPrimitive, - T2: Copy + 'static + tch::kind::Element, -{ - let v = v.iter().map(|e| e.as_()).collect::>(); - let t: Tensor = TryFrom::>::try_from(v).unwrap(); + pub fn vec_to_tensor(v: Vec, add_batch_dim: bool) -> Tensor + where + T1: num_traits::AsPrimitive, + T2: Copy + 'static + tch::kind::Element, + { + let v = v.iter().map(|e| e.as_()).collect::>(); + let t: Tensor = TryFrom::>::try_from(v).unwrap(); - match add_batch_dim { - true => t.unsqueeze(0), - false => t, + match add_batch_dim { + true => t.unsqueeze(0), + false => t, + } } -} -#[cfg(feature = "tch")] -pub fn arrayd_to_tensor(a: ArrayD, add_batch_dim: bool) -> Tensor -where - T1: num_traits::AsPrimitive, - T2: Copy + 'static + tch::kind::Element, -{ - let v = a.iter().map(|e| e.as_()).collect::>(); - let t: Tensor = TryFrom::>::try_from(v).unwrap(); + pub fn arrayd_to_tensor(a: ArrayD, add_batch_dim: bool) -> Tensor + where + T1: num_traits::AsPrimitive, + T2: Copy + 'static + tch::kind::Element, + { + let v = a.iter().map(|e| e.as_()).collect::>(); + let t: Tensor = TryFrom::>::try_from(v).unwrap(); + + match add_batch_dim { + true => t.unsqueeze(0), + false => t, + } + } - match add_batch_dim { - true => t.unsqueeze(0), - false => t, + pub fn tensor_to_arrayd(t: Tensor, delete_batch_dim: bool) -> ArrayD + where + T: tch::kind::Element + Copy, + { + let shape = match delete_batch_dim { + false => t.size()[..].iter().map(|x| *x as usize).collect::>(), + true => t.size()[1..] + .iter() + .map(|x| *x as usize) + .collect::>(), + }; + let v = + Vec::::try_from(&t.flatten(0, -1)).expect("Failed to convert from Tensor to Vec"); + + ndarray::Array1::::from(v) + .into_shape(ndarray::IxDyn(&shape)) + .unwrap() } } #[cfg(feature = "tch")] -pub fn tensor_to_arrayd(t: Tensor, delete_batch_dim: bool) -> ArrayD -where - T: tch::kind::Element, -{ - let shape = match delete_batch_dim { - false => t.size()[..].iter().map(|x| *x as usize).collect::>(), - true => t.size()[1..] - .iter() - .map(|x| *x as usize) - .collect::>(), - }; - let v: Vec = t.into(); - - ndarray::Array1::::from(v) - .into_shape(ndarray::IxDyn(&shape)) - .unwrap() +pub use _tch::*; + +#[cfg(feature = "candle-core")] +mod _candle { + use super::*; + use anyhow::Result; + use candle_core::{Tensor, WithDType}; + use std::convert::TryFrom; + + pub fn vec_to_tensor(v: Vec, add_batch_dim: bool) -> Result + where + T1: num_traits::AsPrimitive, + T2: WithDType, + { + let v = v.iter().map(|e| e.as_()).collect::>(); + let t: Tensor = TryFrom::>::try_from(v).unwrap(); + + match add_batch_dim { + true => Ok(t.unsqueeze(0)?), + false => Ok(t), + } + } + + pub fn arrayd_to_tensor(a: ArrayD, add_batch_dim: bool) -> Result + where + T1: num_traits::AsPrimitive, + T2: WithDType, + { + let shape = a.shape(); + let v = a.iter().map(|e| e.as_()).collect::>(); + let t: Tensor = TryFrom::>::try_from(v)?; + let t = t.reshape(shape)?; + + match add_batch_dim { + true => Ok(t.unsqueeze(0)?), + false => Ok(t), + } + } + + pub fn tensor_to_arrayd(t: Tensor, delete_batch_dim: bool) -> Result> + where + T: WithDType, //tch::kind::Element, + { + let shape = match delete_batch_dim { + false => t.dims()[..].iter().map(|x| *x as usize).collect::>(), + true => t.dims()[1..] + .iter() + .map(|x| *x as usize) + .collect::>(), + }; + let v: Vec = t.flatten_all()?.to_vec1()?; + + Ok(ndarray::Array1::::from(v).into_shape(ndarray::IxDyn(&shape))?) + } } +#[cfg(feature = "candle-core")] +pub use _candle::*; + #[derive(Clone, Debug, Deserialize, Serialize)] pub enum ArrayType { F32Array, diff --git a/border-py-gym-env/src/vec.rs b/border-py-gym-env/src/vec.rs index 2c7659a4..7509f809 100644 --- a/border-py-gym-env/src/vec.rs +++ b/border-py-gym-env/src/vec.rs @@ -4,5 +4,5 @@ mod base; #[allow(dead_code)] mod config; // mod config; -pub use base::PyVecGymEnv; +// pub use base::PyVecGymEnv; pub use config::PyVecGymEnvConfig; diff --git a/border-py-gym-env/src/vec/base.rs b/border-py-gym-env/src/vec/base.rs index 155e8be0..ff2c7b00 100644 --- a/border-py-gym-env/src/vec/base.rs +++ b/border-py-gym-env/src/vec/base.rs @@ -1,14 +1,14 @@ //! Vectorized environment using multiprocess module in Python. #![allow(unused_variables, unreachable_code)] -use crate::AtariWrapper; use super::PyVecGymEnvConfig; -use crate::{GymActFilter, GymObsFilter, GymInfo}; +use crate::AtariWrapper; +use crate::{GymActFilter, GymInfo, GymObsFilter}; use anyhow::Result; use border_core::{record::Record, Act, Env, Obs, Step}; use log::trace; use pyo3::{ - types::{IntoPyDict, PyTuple}, - PyObject, ToPyObject, + types::{IntoPyDict /*PyTuple*/}, + PyObject, /*ToPyObject,*/ }; use std::{fmt::Debug, marker::PhantomData}; @@ -81,10 +81,12 @@ where AtariWrapper::Eval => false, }; // gym.call("make", (name, true, mode, config.n_procs), None)? - gym.getattr("make")?.call((name, true, mode, config.n_procs), None)? + gym.getattr("make")? + .call((name, true, mode, config.n_procs), None)? } else { // gym.call("make", (name, false, false, config.n_procs), None)? - gym.getattr("make")?.call((name, false, false, config.n_procs), None)? + gym.getattr("make")? + .call((name, false, false, config.n_procs), None)? }; Ok(PyVecGymEnv { @@ -129,37 +131,31 @@ where } fn step(&mut self, a: &A) -> (Step, Record) { - trace!("PyVecGymEnv::step()"); - trace!("{:?}", &a); - - pyo3::Python::with_gil(|py| { - // Does not support render - - let (a_py, record_a) = self.act_filter.filt(a.clone()); - let ret = self.env.call_method(py, "step", (a_py,), None).unwrap(); - let step: &PyTuple = ret.extract(py).unwrap(); - let obs = step.get_item(0).to_object(py); - let (obs, record_o) = self.obs_filter.filt(obs); - - // Reward and is_done - let reward = step.get_item(1).to_object(py); - let reward: Vec = reward.extract(py).unwrap(); - let is_done = step.get_item(2).to_object(py); - let is_done: Vec = is_done.extract(py).unwrap(); - let is_done: Vec = is_done.into_iter().map(|x| x as i8).collect(); - let n = obs.len(); - - let step = Step::::new( - obs, - a.clone(), - reward, - is_done, - GymInfo {}, - O::dummy(n), - ); - let record = record_o.merge(record_a); - - (step, record) - }) + unimplemented!(); + // trace!("PyVecGymEnv::step()"); + // trace!("{:?}", &a); + + // pyo3::Python::with_gil(|py| { + // // Does not support render + + // let (a_py, record_a) = self.act_filter.filt(a.clone()); + // let ret = self.env.call_method(py, "step", (a_py,), None).unwrap(); + // let step: &PyTuple = ret.extract(py).unwrap(); + // let obs = step.get_item(0).to_object(py); + // let (obs, record_o) = self.obs_filter.filt(obs); + + // // Reward and is_done + // let reward = step.get_item(1).to_object(py); + // let reward: Vec = reward.extract(py).unwrap(); + // let is_done = step.get_item(2).to_object(py); + // let is_done: Vec = is_done.extract(py).unwrap(); + // let is_done: Vec = is_done.into_iter().map(|x| x as i8).collect(); + // let n = obs.len(); + + // let step = Step::::new(obs, a.clone(), reward, is_done, GymInfo {}, O::dummy(n)); + // let record = record_o.merge(record_a); + + // (step, record) + // }) } } diff --git a/border-py-gym-env/src/vec/config.rs b/border-py-gym-env/src/vec/config.rs index e885c7b6..45b25a25 100644 --- a/border-py-gym-env/src/vec/config.rs +++ b/border-py-gym-env/src/vec/config.rs @@ -14,13 +14,13 @@ where AF: GymActFilter, { // Name of the environment - pub(super) name: String, - pub(super) max_steps: Option, - pub(super) atari_wrapper: Option, + pub name: String, + pub max_steps: Option, + pub atari_wrapper: Option, // The number of processes - pub(super) n_procs: usize, - pub(super) obs_filter_config: Option, - pub(super) act_filter_config: Option, + pub n_procs: usize, + pub obs_filter_config: Option, + pub act_filter_config: Option, phantom: PhantomData<(O, A, OF, AF)>, } diff --git a/border-tch-agent/Cargo.toml b/border-tch-agent/Cargo.toml index 230775ee..b3ec2531 100644 --- a/border-tch-agent/Cargo.toml +++ b/border-tch-agent/Cargo.toml @@ -1,21 +1,17 @@ [package] name = "border-tch-agent" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } -border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +border-core = { version = "0.0.7", path = "../border-core" } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } serde = { workspace = true, features = ["derive"] } serde_yaml = { workspace = true } tensorboard-rs = { workspace = true } diff --git a/border-tch-agent/src/cnn/base.rs b/border-tch-agent/src/cnn/base.rs index 73f5a27e..169854b1 100644 --- a/border-tch-agent/src/cnn/base.rs +++ b/border-tch-agent/src/cnn/base.rs @@ -1,6 +1,6 @@ +use super::CnnConfig; use crate::model::SubModel; use tch::{nn, nn::Module, Device, Tensor}; -use super::CnnConfig; #[allow(clippy::upper_case_acronyms)] /// Convolutional neural network, which has the same architecture of the DQN paper. @@ -68,6 +68,16 @@ impl SubModel for Cnn { Self::create_net(var_store, n_stack, out_dim) }; + // // Debug: check weight scale + // for (k, v) in var_store.variables() { + // if k.starts_with("c") { + // let m: f32 = v.mean(tch::Kind::Float).into(); + // let s: f32 = v.std(false).into(); + // println!("{}: mean={}, std={}", k, m, s); + // } + // } + // panic!(); + Self { n_stack, out_dim, @@ -96,4 +106,4 @@ impl SubModel for Cnn { skip_linear, } } -} \ No newline at end of file +} diff --git a/border-tch-agent/src/cnn/config.rs b/border-tch-agent/src/cnn/config.rs index f65c024b..6cb5982f 100644 --- a/border-tch-agent/src/cnn/config.rs +++ b/border-tch-agent/src/cnn/config.rs @@ -20,7 +20,11 @@ pub struct CnnConfig { impl CnnConfig { /// Constructs [`CnnConfig`] pub fn new(n_stack: i64, out_dim: i64) -> Self { - Self { n_stack, out_dim, skip_linear: false } + Self { + n_stack, + out_dim, + skip_linear: false, + } } pub fn skip_linear(mut self, skip_linear: bool) -> Self { diff --git a/border-tch-agent/src/dqn/base.rs b/border-tch-agent/src/dqn/base.rs index c1057f41..3efa4e06 100644 --- a/border-tch-agent/src/dqn/base.rs +++ b/border-tch-agent/src/dqn/base.rs @@ -2,35 +2,32 @@ use super::{config::DqnConfig, explorer::DqnExplorer, model::DqnModel}; use crate::{ model::{ModelBase, SubModel}, - util::{track, OutDim}, + util::{track, CriticLoss, OutDim}, }; use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Env, Policy, ReplayBufferBase, StdBatchBase, + Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; -use std::{fs, marker::PhantomData, path::Path}; +use std::{ + convert::{TryFrom, TryInto}, + fs, + marker::PhantomData, + path::Path, +}; use tch::{no_grad, Device, Tensor}; #[allow(clippy::upper_case_acronyms)] /// DQN agent implemented with tch-rs. pub struct Dqn where - E: Env, Q: SubModel, - R: ReplayBufferBase, - E::Obs: Into, - E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, { pub(in crate::dqn) soft_update_interval: usize, pub(in crate::dqn) soft_update_counter: usize, pub(in crate::dqn) n_updates_per_opt: usize, - pub(in crate::dqn) min_transitions_warmup: usize, pub(in crate::dqn) batch_size: usize, pub(in crate::dqn) qnet: DqnModel, pub(in crate::dqn) qnet_tgt: DqnModel, @@ -44,6 +41,10 @@ where pub(in crate::dqn) double_dqn: bool, pub(in crate::dqn) _clip_reward: Option, pub(in crate::dqn) clip_td_err: Option<(f64, f64)>, + pub(in crate::dqn) critic_loss: CriticLoss, + n_samples_act: usize, + n_samples_best_act: usize, + record_verbose_level: usize, } impl Dqn @@ -51,28 +52,43 @@ where E: Env, Q: SubModel, R: ReplayBufferBase, - E::Obs: Into, - E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, { - fn update_critic(&mut self, buffer: &mut R) -> f32 { + fn update_critic(&mut self, buffer: &mut R) -> Record { + let mut record = Record::empty(); let batch = buffer.batch(self.batch_size).unwrap(); - let (obs, act, next_obs, reward, is_done, ixs, weight) = batch.unpack(); + let (obs, act, next_obs, reward, is_terminated, _is_truncated, ixs, weight) = + batch.unpack(); let obs = obs.into(); let act = act.into().to(self.device); let next_obs = next_obs.into(); - let reward = Tensor::of_slice(&reward[..]).to(self.device); - let is_done = Tensor::of_slice(&is_done[..]).to(self.device); + let reward = Tensor::from_slice(&reward[..]).to(self.device); + let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device); let pred = { let x = self.qnet.forward(&obs); x.gather(-1, &act, false).squeeze() }; - let tgt = no_grad(|| { + if self.record_verbose_level >= 2 { + record.insert( + "pred_mean", + RecordValue::Scalar( + f32::try_from(pred.mean(tch::Kind::Float)) + .expect("Failed to convert Tensor to f32"), + ), + ); + } + + if self.record_verbose_level >= 2 { + let reward_mean: f32 = reward.mean(tch::Kind::Float).try_into().unwrap(); + record.insert("reward_mean", RecordValue::Scalar(reward_mean)); + } + + let tgt: Tensor = no_grad(|| { let q = if self.double_dqn { let x = self.qnet.forward(&next_obs); let y = x.argmax(-1, false).unsqueeze(-1); @@ -85,40 +101,90 @@ where let y = x.argmax(-1, false).unsqueeze(-1); x.gather(-1, &y, false).squeeze() }; - reward + (1 - is_done) * self.discount_factor * q + reward + (1 - is_terminated) * self.discount_factor * q }); + if self.record_verbose_level >= 2 { + record.insert( + "tgt_mean", + RecordValue::Scalar( + f32::try_from(tgt.mean(tch::Kind::Float)) + .expect("Failed to convert Tensor to f32"), + ), + ); + let tgt_minus_pred_mean: f32 = + (&tgt - &pred).mean(tch::Kind::Float).try_into().unwrap(); + record.insert( + "tgt_minus_pred_mean", + RecordValue::Scalar(tgt_minus_pred_mean), + ); + } + let loss = if let Some(ws) = weight { let n = ws.len() as i64; let td_errs = match self.clip_td_err { None => (&pred - &tgt).abs(), Some((min, max)) => (&pred - &tgt).abs().clip(min, max), }; - let loss = Tensor::of_slice(&ws[..]).to(self.device) * &td_errs; - let loss = loss.smooth_l1_loss( - &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device), - tch::Reduction::Mean, - 1.0, - ); + let loss = Tensor::from_slice(&ws[..]).to(self.device) * &td_errs; + let loss = match self.critic_loss { + CriticLoss::SmoothL1 => loss.smooth_l1_loss( + &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device), + tch::Reduction::Mean, + 1.0, + ), + CriticLoss::Mse => loss.mse_loss( + &Tensor::zeros(&[n], tch::kind::FLOAT_CPU).to(self.device), + tch::Reduction::Mean, + ), + }; self.qnet.backward_step(&loss); - let td_errs = Vec::::from(td_errs); + let td_errs = Vec::::try_from(td_errs).expect("Failed to convert Tensor to f32"); buffer.update_priority(&ixs, &Some(td_errs)); loss } else { - let loss = pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0); + let loss = match self.critic_loss { + CriticLoss::SmoothL1 => pred.smooth_l1_loss(&tgt, tch::Reduction::Mean, 1.0), + CriticLoss::Mse => pred.mse_loss(&tgt, tch::Reduction::Mean), + }; self.qnet.backward_step(&loss); loss }; - f32::from(loss) + record.insert( + "loss", + RecordValue::Scalar(f32::try_from(loss).expect("Failed to convert Tensor to f32")), + ); + + record } + // fn opt_(&mut self, buffer: &mut R) -> Record { + // let mut loss = 0f32; + + // for _ in 0..self.n_updates_per_opt { + // loss += self.update_critic(buffer); + // } + + // self.soft_update_counter += 1; + // if self.soft_update_counter == self.soft_update_interval { + // self.soft_update_counter = 0; + // track(&mut self.qnet_tgt, &mut self.qnet, self.tau); + // } + + // loss /= self.n_updates_per_opt as f32; + + // self.n_opts += 1; + + // Record::from_slice(&[("loss", RecordValue::Scalar(loss))]) + // } + fn opt_(&mut self, buffer: &mut R) -> Record { - let mut loss_critic = 0f32; + let mut record_ = Record::empty(); for _ in 0..self.n_updates_per_opt { - let loss = self.update_critic(buffer); - loss_critic += loss; + let record = self.update_critic(buffer); + record_ = record_.merge(record); } self.soft_update_counter += 1; @@ -127,11 +193,10 @@ where track(&mut self.qnet_tgt, &mut self.qnet, self.tau); } - loss_critic /= self.n_updates_per_opt as f32; - self.n_opts += 1; - Record::from_slice(&[("loss_critic", RecordValue::Scalar(loss_critic))]) + record_ + // Record::from_slice(&[("loss", RecordValue::Scalar(loss_critic))]) } } @@ -139,13 +204,50 @@ impl Policy for Dqn where E: Env, Q: SubModel, - R: ReplayBufferBase, E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, +{ + fn sample(&mut self, obs: &E::Obs) -> E::Act { + no_grad(|| { + let a = self.qnet.forward(&obs.clone().into()); + let a = if self.train { + self.n_samples_act += 1; + match &mut self.explorer { + DqnExplorer::Softmax(softmax) => softmax.action(&a), + DqnExplorer::EpsilonGreedy(egreedy) => { + if self.record_verbose_level >= 2 { + let (act, best) = egreedy.action_with_best(&a); + if best { + self.n_samples_best_act += 1; + } + act + } else { + egreedy.action(&a) + } + } + } + } else { + if fastrand::f32() < 0.01 { + let n_actions = a.size()[1] as i32; + let a = fastrand::i32(0..n_actions); + Tensor::from(a) + } else { + a.argmax(-1, true) + } + }; + a.into() + }) + } +} + +impl Configurable for Dqn +where + E: Env, + Q: SubModel, + E::Obs: Into, + E::Act: From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, { type Config = DqnConfig; @@ -164,7 +266,6 @@ where soft_update_interval: config.soft_update_interval, soft_update_counter: 0, n_updates_per_opt: config.n_updates_per_opt, - min_transitions_warmup: config.min_transitions_warmup, batch_size: config.batch_size, discount_factor: config.discount_factor, tau: config.tau, @@ -175,30 +276,13 @@ where _clip_reward: config.clip_reward, double_dqn: config.double_dqn, clip_td_err: config.clip_td_err, + critic_loss: config.critic_loss, + n_samples_act: 0, + n_samples_best_act: 0, + record_verbose_level: config.record_verbose_level, phantom: PhantomData, } } - - fn sample(&mut self, obs: &E::Obs) -> E::Act { - no_grad(|| { - let a = self.qnet.forward(&obs.clone().into()); - let a = if self.train { - match &mut self.explorer { - DqnExplorer::Softmax(softmax) => softmax.action(&a), - DqnExplorer::EpsilonGreedy(egreedy) => egreedy.action(&a), - } - } else { - if fastrand::f32() < 0.01 { - let n_actions = a.size()[1] as i32; - let a = fastrand::i32(0..n_actions); - Tensor::from(a) - } else { - a.argmax(-1, true) - } - }; - a.into() - }) - } } impl Agent for Dqn @@ -209,9 +293,9 @@ where E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, { fn train(&mut self) { self.train = true; @@ -225,27 +309,57 @@ where self.train } - fn opt(&mut self, buffer: &mut R) -> Option { - if buffer.len() >= self.min_transitions_warmup { - Some(self.opt_(buffer)) - } else { - None + fn opt(&mut self, buffer: &mut R) { + self.opt_(buffer); + } + + fn opt_with_record(&mut self, buffer: &mut R) -> Record { + let mut record = { + let record = self.opt_(buffer); + + match self.record_verbose_level >= 2 { + true => { + let record_weights = self.qnet.param_stats(); + let record = record.merge(record_weights); + record + } + false => record, + } + }; + + // Best action ratio for epsilon greedy + if self.record_verbose_level >= 2 { + let ratio = match self.n_samples_act == 0 { + true => 0f32, + false => self.n_samples_best_act as f32 / self.n_samples_act as f32, + }; + record.insert("ratio_best_act", RecordValue::Scalar(ratio)); + self.n_samples_act = 0; + self.n_samples_best_act = 0; } + + record } - fn save>(&self, path: T) -> Result<()> { + /// Save model parameters in the given directory. + /// + /// The parameters of the model are saved as `qnet.pt`. + /// The parameters of the target model are saved as `qnet_tgt.pt`. + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; - self.qnet.save(&path.as_ref().join("qnet.pt").as_path())?; + self.qnet + .save(&path.as_ref().join("qnet.pt.tch").as_path())?; self.qnet_tgt - .save(&path.as_ref().join("qnet_tgt.pt").as_path())?; + .save(&path.as_ref().join("qnet_tgt.pt.tch").as_path())?; Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { - self.qnet.load(&path.as_ref().join("qnet.pt").as_path())?; + fn load_params>(&mut self, path: T) -> Result<()> { + self.qnet + .load(&path.as_ref().join("qnet.pt.tch").as_path())?; self.qnet_tgt - .load(&path.as_ref().join("qnet_tgt.pt").as_path())?; + .load(&path.as_ref().join("qnet_tgt.pt.tch").as_path())?; Ok(()) } } @@ -262,9 +376,9 @@ where E::Obs: Into, E::Act: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, { type ModelInfo = NamedTensors; diff --git a/border-tch-agent/src/dqn/config.rs b/border-tch-agent/src/dqn/config.rs index 084a6fc4..c7fc58f8 100644 --- a/border-tch-agent/src/dqn/config.rs +++ b/border-tch-agent/src/dqn/config.rs @@ -3,7 +3,12 @@ use super::{ explorer::{DqnExplorer, Softmax}, DqnModelConfig, }; -use crate::{model::SubModel, util::OutDim, Device}; +use crate::{ + model::SubModel, + opt::OptimizerConfig, + util::{CriticLoss, OutDim}, + Device, +}; use anyhow::Result; use log::info; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -16,29 +21,30 @@ use std::{ }; use tch::Tensor; -/// Configuration of [Dqn](super::Dqn) agent. +/// Configuration of [`Dqn`](super::Dqn) agent. #[derive(Debug, Deserialize, Serialize, PartialEq)] pub struct DqnConfig where Q: SubModel, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, { - pub(super) model_config: DqnModelConfig, - pub(super) soft_update_interval: usize, - pub(super) n_updates_per_opt: usize, - pub(super) min_transitions_warmup: usize, - pub(super) batch_size: usize, - pub(super) discount_factor: f64, - pub(super) tau: f64, - pub(super) train: bool, - pub(super) explorer: DqnExplorer, + pub model_config: DqnModelConfig, + pub soft_update_interval: usize, + pub n_updates_per_opt: usize, + pub batch_size: usize, + pub discount_factor: f64, + pub tau: f64, + pub train: bool, + pub explorer: DqnExplorer, #[serde(default)] - pub(super) clip_reward: Option, + pub clip_reward: Option, #[serde(default)] - pub(super) double_dqn: bool, - pub(super) clip_td_err: Option<(f64, f64)>, + pub double_dqn: bool, + pub clip_td_err: Option<(f64, f64)>, pub device: Option, - phantom: PhantomData, + pub critic_loss: CriticLoss, + pub record_verbose_level: usize, + pub phantom: PhantomData, } impl Clone for DqnConfig @@ -51,7 +57,6 @@ where model_config: self.model_config.clone(), soft_update_interval: self.soft_update_interval, n_updates_per_opt: self.n_updates_per_opt, - min_transitions_warmup: self.min_transitions_warmup, batch_size: self.batch_size, discount_factor: self.discount_factor, tau: self.tau, @@ -61,7 +66,9 @@ where double_dqn: self.double_dqn, clip_td_err: self.clip_td_err, device: self.device.clone(), - phantom: PhantomData, + critic_loss: self.critic_loss.clone(), + record_verbose_level: self.record_verbose_level, + phantom: PhantomData, } } } @@ -77,7 +84,6 @@ where model_config: Default::default(), soft_update_interval: 1, n_updates_per_opt: 1, - min_transitions_warmup: 1, batch_size: 1, discount_factor: 0.99, tau: 0.005, @@ -89,6 +95,8 @@ where double_dqn: false, clip_td_err: None, device: None, + critic_loss: CriticLoss::Mse, + record_verbose_level: 0, phantom: PhantomData, } } @@ -111,12 +119,6 @@ where self } - /// Interval before starting optimization. - pub fn min_transitions_warmup(mut self, v: usize) -> Self { - self.min_transitions_warmup = v; - self - } - /// Batch size. pub fn batch_size(mut self, v: usize) -> Self { self.batch_size = v; @@ -147,6 +149,12 @@ where self } + /// Sets the configration of the optimizer. + pub fn opt_config(mut self, opt_config: OptimizerConfig) -> Self { + self.model_config = self.model_config.opt_config(opt_config); + self + } + /// Sets the output dimention of the dqn model of the DQN agent. pub fn out_dim(mut self, out_dim: i64) -> Self { let model_config = self.model_config.clone(); @@ -178,7 +186,19 @@ where self } - /// Loads [DqnConfig] from YAML file. + /// Sets critic loss. + pub fn critic_loss(mut self, v: CriticLoss) -> Self { + self.critic_loss = v; + self + } + + /// Sets verbose level. + pub fn record_verbose_level(mut self, v: usize) -> Self { + self.record_verbose_level = v; + self + } + + /// Loads [`DqnConfig`] from YAML file. pub fn load(path: impl AsRef) -> Result { let path_ = path.as_ref().to_owned(); let file = File::open(path)?; @@ -188,7 +208,7 @@ where Ok(b) } - /// Saves [DqnConfig]. + /// Saves [`DqnConfig`]. pub fn save(&self, path: impl AsRef) -> Result<()> { let path_ = path.as_ref().to_owned(); let mut file = File::create(path)?; diff --git a/border-tch-agent/src/dqn/explorer.rs b/border-tch-agent/src/dqn/explorer.rs index 05040941..2c0852d9 100644 --- a/border-tch-agent/src/dqn/explorer.rs +++ b/border-tch-agent/src/dqn/explorer.rs @@ -1,4 +1,6 @@ //! Exploration strategies of DQN. +use std::convert::TryInto; + use serde::{Deserialize, Serialize}; use tch::Tensor; @@ -32,10 +34,10 @@ impl Softmax { /// Epsilon-greedy explorer for DQN. #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub struct EpsilonGreedy { - n_opts: usize, - eps_start: f64, - eps_final: f64, - final_step: usize, + pub n_opts: usize, + pub eps_start: f64, + pub eps_final: f64, + pub final_step: usize, } #[allow(clippy::new_without_default)] @@ -70,17 +72,50 @@ impl EpsilonGreedy { let is_random = r < eps; self.n_opts += 1; + let best = a.argmax(-1, true); + + if is_random { + let n_procs = a.size()[0] as u32; + let n_actions = a.size()[1] as u32; + let act = Tensor::from_slice( + (0..n_procs) + .map(|_| fastrand::u32(..n_actions) as i32) + .collect::>() + .as_slice(), + ); + act + } else { + best + } + } + + /// Takes an action based on the observation and the critic. + pub fn action_with_best(&mut self, a: &Tensor) -> (Tensor, bool) { + let d = (self.eps_start - self.eps_final) / (self.final_step as f64); + let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final); + let r = fastrand::f64(); + let is_random = r < eps; + self.n_opts += 1; + + let best = a.argmax(-1, true); + if is_random { let n_procs = a.size()[0] as u32; let n_actions = a.size()[1] as u32; - Tensor::of_slice( + let act = Tensor::from_slice( (0..n_procs) .map(|_| fastrand::u32(..n_actions) as i32) .collect::>() .as_slice(), - ) + ); + let diff: i64 = (&act - &best.to(tch::Device::Cpu)) + .abs() + .sum(tch::Kind::Int64) + .try_into() + .unwrap(); + (act, diff == 0) } else { - a.argmax(-1, true) + (best, true) } } diff --git a/border-tch-agent/src/dqn/model/base.rs b/border-tch-agent/src/dqn/model/base.rs index 63d1efba..56180183 100644 --- a/border-tch-agent/src/dqn/model/base.rs +++ b/border-tch-agent/src/dqn/model/base.rs @@ -5,13 +5,18 @@ use crate::{ util::OutDim, }; use anyhow::Result; +use border_core::record::Record; use log::{info, trace}; use serde::{de::DeserializeOwned, Serialize}; use std::{marker::PhantomData, path::Path}; use tch::{nn, Device, Tensor}; -#[allow(clippy::upper_case_acronyms)] -/// Represents value functions for DQN agents. +/// Action value function model for DQN. +/// +/// The architecture of the model is defined by the type parameter `Q`, +/// which should implement [`SubModel`]. +/// This takes [`SubModel::Input`] as input and outputs a tensor. +/// The output tensor should have the same dimension as the number of actions. pub struct DqnModel where Q: SubModel, @@ -74,12 +79,16 @@ where } } - /// Outputs the action-value given an observation. + /// Outputs the action-value given observation(s). pub fn forward(&self, x: &Q::Input) -> Tensor { let a = self.q.forward(&x); debug_assert_eq!(a.size().as_slice()[1], self.out_dim); a } + + pub fn param_stats(&self) -> Record { + crate::util::param_stats(&self.var_store) + } } impl Clone for DqnModel diff --git a/border-tch-agent/src/dqn/model/config.rs b/border-tch-agent/src/dqn/model/config.rs index d6697dc3..ce31802c 100644 --- a/border-tch-agent/src/dqn/model/config.rs +++ b/border-tch-agent/src/dqn/model/config.rs @@ -8,15 +8,15 @@ use std::{ }; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [DqnModel](super::DqnModel). +/// Configuration of [`DqnModel`](super::DqnModel). pub struct DqnModelConfig where // Q: SubModel, // Q::Config: DeserializeOwned + Serialize + OutDim, Q: OutDim, { - pub(super) q_config: Option, - pub(super) opt_config: OptimizerConfig, + pub q_config: Option, + pub opt_config: OptimizerConfig, } // impl> Default for DQNModelConfig @@ -41,7 +41,6 @@ where // Q::Config: DeserializeOwned + Serialize + OutDim, Q: DeserializeOwned + Serialize + OutDim, { - /// Sets configurations for action-value function. // pub fn q_config(mut self, v: Q::Config) -> Self { pub fn q_config(mut self, v: Q) -> Self { diff --git a/border-tch-agent/src/iqn.rs b/border-tch-agent/src/iqn.rs index d65945b9..57c90bb3 100644 --- a/border-tch-agent/src/iqn.rs +++ b/border-tch-agent/src/iqn.rs @@ -5,5 +5,5 @@ mod explorer; mod model; pub use base::Iqn; pub use config::IqnConfig; -pub use explorer::{EpsilonGreedy, IqnExplorer}; -pub use model::{IqnModel, IqnModelConfig, IqnSample, average}; +pub use explorer::{EpsilonGreedy, IqnExplorer, Softmax}; +pub use model::{average, IqnModel, IqnModelConfig, IqnSample}; diff --git a/border-tch-agent/src/iqn/base.rs b/border-tch-agent/src/iqn/base.rs index 41a7ac4b..bd549d26 100644 --- a/border-tch-agent/src/iqn/base.rs +++ b/border-tch-agent/src/iqn/base.rs @@ -1,5 +1,5 @@ //! IQN agent implemented with tch-rs. -use super::{average, IqnExplorer, IqnConfig, IqnModel, IqnSample}; +use super::{average, IqnConfig, IqnExplorer, IqnModel, IqnSample}; use crate::{ model::{ModelBase, SubModel}, util::{quantile_huber_loss, track, OutDim}, @@ -7,11 +7,11 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, StdBatchBase, Env, Policy, ReplayBufferBase, + Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, }; use log::trace; use serde::{de::DeserializeOwned, Serialize}; -use std::{fs, marker::PhantomData, path::Path}; +use std::{convert::TryFrom, fs, marker::PhantomData, path::Path}; use tch::{no_grad, Device, Tensor}; /// IQN agent implemented with tch-rs. @@ -20,22 +20,14 @@ use tch::{no_grad, Device, Tensor}; /// `M::Input` and returns feature vectors. pub struct Iqn where - E: Env, F: SubModel, M: SubModel, - R: ReplayBufferBase, - E::Obs: Into, - E::Act: From, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, { pub(in crate::iqn) soft_update_interval: usize, pub(in crate::iqn) soft_update_counter: usize, pub(in crate::iqn) n_updates_per_opt: usize, - pub(in crate::iqn) min_transitions_warmup: usize, pub(in crate::iqn) batch_size: usize, pub(in crate::iqn) iqn: IqnModel, pub(in crate::iqn) iqn_tgt: IqnModel, @@ -57,30 +49,33 @@ where F: SubModel, M: SubModel, R: ReplayBufferBase, - E::Obs: Into, - E::Act: From, F::Config: DeserializeOwned + Serialize, M::Config: DeserializeOwned + Serialize + OutDim, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, { fn update_critic(&mut self, buffer: &mut R) -> f32 { trace!("IQN::update_critic()"); let batch = buffer.batch(self.batch_size).unwrap(); - let (obs, act, next_obs, reward, is_done, _ixs, _weight) = batch.unpack(); + let (obs, act, next_obs, reward, is_terminated, _is_truncated, _ixs, _weight) = + batch.unpack(); let obs = obs.into(); let act = act.into().to(self.device); let next_obs = next_obs.into(); - let reward = Tensor::of_slice(&reward[..]).to(self.device).unsqueeze(-1); - let is_done = Tensor::of_slice(&is_done[..]).to(self.device).unsqueeze(-1); + let reward = Tensor::from_slice(&reward[..]) + .to(self.device) + .unsqueeze(-1); + let is_terminated = Tensor::from_slice(&is_terminated[..]) + .to(self.device) + .unsqueeze(-1); let batch_size = self.batch_size as _; let n_percent_points_pred = self.sample_percents_pred.n_percent_points(); let n_percent_points_tgt = self.sample_percents_tgt.n_percent_points(); debug_assert_eq!(reward.size().as_slice(), &[batch_size, 1]); - debug_assert_eq!(is_done.size().as_slice(), &[batch_size, 1]); + debug_assert_eq!(is_terminated.size().as_slice(), &[batch_size, 1]); debug_assert_eq!(act.size().as_slice(), &[batch_size, 1]); let loss = { @@ -130,7 +125,9 @@ where ); // argmax_a z(s,a), where z are averaged over tau - let y = z.copy().mean_dim(&[1], false, tch::Kind::Float); + let y = z + .copy() + .mean_dim(Some([1].as_slice()), false, tch::Kind::Float); let a = y.argmax(-1, false).unsqueeze(-1).unsqueeze(-1).repeat(&[ 1, n_percent_points, @@ -143,7 +140,7 @@ where debug_assert_eq!(z.size().as_slice(), &[batch_size, n_percent_points]); // target value - let tgt: Tensor = reward + (1 - is_done) * self.discount_factor * z; + let tgt: Tensor = reward + (1 - is_terminated) * self.discount_factor * z; debug_assert_eq!(tgt.size().as_slice(), &[batch_size, n_percent_points]); tgt.unsqueeze(-1) @@ -164,7 +161,7 @@ where self.iqn.backward_step(&loss); - f32::from(loss) + f32::try_from(loss).expect("Failed to convert Tensor to f32") } fn opt_(&mut self, buffer: &mut R) -> Record { @@ -194,18 +191,51 @@ where E: Env, F: SubModel, M: SubModel, - R: ReplayBufferBase, E::Obs: Into, E::Act: From, F::Config: DeserializeOwned + Serialize + Clone, M::Config: DeserializeOwned + Serialize + Clone + OutDim, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, +{ + fn sample(&mut self, obs: &E::Obs) -> E::Act { + // Do not support vectorized env + let batch_size = 1; + + let a = no_grad(|| { + let action_value = average( + batch_size, + &obs.clone().into(), + &self.iqn, + &self.sample_percents_act, + self.device, + ); + + if self.train { + match &mut self.explorer { + IqnExplorer::Softmax(softmax) => softmax.action(&action_value), + IqnExplorer::EpsilonGreedy(egreedy) => egreedy.action(action_value), + } + } else { + action_value.argmax(-1, true) + } + }); + + a.into() + } +} + +impl Configurable for Iqn +where + E: Env, + F: SubModel, + M: SubModel, + E::Obs: Into, + E::Act: From, + F::Config: DeserializeOwned + Serialize + Clone, + M::Config: DeserializeOwned + Serialize + Clone + OutDim, { type Config = IqnConfig; - /// Constructs [Iqn] agent. + /// Constructs [`Iqn`] agent. fn build(config: Self::Config) -> Self { let device = config .device @@ -220,7 +250,6 @@ where soft_update_interval: config.soft_update_interval, soft_update_counter: 0, n_updates_per_opt: config.n_updates_per_opt, - min_transitions_warmup: config.min_transitions_warmup, batch_size: config.batch_size, discount_factor: config.discount_factor, tau: config.tau, @@ -234,32 +263,6 @@ where phantom: PhantomData, } } - - fn sample(&mut self, obs: &E::Obs) -> E::Act { - // Do not support vectorized env - let batch_size = 1; - - let a = no_grad(|| { - let obs = obs.clone().into(); - let action_value = average( - batch_size, - &obs, - &self.iqn, - &self.sample_percents_act, - self.device, - ); - - if self.train { - match &mut self.explorer { - IqnExplorer::EpsilonGreedy(egreedy) => egreedy.action(action_value), - } - } else { - action_value.argmax(-1, true) - } - }); - - a.into() - } } impl Agent for Iqn @@ -272,9 +275,9 @@ where E::Act: From, F::Config: DeserializeOwned + Serialize + Clone, M::Config: DeserializeOwned + Serialize + Clone + OutDim, - R::Batch: StdBatchBase, - ::ObsBatch: Into, - ::ActBatch: Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into, + ::ActBatch: Into, { fn train(&mut self) { self.train = true; @@ -288,74 +291,23 @@ where self.train } - fn opt(&mut self, buffer: &mut R) -> Option { - if buffer.len() >= self.min_transitions_warmup { - Some(self.opt_(buffer)) - } else { - None - } + fn opt_with_record(&mut self, buffer: &mut R) -> Record { + self.opt_(buffer) } - // /// Update model parameters. - // /// - // /// When the return value is `Some(Record)`, it includes: - // /// * `loss_critic`: Loss of critic - // fn observe(&mut self, step: Step) -> Option { - // trace!("DQN::observe()"); - - // // Check if doing optimization - // let do_optimize = self.opt_interval_counter.do_optimize(&step.is_done) - // && self.replay_buffer.len() + 1 >= self.min_transitions_warmup; - - // // Push transition to the replay buffer - // self.push_transition(step); - // trace!("Push transition"); - - // // Do optimization - // if do_optimize { - // let mut loss_critic = 0f32; - - // for _ in 0..self.n_updates_per_opt { - // let batch = self - // .replay_buffer - // .random_batch(self.batch_size, 0f32) - // .unwrap(); - // trace!("Sample random batch"); - - // loss_critic += self.update_critic(batch); - // } - - // self.soft_update_counter += 1; - // if self.soft_update_counter == self.soft_update_interval { - // self.soft_update_counter = 0; - // self.soft_update(); - // trace!("Update target network"); - // } - - // loss_critic /= self.n_updates_per_opt as f32; - - // Some(Record::from_slice(&[( - // "loss_critic", - // RecordValue::Scalar(loss_critic), - // )])) - // } else { - // None - // } - // } - - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; - self.iqn.save(&path.as_ref().join("iqn.pt").as_path())?; + self.iqn.save(&path.as_ref().join("iqn.pt.tch").as_path())?; self.iqn_tgt - .save(&path.as_ref().join("iqn_tgt.pt").as_path())?; + .save(&path.as_ref().join("iqn_tgt.pt.tch").as_path())?; Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { - self.iqn.load(&path.as_ref().join("iqn.pt").as_path())?; + fn load_params>(&mut self, path: T) -> Result<()> { + self.iqn.load(&path.as_ref().join("iqn.pt.tch").as_path())?; self.iqn_tgt - .load(&path.as_ref().join("iqn_tgt.pt").as_path())?; + .load(&path.as_ref().join("iqn_tgt.pt.tch").as_path())?; Ok(()) } } diff --git a/border-tch-agent/src/iqn/config.rs b/border-tch-agent/src/iqn/config.rs index 212a3575..c02e3109 100644 --- a/border-tch-agent/src/iqn/config.rs +++ b/border-tch-agent/src/iqn/config.rs @@ -1,7 +1,7 @@ //! Configuration of IQN agent. use super::{IqnModelConfig, IqnSample}; use crate::{ - iqn::{EpsilonGreedy, IqnExplorer}, + iqn::{IqnExplorer, Softmax}, model::SubModel, util::OutDim, Device, @@ -17,7 +17,7 @@ use std::{ }; #[derive(Debug, Deserialize, Serialize, PartialEq)] -/// Configuration of [Iqn](super::Iqn) agent. +/// Configuration of [`Iqn`](super::Iqn) agent. pub struct IqnConfig where F: SubModel, @@ -25,18 +25,17 @@ where F::Config: DeserializeOwned + Serialize + Clone, M::Config: DeserializeOwned + Serialize + Clone + OutDim, { - pub(super) model_config: IqnModelConfig, - pub(super) soft_update_interval: usize, - pub(super) n_updates_per_opt: usize, - pub(super) min_transitions_warmup: usize, - pub(super) batch_size: usize, - pub(super) discount_factor: f64, - pub(super) tau: f64, - pub(super) train: bool, - pub(super) explorer: IqnExplorer, - pub(super) sample_percents_pred: IqnSample, - pub(super) sample_percents_tgt: IqnSample, - pub(super) sample_percents_act: IqnSample, + pub model_config: IqnModelConfig, + pub soft_update_interval: usize, + pub n_updates_per_opt: usize, + pub batch_size: usize, + pub discount_factor: f64, + pub tau: f64, + pub train: bool, + pub explorer: IqnExplorer, + pub sample_percents_pred: IqnSample, + pub sample_percents_tgt: IqnSample, + pub sample_percents_act: IqnSample, pub device: Option, phantom: PhantomData<(F, M)>, } @@ -53,15 +52,15 @@ where model_config: Default::default(), soft_update_interval: 1, n_updates_per_opt: 1, - min_transitions_warmup: 1, batch_size: 1, discount_factor: 0.99, tau: 0.005, - sample_percents_pred: IqnSample::Uniform64, - sample_percents_tgt: IqnSample::Uniform64, - sample_percents_act: IqnSample::Uniform32, // Const10, + sample_percents_pred: IqnSample::Uniform8, + sample_percents_tgt: IqnSample::Uniform8, + sample_percents_act: IqnSample::Const32, train: false, - explorer: IqnExplorer::EpsilonGreedy(EpsilonGreedy::default()), + explorer: IqnExplorer::Softmax(Softmax::new()), + // explorer: IqnExplorer::EpsilonGreedy(EpsilonGreedy::default()), device: None, phantom: PhantomData, } @@ -93,12 +92,6 @@ where self } - /// Interval before starting optimization. - pub fn min_transitions_warmup(mut self, v: usize) -> Self { - self.min_transitions_warmup = v; - self - } - /// Batch size. pub fn batch_size(mut self, v: usize) -> Self { self.batch_size = v; @@ -154,7 +147,7 @@ where self } - /// Constructs [IqnConfig] from YAML file. + /// Constructs [`IqnConfig`] from YAML file. pub fn load(path: impl AsRef) -> Result { let file = File::open(path)?; let rdr = BufReader::new(file); @@ -162,7 +155,7 @@ where Ok(b) } - /// Saves [IqnConfig]. + /// Saves [`IqnConfig`]. pub fn save(&self, path: impl AsRef) -> Result<()> { let mut file = File::create(path)?; file.write_all(serde_yaml::to_string(&self)?.as_bytes())?; @@ -226,7 +219,6 @@ where model_config: self.model_config.clone(), soft_update_interval: self.soft_update_interval, n_updates_per_opt: self.n_updates_per_opt, - min_transitions_warmup: self.min_transitions_warmup, batch_size: self.batch_size, discount_factor: self.discount_factor, tau: self.tau, diff --git a/border-tch-agent/src/iqn/explorer.rs b/border-tch-agent/src/iqn/explorer.rs index 5c01d22f..152b72a6 100644 --- a/border-tch-agent/src/iqn/explorer.rs +++ b/border-tch-agent/src/iqn/explorer.rs @@ -7,28 +7,28 @@ use tch::Tensor; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] /// Explorers for IQN. pub enum IqnExplorer { - // /// Softmax action selection. - // Softmax(Softmax), + /// Softmax action selection. + Softmax(Softmax), /// Epsilon-greedy action selection. EpsilonGreedy(EpsilonGreedy), } -// /// Softmax explorer for IQN. -// pub struct Softmax {} +/// Softmax explorer for IQN. +#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] +pub struct Softmax {} -// #[allow(clippy::new_without_default)] -// impl Softmax { -// /// Constructs softmax explorer. -// pub fn new() -> Self { Self {} } +#[allow(clippy::new_without_default)] +impl Softmax { + /// Constructs softmax explorer. + pub fn new() -> Self { + Self {} + } -// /// Takes an action based on the observation and the critic. -// pub fn action(&mut self, qnet: &M, obs: &Tensor) -> Tensor where -// M: Model1, -// { -// let a = qnet.forward(obs); -// a.softmax(-1, tch::Kind::Float).multinomial(1, true) -// } -// } + /// Takes an action based on the observation and the critic. + pub fn action(&mut self, a: &Tensor) -> Tensor { + a.softmax(-1, tch::Kind::Float).multinomial(1, true) + } +} #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] /// Epsilon-greedy explorer for IQN. @@ -85,7 +85,7 @@ impl EpsilonGreedy { if is_random { let batch_size = action_value.size()[0]; let n_actions = action_value.size()[1] as u32; - Tensor::of_slice( + Tensor::from_slice( (0..batch_size) .map(|_| fastrand::u32(..n_actions) as i32) .collect::>() diff --git a/border-tch-agent/src/iqn/model.rs b/border-tch-agent/src/iqn/model.rs index 023a01fd..aa28a048 100644 --- a/border-tch-agent/src/iqn/model.rs +++ b/border-tch-agent/src/iqn/model.rs @@ -1,5 +1,5 @@ //! IQN model. mod base; mod config; -pub use base::{IqnModel, IqnSample, average}; +pub use base::{average, IqnModel, IqnSample}; pub use config::IqnModelConfig; diff --git a/border-tch-agent/src/iqn/model/base.rs b/border-tch-agent/src/iqn/model/base.rs index 31ab4506..4cef9afa 100644 --- a/border-tch-agent/src/iqn/model/base.rs +++ b/border-tch-agent/src/iqn/model/base.rs @@ -330,6 +330,9 @@ pub enum IqnSample { /// The precent points are constants. Const10, + /// The precent points are constants. + Const32, + /// 10 samples from uniform distribution. Uniform10, @@ -350,16 +353,20 @@ impl IqnSample { /// Returns samples of percent points. pub fn sample(&self, batch_size: i64) -> Tensor { match self { - Self::Const10 => Tensor::of_slice(&[ + Self::Const10 => Tensor::from_slice(&[ 0.05_f32, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95, ]) .unsqueeze(0) .repeat(&[batch_size, 1]), + Self::Const32 => { + let t: Tensor = (1.0 / 32.0) * Tensor::range(0, 32, tch::kind::FLOAT_CPU); + t.unsqueeze(0).repeat(&[batch_size, 1]) + } Self::Uniform10 => Tensor::rand(&[batch_size, 10], tch::kind::FLOAT_CPU), Self::Uniform8 => Tensor::rand(&[batch_size, 8], tch::kind::FLOAT_CPU), Self::Uniform32 => Tensor::rand(&[batch_size, 32], tch::kind::FLOAT_CPU), Self::Uniform64 => Tensor::rand(&[batch_size, 64], tch::kind::FLOAT_CPU), - Self::Median => Tensor::of_slice(&[0.5_f32]) + Self::Median => Tensor::from_slice(&[0.5_f32]) .unsqueeze(0) .repeat(&[batch_size, 1]), } @@ -369,6 +376,7 @@ impl IqnSample { pub fn n_percent_points(&self) -> i64 { match self { Self::Const10 => 10, + Self::Const32 => 32, Self::Uniform10 => 10, Self::Uniform8 => 8, Self::Uniform32 => 32, @@ -397,7 +405,9 @@ where M::Config: DeserializeOwned + Serialize + OutDim, { let tau = mode.sample(batch_size).to(device); - let averaged_action_value = iqn.forward(obs, &tau).mean_dim(&[1], false, Float); + let averaged_action_value = iqn + .forward(obs, &tau) + .mean_dim(Some([1].as_slice()), false, Float); let batch_size = averaged_action_value.size()[0]; let n_action = iqn.out_dim; debug_assert_eq!( @@ -487,7 +497,7 @@ mod test { .feature_dim(feature_dim) .embed_dim(embed_dim) .learning_rate(learning_rate); - + IqnModel::build_with_submodel_configs(config, fe_config, m_config, device) } diff --git a/border-tch-agent/src/iqn/model/config.rs b/border-tch-agent/src/iqn/model/config.rs index 57773b37..858e1dcd 100644 --- a/border-tch-agent/src/iqn/model/config.rs +++ b/border-tch-agent/src/iqn/model/config.rs @@ -1,10 +1,7 @@ //! IQN model. -use crate::{ - opt::OptimizerConfig, - util::OutDim, -}; +use crate::{opt::OptimizerConfig, util::OutDim}; use anyhow::Result; -use serde::{Deserialize, de::DeserializeOwned, Serialize}; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ default::Default, fs::File, @@ -12,7 +9,6 @@ use std::{ path::Path, }; -#[cfg(not(feature = "adam_eps"))] impl IqnModelConfig where F: DeserializeOwned + Serialize, @@ -22,29 +18,14 @@ where pub fn learning_rate(mut self, v: f64) -> Self { match &self.opt_config { OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v }, + _ => unimplemented!(), }; self } } -// #[cfg(feature = "adam_eps")] -// impl IqnModelConfig -// where -// F::Config: DeserializeOwned + Serialize, -// M::Config: DeserializeOwned + Serialize, -// { -// /// Sets the learning rate. -// pub fn learning_rate(mut self, v: f64) -> Self { -// match &self.opt_config { -// OptimizerConfig::Adam { lr: _ } => self.opt_config = OptimizerConfig::Adam { lr: v }, -// _ => unimplemented!(), -// }; -// self -// } -// } - #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [IqnModel](super::IqnModel). +/// Configuration of [`IqnModel`](super::IqnModel). /// /// The type parameter `F` represents a configuration struct of a feature extractor. /// The type parameter `M` represents a configuration struct of a model for merging diff --git a/border-tch-agent/src/lib.rs b/border-tch-agent/src/lib.rs index ad7be004..be04d0ab 100644 --- a/border-tch-agent/src/lib.rs +++ b/border-tch-agent/src/lib.rs @@ -10,7 +10,7 @@ mod tensor_batch; // pub mod replay_buffer; pub mod util; use serde::{Deserialize, Serialize}; -pub use tensor_batch::{TensorSubBatch, ZeroTensor}; +pub use tensor_batch::{TensorBatch, ZeroTensor}; #[derive(Clone, Debug, Copy, Deserialize, Serialize, PartialEq)] /// Device for using tch-rs. @@ -29,6 +29,8 @@ impl From for Device { match device { tch::Device::Cpu => Self::Cpu, tch::Device::Cuda(n) => Self::Cuda(n), + tch::Device::Mps => unimplemented!(), + tch::Device::Vulkan => unimplemented!(), } } } diff --git a/border-tch-agent/src/mlp.rs b/border-tch-agent/src/mlp.rs index 3ef73a55..10ee405c 100644 --- a/border-tch-agent/src/mlp.rs +++ b/border-tch-agent/src/mlp.rs @@ -10,11 +10,11 @@ use tch::nn; fn mlp(prefix: &str, var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential { let mut seq = nn::seq(); let mut in_dim = config.in_dim; - let p = &var_store.root(); + let p = &(var_store.root() / "mlp"); for (i, &n) in config.units.iter().enumerate() { seq = seq.add(nn::linear( - p / format!("{}{}", prefix, i + 1), + p / format!("{}{}", prefix, i), in_dim, n, Default::default(), diff --git a/border-tch-agent/src/mlp/base.rs b/border-tch-agent/src/mlp/base.rs index 5692ff1d..40e513d4 100644 --- a/border-tch-agent/src/mlp/base.rs +++ b/border-tch-agent/src/mlp/base.rs @@ -2,7 +2,7 @@ use super::{mlp, MlpConfig}; use crate::model::{SubModel, SubModel2}; use tch::{nn, nn::Module, Device, Tensor}; -/// Multilayer perceptron. +/// Multilayer perceptron with ReLU activation function. pub struct Mlp { config: MlpConfig, device: Device, @@ -11,13 +11,13 @@ pub struct Mlp { impl Mlp { fn create_net(var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential { - let p = &var_store.root(); + let p = &(var_store.root() / "mlp"); let mut seq = nn::seq(); let mut in_dim = config.in_dim; for (i, &out_dim) in config.units.iter().enumerate() { seq = seq.add(nn::linear( - p / format!("{}{}", "cl", i + 1), + p / format!("{}{}", "ln", i), in_dim, out_dim, Default::default(), @@ -27,13 +27,13 @@ impl Mlp { } seq = seq.add(nn::linear( - p / format!("{}{}", "cl", config.units.len() + 1), + p / format!("{}{}", "ln", config.units.len()), in_dim, config.out_dim, Default::default(), )); - if !config.activation_out { + if config.activation_out { seq = seq.add_fn(|x| x.relu()); } @@ -91,9 +91,9 @@ impl SubModel2 for Mlp { let units = &config.units; let in_dim = *units.last().unwrap_or(&config.in_dim); let out_dim = config.out_dim; - let p = &var_store.root(); - let seq = mlp("cl", var_store, &config).add(nn::linear( - p / format!("cl{}", units.len() + 1), + let p = &(var_store.root() / "mlp"); + let seq = mlp("ln", var_store, &config).add(nn::linear( + p / format!("ln{}", units.len()), in_dim, out_dim, Default::default(), diff --git a/border-tch-agent/src/mlp/config.rs b/border-tch-agent/src/mlp/config.rs index d873eefd..561f92f0 100644 --- a/border-tch-agent/src/mlp/config.rs +++ b/border-tch-agent/src/mlp/config.rs @@ -5,10 +5,10 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] /// Configuration of [`Mlp`](super::Mlp). pub struct MlpConfig { - pub(super) in_dim: i64, - pub(super) units: Vec, - pub(super) out_dim: i64, - pub(super) activation_out: bool, + pub in_dim: i64, + pub units: Vec, + pub out_dim: i64, + pub activation_out: bool, } impl MlpConfig { diff --git a/border-tch-agent/src/mlp/mlp2.rs b/border-tch-agent/src/mlp/mlp2.rs index a5b7c5a8..5ce89f15 100644 --- a/border-tch-agent/src/mlp/mlp2.rs +++ b/border-tch-agent/src/mlp/mlp2.rs @@ -1,6 +1,6 @@ +use super::{mlp, MlpConfig}; use crate::model::SubModel; use tch::{nn, nn::Module, Device, Tensor}; -use super::{MlpConfig, mlp}; #[allow(clippy::clippy::upper_case_acronyms)] /// Multilayer perceptron that outputs two tensors of the same size. diff --git a/border-tch-agent/src/model/base.rs b/border-tch-agent/src/model/base.rs index 0b375fca..508ed6fc 100644 --- a/border-tch-agent/src/model/base.rs +++ b/border-tch-agent/src/model/base.rs @@ -22,6 +22,7 @@ pub trait ModelBase { } /// Neural networks with a single input and a single output. +#[allow(dead_code)] pub trait Model1: ModelBase { /// The input of the neural network. type Input; @@ -39,6 +40,7 @@ pub trait Model1: ModelBase { } /// Neural networks with double inputs and a single output. +#[allow(dead_code)] pub trait Model2: ModelBase { /// An input of the neural network. type Input1; @@ -51,53 +53,61 @@ pub trait Model2: ModelBase { fn forward(&self, x1s: &Self::Input1, x2s: &Self::Input2) -> Self::Output; } -/// Neural network model that can be initialized with [VarStore] and configuration. +/// Neural network model that can be initialized with [`VarStore`] and configuration. /// /// The purpose of this trait is for modularity of neural network models. -/// Modules, which consists a neural network, should share [VarStore]. -/// To do this, structs implementing this trait can be initialized with a given [VarStore]. -/// This trait also provide the ability to clone with a given [VarStore]. +/// Modules, which consists a neural network, should share [`VarStore`]. +/// To do this, structs implementing this trait can be initialized with a given [`VarStore`]. +/// This trait also provide the ability to clone with a given [`VarStore`]. /// The ability is useful when creating a target network, used in recent deep learning algorithms in common. +/// +/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html pub trait SubModel { - /// Configuration from which [SubModel] is constructed. + /// Configuration from which [`SubModel`] is constructed. type Config; - /// Input of the [SubModel]. + /// Input of the [`SubModel`]. type Input; - /// Output of the [SubModel]. + /// Output of the [`SubModel`]. type Output; - /// Builds [SubModel] with [VarStore] and [SubModel::Config]. + /// Builds [`SubModel`] with [`VarStore`] and [`SubModel::Config`]. fn build(var_store: &VarStore, config: Self::Config) -> Self; - /// Clones [SubModel] with [VarStore]. + /// Clones [`SubModel`] with [`VarStore`]. + /// + /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html fn clone_with_var_store(&self, var_store: &VarStore) -> Self; /// A generalized forward function. fn forward(&self, input: &Self::Input) -> Self::Output; } -/// Neural network model that can be initialized with [VarStore] and configuration. +/// Neural network model that can be initialized with [`VarStore`] and configuration. /// -/// The difference from [SubModel] is that this trait takes two inputs. +/// The difference from [`SubModel`] is that this trait takes two inputs. +/// +/// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html pub trait SubModel2 { - /// Configuration from which [SubModel2] is constructed. + /// Configuration from which [`SubModel2`] is constructed. type Config; - /// Input of the [SubModel2]. + /// Input of the [`SubModel2`]. type Input1; - /// Input of the [SubModel2]. + /// Input of the [`SubModel2`]. type Input2; - /// Output of the [SubModel2]. + /// Output of the [`SubModel2`]. type Output; - /// Builds [SubModel2] with [VarStore] and [SubModel2::Config]. + /// Builds [`SubModel2`] with [VarStore] and [SubModel2::Config]. fn build(var_store: &VarStore, config: Self::Config) -> Self; - /// Clones [SubModel2] with [VarStore]. + /// Clones [`SubModel2`] with [`VarStore`]. + /// + /// [`VarStore`]: https://docs.rs/tch/0.16.0/tch/nn/struct.VarStore.html fn clone_with_var_store(&self, var_store: &VarStore) -> Self; /// A generalized forward function. diff --git a/border-tch-agent/src/opt.rs b/border-tch-agent/src/opt.rs index 95aea45f..6a077c69 100644 --- a/border-tch-agent/src/opt.rs +++ b/border-tch-agent/src/opt.rs @@ -4,23 +4,11 @@ use core::f64; use serde::{Deserialize, Serialize}; use tch::{ // nn, - nn::{Adam, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore}, + nn::{Adam, AdamW, Optimizer as Optimizer_, OptimizerConfig as OptimizerConfig_, VarStore}, Tensor, }; /// Configures an optimizer for training neural networks in an RL agent. -#[cfg(not(feature = "adam_eps"))] -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] -pub enum OptimizerConfig { - /// Adam optimizer. - Adam { - /// Learning rate. - lr: f64, - }, -} - -/// Configures an optimizer for training neural networks in an RL agent. -#[cfg(feature = "adam_eps")] #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] pub enum OptimizerConfig { /// Adam optimizer. @@ -29,16 +17,16 @@ pub enum OptimizerConfig { lr: f64, }, - /// Adam optimizer with the epsilon parameter. - AdamEps { - /// Learning rate. + AdamW { lr: f64, - /// Epsilon parameter. + beta1: f64, + beta2: f64, + wd: f64, eps: f64, + amsgrad: bool, }, } -#[cfg(not(feature = "adam_eps"))] impl OptimizerConfig { /// Constructs an optimizer. pub fn build(&self, vs: &VarStore) -> Result { @@ -47,24 +35,23 @@ impl OptimizerConfig { let opt = Adam::default().build(vs, *lr)?; Ok(Optimizer::Adam(opt)) } - } - } -} - -#[cfg(feature = "adam_eps")] -impl OptimizerConfig { - /// Constructs an optimizer. - pub fn build(&self, vs: &VarStore) -> Result { - match &self { - OptimizerConfig::Adam { lr } => { - let opt = Adam::default().build(vs, *lr)?; - Ok(Optimizer::Adam(opt)) - } - OptimizerConfig::AdamEps { lr, eps } => { - let mut opt = Adam::default(); - opt.eps = *eps; - let opt = opt.build(vs, *lr)?; - Ok(Optimizer::Adam(opt)) + OptimizerConfig::AdamW { + lr, + beta1, + beta2, + wd, + eps, + amsgrad, + } => { + let opt = AdamW { + beta1: *beta1, + beta2: *beta2, + wd: *wd, + eps: *eps, + amsgrad: *amsgrad, + } + .build(vs, *lr)?; + Ok(Optimizer::AdamW(opt)) } } } @@ -73,9 +60,13 @@ impl OptimizerConfig { /// Optimizers. /// /// This is a thin wrapper of [tch::nn::Optimizer]. +/// +/// [tch::nn::Optimizer]: https://docs.rs/tch/0.16.0/tch/nn/struct.Optimizer.html pub enum Optimizer { /// Adam optimizer. Adam(Optimizer_), + + AdamW(Optimizer_), } impl Optimizer { @@ -85,6 +76,9 @@ impl Optimizer { Self::Adam(opt) => { opt.backward_step(loss); } + Self::AdamW(opt) => { + opt.backward_step(loss); + } } } } diff --git a/border-tch-agent/src/sac.rs b/border-tch-agent/src/sac.rs index bd2b31ea..bf3e8215 100644 --- a/border-tch-agent/src/sac.rs +++ b/border-tch-agent/src/sac.rs @@ -1,10 +1,156 @@ //! SAC agent. //! -//! Here is an example in `border/examples/sac_pendulum.rs` +//! Here is an example of creating SAC agent: //! -//! ```rust,ignore +//! ```no_run +//! # use anyhow::Result; +//! use border_core::{ +//! # Env as Env_, Obs as Obs_, Act as Act_, Step, test::{ +//! # TestAct as TestAct_, TestActBatch as TestActBatch_, +//! # TestEnv as TestEnv_, +//! # TestObs as TestObs_, TestObsBatch as TestObsBatch_, +//! # }, +//! # record::Record, +//! # generic_replay_buffer::{SimpleReplayBuffer, BatchBase}, +//! Configurable, +//! }; +//! use border_tch_agent::{ +//! sac::{ActorConfig, CriticConfig, Sac, SacConfig}, +//! mlp::{Mlp, Mlp2, MlpConfig}, +//! opt::OptimizerConfig +//! }; +//! +//! # struct TestEnv(TestEnv_); +//! # #[derive(Clone, Debug)] +//! # struct TestObs(TestObs_); +//! # #[derive(Clone, Debug)] +//! # struct TestAct(TestAct_); +//! # struct TestObsBatch(TestObsBatch_); +//! # struct TestActBatch(TestActBatch_); +//! # +//! # impl Obs_ for TestObs { +//! # fn dummy(n: usize) -> Self { +//! # Self(TestObs_::dummy(n)) +//! # } +//! # +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl Into for TestObs { +//! # fn into(self) -> tch::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl BatchBase for TestObsBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestObsBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl BatchBase for TestActBatch { +//! # fn new(n: usize) -> Self { +//! # Self(TestActBatch_::new(n)) +//! # } +//! # +//! # fn push(&mut self, ix: usize, data: Self) { +//! # self.0.push(ix, data.0); +//! # } +//! # +//! # fn sample(&self, ixs: &Vec) -> Self { +//! # Self(self.0.sample(ixs)) +//! # } +//! # } +//! # +//! # impl Act_ for TestAct { +//! # fn len(&self) -> usize { +//! # self.0.len() +//! # } +//! # } +//! # +//! # impl From for TestAct { +//! # fn from(t: tch::Tensor) -> Self { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Into for TestAct { +//! # fn into(self) -> tch::Tensor { +//! # unimplemented!(); +//! # } +//! # } +//! # +//! # impl Env_ for TestEnv { +//! # type Config = ::Config; +//! # type Obs = TestObs; +//! # type Act = TestAct; +//! # type Info = ::Info; +//! # +//! # fn build(config: &Self::Config, seed: i64) -> Result { +//! # Ok(Self(TestEnv_::build(&config, seed).unwrap())) +//! # } +//! # +//! # fn step(&mut self, act: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step(&act.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset(&mut self, is_done: Option<&Vec>) -> Result { +//! # Ok(TestObs(self.0.reset(is_done).unwrap())) +//! # } +//! # +//! # fn step_with_reset(&mut self, a: &TestAct) -> (Step, Record) { +//! # let (step, record) = self.0.step_with_reset(&a.0); +//! # let step = Step { +//! # obs: TestObs(step.obs), +//! # act: TestAct(step.act), +//! # reward: step.reward, +//! # is_terminated: step.is_terminated, +//! # is_truncated: step.is_truncated, +//! # info: step.info, +//! # init_obs: TestObs(step.init_obs), +//! # }; +//! # (step, record) +//! # } +//! # +//! # fn reset_with_index(&mut self, ix: usize) -> Result { +//! # Ok(TestObs(self.0.reset_with_index(ix).unwrap())) +//! # } +//! # } +//! # +//! # type Env = TestEnv; +//! # type ObsBatch = TestObsBatch; +//! # type ActBatch = TestActBatch; +//! # type ReplayBuffer = SimpleReplayBuffer; +//! # +//! const DIM_OBS: i64 = 3; +//! const DIM_ACT: i64 = 1; +//! const LR_ACTOR: f64 = 1e-3; +//! const LR_CRITIC: f64 = 1e-3; +//! const BATCH_SIZE: usize = 256; +//! //! fn create_agent(in_dim: i64, out_dim: i64) -> Sac { //! let device = tch::Device::cuda_if_available(); +//! //! let actor_config = ActorConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) //! .out_dim(out_dim) @@ -12,25 +158,13 @@ //! let critic_config = CriticConfig::default() //! .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) //! .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); -//! let sac_config = SacConfig::default() +//! let sac_config = SacConfig::::default() //! .batch_size(BATCH_SIZE) -//! .min_transitions_warmup(N_TRANSITIONS_WARMUP) //! .actor_config(actor_config) //! .critic_config(critic_config) //! .device(device); //! Sac::build(sac_config) //! } -//! -//! fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { -//! let trainer = //... -//! let mut agent = create_agent(DIM_OBS, DIM_ACT); -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; -//! -//! trainer.train(&mut agent, &mut recorder, &mut evaluator)?; -//! -//! Ok(()) -//! } //! ``` mod actor; mod base; diff --git a/border-tch-agent/src/sac/actor.rs b/border-tch-agent/src/sac/actor.rs index 88d13af0..9360b089 100644 --- a/border-tch-agent/src/sac/actor.rs +++ b/border-tch-agent/src/sac/actor.rs @@ -1,5 +1,5 @@ //! Actor of SAC agent. -mod config; mod base; -pub use config::ActorConfig; +mod config; pub use base::Actor; +pub use config::ActorConfig; diff --git a/border-tch-agent/src/sac/actor/base.rs b/border-tch-agent/src/sac/actor/base.rs index 85b067b8..756bbb77 100644 --- a/border-tch-agent/src/sac/actor/base.rs +++ b/border-tch-agent/src/sac/actor/base.rs @@ -7,13 +7,10 @@ use crate::{ use anyhow::{Context, Result}; use log::{info, trace}; use serde::{de::DeserializeOwned, Serialize}; -use std::{ - path::Path, -}; +use std::path::Path; use tch::{nn, Device, Tensor}; -#[allow(clippy::upper_case_acronyms)] -/// Represents a stochastic policy for SAC agents. +/// Stochastic policy for SAC agents. pub struct Actor

where P: SubModel, @@ -38,7 +35,7 @@ where P: SubModel, P::Config: DeserializeOwned + Serialize + OutDim, { - /// Constructs [Actor]. + /// Constructs [`Actor`]. pub fn build(config: ActorConfig, device: Device) -> Result> { let pi_config = config.pi_config.context("pi_config is not set.")?; let out_dim = pi_config.get_out_dim(); diff --git a/border-tch-agent/src/sac/actor/config.rs b/border-tch-agent/src/sac/actor/config.rs index ad84308f..05aec6e8 100644 --- a/border-tch-agent/src/sac/actor/config.rs +++ b/border-tch-agent/src/sac/actor/config.rs @@ -9,10 +9,10 @@ use std::{ #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [Actor](super::Actor). +/// Configuration of [`Actor`](super::Actor). pub struct ActorConfig { - pub(super) pi_config: Option

, - pub(super) opt_config: OptimizerConfig, + pub pi_config: Option

, + pub opt_config: OptimizerConfig, } impl Default for ActorConfig

{ diff --git a/border-tch-agent/src/sac/base.rs b/border-tch-agent/src/sac/base.rs index 401c87e6..03b4bffc 100644 --- a/border-tch-agent/src/sac/base.rs +++ b/border-tch-agent/src/sac/base.rs @@ -6,11 +6,11 @@ use crate::{ use anyhow::Result; use border_core::{ record::{Record, RecordValue}, - Agent, Env, Policy, ReplayBufferBase, StdBatchBase, + Agent, Configurable, Env, Policy, ReplayBufferBase, TransitionBatch, }; use serde::{de::DeserializeOwned, Serialize}; // use log::info; -use std::{fs, marker::PhantomData, path::Path}; +use std::{convert::TryFrom, fs, marker::PhantomData, path::Path}; use tch::{no_grad, Tensor}; type ActionValue = Tensor; @@ -20,25 +20,16 @@ type ActStd = Tensor; fn normal_logp(x: &Tensor) -> Tensor { let tmp: Tensor = Tensor::from(-0.5 * (2.0 * std::f32::consts::PI).ln() as f32) - 0.5 * x.pow_tensor_scalar(2); - tmp.sum_dim_intlist(&[-1], false, tch::Kind::Float) + tmp.sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float) } /// Soft actor critic (SAC) agent. -#[allow(clippy::upper_case_acronyms)] pub struct Sac where - E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, - E::Obs: Into + Into, - E::Act: Into, - Q::Input2: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into + Into + Clone, - ::ActBatch: Into + Into, { pub(super) qnets: Vec>, pub(super) qnets_tgt: Vec>, @@ -50,7 +41,6 @@ where pub(super) min_lstd: f64, pub(super) max_lstd: f64, pub(super) n_updates_per_opt: usize, - pub(super) min_transitions_warmup: usize, pub(super) batch_size: usize, pub(super) train: bool, pub(super) reward_scale: f32, @@ -71,9 +61,9 @@ where Q::Input2: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into + Into + Clone, - ::ActBatch: Into + Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into + Into + Clone, + ::ActBatch: Into + Into, { fn action_logp(&self, o: &P::Input) -> (Tensor, Tensor) { let (mean, lstd) = self.pi.forward(o); @@ -83,7 +73,7 @@ where let log_p = normal_logp(&z) - (Tensor::from(1f32) - a.pow_tensor_scalar(2.0) + Tensor::from(self.epsilon)) .log() - .sum_dim_intlist(&[-1], false, tch::Kind::Float); + .sum_dim_intlist(Some([-1].as_slice()), false, tch::Kind::Float); debug_assert_eq!(a.size().as_slice()[0], self.batch_size as i64); debug_assert_eq!(log_p.size().as_slice(), [self.batch_size as i64]); @@ -111,9 +101,9 @@ where fn update_critic(&mut self, batch: R::Batch) -> f32 { let losses = { - let (obs, act, next_obs, reward, is_done, _, _) = batch.unpack(); - let reward = Tensor::of_slice(&reward[..]).to(self.device); - let is_done = Tensor::of_slice(&is_done[..]).to(self.device); + let (obs, act, next_obs, reward, is_terminated, _is_truncated, _, _) = batch.unpack(); + let reward = Tensor::from_slice(&reward[..]).to(self.device); + let is_terminated = Tensor::from_slice(&is_terminated[..]).to(self.device); let preds = self.qvals(&self.qnets, &obs.into(), &act.into()); let tgt = { @@ -122,13 +112,14 @@ where let next_q = self.qvals_min(&self.qnets_tgt, &next_obs.into(), &next_a.into()); next_q - self.ent_coef.alpha() * next_log_p }); - self.reward_scale * reward + (1f32 - &is_done) * Tensor::from(self.gamma) * next_q + self.reward_scale * reward + + (1f32 - &is_terminated) * Tensor::from(self.gamma) * next_q }; debug_assert_eq!(tgt.size().as_slice(), [self.batch_size as i64]); let losses: Vec<_> = match self.critic_loss { - CriticLoss::MSE => preds + CriticLoss::Mse => preds .iter() .map(|pred| pred.mse_loss(&tgt, tch::Reduction::Mean)) .collect(), @@ -144,7 +135,12 @@ where qnet.backward_step(&loss); } - losses.iter().map(f32::from).sum::() / (self.qnets.len() as f32) + losses + .iter() + .map(f32::try_from) + .map(|a| a.expect("Failed to convert Tensor to f32")) + .sum::() + / (self.qnets.len() as f32) } fn update_actor(&mut self, batch: &R::Batch) -> f32 { @@ -153,16 +149,16 @@ where let (a, log_p) = self.action_logp(&o.into()); // Update the entropy coefficient - self.ent_coef.update(&log_p); + self.ent_coef.update(&log_p.detach()); let o = batch.obs().clone(); let qval = self.qvals_min(&self.qnets, &o.into(), &a.into()); - (self.ent_coef.alpha() * &log_p - &qval).mean(tch::Kind::Float) + (self.ent_coef.alpha().detach() * &log_p - &qval).mean(tch::Kind::Float) }; self.pi.backward_step(&loss); - f32::from(loss) + f32::try_from(loss).expect("Failed to convert Tensor to f32") } fn soft_update(&mut self) { @@ -195,6 +191,10 @@ where ), ]) } + + pub fn get_policy_net(&self) -> &Actor

{ + &self.pi + } } impl Policy for Sac @@ -202,19 +202,37 @@ where E: Env, Q: SubModel2, P: SubModel, - R: ReplayBufferBase, E::Obs: Into + Into, E::Act: Into + From, - Q::Input2: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into + Into + Clone, - ::ActBatch: Into + Into, +{ + fn sample(&mut self, obs: &E::Obs) -> E::Act { + let obs = obs.clone().into(); + let (mean, lstd) = self.pi.forward(&obs); + let std = lstd.clip(self.min_lstd, self.max_lstd).exp(); + let act = if self.train { + std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean + } else { + mean + }; + act.tanh().into() + } +} + +impl Configurable for Sac +where + E: Env, + Q: SubModel2, + P: SubModel, + E::Obs: Into + Into, + E::Act: Into + From, + Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, + P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, { type Config = SacConfig; - /// Constructs [Sac] agent. + /// Constructs [`Sac`] agent. fn build(config: Self::Config) -> Self { let device = config .device @@ -245,7 +263,6 @@ where min_lstd: config.min_lstd, max_lstd: config.max_lstd, n_updates_per_opt: config.n_updates_per_opt, - min_transitions_warmup: config.min_transitions_warmup, batch_size: config.batch_size, train: config.train, reward_scale: config.reward_scale, @@ -255,18 +272,6 @@ where phantom: PhantomData, } } - - fn sample(&mut self, obs: &E::Obs) -> E::Act { - let obs = obs.clone().into(); - let (mean, lstd) = self.pi.forward(&obs); - let std = lstd.clip(self.min_lstd, self.max_lstd).exp(); - let act = if self.train { - std * Tensor::randn(&mean.size(), tch::kind::FLOAT_CPU).to(self.device) + mean - } else { - mean - }; - act.tanh().into() - } } impl Agent for Sac @@ -280,9 +285,9 @@ where Q::Input2: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into + Into + Clone, - ::ActBatch: Into + Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into + Into + Clone, + ::ActBatch: Into + Into, { fn train(&mut self) { self.train = true; @@ -296,35 +301,41 @@ where self.train } - fn opt(&mut self, buffer: &mut R) -> Option { - if buffer.len() >= self.min_transitions_warmup { - Some(self.opt_(buffer)) - } else { - None - } + fn opt_with_record(&mut self, buffer: &mut R) -> Record { + self.opt_(buffer) } - fn save>(&self, path: T) -> Result<()> { + fn save_params>(&self, path: T) -> Result<()> { // TODO: consider to rename the path if it already exists fs::create_dir_all(&path)?; for (i, (qnet, qnet_tgt)) in self.qnets.iter().zip(&self.qnets_tgt).enumerate() { - qnet.save(&path.as_ref().join(format!("qnet_{}.pt", i)).as_path())?; - qnet_tgt.save(&path.as_ref().join(format!("qnet_tgt_{}.pt", i)).as_path())?; + qnet.save(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?; + qnet_tgt.save( + &path + .as_ref() + .join(format!("qnet_tgt_{}.pt.tch", i)) + .as_path(), + )?; } - self.pi.save(&path.as_ref().join("pi.pt").as_path())?; + self.pi.save(&path.as_ref().join("pi.pt.tch").as_path())?; self.ent_coef - .save(&path.as_ref().join("ent_coef.pt").as_path())?; + .save(&path.as_ref().join("ent_coef.pt.tch").as_path())?; Ok(()) } - fn load>(&mut self, path: T) -> Result<()> { + fn load_params>(&mut self, path: T) -> Result<()> { for (i, (qnet, qnet_tgt)) in self.qnets.iter_mut().zip(&mut self.qnets_tgt).enumerate() { - qnet.load(&path.as_ref().join(format!("qnet_{}.pt", i)).as_path())?; - qnet_tgt.load(&path.as_ref().join(format!("qnet_tgt_{}.pt", i)).as_path())?; + qnet.load(&path.as_ref().join(format!("qnet_{}.pt.tch", i)).as_path())?; + qnet_tgt.load( + &path + .as_ref() + .join(format!("qnet_tgt_{}.pt.tch", i)) + .as_path(), + )?; } - self.pi.load(&path.as_ref().join("pi.pt").as_path())?; + self.pi.load(&path.as_ref().join("pi.pt.tch").as_path())?; self.ent_coef - .load(&path.as_ref().join("ent_coef.pt").as_path())?; + .load(&path.as_ref().join("ent_coef.pt.tch").as_path())?; Ok(()) } } @@ -344,9 +355,9 @@ where Q::Input2: From, Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone, - R::Batch: StdBatchBase, - ::ObsBatch: Into + Into + Clone, - ::ActBatch: Into + Into, + R::Batch: TransitionBatch, + ::ObsBatch: Into + Into + Clone, + ::ActBatch: Into + Into, { type ModelInfo = NamedTensors; diff --git a/border-tch-agent/src/sac/config.rs b/border-tch-agent/src/sac/config.rs index c74b96a4..4686ee32 100644 --- a/border-tch-agent/src/sac/config.rs +++ b/border-tch-agent/src/sac/config.rs @@ -18,8 +18,7 @@ use std::{ }; use tch::Tensor; -/// Constructs [Sac](super::Sac). -#[allow(clippy::upper_case_acronyms)] +/// Configuration of [`Sac`](super::Sac). #[derive(Debug, Deserialize, Serialize, PartialEq)] pub struct SacConfig where @@ -28,23 +27,21 @@ where P: SubModel, P::Config: DeserializeOwned + Serialize + OutDim + Debug + PartialEq + Clone, { - pub(super) actor_config: ActorConfig, - pub(super) critic_config: CriticConfig, - pub(super) gamma: f64, - pub(super) tau: f64, - pub(super) ent_coef_mode: EntCoefMode, - pub(super) epsilon: f64, - pub(super) min_lstd: f64, - pub(super) max_lstd: f64, - pub(super) n_updates_per_opt: usize, - pub(super) min_transitions_warmup: usize, - pub(super) batch_size: usize, - pub(super) train: bool, - pub(super) critic_loss: CriticLoss, - pub(super) reward_scale: f32, - pub(super) replay_burffer_capacity: usize, - pub(super) n_critics: usize, - pub(super) seed: Option, + pub actor_config: ActorConfig, + pub critic_config: CriticConfig, + pub gamma: f64, + pub tau: f64, + pub ent_coef_mode: EntCoefMode, + pub epsilon: f64, + pub min_lstd: f64, + pub max_lstd: f64, + pub n_updates_per_opt: usize, + pub batch_size: usize, + pub train: bool, + pub critic_loss: CriticLoss, + pub reward_scale: f32, + pub n_critics: usize, + pub seed: Option, pub device: Option, // expr_sampling: ExperienceSampling, } @@ -67,15 +64,13 @@ where min_lstd: self.min_lstd.clone(), max_lstd: self.max_lstd.clone(), n_updates_per_opt: self.n_updates_per_opt.clone(), - min_transitions_warmup: self.min_transitions_warmup.clone(), batch_size: self.batch_size.clone(), train: self.train.clone(), critic_loss: self.critic_loss.clone(), reward_scale: self.reward_scale.clone(), - replay_burffer_capacity: self.replay_burffer_capacity.clone(), n_critics: self.n_critics.clone(), seed: self.seed.clone(), - device: self.device.clone() + device: self.device.clone(), } } } @@ -98,12 +93,10 @@ where min_lstd: -20.0, max_lstd: 2.0, n_updates_per_opt: 1, - min_transitions_warmup: 1, batch_size: 1, train: false, - critic_loss: CriticLoss::MSE, + critic_loss: CriticLoss::Mse, reward_scale: 1.0, - replay_burffer_capacity: 100, n_critics: 1, seed: None, device: None, @@ -125,12 +118,6 @@ where self } - /// Interval before starting optimization. - pub fn min_transitions_warmup(mut self, v: usize) -> Self { - self.min_transitions_warmup = v; - self - } - /// Batch size. pub fn batch_size(mut self, v: usize) -> Self { self.batch_size = v; @@ -155,12 +142,6 @@ where self } - /// Replay buffer capacity. - pub fn replay_burffer_capacity(mut self, v: usize) -> Self { - self.replay_burffer_capacity = v; - self - } - /// Reward scale. /// /// It works for obtaining target values, not the values in logs. diff --git a/border-tch-agent/src/sac/critic/config.rs b/border-tch-agent/src/sac/critic/config.rs index 6c5ecd47..20045aa4 100644 --- a/border-tch-agent/src/sac/critic/config.rs +++ b/border-tch-agent/src/sac/critic/config.rs @@ -9,10 +9,10 @@ use std::{ #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] -/// Configuration of [Critic](super::Critic). +/// Configuration of [`Critic`](super::Critic). pub struct CriticConfig { - pub(super) q_config: Option, - pub(super) opt_config: OptimizerConfig, + pub q_config: Option, + pub opt_config: OptimizerConfig, } impl Default for CriticConfig { diff --git a/border-tch-agent/src/sac/ent_coef.rs b/border-tch-agent/src/sac/ent_coef.rs index ca4f816f..c76841eb 100644 --- a/border-tch-agent/src/sac/ent_coef.rs +++ b/border-tch-agent/src/sac/ent_coef.rs @@ -2,7 +2,7 @@ use anyhow::Result; use log::{info, trace}; use serde::{Deserialize, Serialize}; -use std::{borrow::Borrow, path::Path}; +use std::{/*borrow::Borrow,*/ path::Path}; use tch::{nn, nn::OptimizerConfig, Tensor}; /// Mode of the entropy coefficient of SAC. @@ -30,12 +30,14 @@ impl EntCoef { let (log_alpha, target_entropy, opt) = match mode { EntCoefMode::Fix(alpha) => { let init = nn::Init::Const(alpha.ln()); - let log_alpha = path.borrow().var("log_alpha", &[1], init); + // let log_alpha = path.borrow().var("log_alpha", &[1], init); + let log_alpha = path.var("log_alpha", &[1], init); (log_alpha, None, None) } EntCoefMode::Auto(target_entropy, learning_rate) => { let init = nn::Init::Const(0.0); - let log_alpha = path.borrow().var("log_alpha", &[1], init); + // let log_alpha = path.borrow().var("log_alpha", &[1], init); + let log_alpha = path.var("log_alpha", &[1], init); let opt = nn::Adam::default() .build(&var_store, learning_rate) .unwrap(); diff --git a/border-tch-agent/src/tensor_batch.rs b/border-tch-agent/src/tensor_batch.rs index ea7805f5..e1c3d38c 100644 --- a/border-tch-agent/src/tensor_batch.rs +++ b/border-tch-agent/src/tensor_batch.rs @@ -1,7 +1,9 @@ -use border_core::replay_buffer::SubBatch; +use border_core::generic_replay_buffer::BatchBase; use tch::{Device, Tensor}; -/// Adds capability of constructing [Tensor] with a static method. +/// Adds capability of constructing [`Tensor`] with a static method. +/// +/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html pub trait ZeroTensor { /// Constructs zero tensor. fn zeros(shape: &[i64]) -> Tensor; @@ -9,40 +11,42 @@ pub trait ZeroTensor { impl ZeroTensor for u8 { fn zeros(shape: &[i64]) -> Tensor { - Tensor::zeros(&shape, (tch::kind::Kind::Uint8, Device::Cpu)) + Tensor::zeros(shape, (tch::kind::Kind::Uint8, Device::Cpu)) } } impl ZeroTensor for i32 { fn zeros(shape: &[i64]) -> Tensor { - Tensor::zeros(&shape, (tch::kind::Kind::Int, Device::Cpu)) + Tensor::zeros(shape, (tch::kind::Kind::Int, Device::Cpu)) } } impl ZeroTensor for f32 { fn zeros(shape: &[i64]) -> Tensor { - Tensor::zeros(&shape, tch::kind::FLOAT_CPU) + Tensor::zeros(shape, tch::kind::FLOAT_CPU) } } impl ZeroTensor for i64 { fn zeros(shape: &[i64]) -> Tensor { - Tensor::zeros(&shape, (tch::kind::Kind::Int64, Device::Cpu)) + Tensor::zeros(shape, (tch::kind::Kind::Int64, Device::Cpu)) } } -/// A buffer consisting of a [`Tensor`](tch::Tensor). +/// A buffer consisting of a [`Tensor`]. /// /// The internal buffer of this struct has the shape of `[n_capacity, shape[1..]]`, /// where `shape` is obtained from the data pushed at the first time via -/// [`TensorSubBatch::push`] method. `[1..]` means that the first axis of the +/// [`TensorBatch::push`] method. `[1..]` means that the first axis of the /// given data is ignored as it might be batch size. -pub struct TensorSubBatch { +/// +/// [`Tensor`]: https://docs.rs/tch/0.16.0/tch/struct.Tensor.html +pub struct TensorBatch { buf: Option, capacity: i64, } -impl Clone for TensorSubBatch { +impl Clone for TensorBatch { fn clone(&self) -> Self { let buf = match self.buf.is_none() { true => None, @@ -56,7 +60,7 @@ impl Clone for TensorSubBatch { } } -impl TensorSubBatch { +impl TensorBatch { pub fn from_tensor(t: Tensor) -> Self { let capacity = t.size()[0] as _; Self { @@ -66,7 +70,7 @@ impl TensorSubBatch { } } -impl SubBatch for TensorSubBatch { +impl BatchBase for TensorBatch { fn new(capacity: usize) -> Self { // let capacity = capacity as i64; // let mut shape: Vec<_> = S::shape().to_vec().iter().map(|e| *e as i64).collect(); @@ -83,7 +87,7 @@ impl SubBatch for TensorSubBatch { /// /// If the internal buffer is empty, it will be initialized with the shape /// `[capacity, data.buf.size()[1..]]`. - fn push(&mut self, index: usize, data: &Self) { + fn push(&mut self, index: usize, data: Self) { if data.buf.is_none() { return; } @@ -112,7 +116,7 @@ impl SubBatch for TensorSubBatch { fn sample(&self, ixs: &Vec) -> Self { let ixs = ixs.iter().map(|&ix| ix as i64).collect::>(); - let batch_indexes = Tensor::of_slice(&ixs); + let batch_indexes = Tensor::from_slice(&ixs); let buf = Some(self.buf.as_ref().unwrap().index_select(0, &batch_indexes)); Self { buf, @@ -121,8 +125,8 @@ impl SubBatch for TensorSubBatch { } } -impl From for Tensor { - fn from(b: TensorSubBatch) -> Self { +impl From for Tensor { + fn from(b: TensorBatch) -> Self { b.buf.unwrap() } } diff --git a/border-tch-agent/src/util.rs b/border-tch-agent/src/util.rs index be1bcb42..dcc3d3e1 100644 --- a/border-tch-agent/src/util.rs +++ b/border-tch-agent/src/util.rs @@ -2,25 +2,30 @@ use crate::model::ModelBase; use log::trace; use serde::{Deserialize, Serialize}; -mod quantile_loss; mod named_tensors; -pub use quantile_loss::quantile_huber_loss; +mod quantile_loss; +use border_core::record::{Record, RecordValue}; pub use named_tensors::NamedTensors; +pub use quantile_loss::quantile_huber_loss; +use std::convert::TryFrom; +use tch::nn::VarStore; /// Critic loss type. #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Deserialize, Serialize, PartialEq, Clone)] pub enum CriticLoss { /// Mean squared error. - MSE, + Mse, /// Smooth L1 loss. SmoothL1, } -/// Apply soft update on a model. +/// Apply soft update on variables. /// /// Variables are identified by their names. +/// +/// dest = tau * src + (1.0 - tau) * dest pub fn track(dest: &mut M, src: &mut M, tau: f64) { let src = &mut src.get_var_store().variables(); let dest = &mut dest.get_var_store().variables(); @@ -44,11 +49,29 @@ pub fn concat_slices(s1: &[i64], s2: &[i64]) -> Vec { v } -/// Returns the dimension of output vectors, i.e., the number of discrete outputs. +/// Interface for handling output dimensions. pub trait OutDim { - /// Returns the dimension of output vectors, i.e., the number of discrete outputs. + /// Returns the output dimension. fn get_out_dim(&self) -> i64; /// Sets the output dimension. fn set_out_dim(&mut self, v: i64); } + +/// Returns the mean and standard deviation of the parameters. +pub fn param_stats(var_store: &VarStore) -> Record { + let mut record = Record::empty(); + + for (k, v) in var_store.variables() { + // let m: f32 = v.mean(tch::Kind::Float).into(); + let m = f32::try_from(v.mean(tch::Kind::Float)).expect("Failed to convert Tensor to f32"); + let k_mean = format!("{}_mean", &k); + record.insert(k_mean, RecordValue::Scalar(m)); + + let m = f32::try_from(v.std(false)).expect("Failed to convert Tensor to f32"); + let k_std = format!("{}_std", k); + record.insert(k_std, RecordValue::Scalar(m)); + } + + record +} diff --git a/border-tch-agent/src/util/named_tensors.rs b/border-tch-agent/src/util/named_tensors.rs index d27a301c..5ef684a4 100644 --- a/border-tch-agent/src/util/named_tensors.rs +++ b/border-tch-agent/src/util/named_tensors.rs @@ -25,13 +25,13 @@ impl NamedTensors { let dest = &mut vs.variables(); // let device = vs.device(); debug_assert_eq!(src.len(), dest.len()); - + tch::no_grad(|| { for (name, src) in src.iter() { let dest = dest.get_mut(name).unwrap(); dest.copy_(src); } - }); + }); } } @@ -52,31 +52,37 @@ impl Clone for NamedTensors { mod test { use super::NamedTensors; use std::convert::{TryFrom, TryInto}; - use tch::{Tensor, nn::{self, Module}, Device::Cpu}; + use tch::{ + nn::{self, Module}, + Device::Cpu, + Tensor, + }; #[test] fn test_named_tensors() { tch::manual_seed(42); - let tensor1 = Tensor::try_from(vec![1., 2., 3.]).unwrap().internal_cast_float(false); + let tensor1 = Tensor::try_from(vec![1., 2., 3.]) + .unwrap() + .internal_cast_float(false); let vs1 = nn::VarStore::new(Cpu); let model1 = nn::seq() .add(nn::linear(&vs1.root() / "layer1", 3, 8, Default::default())) .add(nn::linear(&vs1.root() / "layer2", 8, 2, Default::default())); - let mut vs2 = nn::VarStore::new(tch::Device::cuda_if_available()); + let mut vs2 = nn::VarStore::new(tch::Device::cuda_if_available()); let model2 = nn::seq() .add(nn::linear(&vs2.root() / "layer1", 3, 8, Default::default())) .add(nn::linear(&vs2.root() / "layer2", 8, 2, Default::default())); let device = vs2.device(); - + let t1: Vec = model1.forward(&tensor1).try_into().unwrap(); let t2: Vec = model2.forward(&tensor1.to(device)).try_into().unwrap(); - + let nt = NamedTensors::copy_from(&vs1); nt.copy_to(&mut vs2); - + let t3: Vec = model2.forward(&tensor1.to(device)).try_into().unwrap(); for i in 0..2 { @@ -86,5 +92,5 @@ mod test { // println!("{:?}", t1); // println!("{:?}", t2); // println!("{:?}", t3); - } + } } diff --git a/border-tensorboard/Cargo.toml b/border-tensorboard/Cargo.toml index 2bb44c7f..c7b709d2 100644 --- a/border-tensorboard/Cargo.toml +++ b/border-tensorboard/Cargo.toml @@ -1,18 +1,14 @@ [package] name = "border-tensorboard" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "MIT OR Apache-2.0" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true readme = "README.md" -autoexamples = false [dependencies] -border-core = { version = "0.0.6", path = "../border-core" } +border-core = { version = "0.0.7", path = "../border-core" } tensorboard-rs = { workspace = true } diff --git a/border-tensorboard/src/lib.rs b/border-tensorboard/src/lib.rs index 428450aa..e958b02b 100644 --- a/border-tensorboard/src/lib.rs +++ b/border-tensorboard/src/lib.rs @@ -1,4 +1,4 @@ -use border_core::record::{Record, RecordValue, Recorder}; +use border_core::record::{AggregateRecorder, Record, RecordValue, Recorder}; use std::path::Path; use tensorboard_rs::summary_writer::SummaryWriter; @@ -6,6 +6,7 @@ use tensorboard_rs::summary_writer::SummaryWriter; pub struct TensorboardRecorder { writer: SummaryWriter, step_key: String, + latest_record: Option, ignore_unsupported_value: bool, } @@ -18,6 +19,7 @@ impl TensorboardRecorder { writer: SummaryWriter::new(logdir), step_key: "opt_steps".to_string(), ignore_unsupported_value: true, + latest_record: None, } } @@ -29,6 +31,7 @@ impl TensorboardRecorder { writer: SummaryWriter::new(logdir), step_key: "opt_steps".to_string(), ignore_unsupported_value: false, + latest_record: None, } } } @@ -75,3 +78,17 @@ impl Recorder for TensorboardRecorder { } } } + +impl AggregateRecorder for TensorboardRecorder { + fn store(&mut self, record: Record) { + self.latest_record = Some(record); + } + + fn flush(&mut self, step: i64) { + if self.latest_record.is_some() { + let mut record = self.latest_record.take().unwrap(); + record.insert("opt_steps", RecordValue::Scalar(step as _)); + self.write(record); + } + } +} diff --git a/border/Cargo.toml b/border/Cargo.toml index bb57f4e8..bcdbe56d 100644 --- a/border/Cargo.toml +++ b/border/Cargo.toml @@ -1,46 +1,65 @@ [package] name = "border" -version = "0.0.6" -authors = ["Taku Yoshioka "] -edition = "2018" -rust-version = "1.68.2" - -description = "Reinforcement learning library" -repository = "https://github.com/taku-y/border" -keywords = ["rl"] -categories = ["science"] -license = "GPL-2.0-or-later" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +package.license = "GPL-2.0-or-later" readme = "README.md" -autoexamples = false [dependencies] aquamarine = { workspace = true } tch = { workspace = true, optional = true } -border-async-trainer = { version = "0.0.6", path = "../border-async-trainer", optional = true } +candle-core = { workspace = true, optional = true } +border-async-trainer = { version = "0.0.7", path = "../border-async-trainer", optional = true } anyhow = { workspace = true } log = { workspace = true } dirs = { workspace = true } zip = "0.5.12" -reqwest = { version = "0.11.3", features = ["blocking"] } -border-core = { version = "0.0.6", path = "../border-core" } +reqwest = { workspace = true } +border-core = { version = "0.0.7", path = "../border-core" } [[example]] name = "dqn_cartpole" +path = "examples/gym/dqn_cartpole.rs" +required-features = ["candle-core"] +test = false + +[[example]] +name = "dqn_cartpole_tch" +path = "examples/gym/dqn_cartpole_tch.rs" required-features = ["tch"] test = true [[example]] -name = "iqn_cartpole" +name = "iqn_cartpole_tch" +path = "examples/gym/iqn_cartpole_tch.rs" required-features = ["tch"] test = true [[example]] -name = "sac_pendulum" +name = "sac_pendulum_tch" +path = "examples/gym/sac_pendulum_tch.rs" required-features = ["tch"] test = true +[[example]] +name = "sac_pendulum" +path = "examples/gym/sac_pendulum.rs" +required-features = ["candle-core"] +test = true + [[example]] name = "sac_lunarlander_cont" +path = "examples/gym/sac_lunarlander_cont.rs" +required-features = ["candle-core"] +test = false + +[[example]] +name = "sac_lunarlander_cont_tch" +path = "examples/gym/sac_lunarlander_cont_tch.rs" required-features = ["tch"] test = false @@ -52,12 +71,18 @@ test = false [[example]] name = "dqn_atari" path = "examples/atari/dqn_atari.rs" +required-features = ["candle-core"] +test = false + +[[example]] +name = "dqn_atari_tch" +path = "examples/atari/dqn_atari_tch.rs" required-features = ["tch"] test = false [[example]] -name = "dqn_atari_async" -path = "examples/atari/dqn_atari_async.rs" +name = "dqn_atari_async_tch" +path = "examples/atari/dqn_atari_async_tch.rs" required-features = ["tch", "border-async-trainer"] test = false @@ -67,41 +92,69 @@ path = "examples/gym-robotics/sac_fetch_reach.rs" required-features = ["tch"] test = false +# [[example]] +# name = "iqn_atari_rs" +# required-features = ["tch"] +# test = false + [[example]] -name = "iqn_atari_rs" -required-features = ["tch"] +name = "sac_mujoco" +path = "examples/mujoco/sac_mujoco.rs" +required-features = ["candle-core"] test = false [[example]] -name = "sac_ant" +name = "sac_mujoco_tch" +path = "examples/mujoco/sac_mujoco_tch.rs" required-features = ["tch"] test = false [[example]] -name = "sac_ant_async" -required-features = ["tch", "border-async-trainer"] +name = "convert_sac_policy_to_edge" +path = "examples/gym/convert_sac_policy_to_edge.rs" +required-features = ["border-tch-agent", "tch"] test = false [[example]] -name = "make_cfg_dqn_atari" -required-features = ["border-async-trainer"] +name = "sac_mujoco_async_tch" +path = "examples/mujoco/sac_mujoco_async_tch.rs" +required-features = ["tch", "border-async-trainer"] test = false [[example]] -name = "make_cfg_iqn_atari" -required-features = ["border-async-trainer"] +name = "pendulum_edge" +path = "examples/gym/pendulum_edge.rs" test = false +# [[example]] +# name = "sac_ant_async" +# path = "examples/mujoco/sac_ant_async.rs" +# required-features = ["tch", "border-async-trainer"] +# test = false + +# [[example]] +# name = "make_cfg_dqn_atari" +# required-features = ["border-async-trainer"] +# test = false + +# [[example]] +# name = "make_cfg_iqn_atari" +# required-features = ["border-async-trainer"] +# test = false + [dev-dependencies] clap = { workspace = true } csv = { workspace = true } tempdir = { workspace = true } -border-derive = { version = "0.0.6", path = "../border-derive" } -border-core = { version = "0.0.6", path = "../border-core" } -border-tensorboard = { version = "0.0.6", path = "../border-tensorboard" } -border-tch-agent = { version = "0.0.6", path = "../border-tch-agent" } -border-py-gym-env = { version = "0.0.6", path = "../border-py-gym-env" } -border-atari-env = { version = "0.0.6", path = "../border-atari-env" } +border-derive = { version = "0.0.7", path = "../border-derive" } +border-core = { version = "0.0.7", path = "../border-core" } +border-tensorboard = { version = "0.0.7", path = "../border-tensorboard" } +border-tch-agent = { version = "0.0.7", path = "../border-tch-agent" } +border-policy-no-backend = { version = "0.0.7", path = "../border-policy-no-backend" } +border-py-gym-env = { version = "0.0.7", path = "../border-py-gym-env" } +border-atari-env = { version = "0.0.7", path = "../border-atari-env" } +border-candle-agent = { version = "0.0.7", path = "../border-candle-agent" } +border-mlflow-tracking = { version = "0.0.7", path = "../border-mlflow-tracking" } serde = { workspace = true, features = ["derive"] } crossbeam-channel = { workspace = true } env_logger = { workspace = true } @@ -114,14 +167,13 @@ chrono = { workspace = true } tensorboard-rs = { workspace = true } thiserror = { workspace = true } serde_yaml = { workspace = true } +bincode = { workspace = true } [package.metadata.docs.rs] features = ["doc-only"] [features] -# default = [ "adam_eps" ] doc-only = ["tch/doc-only"] -adam_eps = [] - -#[target.'cfg(feature="adam_eps")'.patch.crates-io] -#tch = { git = "https://github.com/taku-y/tch-rs", branch = "adam_eps" } +cuda = ["candle-core/cuda"] +cudnn = ["candle-core/cudnn"] +border-tch-agent = [] diff --git a/border/README.md b/border/README.md index 317e0f30..e37344ec 100644 --- a/border/README.md +++ b/border/README.md @@ -9,13 +9,19 @@ A reinforcement learning library in Rust. Border consists of the following crates: -* [border-core](https://crates.io/crates/border-core) provides basic traits and functions generic to environments and reinforcmenet learning (RL) agents. -* [border-py-gym-env](https://crates.io/crates/border-py-gym-env) is a wrapper of the [Gym](https://gym.openai.com) environments written in Python, with the support of [pybullet-gym](https://github.com/benelot/pybullet-gym) and [atari](https://github.com/mgbellemare/Arcade-Learning-Environment). -* [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). -* [border-tch-agent](https://crates.io/crates/border-tch-agent) is a collection of RL agents based on [tch](https://crates.io/crates/tch). Deep Q network (DQN), implicit quantile network (IQN), and soft actor critic (SAC) are includes. -* [border-async-trainer](https://crates.io/crates/border-async-trainer) defines some traits and functions for asynchronous training of RL agents by multiple actors, each of which runs a sampling process of an agent and an environment in parallel. - -You can use a part of these crates for your purposes, though [border-core](https://crates.io/crates/border-core) is mandatory. [This crate](https://crates.io/crates/border) is just a collection of examples. See [Documentation](https://docs.rs/border) for more details. +* Core and utility + * [border-core](https://crates.io/crates/border-core) provides basic traits and functions generic to environments and reinforcmenet learning (RL) agents. + * [border-tensorboard](https://crates.io/crates/border-tensorboard) has `TensorboardRecorder` struct to write records which can be shown in Tensorboard. It is based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). + * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) support MLflow tracking to log metrices during training via REST API. + * [border-async-trainer](https://crates.io/crates/border-async-trainer) defines some traits and functions for asynchronous training of RL agents by multiple actors, which runs sampling processes in parallel. In each sampling process, an agent interacts with an environment to collect samples to be sent to a shared replay buffer. + * [border](https://crates.io/crates/border) is just a collection of examples. +* Environment + * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) is a wrapper of the [Gymnasium](https://gymnasium.farama.org) environments written in Python. + * [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). +* Agent + * [border-tch-agent](https://crates.io/crates/border-tch-agent) includes RL agents based on [tch](https://crates.io/crates/tch), including Deep Q network (DQN), implicit quantile network (IQN), and soft actor critic (SAC). + * [border-candle-agent](https://crates.io/crates/border-candle-agent) includes RL agents based on [candle](https://crates.io/crates/candle-core) + * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) includes a policy that is independent of any deep learning backend, such as Torch. ## Status @@ -23,15 +29,23 @@ Border is experimental and currently under development. API is unstable. ## Examples -In examples directory, you can see how to run some examples. Python>=3.7 and [gym](https://gym.openai.com) must be installed for running examples using [border-py-gym-env](https://crates.io/crates/border-py-gym-env). Some examples requires [PyBullet Gym](https://github.com/benelot/pybullet-gym). As the agents used in the examples are based on [tch-rs](https://github.com/LaurentMazare/tch-rs), libtorch is required to be installed. +There are some example sctipts in `border/examples` directory. These are tested in Docker containers, speficically the one in `aarch64` directory on M2 Macbook air. Some scripts take few days for the training process, tested on Ubuntu22.04 virtual machine in [GPUSOROBAN](https://soroban.highreso.jp), a computing cloud. + +## Docker + +In `docker` directory, there are scripts for running a Docker container, in which you can try the examples described above. Currently, only `aarch64` is mainly used for the development. ## License -Crates | License -----------------------|------------------ -`border-core` | MIT OR Apache-2.0 -`border-py-gym-env` | MIT OR Apache-2.0 -`border-atari-env` | GPL-2.0-or-later -`border-tch-agent` | MIT OR Apache-2.0 -`border-async-trainer`| MIT OR Apache-2.0 -`border` | GPL-2.0-or-later +Crates | License +--------------------------|------------------ +`border-core` | MIT OR Apache-2.0 +`border-tensorboard` | MIT OR Apache-2.0 +`border-mlflow-tracking` | MIT OR Apache-2.0 +`border-async-trainer` | MIT OR Apache-2.0 +`border-py-gym-env` | MIT OR Apache-2.0 +`border-atari-env` | GPL-2.0-or-later +`border-tch-agent` | MIT OR Apache-2.0 +`border-candle-agent` | MIT OR Apache-2.0 +`border-policy-no-backend`| MIT OR Apache-2.0 +`border` | GPL-2.0-or-later diff --git a/border/examples/.gitignore b/border/examples/.gitignore new file mode 100644 index 00000000..233d6f93 --- /dev/null +++ b/border/examples/.gitignore @@ -0,0 +1,10 @@ +# !.gitignore +# */* +# * +*.pt +*.pt.tch +events* +*.csv +*.zip +*.gz +**/best diff --git a/border/examples/README.md b/border/examples/README.md index 46349cf5..3d37d7d3 100644 --- a/border/examples/README.md +++ b/border/examples/README.md @@ -1,3 +1,10 @@ +The following directories contain example scripts. + +* `gym` - Classic control environments in [Gymnasium](https://gymnasium.farama.org) based on [border-py-gym-env](https://crates.io/crates/border-py-gym-env). +* `gym-robotics` - A robotic environment (fetch-reach) in [Gymnasium-Robotics](https://robotics.farama.org/) based on [border-py-gym-env](https://crates.io/crates/border-py-gym-env). +* `mujoco` - Mujoco environments in [Gymnasium](https://gymnasium.farama.org) based on [border-py-gym-env](https://crates.io/crates/border-py-gym-env). +* `atari` - Atari environments based on [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of [atari-env](https://crates.io/crates/atari-env), which is a part of [gym-rs](https://crates.io/crates/gym-rs). + ## Gym You may need to set PYTHONPATH as `PYTHONPATH=./border-py-gym-env/examples`. diff --git a/border/examples/atari/.gitignore b/border/examples/atari/.gitignore index 5b229866..9b2be0bf 100644 --- a/border/examples/atari/.gitignore +++ b/border/examples/atari/.gitignore @@ -7,3 +7,4 @@ events* *.zip *.gz dqn_pong_tmp +**/best diff --git a/border/examples/atari/README.md b/border/examples/atari/README.md index ccd6aa86..19e43a22 100644 --- a/border/examples/atari/README.md +++ b/border/examples/atari/README.md @@ -2,3 +2,14 @@ This directory contains examples using Atari environments. +## tch agent + +```bash +cargo run --release --example dqn_atari_tch --features=tch -- pong --mlflow +``` + +## candle agent + +```bash +cargo run --release --example dqn_atari --features=candle-core,cuda,cudnn -- pong --mlflow +``` diff --git a/border/examples/atari/dqn_atari.rs b/border/examples/atari/dqn_atari.rs index 02574518..829f2094 100644 --- a/border/examples/atari/dqn_atari.rs +++ b/border/examples/atari/dqn_atari.rs @@ -4,239 +4,314 @@ use border_atari_env::{ BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, BorderAtariObsRawFilter, }; +use border_candle_agent::{ + cnn::Cnn, + dqn::{Dqn as Dqn_, DqnConfig}, + TensorBatch, +}; use border_core::{ - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, DefaultEvaluator, Env as _, Evaluator as _, Policy, Trainer, TrainerConfig, -}; -use border_derive::{Act, SubBatch}; -use border_tch_agent::{ - cnn::Cnn, - dqn::{Dqn as Dqn_, DqnConfig}, - TensorSubBatch, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; +use border_derive::{Act, BatchBase}; +use border_mlflow_tracking::MlflowTrackingClient; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; -use util_dqn_atari::{model_dir as model_dir_, Params}; +use clap::Parser; -type Obs = BorderAtariObs; +mod obs_act_types { + use super::*; -#[derive(Clone, SubBatch)] -struct ObsBatch(TensorSubBatch); + pub type Obs = BorderAtariObs; -impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } } -} -#[derive(SubBatch)] -// struct ActBatch(TensorSubBatch); -struct ActBatch(TensorSubBatch); + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); -impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } } -} -// Wrap `BorderAtariAct` to make a new type. -// Act also implements Into. -// TODO: Consider to implement Into on BorderAtariAct when feature=tch. -#[derive(Debug, Clone, Act)] -struct Act(BorderAtariAct); - -type ObsFilter = BorderAtariObsRawFilter; -type ActFilter = BorderAtariActRawFilter; -type EnvConfig = BorderAtariEnvConfig; -type Env = BorderAtariEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Dqn = Dqn_; -type Evaluator = DefaultEvaluator; - -fn env_config(name: impl Into) -> EnvConfig { - BorderAtariEnvConfig::default().name(name.into()) + // Wrap `BorderAtariAct` to make a new type. + // Act also implements Into. + // TODO: Consider to implement Into on BorderAtariAct when feature=tch. + #[derive(Debug, Clone, Act)] + pub struct Act(BorderAtariAct); + + pub type ObsFilter = BorderAtariObsRawFilter; + pub type ActFilter = BorderAtariActRawFilter; + pub type EnvConfig = BorderAtariEnvConfig; + pub type Env = BorderAtariEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Dqn = Dqn_; + pub type Evaluator = DefaultEvaluator; } -fn init<'a>() -> ArgMatches<'a> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_atari") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("name") - .long("name") - .takes_value(true) - .required(true) - .index(1) - .help("The name of the atari rom (e.g., pong)"), - ) - .arg( - Arg::with_name("play") - .long("play") - .takes_value(true) - .help("Play with the trained model of the given path"), - ) - .arg( - Arg::with_name("play-gdrive") - .long("play-gdrive") - .takes_value(false) - .help("Play with the trained model downloaded from google drive"), - ) - .arg( - Arg::with_name("per") - .long("per") - .takes_value(false) - .help("Train/play with prioritized experience replay"), - ) - .arg( - Arg::with_name("ddqn") - .long("ddqn") - .takes_value(false) - .help("Train/play with double DQN"), - ) - .arg( - Arg::with_name("debug") - .long("debug") - .takes_value(false) - .help("Run with debug configuration"), - ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when playing"), - ) - .arg( - Arg::with_name("show-config") - .long("show-config") - .takes_value(false) - .help("Showing configuration loaded from files"), - ) - .get_matches(); - - matches -} +use config::DqnAtariConfig; +use obs_act_types::*; -fn show_config( - env_config: &EnvConfig, - agent_config: &DqnConfig, - trainer_config: &TrainerConfig, -) { - println!("Device: {:?}", tch::Device::cuda_if_available()); - println!("{}", serde_yaml::to_string(&env_config).unwrap()); - println!("{}", serde_yaml::to_string(&agent_config).unwrap()); - println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); +fn cuda_if_available() -> candle_core::Device { + candle_core::Device::cuda_if_available(0).unwrap() } -fn model_dir(matches: &ArgMatches) -> Result { - let name = matches - .value_of("name") - .expect("The name of the environment was not given") - .to_string(); - let mut params = Params::default(); +mod config { + use self::util_dqn_atari::{ + DqnAtariAgentConfig, DqnAtariReplayBufferConfig, DqnAtariTrainerConfig, + }; + use serde::Serialize; + use std::io::Write; + + use super::*; - if matches.is_present("ddqn") { - params = params.ddqn(); + pub fn env_config(name: impl Into) -> EnvConfig { + BorderAtariEnvConfig::default().name(name.into()) } - if matches.is_present("per") { - params = params.per(); + pub fn show_config( + env_config: &EnvConfig, + agent_config: &DqnConfig, + trainer_config: &TrainerConfig, + ) { + println!("Device: {:?}", cuda_if_available()); + println!("{}", serde_yaml::to_string(&env_config).unwrap()); + println!("{}", serde_yaml::to_string(&agent_config).unwrap()); + println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); } - if matches.is_present("debug") { - params = params.debug(); + pub fn load_dqn_config<'a>(model_dir: impl Into<&'a str>) -> Result> { + let config_path = format!("{}/agent.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariAgentConfig = serde_yaml::from_reader(rdr)?; + println!("Load agent config: {}", config_path); + Ok(config.into()) } - model_dir_(name, ¶ms) -} + pub fn load_trainer_config<'a>(model_dir: impl Into<&'a str>) -> Result { + let config_path = format!("{}/trainer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariTrainerConfig = serde_yaml::from_reader(rdr)?; + println!("Load trainer config: {}", config_path); + Ok(config.into()) + } -fn model_dir_for_play(matches: &ArgMatches) -> String { - matches.value_of("play").unwrap().to_string() -} + pub fn load_replay_buffer_config<'a>( + model_dir: impl Into<&'a str>, + ) -> Result { + let config_path = format!("{}/replay_buffer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariReplayBufferConfig = serde_yaml::from_reader(rdr)?; + println!("Load replay buffer config: {}", config_path); + Ok(config.into()) + } -fn n_actions(env_config: &EnvConfig) -> Result { - Ok(Env::build(env_config, 0)?.get_num_actions_atari() as usize) -} + pub fn create_trainer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariTrainerConfig::default(); + let path = model_dir + "/trainer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create trainer config file: {}", path); + Ok(()) + } -fn load_dqn_config<'a>(model_dir: impl Into<&'a str>) -> Result> { - let config_path = format!("{}/agent.yaml", model_dir.into()); - DqnConfig::::load(config_path) + pub fn create_replay_buffer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariReplayBufferConfig::default(); + let path = model_dir + "/replay_buffer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create replay buffer config file: {}", path); + Ok(()) + } + + pub fn create_agent_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariAgentConfig::default(); + let path = model_dir + "/agent.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create agent config file: {}", path); + Ok(()) + } + + #[derive(Serialize)] + pub struct DqnAtariConfig { + pub trainer: TrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: DqnConfig, + } } -fn load_trainer_config<'a>(model_dir: impl Into<&'a str>) -> Result { - let config_path = format!("{}/trainer.yaml", model_dir.into()); - TrainerConfig::load(config_path) +mod utils { + use super::*; + + pub fn model_dir(args: &Args) -> String { + let name = &args.name; + format!("./border/examples/atari/model/candle/dqn_{}", name) + + // let mut params = Params::default(); + + // if matches.is_present("ddqn") { + // params = params.ddqn(); + // } + + // if matches.is_present("per") { + // params = params.per(); + // } + + // if matches.is_present("debug") { + // params = params.debug(); + // } + + // model_dir_(name, ¶ms) + } + + pub fn model_dir_for_eval(args: &Args) -> String { + model_dir(args) + } + + pub fn n_actions(env_config: &EnvConfig) -> Result { + Ok(Env::build(env_config, 0)?.get_num_actions_atari() as usize) + } + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &DqnAtariConfig, + ) -> Result> { + match args.mlflow { + true => { + let name = &args.name; + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment_id("Atari")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", name)?; + recorder_run.set_tag("algo", "dqn")?; + recorder_run.set_tag("backend", "candle")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } } -fn load_replay_buffer_config<'a>( - model_dir: impl Into<&'a str>, -) -> Result { - let config_path = format!("{}/replay_buffer.yaml", model_dir.into()); - SimpleReplayBufferConfig::load(config_path) +/// Train/eval DQN agent in atari environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the game + name: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, } -fn train(matches: ArgMatches) -> Result<()> { +fn train(args: &Args) -> Result<()> { // Configurations - let name = matches.value_of("name").unwrap(); - let model_dir = model_dir(&matches)?; - let env_config_train = env_config(name); - let env_config_eval = env_config(name).eval(); - let n_actions = n_actions(&env_config_train)?; + let name = &args.name; + let model_dir = utils::model_dir(&args); + let env_config_train = config::env_config(name); + let env_config_eval = config::env_config(name).eval(); + let n_actions = utils::n_actions(&env_config_train)?; let agent_config = { - let agent_config = load_dqn_config(model_dir.as_str())? + let agent_config = config::load_dqn_config(model_dir.as_str())? .out_dim(n_actions as _) - .device(tch::Device::cuda_if_available()); + .device(cuda_if_available()); agent_config }; - let trainer_config = load_trainer_config(model_dir.as_str())?; - let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?; + let trainer_config = + config::load_trainer_config(model_dir.as_str())?.model_dir(model_dir.clone()); + let replay_buffer_config = config::load_replay_buffer_config(model_dir.as_str())?; let step_proc_config = SimpleStepProcessorConfig {}; // Show configs or train - if matches.is_present("show-config") { - show_config(&env_config_train, &agent_config, &trainer_config); + if args.show_config { + config::show_config(&env_config_train, &agent_config, &trainer_config); } else { - let mut trainer = Trainer::::build( - trainer_config, - env_config_train, - step_proc_config, - replay_buffer_config, - ); + let config = DqnAtariConfig { + trainer: trainer_config.clone(), + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let mut trainer = Trainer::build(trainer_config); + let env = Env::build(&env_config_train, 0)?; + let step_proc = StepProc::build(&step_proc_config); let mut agent = Dqn::build(agent_config); - let mut recorder = TensorboardRecorder::new(model_dir); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(&args, &model_dir, &config)?; let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; } Ok(()) } -fn play(matches: ArgMatches) -> Result<()> { - let name = matches.value_of("name").unwrap(); - let model_dir = model_dir_for_play(&matches); +fn eval(args: &Args) -> Result<()> { + let name = &args.name; + let model_dir = utils::model_dir_for_eval(&args); let (env_config, n_actions) = { - let env_config = env_config(name).render(true); - let n_actions = n_actions(&env_config)?; + let env_config = config::env_config(name).render(true); + let n_actions = utils::n_actions(&env_config)?; (env_config, n_actions) }; let mut agent = { - let device = tch::Device::cuda_if_available(); - let agent_config = load_dqn_config(model_dir.as_str())? + let device = cuda_if_available(); + let agent_config = config::load_dqn_config(model_dir.as_str())? .out_dim(n_actions as _) .device(device); let mut agent = Dqn::build(agent_config); - agent.load(model_dir + "/best")?; + agent.load_params(model_dir + "/best")?; agent.eval(); agent }; @@ -247,13 +322,22 @@ fn play(matches: ArgMatches) -> Result<()> { Ok(()) } +fn create_config(args: &Args) -> Result<()> { + config::create_trainer_config(&args)?; + config::create_replay_buffer_config(&args)?; + config::create_agent_config(&args)?; + Ok(()) +} + fn main() -> Result<()> { - let matches = init(); + let args = Args::parse(); - if matches.is_present("play") || matches.is_present("play-gdrive") { - play(matches)?; + if args.eval { + eval(&args)?; + } else if args.create_config { + create_config(&args)?; } else { - train(matches)?; + train(&args)?; } Ok(()) diff --git a/border/examples/atari/dqn_atari_async.rs b/border/examples/atari/dqn_atari_async.rs deleted file mode 100644 index 7fbaa6bb..00000000 --- a/border/examples/atari/dqn_atari_async.rs +++ /dev/null @@ -1,300 +0,0 @@ -mod util_dqn_atari; -use anyhow::Result; -use border_async_trainer::{ - actor_stats_fmt, ActorManager as ActorManager_, ActorManagerConfig, - AsyncTrainer as AsyncTrainer_, AsyncTrainerConfig, -}; -use border_atari_env::{ - BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, - BorderAtariObsRawFilter, -}; -use border_core::{ - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - DefaultEvaluator, Env as _, -}; -use border_derive::{Act, SubBatch}; -use border_tch_agent::{ - cnn::Cnn, - dqn::{Dqn, DqnConfig, DqnExplorer, EpsilonGreedy}, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; -use crossbeam_channel::unbounded; -use std::{ - default::Default, - sync::{Arc, Mutex}, -}; -use util_dqn_atari::{model_dir_async as model_dir_async_, Params}; - -type Obs = BorderAtariObs; - -#[derive(Clone, SubBatch)] -struct ObsBatch(TensorSubBatch); - -impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } -} - -#[derive(SubBatch)] -struct ActBatch(TensorSubBatch); - -impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } -} - -// Wrap `BorderAtariAct` to make a new type. -// Act also implements Into. -// TODO: Consider to implement Into on BorderAtariAct when feature=tch. -#[derive(Debug, Clone, Act)] -struct Act(BorderAtariAct); - -type ObsFilter = BorderAtariObsRawFilter; -type ActFilter = BorderAtariActRawFilter; -type EnvConfig = BorderAtariEnvConfig; -type Env = BorderAtariEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Agent = Dqn; -type ActorManager = ActorManager_; -type AsyncTrainer = AsyncTrainer_; -type Evaluator = DefaultEvaluator; - -fn env_config(name: impl Into) -> EnvConfig { - BorderAtariEnvConfig::default().name(name.into()) -} - -fn parse_args<'a>() -> ArgMatches<'a> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_atari_async") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("name") - .long("name") - .takes_value(true) - .required(true) - .index(1) - .help("The name of the atari environment (e.g., PongNoFrameskip-v4)"), - ) - .arg( - Arg::with_name("per") - .long("per") - .takes_value(false) - .help("Train/play with prioritized experience replay"), - ) - .arg( - Arg::with_name("ddqn") - .long("ddqn") - .takes_value(false) - .help("Train/play with double DQN"), - ) - .arg( - Arg::with_name("debug") - .long("debug") - .takes_value(false) - .help("Run with debug configuration"), - ) - .arg( - Arg::with_name("show-config") - .long("show-config") - .takes_value(false) - .help("Showing configuration loaded from files"), - ) - .arg( - Arg::with_name("n-actors") - .long("n-actors") - .takes_value(true) - .default_value("6") - .help("The number of actors"), - ) - .arg( - Arg::with_name("eps-min") - .long("eps-min") - .takes_value(true) - .default_value("0.001") - .help("The minimum value of exploration noise probability"), - ) - .arg( - Arg::with_name("eps-max") - .long("eps-max") - .takes_value(true) - .default_value("0.4") - .help("The maximum value of exploration noise probability"), - ) - .get_matches(); - - matches -} - -fn show_config( - env_config: &EnvConfig, - agent_config: &DqnConfig, - actor_man_config: &ActorManagerConfig, - trainer_config: &AsyncTrainerConfig, -) { - println!("Device: {:?}", tch::Device::cuda_if_available()); - println!("{}", serde_yaml::to_string(&env_config).unwrap()); - println!("{}", serde_yaml::to_string(&agent_config).unwrap()); - println!("{}", serde_yaml::to_string(&actor_man_config).unwrap()); - println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); -} - -fn model_dir_async(matches: &ArgMatches) -> Result { - let name = matches - .value_of("name") - .expect("The name of the environment was not given") - .to_string(); - let mut params = Params::default(); - - if matches.is_present("ddqn") { - params = params.ddqn(); - } - - if matches.is_present("per") { - params = params.per(); - } - - if matches.is_present("debug") { - params = params.debug(); - } - - let model_dir = model_dir_async_(name, ¶ms)?; - - Ok(model_dir) -} - -fn n_actions(env_config: &EnvConfig) -> Result { - Ok(Env::build(env_config, 0)?.get_num_actions_atari() as usize) -} - -fn load_dqn_config<'a>(model_dir: impl Into<&'a str>) -> Result> { - let config_path = format!("{}/agent.yaml", model_dir.into()); - DqnConfig::::load(config_path) -} - -fn load_async_trainer_config<'a>(model_dir: impl Into<&'a str>) -> Result { - let config_path = format!("{}/trainer.yaml", model_dir.into()); - AsyncTrainerConfig::load(config_path) -} - -fn load_replay_buffer_config<'a>( - model_dir: impl Into<&'a str>, -) -> Result { - let config_path = format!("{}/replay_buffer.yaml", model_dir.into()); - SimpleReplayBufferConfig::load(config_path) -} - -fn train(matches: ArgMatches) -> Result<()> { - let name = matches.value_of("name").unwrap(); - let model_dir = model_dir_async(&matches)?; - let env_config_train = env_config(name); - let n_actions = n_actions(&env_config_train)?; - - // exploration parameters - let n_actors = matches - .value_of("n-actors") - .unwrap() - .parse::() - .unwrap(); - let eps_min = matches.value_of("eps-min").unwrap().parse::().unwrap(); - let eps_max = matches.value_of("eps-max").unwrap().parse::().unwrap(); - - // Configurations - let agent_config = load_dqn_config(model_dir.as_str())? - .out_dim(n_actions as _) - .device(tch::Device::cuda_if_available()); - let agent_configs = (0..n_actors) - .map(|ix| { - let n = ix as f64 / ((n_actors - 1) as f64); - let eps = (eps_max - eps_min) * n + eps_min; - let explorer = - DqnExplorer::EpsilonGreedy(EpsilonGreedy::new().eps_start(eps).eps_final(eps)); - agent_config - .clone() - .device(tch::Device::Cpu) - .explorer(explorer) - }) - .collect::>(); - let env_config_eval = env_config(name).eval(); - let replay_buffer_config = load_replay_buffer_config(model_dir.as_str())?; - let step_proc_config = SimpleStepProcessorConfig::default(); - let actor_man_config = ActorManagerConfig::default(); - let async_trainer_config = load_async_trainer_config(model_dir.as_str())?; - - if matches.is_present("show-config") { - show_config( - &env_config_train, - &agent_config, - &actor_man_config, - &async_trainer_config, - ); - } else { - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; - - // Shared flag to stop actor threads - let stop = Arc::new(Mutex::new(false)); - - // Creates channels - let (item_s, item_r) = unbounded(); // items pushed to replay buffer - let (model_s, model_r) = unbounded(); // model_info - - // guard for initialization of envs in multiple threads - let guard_init_env = Arc::new(Mutex::new(true)); - - // Actor manager and async trainer - let mut actors = ActorManager::build( - &actor_man_config, - &agent_configs, - &env_config_train, - &step_proc_config, - item_s, - model_r, - stop.clone(), - ); - let mut trainer = AsyncTrainer::build( - &async_trainer_config, - &agent_config, - &env_config_eval, - &replay_buffer_config, - item_r, - model_s, - stop.clone(), - ); - - // Set the number of threads - tch::set_num_threads(1); - - // Starts sampling and training - actors.run(guard_init_env.clone()); - let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env); - println!("Stats of async trainer"); - println!("{}", stats.fmt()); - - let stats = actors.stop_and_join(); - println!("Stats of generated samples in actors"); - println!("{}", actor_stats_fmt(&stats)); - } - - Ok(()) -} - -fn main() -> Result<()> { - let matches = parse_args(); - - train(matches)?; - - Ok(()) -} diff --git a/border/examples/atari/dqn_atari_async_tch.rs b/border/examples/atari/dqn_atari_async_tch.rs new file mode 100644 index 00000000..097810f8 --- /dev/null +++ b/border/examples/atari/dqn_atari_async_tch.rs @@ -0,0 +1,346 @@ +mod util_dqn_atari; +use anyhow::Result; +use border_async_trainer::{ + util::train_async, /*ActorManager as ActorManager_,*/ ActorManagerConfig, + /*AsyncTrainer as AsyncTrainer_,*/ AsyncTrainerConfig, +}; +use border_atari_env::{ + BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, + BorderAtariObsRawFilter, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + DefaultEvaluator, Env as _, +}; +use border_derive::{Act, BatchBase}; +use border_mlflow_tracking::MlflowTrackingClient; +use border_tch_agent::{ + cnn::Cnn, + dqn::{Dqn, DqnConfig, DqnExplorer, EpsilonGreedy}, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; + +mod obs_act_types { + use super::*; + + pub type Obs = BorderAtariObs; + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + // Wrap `BorderAtariAct` to make a new type. + // Act also implements Into. + // TODO: Consider to implement Into on BorderAtariAct when feature=tch. + #[derive(Debug, Clone, Act)] + pub struct Act(BorderAtariAct); + + pub type ObsFilter = BorderAtariObsRawFilter; + pub type ActFilter = BorderAtariActRawFilter; + pub type EnvConfig = BorderAtariEnvConfig; + pub type Env = BorderAtariEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Agent = Dqn; + pub type Evaluator = DefaultEvaluator; +} + +use config::DqnAtariAsyncConfig; +use obs_act_types::*; + +mod config { + use self::util_dqn_atari::{ + DqnAtariAgentConfig, DqnAtariAsyncTrainerConfig, DqnAtariReplayBufferConfig, + }; + use serde::Serialize; + use std::io::Write; + + use super::*; + + pub fn env_config(name: impl Into) -> EnvConfig { + BorderAtariEnvConfig::default().name(name.into()) + } + + pub fn show_config( + env_config: &EnvConfig, + agent_config: &DqnConfig, + actor_man_config: &ActorManagerConfig, + trainer_config: &AsyncTrainerConfig, + ) { + println!("Device: {:?}", tch::Device::cuda_if_available()); + println!("{}", serde_yaml::to_string(&env_config).unwrap()); + println!("{}", serde_yaml::to_string(&agent_config).unwrap()); + println!("{}", serde_yaml::to_string(&actor_man_config).unwrap()); + println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); + } + + pub fn load_dqn_config<'a>(model_dir: impl Into<&'a str>) -> Result> { + let config_path = format!("{}/agent.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariAgentConfig = serde_yaml::from_reader(rdr)?; + println!("Load agent config: {}", config_path); + Ok(config.into()) + } + + pub fn load_async_trainer_config<'a>( + model_dir: impl Into<&'a str>, + ) -> Result { + let config_path = format!("{}/trainer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariAsyncTrainerConfig = serde_yaml::from_reader(rdr)?; + println!("Load async trainer config: {}", config_path); + Ok(config.into()) + } + + pub fn load_replay_buffer_config<'a>( + model_dir: impl Into<&'a str>, + ) -> Result { + let config_path = format!("{}/replay_buffer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariReplayBufferConfig = serde_yaml::from_reader(rdr)?; + println!("Load replay buffer config: {}", config_path); + Ok(config.into()) + } + + pub fn create_async_trainer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = DqnAtariAsyncTrainerConfig::default(); + let path = model_dir + "/trainer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create trainer config file: {}", path); + Ok(()) + } + + pub fn create_replay_buffer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariReplayBufferConfig::default(); + let path = model_dir + "/replay_buffer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create replay buffer config file: {}", path); + Ok(()) + } + + pub fn create_agent_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariAgentConfig::default(); + let path = model_dir + "/agent.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create agent config file: {}", path); + Ok(()) + } + + #[derive(Serialize)] + pub struct DqnAtariAsyncConfig { + pub trainer: AsyncTrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: DqnConfig, + } +} + +mod utils { + use super::*; + + pub fn model_dir(args: &Args) -> String { + let name = &args.name; + format!("./border/examples/atari/model/tch/dqn_{}_async", name) + + // let name = matches + // .value_of("name") + // .expect("The name of the environment was not given") + // .to_string(); + // let mut params = Params::default(); + + // if matches.is_present("ddqn") { + // params = params.ddqn(); + // } + + // if matches.is_present("per") { + // params = params.per(); + // } + + // if matches.is_present("debug") { + // params = params.debug(); + // } + + // let model_dir = model_dir_async_(name, ¶ms)?; + + // Ok(model_dir) + } + + pub fn n_actions(env_config: &EnvConfig) -> Result { + Ok(Env::build(env_config, 0)?.get_num_actions_atari() as usize) + } + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &DqnAtariAsyncConfig, + ) -> Result> { + match args.mlflow { + true => { + let name = &args.name; + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment_id("Atari")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", name)?; + recorder_run.set_tag("algo", "dqn_async")?; + recorder_run.set_tag("backend", "tch")?; + recorder_run.set_tag("n_actors", args.n_actors.to_string())?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train DQN agent in atari environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the game + name: String, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, + + /// Number of actors, default to 6 + #[arg(long, default_value_t = 6)] + n_actors: usize, + + /// The minimum value of exploration noise probability, default to 0.001 + #[arg(long, default_value_t = 0.001)] + eps_min: f64, + + /// The maximum value of exploration noise probability, default to 0.4 + #[arg(long, default_value_t = 0.4)] + eps_max: f64, +} + +fn train(args: &Args) -> Result<()> { + let name = &args.name; + let model_dir = utils::model_dir(&args); + let env_config_train = config::env_config(name); + let n_actions = utils::n_actions(&env_config_train)?; + + // exploration parameters + let n_actors = args.n_actors; + let eps_min = &args.eps_min; + let eps_max = &args.eps_max; + + // Configurations + let agent_config = config::load_dqn_config(model_dir.as_str())? + .out_dim(n_actions as _) + .device(tch::Device::cuda_if_available()); + let agent_configs = (0..n_actors) + .map(|ix| { + let n = ix as f64 / ((n_actors - 1) as f64); + let eps = (eps_max - eps_min) * n + eps_min; + let explorer = + DqnExplorer::EpsilonGreedy(EpsilonGreedy::new().eps_start(eps).eps_final(eps)); + agent_config + .clone() + .device(tch::Device::Cpu) + .explorer(explorer) + }) + .collect::>(); + let env_config_eval = config::env_config(name).eval(); + let replay_buffer_config = config::load_replay_buffer_config(model_dir.as_str())?; + let step_proc_config = SimpleStepProcessorConfig::default(); + let actor_man_config = ActorManagerConfig::default(); + let async_trainer_config = + config::load_async_trainer_config(model_dir.as_str())?.model_dir(model_dir.as_str())?; + + if args.show_config { + config::show_config( + &env_config_train, + &agent_config, + &actor_man_config, + &async_trainer_config, + ); + } else { + let config = config::DqnAtariAsyncConfig { + trainer: async_trainer_config.clone(), + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let mut recorder = utils::create_recorder(&args, &model_dir, &config)?; + let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; + + train_async::( + &agent_config, + &agent_configs, + &env_config_train, + &env_config_eval, + &step_proc_config, + &replay_buffer_config, + &actor_man_config, + &async_trainer_config, + &mut recorder, + &mut evaluator, + ); + } + + Ok(()) +} + +fn create_config(args: &Args) -> Result<()> { + config::create_async_trainer_config(&args)?; + config::create_replay_buffer_config(&args)?; + config::create_agent_config(&args)?; + Ok(()) +} + +fn main() -> Result<()> { + tch::set_num_threads(1); + let args = Args::parse(); + + if args.create_config { + create_config(&args)?; + } else { + train(&args)?; + } + + Ok(()) +} diff --git a/border/examples/atari/dqn_atari_tch.rs b/border/examples/atari/dqn_atari_tch.rs new file mode 100644 index 00000000..8c6c52bf --- /dev/null +++ b/border/examples/atari/dqn_atari_tch.rs @@ -0,0 +1,344 @@ +mod util_dqn_atari; +use anyhow::Result; +use border_atari_env::{ + BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, + BorderAtariObsRawFilter, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::{Act, BatchBase}; +use border_mlflow_tracking::MlflowTrackingClient; +use border_tch_agent::{ + cnn::Cnn, + dqn::{Dqn as Dqn_, DqnConfig}, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; + +mod obs_act_types { + use super::*; + + pub type Obs = BorderAtariObs; + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + // Wrap `BorderAtariAct` to make a new type. + // Act also implements Into. + // TODO: Consider to implement Into on BorderAtariAct when feature=tch. + #[derive(Debug, Clone, Act)] + pub struct Act(BorderAtariAct); + + pub type ObsFilter = BorderAtariObsRawFilter; + pub type ActFilter = BorderAtariActRawFilter; + pub type EnvConfig = BorderAtariEnvConfig; + pub type Env = BorderAtariEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Dqn = Dqn_; + pub type Evaluator = DefaultEvaluator; +} + +use config::DqnAtariConfig; +use obs_act_types::*; + +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod config { + use self::util_dqn_atari::{ + DqnAtariAgentConfig, DqnAtariReplayBufferConfig, DqnAtariTrainerConfig, + }; + use serde::Serialize; + use std::io::Write; + + use super::*; + + pub fn env_config(name: impl Into) -> EnvConfig { + BorderAtariEnvConfig::default().name(name.into()) + } + + pub fn show_config( + env_config: &EnvConfig, + agent_config: &DqnConfig, + trainer_config: &TrainerConfig, + ) { + println!("Device: {:?}", cuda_if_available()); + println!("{}", serde_yaml::to_string(&env_config).unwrap()); + println!("{}", serde_yaml::to_string(&agent_config).unwrap()); + println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); + } + + pub fn load_dqn_config<'a>(model_dir: impl Into<&'a str>) -> Result> { + let config_path = format!("{}/agent.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariAgentConfig = serde_yaml::from_reader(rdr)?; + println!("Load agent config: {}", config_path); + Ok(config.into()) + } + + pub fn load_trainer_config<'a>(model_dir: impl Into<&'a str>) -> Result { + let config_path = format!("{}/trainer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariTrainerConfig = serde_yaml::from_reader(rdr)?; + println!("Load trainer config: {}", config_path); + Ok(config.into()) + } + + pub fn load_replay_buffer_config<'a>( + model_dir: impl Into<&'a str>, + ) -> Result { + let config_path = format!("{}/replay_buffer.yaml", model_dir.into()); + let file = std::fs::File::open(config_path.clone())?; + let rdr = std::io::BufReader::new(file); + let config: DqnAtariReplayBufferConfig = serde_yaml::from_reader(rdr)?; + println!("Load replay buffer config: {}", config_path); + Ok(config.into()) + } + + pub fn create_trainer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariTrainerConfig::default(); + let path = model_dir + "/trainer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create trainer config file: {}", path); + Ok(()) + } + + pub fn create_replay_buffer_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariReplayBufferConfig::default(); + let path = model_dir + "/replay_buffer.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create replay buffer config file: {}", path); + Ok(()) + } + + pub fn create_agent_config(args: &Args) -> Result<()> { + let model_dir = utils::model_dir(args); + let config = util_dqn_atari::DqnAtariAgentConfig::default(); + let path = model_dir + "/agent.yaml"; + let mut file = std::fs::File::create(path.clone())?; + file.write_all(serde_yaml::to_string(&config)?.as_bytes())?; + println!("Create agent config file: {}", path); + Ok(()) + } + + #[derive(Serialize)] + pub struct DqnAtariConfig { + pub trainer: TrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: DqnConfig, + } +} + +mod utils { + use super::*; + + pub fn model_dir(args: &Args) -> String { + let name = &args.name; + format!("./border/examples/atari/model/tch/dqn_{}", name) + + // let mut params = Params::default(); + + // if matches.is_present("ddqn") { + // params = params.ddqn(); + // } + + // if matches.is_present("per") { + // params = params.per(); + // } + + // if matches.is_present("debug") { + // params = params.debug(); + // } + + // model_dir_(name, ¶ms) + } + + pub fn model_dir_for_eval(args: &Args) -> String { + model_dir(args) + } + + pub fn n_actions(env_config: &EnvConfig) -> Result { + Ok(Env::build(env_config, 0)?.get_num_actions_atari() as usize) + } + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &DqnAtariConfig, + ) -> Result> { + match args.mlflow { + true => { + let name = &args.name; + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment_id("Atari")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", name)?; + recorder_run.set_tag("algo", "dqn")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train/eval DQN agent in atari environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the game + name: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, +} + +fn train(args: &Args) -> Result<()> { + // Configurations + let name = &args.name; + let model_dir = utils::model_dir(&args); + let env_config_train = config::env_config(name); + let env_config_eval = config::env_config(name).eval(); + let n_actions = utils::n_actions(&env_config_train)?; + let agent_config = { + let agent_config = config::load_dqn_config(model_dir.as_str())? + .out_dim(n_actions as _) + .device(cuda_if_available()); + agent_config + }; + let trainer_config = + config::load_trainer_config(model_dir.as_str())?.model_dir(model_dir.clone()); + let replay_buffer_config = config::load_replay_buffer_config(model_dir.as_str())?; + let step_proc_config = SimpleStepProcessorConfig {}; + + // Show configs or train + if args.show_config { + config::show_config(&env_config_train, &agent_config, &trainer_config); + } else { + let config = DqnAtariConfig { + trainer: trainer_config.clone(), + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let mut trainer = Trainer::build(trainer_config); + let env = Env::build(&env_config_train, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Dqn::build(agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(&args, &model_dir, &config)?; + let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + } + + Ok(()) +} + +fn eval(args: &Args) -> Result<()> { + let name = &args.name; + let model_dir = utils::model_dir_for_eval(&args); + + let (env_config, n_actions) = { + let env_config = config::env_config(name).render(true); + let n_actions = utils::n_actions(&env_config)?; + (env_config, n_actions) + }; + let mut agent = { + let device = cuda_if_available(); + let agent_config = config::load_dqn_config(model_dir.as_str())? + .out_dim(n_actions as _) + .device(device); + let mut agent = Dqn::build(agent_config); + agent.load_params(model_dir + "/best")?; + agent.eval(); + agent + }; + // let mut recorder = BufferedRecorder::new(); + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn create_config(args: &Args) -> Result<()> { + config::create_trainer_config(&args)?; + config::create_replay_buffer_config(&args)?; + config::create_agent_config(&args)?; + Ok(()) +} + +fn main() -> Result<()> { + let args = Args::parse(); + + if args.eval { + eval(&args)?; + } else if args.create_config { + create_config(&args)?; + } else { + train(&args)?; + } + + Ok(()) +} diff --git a/border/examples/iqn_atari_rs.rs b/border/examples/atari/iqn_atari.rs similarity index 98% rename from border/examples/iqn_atari_rs.rs rename to border/examples/atari/iqn_atari.rs index 150154f7..78eada97 100644 --- a/border/examples/iqn_atari_rs.rs +++ b/border/examples/atari/iqn_atari.rs @@ -16,7 +16,7 @@ use border_tch_agent::{ cnn::Cnn, iqn::{Iqn as Iqn_, IqnConfig as IqnConfig_}, mlp::Mlp, - TensorSubBatch, + TensorBatch, }; use border_tensorboard::TensorboardRecorder; use clap::{App, Arg, ArgMatches}; @@ -27,22 +27,22 @@ use util_iqn_atari::{model_dir as model_dir_, Params}; type Obs = BorderAtariObs; #[derive(Clone, SubBatch)] -struct ObsBatch(TensorSubBatch); +struct ObsBatch(TensorBatch); impl From for ObsBatch { fn from(obs: Obs) -> Self { let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } #[derive(SubBatch)] -struct ActBatch(TensorSubBatch); +struct ActBatch(TensorBatch); impl From for ActBatch { fn from(act: Act) -> Self { let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } @@ -213,7 +213,7 @@ fn play(matches: ArgMatches) -> Result<()> { .out_dim(n_actions as _) .device(device); let mut agent = Iqn::build(agent_config); - agent.load(model_dir + "/best")?; + agent.load_params(model_dir + "/best")?; agent.eval(); agent }; diff --git a/border/examples/make_cfg_dqn_atari.rs b/border/examples/atari/make_cfg_dqn_atari.rs similarity index 97% rename from border/examples/make_cfg_dqn_atari.rs rename to border/examples/atari/make_cfg_dqn_atari.rs index e8a992b4..766b6d41 100644 --- a/border/examples/make_cfg_dqn_atari.rs +++ b/border/examples/atari/make_cfg_dqn_atari.rs @@ -6,7 +6,7 @@ use border_core::{ TrainerConfig, }; use border_tch_agent::{ - cnn::{CnnConfig, Cnn}, + cnn::{Cnn, CnnConfig}, dqn::{DqnConfig, DqnModelConfig}, //, EpsilonGreedy, DQNExplorer}, opt::OptimizerConfig, }; @@ -65,7 +65,6 @@ fn make_trainer_config(env_name: String, params: &Params) -> Result Result IqnConfig { let n_stack = 4; let out_dim = 0; // Set before training/evaluation - // let lr = if params.per { - // params.lr / 4.0 - // } else { - // params.lr - // }; - // let clip_td_err = if params.per { Some((-1.0, 1.0)) } else { None }; + // let lr = if params.per { + // params.lr / 4.0 + // } else { + // params.lr + // }; + // let clip_td_err = if params.per { Some((-1.0, 1.0)) } else { None }; let feature_dim = params.feature_dim; let hidden_dim = params.hidden_dim; - let f_config = CnnConfig::new(n_stack, feature_dim) - .skip_linear(true); - let m_config = MlpConfig::new(feature_dim, vec![hidden_dim], out_dim); + let f_config = CnnConfig::new(n_stack, feature_dim).skip_linear(true); + let m_config = MlpConfig::new(feature_dim, vec![hidden_dim], out_dim, false); let model_config = IqnModelConfig::default() .feature_dim(feature_dim) .embed_dim(params.embed_dim) @@ -69,7 +68,6 @@ fn make_trainer_config(env_name: String, params: &Params) -> Result { - // Agent parameters - pub replay_buffer_capacity: usize, - pub per: bool, - pub double_dqn: bool, - pub optimizer: &'a str, - pub batch_size: usize, - pub discount_factor: f64, - pub min_transition_warmup: usize, - pub soft_update_interval: usize, - pub lr: f64, - pub clip_reward: Option, - pub explorer: DqnExplorer, - pub tau: f64, - - // Trainer parameters - pub max_opts: usize, - pub eval_interval: usize, - pub eval_episodes: usize, - pub opt_interval: usize, - pub record_interval: usize, - pub save_interval: usize, - - // Debug parameters - pub debug: bool, +mod trainer_config { + use border_core::TrainerConfig; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize, Serialize)] + pub struct DqnAtariTrainerConfig { + pub model_dir: String, + + #[serde( + default = "default_max_opts", + skip_serializing_if = "is_default_max_opts" + )] + pub max_opts: usize, + + #[serde( + default = "default_opt_interval", + skip_serializing_if = "is_default_opt_interval" + )] + pub opt_interval: usize, + + #[serde( + default = "default_eval_interval", + skip_serializing_if = "is_default_eval_interval" + )] + pub eval_interval: usize, + + #[serde( + default = "default_flush_record_interval", + skip_serializing_if = "is_default_flush_record_interval" + )] + pub flush_record_interval: usize, + + #[serde( + default = "default_record_agent_info_interval", + skip_serializing_if = "is_default_record_agent_info_interval" + )] + pub record_agent_info_interval: usize, + + #[serde( + default = "default_record_compute_cost_interval", + skip_serializing_if = "is_default_record_compute_cost_interval" + )] + pub record_compute_cost_interval: usize, + + #[serde( + default = "default_warmup_period", + skip_serializing_if = "is_default_warmup_period" + )] + pub warmup_period: usize, + + #[serde( + default = "default_save_interval", + skip_serializing_if = "is_default_save_interval" + )] + pub save_interval: usize, + } + + impl Default for DqnAtariTrainerConfig { + fn default() -> Self { + Self { + model_dir: "".to_string(), + max_opts: 3000000, + opt_interval: 1, + eval_interval: 5000, + record_agent_info_interval: 5000, + record_compute_cost_interval: 5000, + flush_record_interval: 5000, + warmup_period: 2500, + save_interval: 500000, + // // For debug + // model_dir: "".to_string(), + // max_opts: 3000000, + // opt_interval: 1, + // eval_interval: 10, + // record_agent_info_interval: 10, + // record_compute_cost_interval: 10, + // flush_record_interval: 10, + // warmup_period: 32, + // save_interval: 100, + } + } + } + + fn default_max_opts() -> usize { + DqnAtariTrainerConfig::default().max_opts + } + + fn default_opt_interval() -> usize { + DqnAtariTrainerConfig::default().opt_interval + } + + fn default_eval_interval() -> usize { + DqnAtariTrainerConfig::default().eval_interval + } + + fn default_flush_record_interval() -> usize { + DqnAtariTrainerConfig::default().flush_record_interval + } + + fn default_record_agent_info_interval() -> usize { + DqnAtariTrainerConfig::default().record_agent_info_interval + } + + fn default_record_compute_cost_interval() -> usize { + DqnAtariTrainerConfig::default().record_compute_cost_interval + } + + fn default_warmup_period() -> usize { + DqnAtariTrainerConfig::default().warmup_period + } + + fn default_save_interval() -> usize { + DqnAtariTrainerConfig::default().save_interval + } + + fn is_default_max_opts(v: &usize) -> bool { + *v == default_max_opts() + } + + fn is_default_opt_interval(v: &usize) -> bool { + *v == default_opt_interval() + } + + fn is_default_eval_interval(v: &usize) -> bool { + *v == default_eval_interval() + } + + fn is_default_flush_record_interval(v: &usize) -> bool { + *v == default_flush_record_interval() + } + + fn is_default_record_agent_info_interval(v: &usize) -> bool { + *v == default_record_agent_info_interval() + } + + fn is_default_record_compute_cost_interval(v: &usize) -> bool { + *v == default_record_compute_cost_interval() + } + + fn is_default_warmup_period(v: &usize) -> bool { + *v == default_warmup_period() + } + + fn is_default_save_interval(v: &usize) -> bool { + *v == default_save_interval() + } + + impl Into for DqnAtariTrainerConfig { + fn into(self) -> TrainerConfig { + TrainerConfig { + model_dir: Some(self.model_dir), + max_opts: self.max_opts, + opt_interval: self.opt_interval, + eval_interval: self.eval_interval, + flush_record_interval: self.flush_record_interval, + record_agent_info_interval: self.record_agent_info_interval, + record_compute_cost_interval: self.record_compute_cost_interval, + warmup_period: self.warmup_period, + save_interval: self.save_interval, + } + } + } } -impl<'a> Default for Params<'a> { - fn default() -> Self { - Self { - // Agent parameters - replay_buffer_capacity: 50_000, - per: false, - double_dqn: false, - optimizer: "adam", - batch_size: 32, - discount_factor: 0.99, - min_transition_warmup: 2500, - soft_update_interval: 10_000, - lr: 1e-4, - clip_reward: Some(1.0), - explorer: EpsilonGreedy::with_final_step(1_000_000), - tau: 1.0, - - // Trainer parameters - max_opts: 3_000_000, - eval_interval: 50_000, - eval_episodes: 1, - opt_interval: 1, - record_interval: 50_000, - save_interval: 500_000, - - // Debug parameters - debug: false, +#[cfg(feature = "border-async-trainer")] +mod async_trainer_config { + use border_async_trainer::AsyncTrainerConfig; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize, Serialize)] + pub struct DqnAtariAsyncTrainerConfig { + pub model_dir: Option, + + #[serde( + default = "default_max_opts", + skip_serializing_if = "is_default_max_opts" + )] + pub max_opts: usize, + + #[serde( + default = "default_eval_interval", + skip_serializing_if = "is_default_eval_interval" + )] + pub eval_interval: usize, + + #[serde( + default = "default_flush_record_interval", + skip_serializing_if = "is_default_flush_record_interval" + )] + pub flush_record_interval: usize, + + #[serde( + default = "default_record_compute_cost_interval", + skip_serializing_if = "is_default_record_compute_cost_interval" + )] + pub record_compute_cost_interval: usize, + + #[serde( + default = "default_save_interval", + skip_serializing_if = "is_default_save_interval" + )] + pub save_interval: usize, + + #[serde( + default = "default_sync_interval", + skip_serializing_if = "is_default_sync_interval" + )] + pub sync_interval: usize, + + #[serde( + default = "default_warmup_period", + skip_serializing_if = "is_default_warmup_period" + )] + pub warmup_period: usize, + } + + impl Default for DqnAtariAsyncTrainerConfig { + fn default() -> Self { + Self { + model_dir: None, + max_opts: 3000000, + eval_interval: 5000, + flush_record_interval: 5000, + record_compute_cost_interval: 5000, + sync_interval: 1, + save_interval: 500000, + warmup_period: 10000, + } + } + } + + fn default_max_opts() -> usize { + DqnAtariAsyncTrainerConfig::default().max_opts + } + + fn default_eval_interval() -> usize { + DqnAtariAsyncTrainerConfig::default().eval_interval + } + + fn default_flush_record_interval() -> usize { + DqnAtariAsyncTrainerConfig::default().flush_record_interval + } + + fn default_record_compute_cost_interval() -> usize { + DqnAtariAsyncTrainerConfig::default().record_compute_cost_interval + } + + fn default_sync_interval() -> usize { + DqnAtariAsyncTrainerConfig::default().sync_interval + } + + fn default_save_interval() -> usize { + DqnAtariAsyncTrainerConfig::default().save_interval + } + + fn default_warmup_period() -> usize { + DqnAtariAsyncTrainerConfig::default().warmup_period + } + + fn is_default_max_opts(v: &usize) -> bool { + *v == default_max_opts() + } + + fn is_default_eval_interval(v: &usize) -> bool { + *v == default_eval_interval() + } + + fn is_default_flush_record_interval(v: &usize) -> bool { + *v == default_flush_record_interval() + } + + fn is_default_record_compute_cost_interval(v: &usize) -> bool { + *v == default_record_compute_cost_interval() + } + + fn is_default_sync_interval(v: &usize) -> bool { + *v == default_sync_interval() + } + + fn is_default_save_interval(v: &usize) -> bool { + *v == default_save_interval() + } + + fn is_default_warmup_period(v: &usize) -> bool { + *v == default_warmup_period() + } + + impl Into for DqnAtariAsyncTrainerConfig { + fn into(self) -> AsyncTrainerConfig { + AsyncTrainerConfig { + model_dir: self.model_dir, + max_opts: self.max_opts, + eval_interval: self.eval_interval, + flush_record_interval: self.flush_record_interval, + record_compute_cost_interval: self.record_compute_cost_interval, + save_interval: self.save_interval, + sync_interval: self.sync_interval, + warmup_period: self.warmup_period, + } } } } -impl<'a> Params<'a> { - #[allow(dead_code)] - pub fn per(mut self) -> Self { - self.per = true; - self +mod replay_buffer_config { + use border_core::generic_replay_buffer::{PerConfig, SimpleReplayBufferConfig}; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize, Serialize)] + pub struct DqnAtariReplayBufferConfig { + #[serde( + default = "default_capacity", + skip_serializing_if = "is_default_capacity" + )] + pub capacity: usize, + + #[serde(default = "default_seed", skip_serializing_if = "is_default_seed")] + pub seed: u64, + + /// Currently, fixed to None + #[serde( + default = "default_per_config", + skip_serializing_if = "is_default_per_config" + )] + pub per_config: Option, + } + + impl Default for DqnAtariReplayBufferConfig { + fn default() -> Self { + Self { + capacity: 262144, + seed: 42, + per_config: None, + } + } } - #[allow(dead_code)] - pub fn ddqn(mut self) -> Self { - self.double_dqn = true; - self + fn default_capacity() -> usize { + DqnAtariReplayBufferConfig::default().capacity } - #[allow(dead_code)] - pub fn debug(mut self) -> Self { - self.debug = true; - self + fn default_seed() -> u64 { + DqnAtariReplayBufferConfig::default().seed } - #[allow(dead_code)] - pub fn replay_buffer_capacity(mut self, replay_buffer_capacity: usize) -> Self { - self.replay_buffer_capacity = replay_buffer_capacity; - self + fn default_per_config() -> Option { + DqnAtariReplayBufferConfig::default().per_config } - #[allow(dead_code)] - pub fn max_opts(mut self, max_opts: usize) -> Self { - self.max_opts = max_opts; - self + fn is_default_capacity(v: &usize) -> bool { + *v == default_capacity() } - #[allow(dead_code)] - pub fn save_interval(mut self, save_interval: usize) -> Self { - self.save_interval = save_interval; - self + fn is_default_seed(v: &u64) -> bool { + *v == default_seed() } - #[allow(dead_code)] - pub fn eval_interval(mut self, eval_interval: usize) -> Self { - self.eval_interval = eval_interval; - self + fn is_default_per_config(v: &Option) -> bool { + *v == default_per_config() } - #[allow(dead_code)] - pub fn optimizer(mut self, optimizer: &'a str) -> Self { - self.optimizer = optimizer; - self + impl Into for DqnAtariReplayBufferConfig { + fn into(self) -> SimpleReplayBufferConfig { + SimpleReplayBufferConfig { + capacity: self.capacity, + seed: self.seed, + per_config: self.per_config, + } + } } } -#[allow(dead_code)] -pub fn model_dir(env_name: String, params: &Params) -> Result { - let per = params.per; - let ddqn = params.double_dqn; - let debug = params.debug; +#[cfg(feature = "tch")] +mod tch_dqn_config { + use std::marker::PhantomData; + + use border_tch_agent::{ + cnn::{Cnn, CnnConfig}, + dqn::{DqnConfig, DqnExplorer, DqnModelConfig, EpsilonGreedy}, + opt::OptimizerConfig, + util::CriticLoss, + Device, + }; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize, Serialize)] + pub struct DqnAtariAgentConfig { + #[serde( + default = "default_model_config", + skip_serializing_if = "is_default_model_config" + )] + pub model_config: DqnModelConfig, + + #[serde( + default = "default_soft_update_interval", + skip_serializing_if = "is_default_soft_update_interval" + )] + pub soft_update_interval: usize, + + #[serde( + default = "default_n_updates_per_opt", + skip_serializing_if = "is_default_n_updates_per_opt" + )] + pub n_updates_per_opt: usize, + + #[serde( + default = "default_batch_size", + skip_serializing_if = "is_default_batch_size" + )] + pub batch_size: usize, + + #[serde( + default = "default_discount_factor", + skip_serializing_if = "is_default_discount_factor" + )] + pub discount_factor: f64, + + #[serde(default = "default_tau", skip_serializing_if = "is_default_tau")] + pub tau: f64, + + #[serde(default = "default_train", skip_serializing_if = "is_default_train")] + pub train: bool, + + #[serde( + default = "default_explorer", + skip_serializing_if = "is_default_explorer" + )] + pub explorer: DqnExplorer, + + #[serde( + default = "default_clip_reward", + skip_serializing_if = "is_default_clip_reward" + )] + pub clip_reward: Option, + + #[serde( + default = "default_double_dqn", + skip_serializing_if = "is_default_double_dqn" + )] + pub double_dqn: bool, + + #[serde( + default = "default_clip_td_err", + skip_serializing_if = "is_default_clip_td_err" + )] + pub clip_td_err: Option<(f64, f64)>, + + #[serde( + default = "default_critic_loss", + skip_serializing_if = "is_default_critic_loss" + )] + pub critic_loss: CriticLoss, + + #[serde( + default = "default_record_verbose_level", + skip_serializing_if = "is_default_record_verbose_level" + )] + pub record_verbose_level: usize, + + #[serde(default = "default_device", skip_serializing_if = "is_default_device")] + pub device: Option, + // phantom: PhantomData, + } + + impl Default for DqnAtariAgentConfig { + fn default() -> Self { + DqnAtariAgentConfig { + model_config: DqnModelConfig { + q_config: Some(CnnConfig { + n_stack: 4, + out_dim: 0, + skip_linear: false, + }), + opt_config: OptimizerConfig::Adam { lr: 0.0001 }, + }, + soft_update_interval: 10000, + n_updates_per_opt: 1, + batch_size: 32, + discount_factor: 0.99, + tau: 1.0, + train: false, + explorer: DqnExplorer::EpsilonGreedy(EpsilonGreedy { + n_opts: 0, + eps_start: 1.0, + eps_final: 0.02, + final_step: 1000000, + }), + clip_reward: Some(1.0), + double_dqn: false, + clip_td_err: None, + critic_loss: CriticLoss::Mse, + record_verbose_level: 0, + device: None, + // phantom: PhantomData, + } + } + } + + fn default_model_config() -> DqnModelConfig { + DqnAtariAgentConfig::default().model_config + } + + fn default_soft_update_interval() -> usize { + DqnAtariAgentConfig::default().soft_update_interval + } + + fn default_n_updates_per_opt() -> usize { + DqnAtariAgentConfig::default().n_updates_per_opt + } + + fn default_batch_size() -> usize { + DqnAtariAgentConfig::default().batch_size + } + + fn default_discount_factor() -> f64 { + DqnAtariAgentConfig::default().discount_factor + } + + fn default_tau() -> f64 { + DqnAtariAgentConfig::default().tau + } + + fn default_train() -> bool { + DqnAtariAgentConfig::default().train + } + + fn default_explorer() -> DqnExplorer { + DqnAtariAgentConfig::default().explorer + } + + fn default_clip_reward() -> Option { + DqnAtariAgentConfig::default().clip_reward + } - let mut model_dir = format!("./border/examples/atari/model/dqn_{}", env_name); - if ddqn { - model_dir.push_str("_ddqn"); + fn default_double_dqn() -> bool { + DqnAtariAgentConfig::default().double_dqn } - if per { - model_dir.push_str("_per"); + fn default_clip_td_err() -> Option<(f64, f64)> { + DqnAtariAgentConfig::default().clip_td_err } - if debug { - model_dir.push_str("_debug"); + fn default_critic_loss() -> CriticLoss { + DqnAtariAgentConfig::default().critic_loss } - if !Path::new(&model_dir).exists() { - std::fs::create_dir(Path::new(&model_dir))?; + fn default_record_verbose_level() -> usize { + DqnAtariAgentConfig::default().record_verbose_level } - Ok(model_dir) + fn default_device() -> Option { + DqnAtariAgentConfig::default().device + } + + fn is_default_model_config(config: &DqnModelConfig) -> bool { + config == &default_model_config() + } + + fn is_default_soft_update_interval(soft_update_interval: &usize) -> bool { + soft_update_interval == &default_soft_update_interval() + } + + fn is_default_n_updates_per_opt(n_updates_per_opt: &usize) -> bool { + n_updates_per_opt == &default_n_updates_per_opt() + } + + fn is_default_batch_size(batch_size: &usize) -> bool { + batch_size == &default_batch_size() + } + + fn is_default_discount_factor(discount_factor: &f64) -> bool { + discount_factor == &default_discount_factor() + } + + fn is_default_tau(tau: &f64) -> bool { + tau == &default_tau() + } + + fn is_default_train(train: &bool) -> bool { + train == &default_train() + } + + fn is_default_explorer(explorer: &DqnExplorer) -> bool { + explorer == &default_explorer() + } + + fn is_default_clip_reward(clip_reward: &Option) -> bool { + clip_reward == &default_clip_reward() + } + + fn is_default_double_dqn(double_dqn: &bool) -> bool { + double_dqn == &default_double_dqn() + } + + fn is_default_clip_td_err(clip_td_err: &Option<(f64, f64)>) -> bool { + clip_td_err == &default_clip_td_err() + } + + fn is_default_critic_loss(critic_loss: &CriticLoss) -> bool { + critic_loss == &default_critic_loss() + } + + fn is_default_record_verbose_level(record_verbose_level: &usize) -> bool { + record_verbose_level == &default_record_verbose_level() + } + + fn is_default_device(device: &Option) -> bool { + device == &default_device() + } + + impl Into> for DqnAtariAgentConfig { + fn into(self) -> DqnConfig { + DqnConfig { + model_config: self.model_config, + soft_update_interval: self.soft_update_interval, + n_updates_per_opt: self.n_updates_per_opt, + batch_size: self.batch_size, + discount_factor: self.discount_factor, + tau: self.tau, + train: self.train, + explorer: self.explorer, + clip_reward: self.clip_reward, + double_dqn: self.double_dqn, + clip_td_err: self.clip_td_err, + device: self.device, + critic_loss: self.critic_loss, + record_verbose_level: self.record_verbose_level, + phantom: PhantomData, + } + } + } } -#[allow(dead_code)] -pub fn model_dir_async(env_name: String, params: &Params) -> Result { - let per = params.per; - let ddqn = params.double_dqn; - let debug = params.debug; +#[cfg(feature = "candle-core")] +mod candle_dqn_config { + use std::marker::PhantomData; + + use border_candle_agent::{ + cnn::{Cnn, CnnConfig}, + dqn::{DqnConfig, DqnExplorer, DqnModelConfig, EpsilonGreedy}, + opt::OptimizerConfig, + util::CriticLoss, + Device, + }; + use serde::{Deserialize, Serialize}; + + #[derive(Deserialize, Serialize)] + pub struct DqnAtariAgentConfig { + #[serde( + default = "default_model_config", + skip_serializing_if = "is_default_model_config" + )] + pub model_config: DqnModelConfig, + + #[serde( + default = "default_soft_update_interval", + skip_serializing_if = "is_default_soft_update_interval" + )] + pub soft_update_interval: usize, + + #[serde( + default = "default_n_updates_per_opt", + skip_serializing_if = "is_default_n_updates_per_opt" + )] + pub n_updates_per_opt: usize, + + #[serde( + default = "default_batch_size", + skip_serializing_if = "is_default_batch_size" + )] + pub batch_size: usize, + + #[serde( + default = "default_discount_factor", + skip_serializing_if = "is_default_discount_factor" + )] + pub discount_factor: f64, + + #[serde(default = "default_tau", skip_serializing_if = "is_default_tau")] + pub tau: f64, + + #[serde(default = "default_train", skip_serializing_if = "is_default_train")] + pub train: bool, + + #[serde( + default = "default_explorer", + skip_serializing_if = "is_default_explorer" + )] + pub explorer: DqnExplorer, + + #[serde( + default = "default_clip_reward", + skip_serializing_if = "is_default_clip_reward" + )] + pub clip_reward: Option, + + #[serde( + default = "default_double_dqn", + skip_serializing_if = "is_default_double_dqn" + )] + pub double_dqn: bool, + + #[serde( + default = "default_clip_td_err", + skip_serializing_if = "is_default_clip_td_err" + )] + pub clip_td_err: Option<(f64, f64)>, - let mut model_dir = format!("./border/examples/atari/model/dqn_{}", env_name); - if ddqn { - model_dir.push_str("_ddqn"); + #[serde( + default = "default_critic_loss", + skip_serializing_if = "is_default_critic_loss" + )] + pub critic_loss: CriticLoss, + + #[serde( + default = "default_record_verbose_level", + skip_serializing_if = "is_default_record_verbose_level" + )] + pub record_verbose_level: usize, + + #[serde(default = "default_device", skip_serializing_if = "is_default_device")] + pub device: Option, + // phantom: PhantomData, + } + + impl Default for DqnAtariAgentConfig { + fn default() -> Self { + DqnAtariAgentConfig { + model_config: DqnModelConfig { + q_config: Some(CnnConfig { + n_stack: 4, + out_dim: 0, + skip_linear: false, + }), + opt_config: OptimizerConfig::Adam { lr: 0.0001 }, + }, + soft_update_interval: 10000, + n_updates_per_opt: 1, + batch_size: 32, + discount_factor: 0.99, + tau: 1.0, + train: false, + explorer: DqnExplorer::EpsilonGreedy(EpsilonGreedy { + n_opts: 0, + eps_start: 1.0, + eps_final: 0.02, + final_step: 1000000, + }), + clip_reward: Some(1.0), + double_dqn: false, + clip_td_err: None, + critic_loss: CriticLoss::Mse, + record_verbose_level: 0, + device: None, + // phantom: PhantomData, + } + } + } + + fn default_model_config() -> DqnModelConfig { + DqnAtariAgentConfig::default().model_config + } + + fn default_soft_update_interval() -> usize { + DqnAtariAgentConfig::default().soft_update_interval + } + + fn default_n_updates_per_opt() -> usize { + DqnAtariAgentConfig::default().n_updates_per_opt + } + + fn default_batch_size() -> usize { + DqnAtariAgentConfig::default().batch_size + } + + fn default_discount_factor() -> f64 { + DqnAtariAgentConfig::default().discount_factor } - if per { - model_dir.push_str("_per"); + fn default_tau() -> f64 { + DqnAtariAgentConfig::default().tau } - if debug { - model_dir.push_str("_debug"); + fn default_train() -> bool { + DqnAtariAgentConfig::default().train } - model_dir.push_str("_async"); + fn default_explorer() -> DqnExplorer { + DqnAtariAgentConfig::default().explorer + } + + fn default_clip_reward() -> Option { + DqnAtariAgentConfig::default().clip_reward + } + + fn default_double_dqn() -> bool { + DqnAtariAgentConfig::default().double_dqn + } + + fn default_clip_td_err() -> Option<(f64, f64)> { + DqnAtariAgentConfig::default().clip_td_err + } + + fn default_critic_loss() -> CriticLoss { + DqnAtariAgentConfig::default().critic_loss + } + + fn default_record_verbose_level() -> usize { + DqnAtariAgentConfig::default().record_verbose_level + } + + fn default_device() -> Option { + DqnAtariAgentConfig::default().device + } + + fn is_default_model_config(config: &DqnModelConfig) -> bool { + config == &default_model_config() + } - if !Path::new(&model_dir).exists() { - std::fs::create_dir(Path::new(&model_dir))?; + fn is_default_soft_update_interval(soft_update_interval: &usize) -> bool { + soft_update_interval == &default_soft_update_interval() } - Ok(model_dir) + fn is_default_n_updates_per_opt(n_updates_per_opt: &usize) -> bool { + n_updates_per_opt == &default_n_updates_per_opt() + } + + fn is_default_batch_size(batch_size: &usize) -> bool { + batch_size == &default_batch_size() + } + + fn is_default_discount_factor(discount_factor: &f64) -> bool { + discount_factor == &default_discount_factor() + } + + fn is_default_tau(tau: &f64) -> bool { + tau == &default_tau() + } + + fn is_default_train(train: &bool) -> bool { + train == &default_train() + } + + fn is_default_explorer(explorer: &DqnExplorer) -> bool { + explorer == &default_explorer() + } + + fn is_default_clip_reward(clip_reward: &Option) -> bool { + clip_reward == &default_clip_reward() + } + + fn is_default_double_dqn(double_dqn: &bool) -> bool { + double_dqn == &default_double_dqn() + } + + fn is_default_clip_td_err(clip_td_err: &Option<(f64, f64)>) -> bool { + clip_td_err == &default_clip_td_err() + } + + fn is_default_critic_loss(critic_loss: &CriticLoss) -> bool { + critic_loss == &default_critic_loss() + } + + fn is_default_record_verbose_level(record_verbose_level: &usize) -> bool { + record_verbose_level == &default_record_verbose_level() + } + + fn is_default_device(device: &Option) -> bool { + device == &default_device() + } + + impl Into> for DqnAtariAgentConfig { + fn into(self) -> DqnConfig { + DqnConfig { + model_config: self.model_config, + soft_update_interval: self.soft_update_interval, + n_updates_per_opt: self.n_updates_per_opt, + batch_size: self.batch_size, + discount_factor: self.discount_factor, + tau: self.tau, + train: self.train, + explorer: self.explorer, + clip_reward: self.clip_reward, + double_dqn: self.double_dqn, + clip_td_err: self.clip_td_err, + device: self.device, + critic_loss: self.critic_loss, + record_verbose_level: self.record_verbose_level, + phantom: PhantomData, + } + } + } } + +#[allow(unused_imports)] +pub use replay_buffer_config::DqnAtariReplayBufferConfig; +#[allow(unused_imports)] +pub use trainer_config::DqnAtariTrainerConfig; + +#[cfg(feature = "border-async-trainer")] +pub use async_trainer_config::DqnAtariAsyncTrainerConfig; + +#[allow(unused_imports)] +#[cfg(feature = "candle-core")] +pub use candle_dqn_config::DqnAtariAgentConfig; +#[allow(unused_imports)] +#[cfg(feature = "tch")] +pub use tch_dqn_config::DqnAtariAgentConfig; diff --git a/border/examples/util_iqn_atari.rs b/border/examples/atari/util_iqn_atari.rs similarity index 94% rename from border/examples/util_iqn_atari.rs rename to border/examples/atari/util_iqn_atari.rs index 73774131..61ad2e6f 100644 --- a/border/examples/util_iqn_atari.rs +++ b/border/examples/atari/util_iqn_atari.rs @@ -1,5 +1,8 @@ use anyhow::Result; -use border_tch_agent::{iqn::{IqnExplorer, EpsilonGreedy}, opt::OptimizerConfig}; +use border_tch_agent::{ + iqn::{EpsilonGreedy, IqnExplorer}, + opt::OptimizerConfig, +}; use std::{default::Default, path::Path}; #[derive(Clone)] @@ -43,7 +46,7 @@ impl Default for Params { lr: 1e-5, // eps: 0.01 / 32.0 }, - + // Agent parameters replay_buffer_capacity: 1_048_576, per: false, @@ -52,11 +55,7 @@ impl Default for Params { min_transition_warmup: 2500, soft_update_interval: 10_000, tau: 1.0, - explorer: EpsilonGreedy::with_params( - 1.0, - 0.02, - 1_000_000, - ), + explorer: EpsilonGreedy::with_params(1.0, 0.02, 1_000_000), // Trainer parameters max_opts: 50_000_000, diff --git a/border/examples/dqn_atari_model.rs b/border/examples/backup/dqn_atari_model.rs similarity index 100% rename from border/examples/dqn_atari_model.rs rename to border/examples/backup/dqn_atari_model.rs diff --git a/border/examples/backup/random_atari.rs b/border/examples/backup/random_atari.rs index 99f37c4e..0b36ca3f 100644 --- a/border/examples/backup/random_atari.rs +++ b/border/examples/backup/random_atari.rs @@ -5,7 +5,7 @@ use border_py_gym_env::{ FrameStackFilter, PyGymEnv, PyGymEnvActFilter, PyGymEnvConfig, PyGymEnvDiscreteAct, PyGymEnvDiscreteActRawFilter, PyGymEnvObs, }; -use border_tch_agent::TensorSubBatch; +use border_tch_agent::TensorBatch; // use clap::{App, Arg}; use std::convert::TryFrom; // use tch::Tensor; @@ -16,13 +16,12 @@ type PyObsDtype = u8; struct Obs(PyGymEnvObs); #[derive(Clone, SubBatch)] -// struct ObsBatch(TensorSubBatch); -struct ObsBatch(TensorSubBatch); +struct ObsBatch(TensorBatch); impl From for ObsBatch { fn from(obs: Obs) -> Self { let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } @@ -30,13 +29,12 @@ impl From for ObsBatch { struct Act(PyGymEnvDiscreteAct); #[derive(SubBatch)] -// struct ActBatch(TensorSubBatch); -struct ActBatch(TensorSubBatch); +struct ActBatch(TensorBatch); impl From for ActBatch { fn from(act: Act) -> Self { let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } @@ -44,7 +42,7 @@ type ObsFilter = FrameStackFilter; type ActFilter = PyGymEnvDiscreteActRawFilter; type Env = PyGymEnv; -#[derive(Clone)] +#[derive(Clone, Deserialize)] struct RandomPolicyConfig { pub n_acts: usize, } diff --git a/border/examples/dqn_cartpole.rs b/border/examples/dqn_cartpole.rs deleted file mode 100644 index c8bb853a..00000000 --- a/border/examples/dqn_cartpole.rs +++ /dev/null @@ -1,334 +0,0 @@ -use anyhow::Result; -use border_core::{ - record::Record, - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, SubBatch, - }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, -}; -use border_py_gym_env::{ - util::vec_to_tensor, ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, - GymObsFilter, -}; -use border_tch_agent::{ - dqn::{Dqn, DqnConfig, DqnModelConfig}, - mlp::{Mlp, MlpConfig}, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; -// use csv::WriterBuilder; -use ndarray::{ArrayD, IxDyn}; -use serde::Serialize; -use std::convert::TryFrom; //, fs::File}; -use tch::Tensor; - -const DIM_OBS: i64 = 4; -const DIM_ACT: i64 = 2; -const LR_CRITIC: f64 = 0.001; -const DISCOUNT_FACTOR: f64 = 0.99; -const BATCH_SIZE: usize = 64; -const N_TRANSITIONS_WARMUP: usize = 100; -const N_UPDATES_PER_OPT: usize = 1; -const TAU: f64 = 0.005; -const OPT_INTERVAL: usize = 50; -const MAX_OPTS: usize = 1000; -const EVAL_INTERVAL: usize = 50; -const REPLAY_BUFFER_CAPACITY: usize = 10000; -const N_EPISODES_PER_EVAL: usize = 5; -const MODEL_DIR: &str = "./border/examples/model/dqn_cartpole"; - -type PyObsDtype = f32; - -mod obs { - use super::*; - - #[derive(Clone, Debug)] - pub struct Obs(ArrayD); - - impl border_core::Obs for Obs { - fn dummy(_n: usize) -> Self { - Self(ArrayD::zeros(IxDyn(&[0]))) - } - - fn len(&self) -> usize { - self.0.shape()[0] - } - } - - impl From> for Obs { - fn from(obs: ArrayD) -> Self { - Obs(obs) - } - } - - impl From for Tensor { - fn from(obs: Obs) -> Tensor { - Tensor::try_from(&obs.0).unwrap() - } - } - - pub struct ObsBatch(TensorSubBatch); - - impl SubBatch for ObsBatch { - fn new(capacity: usize) -> Self { - Self(TensorSubBatch::new(capacity)) - } - - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) - } - - fn sample(&self, ixs: &Vec) -> Self { - let buf = self.0.sample(ixs); - Self(buf) - } - } - - impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } - - impl From for Tensor { - fn from(b: ObsBatch) -> Self { - b.0.into() - } - } -} - -mod act { - use super::*; - - #[derive(Clone, Debug)] - pub struct Act(Vec); - - impl border_core::Act for Act {} - - impl From for Vec { - fn from(value: Act) -> Self { - value.0 - } - } - - impl From for Act { - // `t` must be a 1-dimentional tensor of `f32` - fn from(t: Tensor) -> Self { - let data: Vec = t.into(); - let data = data.iter().map(|&e| e as i32).collect(); - Act(data) - } - } - - pub struct ActBatch(TensorSubBatch); - - impl SubBatch for ActBatch { - fn new(capacity: usize) -> Self { - Self(TensorSubBatch::new(capacity)) - } - - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) - } - - fn sample(&self, ixs: &Vec) -> Self { - let buf = self.0.sample(ixs); - Self(buf) - } - } - - impl From for ActBatch { - fn from(act: Act) -> Self { - let t = vec_to_tensor::<_, i64>(act.0, true); - Self(TensorSubBatch::from_tensor(t)) - } - } - - // Required by Dqn - impl From for Tensor { - fn from(act: ActBatch) -> Self { - act.0.into() - } - } -} - -use act::{Act, ActBatch}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayObsFilter; -type ActFilter = DiscreteActFilter; -type EnvConfig = GymEnvConfig; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Evaluator = DefaultEvaluator>; - -#[derive(Debug, Serialize)] -struct CartpoleRecord { - episode: usize, - step: usize, - reward: f32, - obs: Vec, -} - -impl TryFrom<&Record> for CartpoleRecord { - type Error = anyhow::Error; - - fn try_from(record: &Record) -> Result { - Ok(Self { - episode: record.get_scalar("episode")? as _, - step: record.get_scalar("step")? as _, - reward: record.get_scalar("reward")?, - obs: record - .get_array1("obs")? - .iter() - .map(|v| *v as f64) - .collect(), - }) - } -} - -fn create_agent(in_dim: i64, out_dim: i64) -> Dqn { - let device = tch::Device::cuda_if_available(); - let config = { - let opt_config = border_tch_agent::opt::OptimizerConfig::Adam { lr: LR_CRITIC }; - let mlp_config = MlpConfig::new(in_dim, vec![256, 256], out_dim, true); - let model_config = DqnModelConfig::default() - .q_config(mlp_config) - .out_dim(out_dim) - .opt_config(opt_config); - DqnConfig::default() - .n_updates_per_opt(N_UPDATES_PER_OPT) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .batch_size(BATCH_SIZE) - .discount_factor(DISCOUNT_FACTOR) - .tau(TAU) - .model_config(model_config) - .device(device) - }; - - Dqn::build(config) -} - -fn env_config() -> EnvConfig { - EnvConfig::default() - .name("CartPole-v0".to_string()) - .obs_filter_config(ObsFilter::default_config()) - .act_filter_config(ActFilter::default_config()) -} - -fn create_evaluator(env_config: &EnvConfig) -> Result { - Evaluator::new(env_config, 0, N_EPISODES_PER_EVAL) -} - -fn train(max_opts: usize, model_dir: &str) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() - .max_opts(max_opts) - .opt_interval(OPT_INTERVAL) - .eval_interval(EVAL_INTERVAL) - .record_interval(EVAL_INTERVAL) - .save_interval(EVAL_INTERVAL) - .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = create_evaluator(&env_config())?; - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; - - Ok(()) -} - -fn eval(model_dir: &str, render: bool) -> Result<()> { - let env_config = { - let mut env_config = env_config(); - if render { - env_config = env_config - .render_mode(Some("human".to_string())) - .set_wait_in_millis(10); - } - env_config - }; - let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; - agent.eval(); - agent - }; - // let mut recorder = BufferedRecorder::new(); - - let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); - - Ok(()) -} - -fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_cartpole") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .get_matches(); - - let do_train = (matches.is_present("train") && !matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - let do_eval = (!matches.is_present("train") && matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - - if do_train { - train(MAX_OPTS, MODEL_DIR)?; - } - if do_eval { - eval(&(MODEL_DIR.to_owned() + "/best"), true)?; - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::{eval, train}; - use anyhow::Result; - use tempdir::TempDir; - - #[test] - fn test_dqn_cartpole() -> Result<()> { - let tmp_dir = TempDir::new("dqn_cartpole")?; - let model_dir = match tmp_dir.as_ref().to_str() { - Some(s) => s, - None => panic!("Failed to get string of temporary directory"), - }; - train(100, model_dir)?; - eval(&(model_dir.to_owned() + "/best"), false)?; - Ok(()) - } -} diff --git a/border/examples/gym-robotics/sac_fetch_reach.rs b/border/examples/gym-robotics/sac_fetch_reach.rs index 4284d2a6..d4084f99 100644 --- a/border/examples/gym-robotics/sac_fetch_reach.rs +++ b/border/examples/gym-robotics/sac_fetch_reach.rs @@ -1,13 +1,15 @@ use anyhow::Result; use border_core::{ - record::{Record, RecordValue}, - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, + record::{AggregateRecorder, Record, RecordValue}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; -use border_derive::SubBatch; +use border_derive::BatchBase; +use border_mlflow_tracking::MlflowTrackingClient; use border_py_gym_env::{ util::{arrayd_to_pyobj, arrayd_to_tensor, tensor_to_arrayd, ArrayType}, ArrayDictObsFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, @@ -17,10 +19,10 @@ use border_tch_agent::{ opt::OptimizerConfig, sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, util::CriticLoss, - TensorSubBatch, + TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; +use clap::Parser; // use csv::WriterBuilder; use ndarray::ArrayD; use pyo3::PyObject; @@ -44,16 +46,21 @@ const TAU: f64 = 0.02; const TARGET_ENTROPY: f64 = -(DIM_ACT as f64); const LR_ENT_COEF: f64 = 3e-4; const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR: &str = "./border/examples/gym-robotics/model/tch/sac_fetch_reach/"; -mod obs { +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod obs_act_types { use super::*; use border_py_gym_env::util::Array; #[derive(Clone, Debug)] pub struct Obs(Vec<(String, Array)>); - #[derive(Clone, SubBatch)] - pub struct ObsBatch(TensorSubBatch); + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); impl border_core::Obs for Obs { fn dummy(_n: usize) -> Self { @@ -85,13 +92,9 @@ mod obs { impl From for ObsBatch { fn from(obs: Obs) -> Self { let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } -} - -mod act { - use super::*; #[derive(Clone, Debug)] pub struct Act(ArrayD); @@ -117,13 +120,13 @@ mod act { } } - #[derive(SubBatch)] - pub struct ActBatch(TensorSubBatch); + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); impl From for ActBatch { fn from(act: Act) -> Self { let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } @@ -156,84 +159,136 @@ mod act { (arrayd_to_pyobj(act_filt), record) } } -} -use act::{Act, ActBatch, ActFilter}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayDictObsFilter; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Evaluator = DefaultEvaluator>; - -fn create_agent(in_dim: i64, out_dim: i64) -> Sac { - let device = tch::Device::cuda_if_available(); - let actor_config = ActorConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) - .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, true)); - let critic_config = CriticConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); - let sac_config = SacConfig::default() - .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .actor_config(actor_config) - .critic_config(critic_config) - .tau(TAU) - .critic_loss(CRITIC_LOSS) - .n_critics(N_CRITICS) - .ent_coef_mode(EntCoefMode::Auto(TARGET_ENTROPY, LR_ENT_COEF)) - .device(device); - Sac::build(sac_config) + pub type ObsFilter = ArrayDictObsFilter; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; } -fn env_config() -> GymEnvConfig { - GymEnvConfig::::default() - .name("FetchReach-v2".to_string()) - .obs_filter_config(ObsFilter::default_config().add_key_and_types(vec![ - ("observation", ArrayType::F32Array), - ("desired_goal", ArrayType::F32Array), - ("achieved_goal", ArrayType::F32Array), - ])) - .act_filter_config(ActFilter::default_config()) -} +use obs_act_types::*; -fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() +mod config { + use super::*; + + pub fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("FetchReach-v2".to_string()) + .obs_filter_config(ObsFilter::default_config().add_key_and_types(vec![ + ("observation", ArrayType::F32Array), + ("desired_goal", ArrayType::F32Array), + ("achieved_goal", ArrayType::F32Array), + ])) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_trainer_config(max_opts: usize, model_dir: &str) -> TrainerConfig { + TrainerConfig::default() .max_opts(max_opts) .opt_interval(OPT_INTERVAL) - .eval_interval(eval_interval) - .record_interval(eval_interval) - .save_interval(eval_interval) - .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; + .eval_interval(EVAL_INTERVAL) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(N_TRANSITIONS_WARMUP) + .model_dir(model_dir) + } - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + pub fn create_sac_config(dim_obs: i64, dim_act: i64, target_ent: f64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(dim_act) + .pi_config(MlpConfig::new(dim_obs, vec![64, 64], dim_act, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(dim_obs + dim_act, vec![64, 64], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .tau(TAU) + .critic_loss(CRITIC_LOSS) + .n_critics(N_CRITICS) + .ent_coef_mode(EntCoefMode::Auto(target_ent, LR_ENT_COEF)) + .device(device) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + model_dir: &str, + mlflow: bool, + config: &TrainerConfig, + ) -> Result> { + match mlflow { + true => { + let client = MlflowTrackingClient::new("http://localhost:8080") + .set_experiment_id("Fetch")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "reach")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train/eval SAC agent in fetch environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(max_opts: usize, model_dir: &str, mlflow: bool) -> Result<()> { + let trainer_config = config::create_trainer_config(max_opts, model_dir); + let env_config = config::env_config(); + let step_proc_config = SimpleStepProcessorConfig {}; + let sac_config = config::create_sac_config(DIM_OBS, DIM_ACT, TARGET_ENTROPY); + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut recorder = utils::create_recorder(model_dir, mlflow, &trainer_config)?; + let mut trainer = Trainer::build(trainer_config); + let mut agent = Sac::build(sac_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&config::env_config(), 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; Ok(()) } fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { let env_config = { - let mut env_config = env_config(); + let mut env_config = config::env_config(); if render { env_config = env_config .render_mode(Some("human".to_string())) @@ -242,8 +297,8 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { env_config }; let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; + let mut agent = Sac::build(config::create_sac_config(DIM_OBS, DIM_ACT, TARGET_ENTROPY)); + agent.load_params(model_dir)?; agent.eval(); agent }; @@ -258,40 +313,15 @@ fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = App::new("sac_fetch_reach") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .get_matches(); - - let do_train = matches.is_present("train"); - let do_eval = matches.is_present("eval"); - - if !do_train && !do_eval { - println!("You need to give either --train or --eval in the command line argument."); - return Ok(()); - } + let args = Args::parse(); - if do_train { - train( - MAX_OPTS, - "./border/examples/model/sac_fetch_reach", - EVAL_INTERVAL, - )?; - } - if do_eval { - eval(5, true, "./border/examples/model/sac_fetch_reach/best")?; + if args.train { + train(MAX_OPTS, MODEL_DIR, args.mlflow)?; + } else if args.eval { + eval(5, true, format!("{}/best", MODEL_DIR).as_str())?; + } else { + train(MAX_OPTS, MODEL_DIR, args.mlflow)?; + eval(5, true, format!("{}/best", MODEL_DIR).as_str())?; } Ok(()) @@ -308,7 +338,7 @@ mod test { let model_dir = TempDir::new("sac_fetch_reach")?; let model_dir = model_dir.path().to_str().unwrap(); - train(100, model_dir, 100)?; + train(100, model_dir, false)?; eval(1, false, (model_dir.to_string() + "/best").as_str())?; Ok(()) diff --git a/border/examples/gym/README.md b/border/examples/gym/README.md new file mode 100644 index 00000000..63ee7ac3 --- /dev/null +++ b/border/examples/gym/README.md @@ -0,0 +1,45 @@ +## Gym + +You need to set PYTHONPATH as `PYTHONPATH=./border-py-gym-env/examples`. + +### DQN + +```bash +cargo run --example dqn_cartpole_tch --features="tch" +``` + +```bash +cargo run --example dqn_cartpole --features="candle-core,cuda,cudnn" +``` + +### SAC + +```bash +cargo run --example sac_pendulum_tch --features="tch" +``` + +```bash +cargo run --example sac_lunarlander_cont_tch --features="tch" +``` + +```bash +cargo run --example sac_pendulum --features="candle-core,cuda,cudnn" +``` + +```bash +cargo run --example sac_lunarlander_cont --features="candle-core,cuda,cudnn" +``` + +### border-policy-no-backend + +`convert_sac_policy_to_edge` converts model parameters obtained with `sac_pendulum_tch`. + +```bash +cargo run --example convert_sac_policy_to_edge +``` + +The converted model parameters can be used with border-policy-no-backend crate. + +```bash +cargo run --example pendulum_edge +``` diff --git a/border/examples/gym/convert_sac_policy_to_edge.rs b/border/examples/gym/convert_sac_policy_to_edge.rs new file mode 100644 index 00000000..8a7294ca --- /dev/null +++ b/border/examples/gym/convert_sac_policy_to_edge.rs @@ -0,0 +1,215 @@ +use anyhow::Result; +use border_core::{Agent, Configurable}; +use border_policy_no_backend::Mlp; +use border_tch_agent::{ + mlp, + model::ModelBase, + sac::{ActorConfig, CriticConfig, SacConfig}, +}; +use std::{fs, io::Write}; + +const DIM_OBS: i64 = 3; +const DIM_ACT: i64 = 1; + +// Dummy types +mod dummy { + use super::mlp::{Mlp, Mlp2}; + use border_tch_agent::sac::Sac as Sac_; + + #[derive(Clone, Debug)] + pub struct DummyObs; + + impl border_core::Obs for DummyObs { + fn dummy(_n: usize) -> Self { + unimplemented!(); + } + + fn len(&self) -> usize { + unimplemented!(); + } + } + + impl Into for DummyObs { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + #[derive(Clone, Debug)] + pub struct DummyAct; + + impl border_core::Act for DummyAct { + fn len(&self) -> usize { + unimplemented!(); + } + } + + impl Into for DummyAct { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + impl From for DummyAct { + fn from(_value: tch::Tensor) -> Self { + unimplemented!(); + } + } + + #[derive(Clone)] + pub struct DummyInnerBatch; + + impl Into for DummyInnerBatch { + fn into(self) -> tch::Tensor { + unimplemented!(); + } + } + + pub struct DummyBatch; + + impl border_core::TransitionBatch for DummyBatch { + type ObsBatch = DummyInnerBatch; + type ActBatch = DummyInnerBatch; + + fn len(&self) -> usize { + unimplemented!(); + } + + fn obs(&self) -> &Self::ObsBatch { + unimplemented!(); + } + + fn unpack( + self, + ) -> ( + Self::ObsBatch, + Self::ActBatch, + Self::ObsBatch, + Vec, + Vec, + Vec, + Option>, + Option>, + ) { + unimplemented!(); + } + } + + pub struct DummyReplayBuffer; + + impl border_core::ReplayBufferBase for DummyReplayBuffer { + type Batch = DummyBatch; + type Config = usize; + + fn batch(&mut self, _size: usize) -> anyhow::Result { + unimplemented!(); + } + + fn build(_config: &Self::Config) -> Self { + unimplemented!(); + } + + fn update_priority(&mut self, _ixs: &Option>, _td_err: &Option>) { + unimplemented!(); + } + } + + #[derive(Clone, Debug)] + pub struct DummyInfo; + + impl border_core::Info for DummyInfo {} + + pub struct DummyEnv; + + impl border_core::Env for DummyEnv { + type Config = usize; + type Act = DummyAct; + type Obs = DummyObs; + type Info = DummyInfo; + + fn build(_config: &Self::Config, _seed: i64) -> anyhow::Result + where + Self: Sized, + { + unimplemented!(); + } + + fn reset(&mut self, _is_done: Option<&Vec>) -> anyhow::Result { + unimplemented!(); + } + + fn reset_with_index(&mut self, _ix: usize) -> anyhow::Result { + unimplemented!(); + } + + fn step(&mut self, _a: &Self::Act) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + unimplemented!(); + } + + fn step_with_reset( + &mut self, + _a: &Self::Act, + ) -> (border_core::Step, border_core::record::Record) + where + Self: Sized, + { + unimplemented!(); + } + } + + pub type Env = DummyEnv; + pub type Sac = Sac_; +} + +use dummy::Sac; + +fn create_sac_config() -> SacConfig { + // Omit learning related parameters + let actor_config = ActorConfig::default() + .out_dim(DIM_ACT) + .pi_config(mlp::MlpConfig::new(DIM_OBS, vec![64, 64], DIM_ACT, false)); + let critic_config = CriticConfig::default().q_config(mlp::MlpConfig::new( + DIM_OBS + DIM_ACT, + vec![64, 64], + 1, + false, + )); + SacConfig::default() + .actor_config(actor_config) + .critic_config(critic_config) + .device(tch::Device::Cpu) +} + +fn main() -> Result<()> { + let src_path = "./border/examples/gym/model/tch/sac_pendulum/best"; + let dest_path = "./border/examples/gym/model/edge/sac_pendulum/best/mlp.bincode"; + + // Load Sac model + let sac = { + let config = create_sac_config(); + let mut sac = Sac::build(config); + sac.load_params(src_path)?; + sac + }; + + // Create Mlp + let mlp = { + let vs = sac.get_policy_net().get_var_store(); + let w_names = ["mlp.al0.weight", "mlp.al1.weight", "ml.weight"]; + let b_names = ["mlp.al0.bias", "mlp.al1.bias", "ml.bias"]; + Mlp::from_varstore(vs, &w_names, &b_names) + }; + + // Serialize to file + let encoded = bincode::serialize(&mlp)?; + let mut file = fs::OpenOptions::new() + .create(true) + .write(true) + .open(&dest_path)?; + file.write_all(&encoded)?; + + Ok(()) +} diff --git a/border/examples/gym/dqn_cartpole.rs b/border/examples/gym/dqn_cartpole.rs new file mode 100644 index 00000000..ea0bf88c --- /dev/null +++ b/border/examples/gym/dqn_cartpole.rs @@ -0,0 +1,361 @@ +use anyhow::Result; +use border_candle_agent::{ + dqn::{Dqn, DqnConfig, DqnModelConfig}, + mlp::{Mlp, MlpConfig}, + opt::OptimizerConfig, + util::CriticLoss, + TensorBatch, +}; +use border_core::{ + generic_replay_buffer::{ + BatchBase, SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::{arrayd_to_tensor, vec_to_tensor}, + ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tensorboard::TensorboardRecorder; +use candle_core::{Device, Tensor}; +use clap::Parser; +use ndarray::{ArrayD, IxDyn}; +use serde::Serialize; + +const DIM_OBS: i64 = 4; +const DIM_ACT: i64 = 2; +const LR_CRITIC: f64 = 0.001; +const DISCOUNT_FACTOR: f64 = 0.99; +const BATCH_SIZE: usize = 64; +const WARMUP_PERIOD: usize = 100; +const N_UPDATES_PER_OPT: usize = 1; +const TAU: f64 = 0.01; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 30000; +const EVAL_INTERVAL: usize = 1000; +const REPLAY_BUFFER_CAPACITY: usize = 10000; +const N_EPISODES_PER_EVAL: usize = 5; +const CRITIC_LOSS: CriticLoss = CriticLoss::Mse; +const MODEL_DIR: &str = "./border/examples/gym/model/candle/dqn_cartpole"; + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + arrayd_to_tensor::<_, f32>(obs.0, false).unwrap() + } + } + + pub struct ObsBatch(TensorBatch); + + impl BatchBase for ObsBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + impl From for Tensor { + fn from(b: ObsBatch) -> Self { + b.0.into() + } + } + + #[derive(Clone, Debug)] + pub struct Act(Vec); + + impl border_core::Act for Act {} + + impl From for Vec { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + // `t` must be a 1-dimentional tensor of `i64` + fn from(t: Tensor) -> Self { + let data = t.to_vec1::().expect("Failed to convert Tensor to Act"); + let data = data.iter().map(|&e| e as i32).collect(); + Self(data) + } + } + + pub struct ActBatch(TensorBatch); + + impl BatchBase for ActBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ActBatch { + fn from(act: Act) -> Self { + let t = + vec_to_tensor::<_, i64>(act.0, true).expect("Failed to convert Act to ActBatch"); + Self(TensorBatch::from_tensor(t)) + } + } + + // Required by Dqn + impl From for Tensor { + fn from(act: ActBatch) -> Self { + act.0.into() + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = DiscreteActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use super::*; + + #[derive(Serialize)] + pub struct DqnCartpoleConfig { + pub env_config: EnvConfig, + pub agent_config: DqnConfig, + pub trainer_config: TrainerConfig, + } + + impl DqnCartpoleConfig { + pub fn new( + in_dim: i64, + out_dim: i64, + max_opts: usize, + model_dir: &str, + eval_interval: usize, + ) -> Self { + let env_config = create_env_config(); + let agent_config = create_agent_config(in_dim, out_dim); + let trainer_config = TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir); + Self { + env_config, + agent_config, + trainer_config, + } + } + } + + pub fn create_env_config() -> EnvConfig { + EnvConfig::default() + .name("CartPole-v0".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_agent_config(in_dim: i64, out_dim: i64) -> DqnConfig { + let device = Device::cuda_if_available(0).unwrap(); + let opt_config = OptimizerConfig::default().learning_rate(LR_CRITIC); + let mlp_config = MlpConfig::new(in_dim, vec![256, 256], out_dim, false); + let model_config = DqnModelConfig::default() + .q_config(mlp_config) + .out_dim(out_dim) + .opt_config(opt_config); + DqnConfig::default() + .n_updates_per_opt(N_UPDATES_PER_OPT) + .batch_size(BATCH_SIZE) + .discount_factor(DISCOUNT_FACTOR) + .tau(TAU) + .model_config(model_config) + .device(device) + .critic_loss(CRITIC_LOSS) + } +} + +use config::{create_agent_config, create_env_config, DqnCartpoleConfig}; + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &DqnCartpoleConfig, + ) -> Result> { + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "cartpole")?; + recorder_run.set_tag("algo", "dqn")?; + recorder_run.set_tag("backend", "candle")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train/eval DQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train DQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { + let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let mut recorder = utils::create_recorder(&args, model_dir, &config)?; + let mut trainer = Trainer::build(config.trainer_config.clone()); + + let env = Env::build(&config.env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Dqn::build(config.agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&config.env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(model_dir: &str, render: bool) -> Result<()> { + let env_config = { + let mut env_config = create_env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + } + env_config + }; + let mut agent = { + let mut agent = Dqn::build(create_agent_config(DIM_OBS, DIM_ACT)); + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + // TODO: set seed + + let args = Args::parse(); + + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } else { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{eval, train, Args}; + use anyhow::Result; + use tempdir::TempDir; + + #[test] + fn test_dqn_cartpole() -> Result<()> { + let tmp_dir = TempDir::new("dqn_cartpole")?; + let model_dir = match tmp_dir.as_ref().to_str() { + Some(s) => s, + None => panic!("Failed to get string of temporary directory"), + }; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; + eval(&(model_dir.to_owned() + "/best"), false)?; + Ok(()) + } +} diff --git a/border/examples/gym/dqn_cartpole_tch.rs b/border/examples/gym/dqn_cartpole_tch.rs new file mode 100644 index 00000000..b6604476 --- /dev/null +++ b/border/examples/gym/dqn_cartpole_tch.rs @@ -0,0 +1,361 @@ +use anyhow::Result; +use border_core::{ + generic_replay_buffer::{ + BatchBase, SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::vec_to_tensor, ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, + GymObsFilter, +}; +use border_tch_agent::{ + dqn::{Dqn, DqnConfig, DqnModelConfig}, + mlp::{Mlp, MlpConfig}, + util::CriticLoss, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +use ndarray::{ArrayD, IxDyn}; +use serde::Serialize; +use std::convert::TryFrom; +use tch::Tensor; + +const DIM_OBS: i64 = 4; +const DIM_ACT: i64 = 2; +const LR_CRITIC: f64 = 0.001; +const DISCOUNT_FACTOR: f64 = 0.99; +const BATCH_SIZE: usize = 64; +const WARMUP_PERIOD: usize = 100; +const N_UPDATES_PER_OPT: usize = 1; +const TAU: f64 = 0.01; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 30000; +const EVAL_INTERVAL: usize = 1000; +const REPLAY_BUFFER_CAPACITY: usize = 10000; +const N_EPISODES_PER_EVAL: usize = 5; +const CRITIC_LOSS: CriticLoss = CriticLoss::Mse; +const MODEL_DIR: &str = "./border/examples/gym/model/tch/dqn_cartpole"; + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + Tensor::try_from(&obs.0).unwrap() + } + } + + pub struct ObsBatch(TensorBatch); + + impl BatchBase for ObsBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + impl From for Tensor { + fn from(b: ObsBatch) -> Self { + b.0.into() + } + } + + #[derive(Clone, Debug)] + pub struct Act(Vec); + + impl border_core::Act for Act {} + + impl From for Vec { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + // `t` must be a 1-dimentional tensor of `f32` + fn from(t: Tensor) -> Self { + let data = Vec::::try_from(&t.flatten(0, -1)) + .expect("Failed to convert from Tensor to Vec"); + let data = data.iter().map(|&e| e as i32).collect(); + Act(data) + } + } + + pub struct ActBatch(TensorBatch); + + impl BatchBase for ActBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ActBatch { + fn from(act: Act) -> Self { + let t = vec_to_tensor::<_, i64>(act.0, true); + Self(TensorBatch::from_tensor(t)) + } + } + + // Required by Dqn + impl From for Tensor { + fn from(act: ActBatch) -> Self { + act.0.into() + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = DiscreteActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use super::*; + + #[derive(Serialize)] + pub struct DqnCartpoleConfig { + pub env_config: EnvConfig, + pub agent_config: DqnConfig, + pub trainer_config: TrainerConfig, + } + + impl DqnCartpoleConfig { + pub fn new( + in_dim: i64, + out_dim: i64, + max_opts: usize, + model_dir: &str, + eval_interval: usize, + ) -> Self { + let env_config = create_env_config(); + let agent_config = create_agent_config(in_dim, out_dim); + let trainer_config = TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir); + Self { + env_config, + agent_config, + trainer_config, + } + } + } + + pub fn create_env_config() -> EnvConfig { + EnvConfig::default() + .name("CartPole-v0".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_agent_config(in_dim: i64, out_dim: i64) -> DqnConfig { + let device = tch::Device::cuda_if_available(); + let opt_config = border_tch_agent::opt::OptimizerConfig::Adam { lr: LR_CRITIC }; + let mlp_config = MlpConfig::new(in_dim, vec![256, 256], out_dim, false); + let model_config = DqnModelConfig::default() + .q_config(mlp_config) + .out_dim(out_dim) + .opt_config(opt_config); + DqnConfig::default() + .n_updates_per_opt(N_UPDATES_PER_OPT) + .batch_size(BATCH_SIZE) + .discount_factor(DISCOUNT_FACTOR) + .tau(TAU) + .model_config(model_config) + .device(device) + .critic_loss(CRITIC_LOSS) + } +} + +use config::{create_agent_config, create_env_config, DqnCartpoleConfig}; + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &DqnCartpoleConfig, + ) -> Result> { + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "cartpole")?; + recorder_run.set_tag("algo", "dqn")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train/eval DQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train DQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { + let config = DqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let mut recorder = utils::create_recorder(&args, model_dir, &config)?; + let mut trainer = Trainer::build(config.trainer_config.clone()); + + let env = Env::build(&config.env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Dqn::build(config.agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&config.env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(model_dir: &str, render: bool) -> Result<()> { + let env_config = { + let mut env_config = create_env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + } + env_config + }; + let mut agent = { + let mut agent = Dqn::build(create_agent_config(DIM_OBS, DIM_ACT)); + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + tch::manual_seed(42); + + let args = Args::parse(); + + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } else { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::{eval, train, Args}; + use anyhow::Result; + use tempdir::TempDir; + + #[test] + fn test_dqn_cartpole() -> Result<()> { + let tmp_dir = TempDir::new("dqn_cartpole")?; + let model_dir = match tmp_dir.as_ref().to_str() { + Some(s) => s, + None => panic!("Failed to get string of temporary directory"), + }; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; + eval(&(model_dir.to_owned() + "/best"), false)?; + Ok(()) + } +} diff --git a/border/examples/gym/iqn_cartpole_tch.rs b/border/examples/gym/iqn_cartpole_tch.rs new file mode 100644 index 00000000..70b9be41 --- /dev/null +++ b/border/examples/gym/iqn_cartpole_tch.rs @@ -0,0 +1,370 @@ +use anyhow::Result; +use border_core::{ + generic_replay_buffer::{ + BatchBase, SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::vec_to_tensor, ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, + GymObsFilter, +}; +use border_tch_agent::{ + iqn::{EpsilonGreedy, Iqn as Iqn_, IqnConfig, IqnModelConfig}, + mlp::{Mlp, MlpConfig}, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +use ndarray::{ArrayD, IxDyn}; +use serde::Serialize; +use std::convert::TryFrom; +use tch::Tensor; + +const DIM_OBS: i64 = 4; +const DIM_ACT: i64 = 2; +const LR_CRITIC: f64 = 0.001; +const DIM_FEATURE: i64 = 256; +const DIM_EMBED: i64 = 64; +const DISCOUNT_FACTOR: f64 = 0.99; +const BATCH_SIZE: usize = 64; +const WARMUP_PERIOD: usize = 100; +const N_UPDATES_PER_OPT: usize = 1; +const TAU: f64 = 0.01; +const SOFT_UPDATE_INTERVAL: usize = 1; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 30000; +const EVAL_INTERVAL: usize = 1000; +const REPLAY_BUFFER_CAPACITY: usize = 10000; +const N_EPISODES_PER_EVAL: usize = 5; +const EPS_START: f64 = 1.0; +const EPS_FINAL: f64 = 0.02; +const FINAL_STEP: usize = MAX_OPTS / 3; +const MODEL_DIR: &str = "border/examples/gym/model/tch/iqn_cartpole"; + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + Tensor::try_from(&obs.0).unwrap() + } + } + + pub struct ObsBatch(TensorBatch); + + impl BatchBase for ObsBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + impl From for Tensor { + fn from(b: ObsBatch) -> Self { + b.0.into() + } + } + + #[derive(Clone, Debug)] + pub struct Act(Vec); + + impl border_core::Act for Act {} + + impl From for Vec { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + // `t` must be a 1-dimentional tensor of `f32` + fn from(t: Tensor) -> Self { + let data = Vec::::try_from(&t.flatten(0, -1)) + .expect("Failed to convert from Tensor to Vec"); + let data = data.iter().map(|&e| e as i32).collect(); + Act(data) + } + } + + pub struct ActBatch(TensorBatch); + + impl BatchBase for ActBatch { + fn new(capacity: usize) -> Self { + Self(TensorBatch::new(capacity)) + } + + fn push(&mut self, i: usize, data: Self) { + self.0.push(i, data.0) + } + + fn sample(&self, ixs: &Vec) -> Self { + let buf = self.0.sample(ixs); + Self(buf) + } + } + + impl From for ActBatch { + fn from(act: Act) -> Self { + let t = vec_to_tensor::<_, i64>(act.0, true); + Self(TensorBatch::from_tensor(t)) + } + } + + // Required by Iqn + impl From for Tensor { + fn from(act: ActBatch) -> Self { + act.0.into() + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = DiscreteActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Iqn = Iqn_; + pub type Evaluator = DefaultEvaluator; +} + +use obs_act_types::*; + +mod config { + use super::*; + + #[derive(Serialize)] + pub struct IqnCartpoleConfig { + pub env_config: EnvConfig, + pub agent_config: IqnConfig, + pub trainer_config: TrainerConfig, + } + + impl IqnCartpoleConfig { + pub fn new( + in_dim: i64, + out_dim: i64, + max_opts: usize, + model_dir: &str, + eval_interval: usize, + ) -> Self { + let env_config = env_config(); + let agent_config = agent_config(in_dim, out_dim); + let trainer_config = TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir); + Self { + env_config, + agent_config, + trainer_config, + } + } + } + + pub fn env_config() -> EnvConfig { + GymEnvConfig::::default() + .name("CartPole-v0".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn agent_config(in_dim: i64, out_dim: i64) -> IqnConfig { + let device = tch::Device::cuda_if_available(); + let opt_config = border_tch_agent::opt::OptimizerConfig::Adam { lr: LR_CRITIC }; + let f_config = MlpConfig::new(in_dim, vec![256], DIM_FEATURE, true); + let m_config = MlpConfig::new(DIM_FEATURE, vec![256], out_dim, false); + let model_config = IqnModelConfig::default() + .feature_dim(DIM_FEATURE) + .embed_dim(DIM_EMBED) + .opt_config(opt_config) + .f_config(f_config) + .m_config(m_config); + + IqnConfig::default() + .n_updates_per_opt(N_UPDATES_PER_OPT) + .batch_size(BATCH_SIZE) + .discount_factor(DISCOUNT_FACTOR) + .tau(TAU) + .model_config(model_config) + .explorer(EpsilonGreedy::with_params(EPS_START, EPS_FINAL, FINAL_STEP)) + .soft_update_interval(SOFT_UPDATE_INTERVAL) + .device(device) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + model_dir: &str, + config: &config::IqnCartpoleConfig, + ) -> Result> { + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "cartpole")?; + recorder_run.set_tag("algo", "iqn")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } +} + +/// Train/eval IQN agent in cartpole environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train IQN agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate IQN agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(args: &Args, max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { + let config = + config::IqnCartpoleConfig::new(DIM_OBS, DIM_ACT, max_opts, model_dir, eval_interval); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let mut recorder = utils::create_recorder(args, model_dir, &config)?; + let mut trainer = Trainer::build(config.trainer_config.clone()); + + let env = Env::build(&config.env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Iqn::build(config.agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&config.env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(model_dir: &str, render: bool) -> Result<()> { + let env_config = { + let mut env_config = config::env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + } + env_config + }; + let mut agent = { + let mut agent = Iqn::build(config::agent_config(DIM_OBS, DIM_ACT)); + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + tch::manual_seed(42); + + let args = Args::parse(); + + if args.train { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + } else if args.eval { + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } else { + train(&args, MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; + eval(&(MODEL_DIR.to_owned() + "/best"), true)?; + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::{eval, train, Args}; + use anyhow::Result; + use tempdir::TempDir; + + #[test] + fn test_iqn_cartpole() -> Result<()> { + let tmp_dir = TempDir::new("iqn_cartpole")?; + let model_dir = match tmp_dir.as_ref().to_str() { + Some(s) => s, + None => panic!("Failed to get string of temporary directory"), + }; + let args = Args { + train: false, + eval: false, + mlflow: false, + }; + train(&args, 100, model_dir, 100)?; + eval(&(model_dir.to_owned() + "/best"), false)?; + Ok(()) + } +} diff --git a/border/examples/gym/model/.gitkeep b/border/examples/gym/model/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/border/examples/gym/pendulum_edge.rs b/border/examples/gym/pendulum_edge.rs new file mode 100644 index 00000000..e81b4f97 --- /dev/null +++ b/border/examples/gym/pendulum_edge.rs @@ -0,0 +1,190 @@ +use anyhow::Result; +use border_core::{DefaultEvaluator, Evaluator as _}; +use border_policy_no_backend::{Mat, Mlp}; +use border_py_gym_env::{ + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use clap::Parser; +use ndarray::ArrayD; +use std::fs; + +type PyObsDtype = f32; + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + /// Observation type. + pub struct Obs(Mat); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(Mat::empty()) + } + + fn len(&self) -> usize { + self.0.shape()[0] as _ + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + let obs = obs.t().to_owned(); + let shape = obs.shape().iter().map(|e| *e as i32).collect(); + let data = obs.into_raw_vec(); + Self(Mat::new(data, shape)) + } + } + + impl From for Mat { + fn from(obs: Obs) -> Mat { + obs.0 + } + } + + #[derive(Clone, Debug)] + pub struct Act(Mat); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + // let shape: Vec<_> = value.0.shape.iter().map(|e| *e as usize).collect(); + let shape = vec![(value.0.shape[0] * value.0.shape[1]) as usize]; + // let data = value.0.data; + let data: Vec = value.0.data.iter().map(|e| 2f32 * *e).collect(); + let t = ArrayD::from_shape_vec(shape, data).unwrap(); + t + } + } + + impl Into for Mat { + fn into(self) -> Act { + Act(self) + } + } +} + +mod policy { + use std::{io::Read, path::Path}; + + use super::*; + use border_core::Policy; + + pub struct MlpPolicy { + mlp: Mlp, + } + + impl Policy for MlpPolicy { + fn sample(&mut self, obs: &Obs) -> Act { + self.mlp.forward(&obs.clone().into()).into() + } + } + + impl MlpPolicy { + pub fn from_serialized_path(path: impl AsRef) -> Result { + let mut file = fs::OpenOptions::new().read(true).open(&path)?; + let mut buf = Vec::::new(); + let _ = file.read_to_end(&mut buf).unwrap(); + let mlp: Mlp = bincode::deserialize(&buf[..])?; + Ok(Self { mlp }) + } + } +} + +use obs_act_types::*; +use policy::*; + +type ObsFilter = ArrayObsFilter; +type ActFilter = ContinuousActFilter; +type Env = GymEnv; +type Evaluator = DefaultEvaluator; + +fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("Pendulum-v1".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) +} + +fn eval(n_episodes: usize, render: bool) -> Result<()> { + let env_config = { + let mut env_config = env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + }; + env_config + }; + let mut policy = MlpPolicy::from_serialized_path( + "./border/examples/gym/model/edge/sac_pendulum/best/mlp.bincode", + )?; + + let _ = Evaluator::new(&env_config, 0, n_episodes)?.evaluate(&mut policy); + + Ok(()) +} + +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + let _ = eval(5, true)?; + + // let args = Args::parse(); + + // if args.train { + // train( + // MAX_OPTS, + // "./border/examples/gym/model/tch/sac_pendulum", + // EVAL_INTERVAL, + // args.mlflow, + // )?; + // } else if args.eval { + // eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + // } else { + // train( + // MAX_OPTS, + // "./border/examples/gym/model/tch/sac_pendulum", + // EVAL_INTERVAL, + // args.mlflow, + // )?; + // eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + // } + + Ok(()) +} + +// #[cfg(test)] +// mod test { +// use super::*; +// use tempdir::TempDir; + +// #[test] +// fn test_sac_pendulum() -> Result<()> { +// tch::manual_seed(42); + +// let model_dir = TempDir::new("sac_pendulum_tch")?; +// let model_dir = model_dir.path().to_str().unwrap(); +// train(100, model_dir, 100, false)?; +// eval(1, false, (model_dir.to_string() + "/best").as_str())?; + +// Ok(()) +// } +// } diff --git a/border/examples/gym/sac_lunarlander_cont.rs b/border/examples/gym/sac_lunarlander_cont.rs new file mode 100644 index 00000000..fa128e9c --- /dev/null +++ b/border/examples/gym/sac_lunarlander_cont.rs @@ -0,0 +1,314 @@ +use anyhow::Result; +use border_candle_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, Sac, SacConfig}, + TensorBatch, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::{AggregateRecorder, Record}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::BatchBase; +use border_py_gym_env::{ + util::{arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +//use csv::WriterBuilder; +use border_mlflow_tracking::MlflowTrackingClient; +use candle_core::Tensor; +use ndarray::{ArrayD, IxDyn}; +use serde::Serialize; +use std::convert::TryFrom; + +const DIM_OBS: i64 = 8; +const DIM_ACT: i64 = 2; +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 128; +const WARMUP_PERIOD: usize = 1000; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 200_000; +const EVAL_INTERVAL: usize = 10_000; +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +const N_EPISODES_PER_EVAL: usize = 5; +const MODEL_DIR: &str = "./border/examples/gym/model/candle/sac_lunarlander_cont"; + +fn cuda_if_available() -> candle_core::Device { + candle_core::Device::cuda_if_available(0).unwrap() +} + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + arrayd_to_tensor::<_, f32>(obs.0, false).unwrap() + } + } + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true).unwrap()) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true).unwrap() + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = ContinuousActFilter; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use serde::Serialize; + + use super::*; + + #[derive(Serialize)] + pub struct SacLunarLanderConfig { + pub trainer_config: TrainerConfig, + pub replay_buffer_config: SimpleReplayBufferConfig, + pub agent_config: SacConfig, + } + + pub fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("LunarLanderContinuous-v2".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn trainer_config(max_opts: usize, eval_interval: usize) -> TrainerConfig { + TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(MODEL_DIR) + } + + pub fn agent_config(in_dim: i64, out_dim: i64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(out_dim) + .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .device(device) + } +} + +#[derive(Debug, Serialize)] +struct LunarlanderRecord { + episode: usize, + step: usize, + reward: f32, + obs: Vec, + act: Vec, +} + +impl TryFrom<&Record> for LunarlanderRecord { + type Error = anyhow::Error; + + fn try_from(record: &Record) -> Result { + Ok(Self { + episode: record.get_scalar("episode")? as _, + step: record.get_scalar("step")? as _, + reward: record.get_scalar("reward")?, + obs: record.get_array1("obs")?.to_vec(), + act: record.get_array1("act")?.to_vec(), + }) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + config: &config::SacLunarLanderConfig, + ) -> Result> { + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "lunarlander")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "candle")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(MODEL_DIR))), + } + } +} + +/// Train/eval SAC agent in lunarlander environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(args: &Args, max_opts: usize) -> Result<()> { + let env_config = config::env_config(); + let trainer_config = config::trainer_config(max_opts, EVAL_INTERVAL); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let agent_config = config::agent_config(DIM_OBS, DIM_ACT); + let config = config::SacLunarLanderConfig { + agent_config: agent_config.clone(), + replay_buffer_config: replay_buffer_config.clone(), + trainer_config, + }; + let mut recorder = utils::create_recorder(&args, &config)?; + let mut trainer = Trainer::build(config.trainer_config.clone()); + + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Sac::build(config.agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(render: bool) -> Result<()> { + let model_dir = MODEL_DIR.to_owned() + "/best"; + let env_config = { + let mut env_config = config::env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + } + env_config + }; + let mut agent = { + let mut agent = Sac::build(config::agent_config(DIM_OBS, DIM_ACT)); + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let args = Args::parse(); + + if args.eval { + eval(true)?; + } else if args.train { + train(&args, MAX_OPTS)?; + } else { + train(&args, MAX_OPTS)?; + eval(true)?; + } + + Ok(()) +} diff --git a/border/examples/gym/sac_lunarlander_cont_tch.rs b/border/examples/gym/sac_lunarlander_cont_tch.rs new file mode 100644 index 00000000..afbcef35 --- /dev/null +++ b/border/examples/gym/sac_lunarlander_cont_tch.rs @@ -0,0 +1,316 @@ +use anyhow::Result; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::{AggregateRecorder, Record}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::BatchBase; +use border_py_gym_env::{ + util::{arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tch_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, Sac, SacConfig}, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +//use csv::WriterBuilder; +use border_mlflow_tracking::MlflowTrackingClient; +use ndarray::{ArrayD, IxDyn}; +use serde::Serialize; +use std::convert::TryFrom; +use tch::Tensor; + +const DIM_OBS: i64 = 8; +const DIM_ACT: i64 = 2; +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 128; +const WARMUP_PERIOD: usize = 1000; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 200_000; +const EVAL_INTERVAL: usize = 10_000; +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +const N_EPISODES_PER_EVAL: usize = 5; +const MODEL_DIR: &str = "./border/examples/gym/model/tch/sac_lunarlander_cont"; + +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod obs_act_types { + use super::*; + + type PyObsDtype = f32; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + Tensor::try_from(&obs.0).unwrap() + } + } + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true)) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true) + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = ContinuousActFilter; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use serde::Serialize; + + use super::*; + + #[derive(Serialize)] + pub struct SacLunarLanderConfig { + pub trainer_config: TrainerConfig, + pub replay_buffer_config: SimpleReplayBufferConfig, + pub agent_config: SacConfig, + } + + pub fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("LunarLanderContinuous-v2".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn trainer_config(max_opts: usize, eval_interval: usize) -> TrainerConfig { + TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(MODEL_DIR) + } + + pub fn agent_config(in_dim: i64, out_dim: i64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(out_dim) + .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .device(device) + } +} + +#[derive(Debug, Serialize)] +struct LunarlanderRecord { + episode: usize, + step: usize, + reward: f32, + obs: Vec, + act: Vec, +} + +impl TryFrom<&Record> for LunarlanderRecord { + type Error = anyhow::Error; + + fn try_from(record: &Record) -> Result { + Ok(Self { + episode: record.get_scalar("episode")? as _, + step: record.get_scalar("step")? as _, + reward: record.get_scalar("reward")?, + obs: record.get_array1("obs")?.to_vec(), + act: record.get_array1("act")?.to_vec(), + }) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + config: &config::SacLunarLanderConfig, + ) -> Result> { + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "lunarlander")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(MODEL_DIR))), + } + } +} + +/// Train/eval SAC agent in lunarlander environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(args: &Args, max_opts: usize) -> Result<()> { + let env_config = config::env_config(); + let trainer_config = config::trainer_config(max_opts, EVAL_INTERVAL); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let agent_config = config::agent_config(DIM_OBS, DIM_ACT); + let config = config::SacLunarLanderConfig { + agent_config: agent_config.clone(), + replay_buffer_config: replay_buffer_config.clone(), + trainer_config, + }; + let mut recorder = utils::create_recorder(&args, &config)?; + let mut trainer = Trainer::build(config.trainer_config.clone()); + + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Sac::build(config.agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(render: bool) -> Result<()> { + let model_dir = MODEL_DIR.to_owned() + "/best"; + let env_config = { + let mut env_config = config::env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + } + env_config + }; + let mut agent = { + let mut agent = Sac::build(config::agent_config(DIM_OBS, DIM_ACT)); + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + + let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + tch::manual_seed(42); + + let args = Args::parse(); + + if args.eval { + eval(true)?; + } else if args.train { + train(&args, MAX_OPTS)?; + } else { + train(&args, MAX_OPTS)?; + eval(true)?; + } + + Ok(()) +} diff --git a/border/examples/gym/sac_pendulum.rs b/border/examples/gym/sac_pendulum.rs new file mode 100644 index 00000000..6921a18f --- /dev/null +++ b/border/examples/gym/sac_pendulum.rs @@ -0,0 +1,359 @@ +use anyhow::Result; +use border_candle_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, Sac, SacConfig}, + TensorBatch, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::{AggregateRecorder, Record, RecordValue}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::BatchBase; +use border_py_gym_env::{ + util::{arrayd_to_pyobj, arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +// use csv::WriterBuilder; +use border_mlflow_tracking::MlflowTrackingClient; +use candle_core::{Device, Tensor}; +use ndarray::{ArrayD, IxDyn}; +use pyo3::PyObject; +use serde::Serialize; +use std::convert::TryFrom; + +const DIM_OBS: i64 = 3; +const DIM_ACT: i64 = 1; +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 128; +const WARMUP_PERIOD: usize = 1000; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 40_000; +const EVAL_INTERVAL: usize = 2_000; +const REPLAY_BUFFER_CAPACITY: usize = 100_000; +const N_EPISODES_PER_EVAL: usize = 5; + +type PyObsDtype = f32; + +mod obs { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + arrayd_to_tensor::<_, f32>(obs.0, false).unwrap() + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } +} + +mod act { + use super::*; + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true).unwrap()) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true).unwrap() + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + // Custom activation filter + #[derive(Clone, Debug)] + pub struct ActFilter {} + + impl GymActFilter for ActFilter { + type Config = (); + + fn build(_config: &Self::Config) -> Result + where + Self: Sized, + { + Ok(Self {}) + } + + fn filt(&mut self, act: Act) -> (PyObject, Record) { + let act_filt = 2f32 * &act.0; + let record = Record::from_slice(&[ + ( + "act_org", + RecordValue::Array1(act.0.iter().cloned().collect()), + ), + ( + "act_filt", + RecordValue::Array1(act_filt.iter().cloned().collect()), + ), + ]); + (arrayd_to_pyobj(act_filt), record) + } + } +} + +use act::{Act, ActBatch, ActFilter}; +use obs::{Obs, ObsBatch}; + +type ObsFilter = ArrayObsFilter; +type Env = GymEnv; +type StepProc = SimpleStepProcessor; +type ReplayBuffer = SimpleReplayBuffer; +type Evaluator = DefaultEvaluator>; + +#[derive(Debug, Serialize)] +struct PendulumRecord { + episode: usize, + step: usize, + reward: f32, + obs: Vec, + act_org: Vec, + act_filt: Vec, +} + +impl TryFrom<&Record> for PendulumRecord { + type Error = anyhow::Error; + + fn try_from(record: &Record) -> Result { + Ok(Self { + episode: record.get_scalar("episode")? as _, + step: record.get_scalar("step")? as _, + reward: record.get_scalar("reward")?, + obs: record.get_array1("obs")?.to_vec(), + act_org: record.get_array1("act_org")?.to_vec(), + act_filt: record.get_array1("act_filt")?.to_vec(), + }) + } +} + +fn create_agent(in_dim: i64, out_dim: i64) -> Result> { + let device = Device::cuda_if_available(0)?; + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::default().learning_rate(LR_ACTOR)) + .out_dim(out_dim) + .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::default().learning_rate(LR_CRITIC)) + .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, false)); + let sac_config = SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .device(device); + Ok(Sac::build(sac_config)) +} + +fn env_config() -> GymEnvConfig { + GymEnvConfig::::default() + .name("Pendulum-v1".to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) +} + +fn create_recorder( + model_dir: &str, + mlflow: bool, + config: &TrainerConfig, +) -> Result> { + match mlflow { + true => { + let client = MlflowTrackingClient::new("http://localhost:8080") + //.basic_auth("user_name", "password") // when using basic authentication + .set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "pendulum")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "candle")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } +} + +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + +fn train(max_opts: usize, model_dir: &str, eval_interval: usize, mlflow: bool) -> Result<()> { + let env_config = env_config(); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let (mut trainer, config) = { + let config = TrainerConfig::default() + .max_opts(max_opts) + .opt_interval(OPT_INTERVAL) + .eval_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir); + let trainer = Trainer::build(config.clone()); + + (trainer, config) + }; + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = create_agent(DIM_OBS, DIM_ACT)?; + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = create_recorder(model_dir, mlflow, &config)?; + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { + let env_config = { + let mut env_config = env_config(); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(10); + }; + env_config + }; + let mut agent = { + let mut agent = create_agent(DIM_OBS, DIM_ACT)?; + agent.load_params(model_dir)?; + agent.eval(); + agent + }; + // let mut recorder = BufferedRecorder::new(); + + let _ = Evaluator::new(&env_config, 0, n_episodes)?.evaluate(&mut agent); + + // // Vec<_> field in a struct does not support writing a header in csv crate, so disable it. + // let mut wtr = WriterBuilder::new() + // .has_headers(false) + // .from_writer(File::create(model_dir.to_string() + "/eval.csv")?); + // for record in recorder.iter() { + // wtr.serialize(PendulumRecord::try_from(record)?)?; + // } + + Ok(()) +} + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + + let args = Args::parse(); + + if args.train { + train( + MAX_OPTS, + "./border/examples/gym/model/tch/sac_pendulum", + EVAL_INTERVAL, + args.mlflow, + )?; + } else if args.eval { + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + } else { + train( + MAX_OPTS, + "./border/examples/gym/model/tch/sac_pendulum", + EVAL_INTERVAL, + args.mlflow, + )?; + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + } + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use tempdir::TempDir; + + #[test] + fn test_sac_pendulum() -> Result<()> { + let model_dir = TempDir::new("sac_pendulum")?; + let model_dir = model_dir.path().to_str().unwrap(); + train(100, model_dir, 100, false)?; + eval(1, false, (model_dir.to_string() + "/best").as_str())?; + + Ok(()) + } +} diff --git a/border/examples/sac_pendulum.rs b/border/examples/gym/sac_pendulum_tch.rs similarity index 66% rename from border/examples/sac_pendulum.rs rename to border/examples/gym/sac_pendulum_tch.rs index 9b73ec5f..de543af7 100644 --- a/border/examples/sac_pendulum.rs +++ b/border/examples/gym/sac_pendulum_tch.rs @@ -1,13 +1,14 @@ use anyhow::Result; use border_core::{ - record::{Record, RecordValue}, - replay_buffer::{ + generic_replay_buffer::{ SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, SimpleStepProcessorConfig, }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, + record::{AggregateRecorder, Record, RecordValue}, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, }; -use border_derive::SubBatch; +use border_derive::BatchBase; use border_py_gym_env::{ util::{arrayd_to_pyobj, arrayd_to_tensor, tensor_to_arrayd}, ArrayObsFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, @@ -16,11 +17,12 @@ use border_tch_agent::{ mlp::{Mlp, Mlp2, MlpConfig}, opt::OptimizerConfig, sac::{ActorConfig, CriticConfig, Sac, SacConfig}, - TensorSubBatch, + TensorBatch, }; use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; +use clap::Parser; // use csv::WriterBuilder; +use border_mlflow_tracking::MlflowTrackingClient; use ndarray::{ArrayD, IxDyn}; use pyo3::PyObject; use serde::Serialize; @@ -32,7 +34,7 @@ const DIM_ACT: i64 = 1; const LR_ACTOR: f64 = 3e-4; const LR_CRITIC: f64 = 3e-4; const BATCH_SIZE: usize = 128; -const N_TRANSITIONS_WARMUP: usize = 1000; +const WARMUP_PERIOD: usize = 1000; const OPT_INTERVAL: usize = 1; const MAX_OPTS: usize = 40_000; const EVAL_INTERVAL: usize = 2_000; @@ -47,8 +49,8 @@ mod obs { #[derive(Clone, Debug)] pub struct Obs(ArrayD); - #[derive(Clone, SubBatch)] - pub struct ObsBatch(TensorSubBatch); + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); impl border_core::Obs for Obs { fn dummy(_n: usize) -> Self { @@ -75,7 +77,7 @@ mod obs { impl From for ObsBatch { fn from(obs: Obs) -> Self { let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } } @@ -107,13 +109,13 @@ mod act { } } - #[derive(SubBatch)] - pub struct ActBatch(TensorSubBatch); + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); impl From for ActBatch { fn from(act: Act) -> Self { let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) + Self(TensorBatch::from_tensor(tensor)) } } @@ -187,13 +189,12 @@ fn create_agent(in_dim: i64, out_dim: i64) -> Sac let actor_config = ActorConfig::default() .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, true)); + .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, false)); let critic_config = CriticConfig::default() .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); + .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, false)); let sac_config = SacConfig::default() .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) .actor_config(actor_config) .critic_config(critic_config) .device(device); @@ -207,33 +208,60 @@ fn env_config() -> GymEnvConfig { .act_filter_config(ActFilter::default_config()) } -fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); +fn create_recorder( + model_dir: &str, + mlflow: bool, + config: &TrainerConfig, +) -> Result> { + match mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", "pendulum")?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } +} + +fn train(max_opts: usize, model_dir: &str, eval_interval: usize, mlflow: bool) -> Result<()> { + let env_config = env_config(); + let step_proc_config = SimpleStepProcessorConfig {}; + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let (mut trainer, config) = { let config = TrainerConfig::default() .max_opts(max_opts) .opt_interval(OPT_INTERVAL) .eval_interval(eval_interval) - .record_interval(eval_interval) - .save_interval(eval_interval) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer + let trainer = Trainer::build(config.clone()); + + (trainer, config) }; + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = create_recorder(model_dir, mlflow, &config)?; + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; Ok(()) } @@ -250,7 +278,7 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { }; let mut agent = { let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; + agent.load_params(model_dir)?; agent.eval(); agent }; @@ -269,41 +297,46 @@ fn eval(n_episodes: usize, render: bool, model_dir: &str) -> Result<()> { Ok(()) } +/// Train/eval SAC agent in pendulum environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Train SAC agent, not evaluate + #[arg(short, long, default_value_t = false)] + train: bool, + + /// Evaluate SAC agent, not train + #[arg(short, long, default_value_t = false)] + eval: bool, + + /// Log metrics with MLflow + #[arg(short, long, default_value_t = false)] + mlflow: bool, +} + fn main() -> Result<()> { env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); tch::manual_seed(42); - let matches = App::new("sac_pendulum") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .get_matches(); - - let do_train = (matches.is_present("train") && !matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - let do_eval = (!matches.is_present("train") && matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - - if do_train { + let args = Args::parse(); + + if args.train { train( MAX_OPTS, - "./border/examples/model/sac_pendulum", + "./border/examples/gym/model/tch/sac_pendulum", EVAL_INTERVAL, + args.mlflow, )?; - } - if do_eval { - eval(5, true, "./border/examples/model/sac_pendulum/best")?; + } else if args.eval { + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; + } else { + train( + MAX_OPTS, + "./border/examples/gym/model/tch/sac_pendulum", + EVAL_INTERVAL, + args.mlflow, + )?; + eval(5, true, "./border/examples/gym/model/tch/sac_pendulum/best")?; } Ok(()) @@ -318,9 +351,9 @@ mod test { fn test_sac_pendulum() -> Result<()> { tch::manual_seed(42); - let model_dir = TempDir::new("sac_pendulum")?; + let model_dir = TempDir::new("sac_pendulum_tch")?; let model_dir = model_dir.path().to_str().unwrap(); - train(100, model_dir, 100)?; + train(100, model_dir, 100, false)?; eval(1, false, (model_dir.to_string() + "/best").as_str())?; Ok(()) diff --git a/border/examples/iqn_cartpole.rs b/border/examples/iqn_cartpole.rs deleted file mode 100644 index 1f20fda0..00000000 --- a/border/examples/iqn_cartpole.rs +++ /dev/null @@ -1,334 +0,0 @@ -use anyhow::Result; -use border_core::{ - record::Record, - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, SubBatch, - }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, -}; -use border_py_gym_env::{ - util::vec_to_tensor, ArrayObsFilter, DiscreteActFilter, GymActFilter, GymEnv, GymEnvConfig, - GymObsFilter, -}; -use border_tch_agent::{ - iqn::{EpsilonGreedy, Iqn as Iqn_, IqnConfig, IqnModelConfig}, - mlp::{Mlp, MlpConfig}, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; -// use csv::WriterBuilder; -use ndarray::{ArrayD, IxDyn}; -use serde::Serialize; -use std::convert::TryFrom; -use tch::Tensor; - -const DIM_OBS: i64 = 4; -const DIM_ACT: i64 = 2; -const LR_CRITIC: f64 = 0.001; -const DIM_FEATURE: i64 = 256; -const DIM_EMBED: i64 = 64; -const DISCOUNT_FACTOR: f64 = 0.99; -const BATCH_SIZE: usize = 64; -const N_TRANSITIONS_WARMUP: usize = 100; -const N_UPDATES_PER_OPT: usize = 1; -const TAU: f64 = 0.1; //0.005; -const SOFT_UPDATE_INTERVAL: usize = 100; -const OPT_INTERVAL: usize = 50; -const MAX_OPTS: usize = 10000; -const EVAL_INTERVAL: usize = 500; -const REPLAY_BUFFER_CAPACITY: usize = 10000; -const N_EPISODES_PER_EVAL: usize = 5; -const EPS_START: f64 = 1.0; -const EPS_FINAL: f64 = 0.1; -const FINAL_STEP: usize = MAX_OPTS; -const MODEL_DIR: &str = "border/examples/model/iqn_cartpole"; - -type PyObsDtype = f32; - -mod obs { - use super::*; - - #[derive(Clone, Debug)] - pub struct Obs(ArrayD); - - impl border_core::Obs for Obs { - fn dummy(_n: usize) -> Self { - Self(ArrayD::zeros(IxDyn(&[0]))) - } - - fn len(&self) -> usize { - self.0.shape()[0] - } - } - - impl From> for Obs { - fn from(obs: ArrayD) -> Self { - Obs(obs) - } - } - - impl From for Tensor { - fn from(obs: Obs) -> Tensor { - Tensor::try_from(&obs.0).unwrap() - } - } - - pub struct ObsBatch(TensorSubBatch); - - impl SubBatch for ObsBatch { - fn new(capacity: usize) -> Self { - Self(TensorSubBatch::new(capacity)) - } - - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) - } - - fn sample(&self, ixs: &Vec) -> Self { - let buf = self.0.sample(ixs); - Self(buf) - } - } - - impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } - - impl From for Tensor { - fn from(b: ObsBatch) -> Self { - b.0.into() - } - } -} - -mod act { - use super::*; - - #[derive(Clone, Debug)] - pub struct Act(Vec); - - impl border_core::Act for Act {} - - impl From for Vec { - fn from(value: Act) -> Self { - value.0 - } - } - - impl From for Act { - // `t` must be a 1-dimentional tensor of `f32` - fn from(t: Tensor) -> Self { - let data: Vec = t.into(); - let data = data.iter().map(|&e| e as i32).collect(); - Act(data) - } - } - - pub struct ActBatch(TensorSubBatch); - - impl SubBatch for ActBatch { - fn new(capacity: usize) -> Self { - Self(TensorSubBatch::new(capacity)) - } - - fn push(&mut self, i: usize, data: &Self) { - self.0.push(i, &data.0) - } - - fn sample(&self, ixs: &Vec) -> Self { - let buf = self.0.sample(ixs); - Self(buf) - } - } - - impl From for ActBatch { - fn from(act: Act) -> Self { - let t = vec_to_tensor::<_, i64>(act.0, true); - Self(TensorSubBatch::from_tensor(t)) - } - } - - // Required by Dqn - impl From for Tensor { - fn from(act: ActBatch) -> Self { - act.0.into() - } - } -} - -use act::{Act, ActBatch}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayObsFilter; -type ActFilter = DiscreteActFilter; -type EnvConfig = GymEnvConfig; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Iqn = Iqn_; -type Evaluator = DefaultEvaluator; - -#[derive(Debug, Serialize)] -struct CartpoleRecord { - episode: usize, - step: usize, - reward: f32, - obs: Vec, -} - -impl TryFrom<&Record> for CartpoleRecord { - type Error = anyhow::Error; - - fn try_from(record: &Record) -> Result { - Ok(Self { - episode: record.get_scalar("episode")? as _, - step: record.get_scalar("step")? as _, - reward: record.get_scalar("reward")?, - obs: record - .get_array1("obs")? - .iter() - .map(|v| *v as f64) - .collect(), - }) - } -} - -fn env_config() -> EnvConfig { - GymEnvConfig::::default() - .name("CartPole-v0".to_string()) - .obs_filter_config(ObsFilter::default_config()) - .act_filter_config(ActFilter::default_config()) -} - -fn create_evaluator(env_config: &EnvConfig) -> Result { - Evaluator::new(env_config, 0, N_EPISODES_PER_EVAL) -} - -fn create_agent(in_dim: i64, out_dim: i64) -> Iqn { - let device = tch::Device::cuda_if_available(); - let config = { - let opt_config = border_tch_agent::opt::OptimizerConfig::Adam { lr: LR_CRITIC }; - let f_config = MlpConfig::new(in_dim, vec![], DIM_FEATURE, true); - let m_config = MlpConfig::new(DIM_FEATURE, vec![], out_dim, false); - let model_config = IqnModelConfig::default() - .feature_dim(DIM_FEATURE) - .embed_dim(DIM_EMBED) - .opt_config(opt_config) - .f_config(f_config) - .m_config(m_config); - - IqnConfig::default() - .n_updates_per_opt(N_UPDATES_PER_OPT) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .batch_size(BATCH_SIZE) - .discount_factor(DISCOUNT_FACTOR) - .tau(TAU) - .model_config(model_config) - .explorer(EpsilonGreedy::with_params(EPS_START, EPS_FINAL, FINAL_STEP)) - .soft_update_interval(SOFT_UPDATE_INTERVAL) - .device(device) - }; - - Iqn::build(config) -} - -fn train(max_opts: usize, model_dir: &str, eval_interval: usize) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() - .max_opts(max_opts) - .opt_interval(OPT_INTERVAL) - .eval_interval(eval_interval) - .record_interval(eval_interval) - .save_interval(eval_interval) - .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = create_evaluator(&env_config())?; - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; - - Ok(()) -} - -fn eval(model_dir: &str, render: bool) -> Result<()> { - let env_config = { - let mut env_config = env_config(); - if render { - env_config = env_config - .render_mode(Some("human".to_string())) - .set_wait_in_millis(10); - } - env_config - }; - let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; - agent.eval(); - agent - }; - // let mut recorder = BufferedRecorder::new(); - - let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); - - Ok(()) -} - -fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_cartpole") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("skip training") - .long("skip-training") - .takes_value(false) - .help("Skip training"), - ) - .get_matches(); - - if !matches.is_present("skip training") { - train(MAX_OPTS, MODEL_DIR, EVAL_INTERVAL)?; - } - - eval(&(MODEL_DIR.to_owned() + "/best"), true)?; - - Ok(()) -} - -#[cfg(test)] -mod test { - use super::*; - use tempdir::TempDir; - - #[test] - fn test_iqn_cartpole() -> Result<()> { - tch::manual_seed(42); - - let model_dir = TempDir::new("sac_pendulum")?; - let model_dir = model_dir.path().to_str().unwrap(); - train(100, model_dir, 100)?; - eval((model_dir.to_string() + "/best").as_str(), false)?; - - Ok(()) - } -} diff --git a/border/examples/model/dqn_HeroNoFrameskip-v4/agent.yaml b/border/examples/model/dqn_HeroNoFrameskip-v4/agent.yaml deleted file mode 100644 index 9e9cf22d..00000000 --- a/border/examples/model/dqn_HeroNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,26 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/dqn_HeroNoFrameskip-v4/model.yaml b/border/examples/model/dqn_HeroNoFrameskip-v4/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/dqn_HeroNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/dqn_HeroNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/dqn_HeroNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 079309ba..00000000 --- a/border/examples/model/dqn_HeroNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 1048576 -seed: 42 -per_config: ~ diff --git a/border/examples/model/dqn_HeroNoFrameskip-v4/trainer.yaml b/border/examples/model/dqn_HeroNoFrameskip-v4/trainer.yaml deleted file mode 100644 index c2c148c7..00000000 --- a/border/examples/model/dqn_HeroNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 50000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_HeroNoFrameskip-v4" -opt_interval: 1 -eval_interval: 500000 -record_interval: 50000 -save_interval: 10000000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4/agent.yaml b/border/examples/model/dqn_PongNoFrameskip-v4/agent.yaml deleted file mode 100644 index 9e9cf22d..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,26 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4/model.yaml b/border/examples/model/dqn_PongNoFrameskip-v4/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 1e0ce1e7..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4/trainer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4/trainer.yaml deleted file mode 100644 index aa9b153c..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 3000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4" -opt_interval: 1 -eval_interval: 50000 -record_interval: 50000 -save_interval: 500000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_debug/agent.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_debug/agent.yaml deleted file mode 100644 index 9e9cf22d..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_debug/agent.yaml +++ /dev/null @@ -1,26 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -clip_td_err: ~ -device: ~ -phantom: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml deleted file mode 100644 index 1e0ce1e7..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_debug/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_debug/trainer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_debug/trainer.yaml deleted file mode 100644 index 9fc0f649..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_debug/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 1000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4_debug" -opt_interval: 100 -eval_interval: 50000 -record_interval: 100 -save_interval: 500000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_per/agent.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_per/agent.yaml deleted file mode 100644 index 40b79f77..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_per/agent.yaml +++ /dev/null @@ -1,24 +0,0 @@ ---- -model_config: - q_config: - n_stack: 4 - out_dim: 0 - opt_config: - Adam: - lr: 0.0001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -clip_reward: 1.0 -double_dqn: false -phantom: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml deleted file mode 100644 index 942608b4..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_per/replay_buffer.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -capacity: 65536 -seed: 42 -per_config: - alpha: 0.6000000238418579 - beta_0: 0.4000000059604645 - beta_final: 1.0 - n_opts_final: 500000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_per/trainer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_per/trainer.yaml deleted file mode 100644 index fe02da59..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_per/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 3000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/dqn_PongNoFrameskip-v4_per" -opt_interval: 1 -eval_interval: 50000 -record_interval: 50000 -save_interval: 500000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_vec/agent.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_vec/agent.yaml deleted file mode 100644 index db95a127..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_vec/agent.yaml +++ /dev/null @@ -1,19 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -replay_burffer_capacity: 50000 diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_vec/model.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_vec/model.yaml deleted file mode 100644 index b5f3800b..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_vec/model.yaml +++ /dev/null @@ -1,8 +0,0 @@ ---- -q_config: - n_stack: 4 - out_dim: 0 -opt_config: - Adam: - lr: 0.0001 -phantom: ~ diff --git a/border/examples/model/dqn_PongNoFrameskip-v4_vec/trainer.yaml b/border/examples/model/dqn_PongNoFrameskip-v4_vec/trainer.yaml deleted file mode 100644 index a8e92db7..00000000 --- a/border/examples/model/dqn_PongNoFrameskip-v4_vec/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 3000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/dqn_PongNoFrameskip-v4_vec" diff --git a/border/examples/model/dqn_SeaquestNoFrameskip-v4/agent.yaml b/border/examples/model/dqn_SeaquestNoFrameskip-v4/agent.yaml deleted file mode 100644 index 55f99184..00000000 --- a/border/examples/model/dqn_SeaquestNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,19 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -train: false -discount_factor: 0.99 -tau: 1.0 -replay_burffer_capacity: 1000000 -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 diff --git a/border/examples/model/dqn_SeaquestNoFrameskip-v4/trainer.yaml b/border/examples/model/dqn_SeaquestNoFrameskip-v4/trainer.yaml deleted file mode 100644 index 9ff6f869..00000000 --- a/border/examples/model/dqn_SeaquestNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 50000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/dqn_SeaquestNoFrameskip-v4" diff --git a/border/examples/model/iqn_PongNoFrameskip-v4/agent.yaml b/border/examples/model/iqn_PongNoFrameskip-v4/agent.yaml deleted file mode 100644 index 01f6f654..00000000 --- a/border/examples/model/iqn_PongNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,22 +0,0 @@ ---- -opt_interval_counter: - opt_interval: - Steps: 1 - count: 0 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -sample_percents_pred: Uniform8 -sample_percents_tgt: Uniform8 -sample_percents_act: Uniform32 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -replay_buffer_capacity: 50000 diff --git a/border/examples/model/iqn_PongNoFrameskip-v4/model.yaml b/border/examples/model/iqn_PongNoFrameskip-v4/model.yaml deleted file mode 100644 index c7258234..00000000 --- a/border/examples/model/iqn_PongNoFrameskip-v4/model.yaml +++ /dev/null @@ -1,15 +0,0 @@ ---- -feature_dim: 3136 -embed_dim: 64 -learning_rate: 0.0001 -f_config: - n_stack: 4 - feature_dim: 3136 -m_config: - in_dim: 3136 - hidden_dim: 512 - out_dim: 0 -phantom: ~ -opt_config: - Adam: - lr: 0.001 diff --git a/border/examples/model/iqn_PongNoFrameskip-v4/trainer.yaml b/border/examples/model/iqn_PongNoFrameskip-v4/trainer.yaml deleted file mode 100644 index c23fa53c..00000000 --- a/border/examples/model/iqn_PongNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,6 +0,0 @@ ---- -max_opts: 5000000 -eval_interval: 10000 -n_episodes_per_eval: 1 -eval_threshold: ~ -model_dir: "./examples/model/iqn_PongNoFrameskip-v4" diff --git a/border/examples/model/iqn_SeaquestNoFrameskip-v4/agent.yaml b/border/examples/model/iqn_SeaquestNoFrameskip-v4/agent.yaml deleted file mode 100644 index b400c5dd..00000000 --- a/border/examples/model/iqn_SeaquestNoFrameskip-v4/agent.yaml +++ /dev/null @@ -1,34 +0,0 @@ ---- -model_config: - feature_dim: 3136 - embed_dim: 64 - f_config: - n_stack: 4 - out_dim: 3136 - skip_linear: true - m_config: - in_dim: 3136 - units: - - 512 - out_dim: 0 - opt_config: - Adam: - lr: 0.00001 -soft_update_interval: 10000 -n_updates_per_opt: 1 -min_transitions_warmup: 2500 -batch_size: 32 -discount_factor: 0.99 -tau: 1.0 -train: false -explorer: - EpsilonGreedy: - n_opts: 0 - eps_start: 1.0 - eps_final: 0.02 - final_step: 1000000 -sample_percents_pred: Uniform64 -sample_percents_tgt: Uniform64 -sample_percents_act: Uniform32 -device: ~ -phantom: ~ diff --git a/border/examples/model/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml b/border/examples/model/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml deleted file mode 100644 index 079309ba..00000000 --- a/border/examples/model/iqn_SeaquestNoFrameskip-v4/replay_buffer.yaml +++ /dev/null @@ -1,4 +0,0 @@ ---- -capacity: 1048576 -seed: 42 -per_config: ~ diff --git a/border/examples/model/iqn_SeaquestNoFrameskip-v4/trainer.yaml b/border/examples/model/iqn_SeaquestNoFrameskip-v4/trainer.yaml deleted file mode 100644 index ebc92b01..00000000 --- a/border/examples/model/iqn_SeaquestNoFrameskip-v4/trainer.yaml +++ /dev/null @@ -1,9 +0,0 @@ ---- -max_opts: 50000000 -eval_episodes: 1 -eval_threshold: ~ -model_dir: "./border/examples/model/iqn_SeaquestNoFrameskip-v4" -opt_interval: 1 -eval_interval: 500000 -record_interval: 500000 -save_interval: 10000000 diff --git a/border/examples/model/.gitignore b/border/examples/mujoco/.gitignore similarity index 87% rename from border/examples/model/.gitignore rename to border/examples/mujoco/.gitignore index f447ca1c..c3460a02 100644 --- a/border/examples/model/.gitignore +++ b/border/examples/mujoco/.gitignore @@ -6,4 +6,4 @@ events* *.csv *.zip *.gz -backup +**/best diff --git a/border/examples/mujoco/README.md b/border/examples/mujoco/README.md new file mode 100644 index 00000000..f9af8898 --- /dev/null +++ b/border/examples/mujoco/README.md @@ -0,0 +1,17 @@ +# Mujoco environment + +This directory contains examples using Mujoco environments. + +## tch agent + +```bash +cargo run --release --example sac_mujoco_tch --features=tch -- --env ant --mlflow +``` + +`env` option can be `ant`, `cheetah`, `walker`, or `hopper`. + +## candle agent + +```bash +cargo run --release --example sac_mujoco --features=candle-core,cuda,cudnn -- --env ant --mlflow +``` diff --git a/border/examples/mujoco/model/candle/.gitkeep b/border/examples/mujoco/model/candle/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/border/examples/mujoco/model/tch/.gitkeep b/border/examples/mujoco/model/tch/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/border/examples/mujoco/sac_mujoco.rs b/border/examples/mujoco/sac_mujoco.rs new file mode 100644 index 00000000..50dff870 --- /dev/null +++ b/border/examples/mujoco/sac_mujoco.rs @@ -0,0 +1,355 @@ +use anyhow::Result; +// use border::util::get_model_from_url; +use border_candle_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, + util::CriticLoss, + TensorBatch, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::BatchBase; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::{arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tensorboard::TensorboardRecorder; +use candle_core::Tensor; +use clap::Parser; +// use log::info; +use ndarray::{ArrayD, IxDyn}; + +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 256; +const WARMUP_PERIOD: usize = 10_000; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 3_000_000; +const EVAL_INTERVAL: usize = 5_000; +const REPLAY_BUFFER_CAPACITY: usize = 300_000; +const N_EPISODES_PER_EVAL: usize = 5; +const N_CRITICS: usize = 2; +const TAU: f64 = 0.02; +const LR_ENT_COEF: f64 = 3e-4; +const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR_BASE: &str = "./border/examples/mujoco/model/candle"; + +fn cuda_if_available() -> candle_core::Device { + candle_core::Device::cuda_if_available(0).unwrap() +} + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + arrayd_to_tensor::<_, f32>(obs.0, false).unwrap() + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true).unwrap()) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true).unwrap() + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = ContinuousActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use serde::Serialize; + + use super::*; + + #[derive(Serialize)] + pub struct SacAntConfig { + pub trainer: TrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: SacConfig, + } + + pub fn env_config(env_name: &str) -> EnvConfig { + GymEnvConfig::::default() + .name(env_name.to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_trainer_config(model_dir: &str) -> TrainerConfig { + TrainerConfig::default() + .max_opts(MAX_OPTS) + .opt_interval(OPT_INTERVAL) + .eval_interval(EVAL_INTERVAL) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir) + } + + pub fn create_sac_config(dim_obs: i64, dim_act: i64, target_ent: f64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(dim_act) + .pi_config(MlpConfig::new(dim_obs, vec![400, 300], dim_act, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(dim_obs + dim_act, vec![400, 300], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .tau(TAU) + .critic_loss(CRITIC_LOSS) + .n_critics(N_CRITICS) + .ent_coef_mode(EntCoefMode::Auto(target_ent, LR_ENT_COEF)) + .device(device) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + config: &config::SacAntConfig, + ) -> Result> { + let env_name = &args.env; + let (_, _, _, _, model_dir) = env_params(&args); + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", env_name)?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "candle")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } + + /// Returns (dim_obs, dim_act, target_ent, env_name, model_dir) + pub fn env_params<'a>(args: &Args) -> (i64, i64, f64, &'a str, String) { + let env_name = &args.env; + let model_dir = format!("{}/sac_{}", MODEL_DIR_BASE, env_name); + match args.env.as_str() { + "ant" => (27, 8, -8., "Ant-v4", model_dir), + "cheetah" => (17, 6, -6., "HalfCheetah-v4", model_dir), + "walker" => (17, 6, -6., "Walker2d-v4", model_dir), + "hopper" => (11, 3, -3., "Hopper-v4", model_dir), + env => panic!("Unsupported env {:?}", env), + } + } +} + +/// Train/eval SAC agent in Mujoco environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Environment name (ant, cheetah, walker, hopper) + #[arg(long)] + env: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + // #[arg(long, default_value_t = String::new())] + // eval: String, + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, +} + +fn train(args: &Args) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(&args); + let env_config = config::env_config(env_name); + let step_proc_config = SimpleStepProcessorConfig {}; + let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); + let trainer_config = config::create_trainer_config(&model_dir); + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let mut trainer = Trainer::build(trainer_config.clone()); + + let config = config::SacAntConfig { + trainer: trainer_config, + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Sac::build(agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(&args, &config)?; + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&args); + let env_config = { + let mut env_config = config::env_config(&env_name); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(wait); + }; + env_config + }; + let mut agent = { + let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); + let mut agent = Sac::build(agent_config); + match agent.load_params(model_dir) { + Ok(_) => {} + Err(_) => println!("Failed to load model parameters from {:?}", model_dir), + } + agent.eval(); + agent + }; + // let mut recorder = BufferedRecorder::new(); + + let _ = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?.evaluate(&mut agent); + + Ok(()) +} + +fn eval1(args: &Args) -> Result<()> { + let model_dir = { + let env_name = &args.env; + format!("{}/sac_{}/best", MODEL_DIR_BASE, env_name) + }; + let render = true; + let wait = args.wait; + eval(&args, &model_dir, render, wait) +} + +// fn eval2(matches: ArgMatches) -> Result<()> { +// let model_dir = { +// let file_base = "sac_ant_20210324_ec2_smoothl1"; +// let url = +// "https://drive.google.com/uc?export=download&id=1XvFi2nJD5OhpTvs-Et3YREuoqy8c3Vkq"; +// let model_dir = get_model_from_url(url, file_base)?; +// info!("Download the model in {:?}", model_dir.as_ref().to_str()); +// model_dir.as_ref().to_str().unwrap().to_string() +// }; +// let render = true; +// let wait = matches.value_of("wait").unwrap().parse().unwrap(); +// eval(&matches, &model_dir, render, wait) +// } + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + fastrand::seed(42); + + let args = Args::parse(); + + if args.train { + train(&args)?; + } else if args.eval { + eval1(&args)?; + } else { + train(&args)?; + eval1(&args)?; + } + // } else if matches.is_present("play-gdrive") { + // eval2(matches)?; + + Ok(()) +} diff --git a/border/examples/mujoco/sac_mujoco_async_tch.rs b/border/examples/mujoco/sac_mujoco_async_tch.rs new file mode 100644 index 00000000..07000e1e --- /dev/null +++ b/border/examples/mujoco/sac_mujoco_async_tch.rs @@ -0,0 +1,329 @@ +use anyhow::Result; +use border_async_trainer::{ + util::train_async, /*ActorManager as ActorManager_,*/ ActorManagerConfig, + /*AsyncTrainer as AsyncTrainer_,*/ AsyncTrainerConfig, +}; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + DefaultEvaluator, +}; +use border_derive::BatchBase; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::{arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tch_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, + util::CriticLoss, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +use ndarray::{ArrayD, IxDyn}; +use std::{convert::TryFrom, default::Default}; +use tch::Tensor; + +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 256; +const WARMUP_PERIOD: usize = 10_000; +const MAX_OPTS: usize = 3_000_000; +const EVAL_INTERVAL: usize = 5_000; +const SAVE_INTERVAL: usize = 500_000; +const REPLAY_BUFFER_CAPACITY: usize = 300_000; +const N_CRITICS: usize = 2; +const TAU: f64 = 0.02; +const LR_ENT_COEF: f64 = 3e-4; +const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR_BASE: &str = "./border/examples/mujoco/model/tch"; + +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + Tensor::try_from(&obs.0).unwrap() + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true)) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true) + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = ContinuousActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use serde::Serialize; + + use super::*; + + pub fn env_config(env_name: &str) -> EnvConfig { + GymEnvConfig::::default() + .name(env_name.to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_async_trainer_config(model_dir: &str) -> Result { + Ok(AsyncTrainerConfig::default() + .model_dir(model_dir)? + .max_opts(MAX_OPTS)? + .eval_interval(EVAL_INTERVAL)? + .flush_record_interval(EVAL_INTERVAL)? + .record_compute_cost_interval(EVAL_INTERVAL)? + .sync_interval(1)? + .save_interval(SAVE_INTERVAL)? + .warmup_period(WARMUP_PERIOD)?) + } + + pub fn create_sac_config(dim_obs: i64, dim_act: i64, target_ent: f64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(dim_act) + .pi_config(MlpConfig::new(dim_obs, vec![400, 300], dim_act, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(dim_obs + dim_act, vec![400, 300], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .tau(TAU) + .critic_loss(CRITIC_LOSS) + .n_critics(N_CRITICS) + .ent_coef_mode(EntCoefMode::Auto(target_ent, LR_ENT_COEF)) + .device(device) + } + + pub fn show_config( + env_config: &EnvConfig, + agent_config: &SacConfig, + actor_man_config: &ActorManagerConfig, + trainer_config: &AsyncTrainerConfig, + ) { + println!("Device: {:?}", tch::Device::cuda_if_available()); + println!("{}", serde_yaml::to_string(&env_config).unwrap()); + println!("{}", serde_yaml::to_string(&agent_config).unwrap()); + println!("{}", serde_yaml::to_string(&actor_man_config).unwrap()); + println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); + } + + #[derive(Serialize)] + pub struct SacMujocoAsyncConfig { + pub trainer: AsyncTrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: SacConfig, + } +} + +mod utils { + use super::*; + + pub fn model_dir(args: &Args) -> String { + let name = &args.env; + format!("./border/examples/mujoco/model/tch/sac_{}_async", name) + } + + pub fn create_recorder( + args: &Args, + config: &config::SacMujocoAsyncConfig, + ) -> Result> { + let env_name = &args.env; + let (_, _, _, _, model_dir) = env_params(&args); + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", env_name)?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + recorder_run.set_tag("n_actors", args.n_actors.to_string())?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } + + /// Returns (dim_obs, dim_act, target_ent, env_name, model_dir) + pub fn env_params<'a>(args: &Args) -> (i64, i64, f64, &'a str, String) { + let env_name = &args.env; + let model_dir = format!("{}/sac_{}", MODEL_DIR_BASE, env_name); + match args.env.as_str() { + "ant" => (27, 8, -8., "Ant-v4", model_dir), + "cheetah" => (17, 6, -6., "HalfCheetah-v4", model_dir), + "walker" => (17, 6, -6., "Walker2d-v4", model_dir), + "hopper" => (11, 3, -3., "Hopper-v4", model_dir), + env => panic!("Unsupported env {:?}", env), + } + } +} + +/// Train SAC agent in Mujoco environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Name of the environment + env: String, + + /// Create config files + #[arg(long, default_value_t = false)] + create_config: bool, + + /// Show config + #[arg(long, default_value_t = false)] + show_config: bool, + + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, + + /// Number of actors, default to 6 + #[arg(long, default_value_t = 6)] + n_actors: usize, +} + +fn train(args: &Args) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, _model_dir) = utils::env_params(args); + let env_config_train = config::env_config(env_name); + let model_dir = utils::model_dir(&args); + let n_actors = args.n_actors; + + // Configurations + let agent_config = + config::create_sac_config(dim_obs, dim_act, target_ent).device(cuda_if_available()); + let agent_configs = (0..n_actors) + .map(|_| agent_config.clone().device(tch::Device::Cpu)) + .collect::>(); + let env_config_eval = config::env_config(env_name); + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let step_proc_config = SimpleStepProcessorConfig::default(); + let actor_man_config = ActorManagerConfig::default(); + let async_trainer_config = config::create_async_trainer_config(model_dir.as_str())?; + + if args.show_config { + config::show_config( + &env_config_train, + &agent_config, + &actor_man_config, + &async_trainer_config, + ); + } else { + let mut evaluator = Evaluator::new(&env_config_eval, 0, 1)?; + let config = config::SacMujocoAsyncConfig { + trainer: async_trainer_config.clone(), + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let mut recorder = utils::create_recorder(&args, &config)?; + + train_async::<_, Env, ReplayBuffer, StepProc>( + &agent_config, + &agent_configs, + &env_config_train, + &env_config_eval, + &step_proc_config, + &replay_buffer_config, + &actor_man_config, + &async_trainer_config, + &mut recorder, + &mut evaluator, + ); + } + + Ok(()) +} + +fn main() -> Result<()> { + tch::set_num_threads(1); + let args = Args::parse(); + train(&args)?; + Ok(()) +} diff --git a/border/examples/mujoco/sac_mujoco_tch.rs b/border/examples/mujoco/sac_mujoco_tch.rs new file mode 100644 index 00000000..14afb7ac --- /dev/null +++ b/border/examples/mujoco/sac_mujoco_tch.rs @@ -0,0 +1,357 @@ +use anyhow::Result; +// use border::util::get_model_from_url; +use border_core::{ + generic_replay_buffer::{ + SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, + SimpleStepProcessorConfig, + }, + record::AggregateRecorder, + Agent, Configurable, DefaultEvaluator, Env as _, Evaluator as _, ReplayBufferBase, + StepProcessor, Trainer, TrainerConfig, +}; +use border_derive::BatchBase; +use border_mlflow_tracking::MlflowTrackingClient; +use border_py_gym_env::{ + util::{arrayd_to_tensor, tensor_to_arrayd}, + ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, +}; +use border_tch_agent::{ + mlp::{Mlp, Mlp2, MlpConfig}, + opt::OptimizerConfig, + sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, + util::CriticLoss, + TensorBatch, +}; +use border_tensorboard::TensorboardRecorder; +use clap::Parser; +// use log::info; +use ndarray::{ArrayD, IxDyn}; +use std::convert::TryFrom; +use tch::Tensor; + +const LR_ACTOR: f64 = 3e-4; +const LR_CRITIC: f64 = 3e-4; +const BATCH_SIZE: usize = 256; +const WARMUP_PERIOD: usize = 10_000; +const OPT_INTERVAL: usize = 1; +const MAX_OPTS: usize = 3_000_000; +const EVAL_INTERVAL: usize = 5_000; +const REPLAY_BUFFER_CAPACITY: usize = 300_000; +const N_EPISODES_PER_EVAL: usize = 5; +const N_CRITICS: usize = 2; +const TAU: f64 = 0.02; +const LR_ENT_COEF: f64 = 3e-4; +const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; +const MODEL_DIR_BASE: &str = "./border/examples/mujoco/model/tch"; + +fn cuda_if_available() -> tch::Device { + tch::Device::cuda_if_available() +} + +mod obs_act_types { + use super::*; + + #[derive(Clone, Debug)] + pub struct Obs(ArrayD); + + #[derive(Clone, BatchBase)] + pub struct ObsBatch(TensorBatch); + + impl border_core::Obs for Obs { + fn dummy(_n: usize) -> Self { + Self(ArrayD::zeros(IxDyn(&[0]))) + } + + fn len(&self) -> usize { + self.0.shape()[0] + } + } + + impl From> for Obs { + fn from(obs: ArrayD) -> Self { + Obs(obs) + } + } + + impl From for Tensor { + fn from(obs: Obs) -> Tensor { + Tensor::try_from(&obs.0).unwrap() + } + } + + impl From for ObsBatch { + fn from(obs: Obs) -> Self { + let tensor = obs.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + #[derive(Clone, Debug)] + pub struct Act(ArrayD); + + impl border_core::Act for Act {} + + impl From for ArrayD { + fn from(value: Act) -> Self { + value.0 + } + } + + impl From for Act { + fn from(t: Tensor) -> Self { + Self(tensor_to_arrayd(t, true)) + } + } + + // Required by Sac + impl From for Tensor { + fn from(value: Act) -> Self { + arrayd_to_tensor::<_, f32>(value.0, true) + } + } + + #[derive(BatchBase)] + pub struct ActBatch(TensorBatch); + + impl From for ActBatch { + fn from(act: Act) -> Self { + let tensor = act.into(); + Self(TensorBatch::from_tensor(tensor)) + } + } + + type PyObsDtype = f32; + pub type ObsFilter = ArrayObsFilter; + pub type ActFilter = ContinuousActFilter; + pub type EnvConfig = GymEnvConfig; + pub type Env = GymEnv; + pub type StepProc = SimpleStepProcessor; + pub type ReplayBuffer = SimpleReplayBuffer; + pub type Evaluator = DefaultEvaluator>; +} + +use obs_act_types::*; + +mod config { + use serde::Serialize; + + use super::*; + + #[derive(Serialize)] + pub struct SacAntConfig { + pub trainer: TrainerConfig, + pub replay_buffer: SimpleReplayBufferConfig, + pub agent: SacConfig, + } + + pub fn env_config(env_name: &str) -> EnvConfig { + GymEnvConfig::::default() + .name(env_name.to_string()) + .obs_filter_config(ObsFilter::default_config()) + .act_filter_config(ActFilter::default_config()) + } + + pub fn create_trainer_config(model_dir: &str) -> TrainerConfig { + TrainerConfig::default() + .max_opts(MAX_OPTS) + .opt_interval(OPT_INTERVAL) + .eval_interval(EVAL_INTERVAL) + .record_agent_info_interval(EVAL_INTERVAL) + .record_compute_cost_interval(EVAL_INTERVAL) + .flush_record_interval(EVAL_INTERVAL) + .save_interval(EVAL_INTERVAL) + .warmup_period(WARMUP_PERIOD) + .model_dir(model_dir) + } + + pub fn create_sac_config(dim_obs: i64, dim_act: i64, target_ent: f64) -> SacConfig { + let device = cuda_if_available(); + let actor_config = ActorConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) + .out_dim(dim_act) + .pi_config(MlpConfig::new(dim_obs, vec![400, 300], dim_act, false)); + let critic_config = CriticConfig::default() + .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) + .q_config(MlpConfig::new(dim_obs + dim_act, vec![400, 300], 1, false)); + + SacConfig::default() + .batch_size(BATCH_SIZE) + .actor_config(actor_config) + .critic_config(critic_config) + .tau(TAU) + .critic_loss(CRITIC_LOSS) + .n_critics(N_CRITICS) + .ent_coef_mode(EntCoefMode::Auto(target_ent, LR_ENT_COEF)) + .device(device) + } +} + +mod utils { + use super::*; + + pub fn create_recorder( + args: &Args, + config: &config::SacAntConfig, + ) -> Result> { + let env_name = &args.env; + let (_, _, _, _, model_dir) = env_params(&args); + match args.mlflow { + true => { + let client = + MlflowTrackingClient::new("http://localhost:8080").set_experiment_id("Gym")?; + let recorder_run = client.create_recorder("")?; + recorder_run.log_params(&config)?; + recorder_run.set_tag("env", env_name)?; + recorder_run.set_tag("algo", "sac")?; + recorder_run.set_tag("backend", "tch")?; + Ok(Box::new(recorder_run)) + } + false => Ok(Box::new(TensorboardRecorder::new(model_dir))), + } + } + + /// Returns (dim_obs, dim_act, target_ent, env_name, model_dir) + pub fn env_params<'a>(args: &Args) -> (i64, i64, f64, &'a str, String) { + let env_name = &args.env; + let model_dir = format!("{}/sac_{}", MODEL_DIR_BASE, env_name); + match args.env.as_str() { + "ant" => (27, 8, -8., "Ant-v4", model_dir), + "cheetah" => (17, 6, -6., "HalfCheetah-v4", model_dir), + "walker" => (17, 6, -6., "Walker2d-v4", model_dir), + "hopper" => (11, 3, -3., "Hopper-v4", model_dir), + env => panic!("Unsupported env {:?}", env), + } + } +} + +/// Train/eval SAC agent in Mujoco environment +#[derive(Parser, Debug)] +#[command(version, about)] +struct Args { + /// Environment name (ant, cheetah, walker, hopper) + #[arg(long)] + env: String, + + /// Train DQN agent, not evaluate + #[arg(long, default_value_t = false)] + train: bool, + + /// Evaluate DQN agent, not train + #[arg(long, default_value_t = false)] + eval: bool, + // #[arg(long, default_value_t = String::new())] + // eval: String, + /// Log metrics with MLflow + #[arg(long, default_value_t = false)] + mlflow: bool, + + /// Waiting time in milliseconds between frames when evaluation + #[arg(long, default_value_t = 25)] + wait: u64, +} + +fn train(args: &Args) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, model_dir) = utils::env_params(args); + let env_config = config::env_config(env_name); + let step_proc_config = SimpleStepProcessorConfig {}; + let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); + let trainer_config = config::create_trainer_config(&model_dir); + let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); + let mut trainer = Trainer::build(trainer_config.clone()); + + let config = config::SacAntConfig { + trainer: trainer_config, + replay_buffer: replay_buffer_config.clone(), + agent: agent_config.clone(), + }; + let env = Env::build(&env_config, 0)?; + let step_proc = StepProc::build(&step_proc_config); + let mut agent = Sac::build(agent_config); + let mut buffer = ReplayBuffer::build(&replay_buffer_config); + let mut recorder = utils::create_recorder(args, &config)?; + let mut evaluator = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?; + + trainer.train( + env, + step_proc, + &mut agent, + &mut buffer, + &mut recorder, + &mut evaluator, + )?; + + Ok(()) +} + +fn eval(args: &Args, model_dir: &str, render: bool, wait: u64) -> Result<()> { + let (dim_obs, dim_act, target_ent, env_name, _) = utils::env_params(&args); + let env_config = { + let mut env_config = config::env_config(&env_name); + if render { + env_config = env_config + .render_mode(Some("human".to_string())) + .set_wait_in_millis(wait); + }; + env_config + }; + let mut agent = { + let agent_config = config::create_sac_config(dim_obs, dim_act, target_ent); + let mut agent = Sac::build(agent_config); + match agent.load_params(model_dir) { + Ok(_) => {} + Err(_) => println!("Failed to load model parameters from {:?}", model_dir), + } + agent.eval(); + agent + }; + // let mut recorder = BufferedRecorder::new(); + + let _ = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?.evaluate(&mut agent); + + Ok(()) +} + +fn eval1(args: &Args) -> Result<()> { + let model_dir = { + let env_name = &args.env; + format!("{}/sac_{}/best", MODEL_DIR_BASE, env_name) + }; + let render = true; + let wait = args.wait; + eval(&args, &model_dir, render, wait) +} + +// fn eval2(matches: ArgMatches) -> Result<()> { +// let model_dir = { +// let file_base = "sac_ant_20210324_ec2_smoothl1"; +// let url = +// "https://drive.google.com/uc?export=download&id=1XvFi2nJD5OhpTvs-Et3YREuoqy8c3Vkq"; +// let model_dir = get_model_from_url(url, file_base)?; +// info!("Download the model in {:?}", model_dir.as_ref().to_str()); +// model_dir.as_ref().to_str().unwrap().to_string() +// }; +// let render = true; +// let wait = matches.value_of("wait").unwrap().parse().unwrap(); +// eval(&matches, &model_dir, render, wait) +// } + +fn main() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); + tch::manual_seed(42); + fastrand::seed(42); + + let args = Args::parse(); + + if args.train { + train(&args)?; + } else if args.eval { + eval1(&args)?; + } else { + train(&args)?; + eval1(&args)?; + } + // } else if matches.is_present("play-gdrive") { + // eval2(matches)?; + + Ok(()) +} diff --git a/border/examples/sac_ant.rs b/border/examples/sac_ant.rs deleted file mode 100644 index 3cf53382..00000000 --- a/border/examples/sac_ant.rs +++ /dev/null @@ -1,270 +0,0 @@ -use anyhow::Result; -use border::util::get_model_from_url; -use border_core::{ - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, -}; -use border_derive::SubBatch; -use border_py_gym_env::{ - util::{arrayd_to_tensor, tensor_to_arrayd}, - ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, -}; -use border_tch_agent::{ - mlp::{Mlp, Mlp2, MlpConfig}, - opt::OptimizerConfig, - sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, - util::CriticLoss, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; -use log::info; -use ndarray::{ArrayD, IxDyn}; -use std::convert::TryFrom; -use tch::Tensor; - -const DIM_OBS: i64 = 27; -const DIM_ACT: i64 = 8; -const LR_ACTOR: f64 = 3e-4; -const LR_CRITIC: f64 = 3e-4; -const BATCH_SIZE: usize = 256; -const N_TRANSITIONS_WARMUP: usize = 10_000; -const OPT_INTERVAL: usize = 1; -const MAX_OPTS: usize = 3_000_000; -const EVAL_INTERVAL: usize = 10_000; -const REPLAY_BUFFER_CAPACITY: usize = 300_000; -const N_EPISODES_PER_EVAL: usize = 5; -const N_CRITICS: usize = 2; -const TAU: f64 = 0.02; -const TARGET_ENTROPY: f64 = -(DIM_ACT as f64); -const LR_ENT_COEF: f64 = 3e-4; -const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; -const MODEL_DIR: &str = "./border/examples/model/sac_ant"; - -type PyObsDtype = f32; - -mod obs { - use super::*; - - #[derive(Clone, Debug)] - pub struct Obs(ArrayD); - - #[derive(Clone, SubBatch)] - pub struct ObsBatch(TensorSubBatch); - - impl border_core::Obs for Obs { - fn dummy(_n: usize) -> Self { - Self(ArrayD::zeros(IxDyn(&[0]))) - } - - fn len(&self) -> usize { - self.0.shape()[0] - } - } - - impl From> for Obs { - fn from(obs: ArrayD) -> Self { - Obs(obs) - } - } - - impl From for Tensor { - fn from(obs: Obs) -> Tensor { - Tensor::try_from(&obs.0).unwrap() - } - } - - impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -mod act { - use super::*; - - #[derive(Clone, Debug)] - pub struct Act(ArrayD); - - impl border_core::Act for Act {} - - impl From for ArrayD { - fn from(value: Act) -> Self { - value.0 - } - } - - impl From for Act { - fn from(t: Tensor) -> Self { - Self(tensor_to_arrayd(t, true)) - } - } - - // Required by Sac - impl From for Tensor { - fn from(value: Act) -> Self { - arrayd_to_tensor::<_, f32>(value.0, true) - } - } - - #[derive(SubBatch)] - pub struct ActBatch(TensorSubBatch); - - impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -use act::{Act, ActBatch}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayObsFilter; -type ActFilter = ContinuousActFilter; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Evaluator = DefaultEvaluator>; - -fn create_agent(in_dim: i64, out_dim: i64) -> Sac { - let device = tch::Device::cuda_if_available(); - let actor_config = ActorConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) - .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![400, 300], out_dim, true)); - let critic_config = CriticConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![400, 300], 1, true)); - let sac_config = SacConfig::default() - .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .actor_config(actor_config) - .critic_config(critic_config) - .tau(TAU) - .critic_loss(CRITIC_LOSS) - .n_critics(N_CRITICS) - .ent_coef_mode(EntCoefMode::Auto(TARGET_ENTROPY, LR_ENT_COEF)) - .device(device); - Sac::build(sac_config) -} - -fn env_config() -> GymEnvConfig { - GymEnvConfig::::default() - .name("Ant-v4".to_string()) - .obs_filter_config(ObsFilter::default_config()) - .act_filter_config(ActFilter::default_config()) -} - -fn train(max_opts: usize, model_dir: &str) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() - .max_opts(max_opts) - .opt_interval(OPT_INTERVAL) - .eval_interval(EVAL_INTERVAL) - .record_interval(EVAL_INTERVAL) - .save_interval(EVAL_INTERVAL) - .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; - - Ok(()) -} - -fn eval(model_dir: &str, render: bool, wait: u64) -> Result<()> { - let env_config = { - let mut env_config = env_config(); - if render { - env_config = env_config - .render_mode(Some("human".to_string())) - .set_wait_in_millis(wait); - }; - env_config - }; - let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; - agent.eval(); - agent - }; - // let mut recorder = BufferedRecorder::new(); - - let _ = Evaluator::new(&env_config, 0, N_EPISODES_PER_EVAL)?.evaluate(&mut agent); - - Ok(()) -} - -fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - fastrand::seed(42); - - let matches = App::new("sac_ant") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("play") - .long("play") - .takes_value(true) - .help("Play with the trained model of the given path"), - ) - .arg( - Arg::with_name("play-gdrive") - .long("play-gdrive") - .takes_value(false) - .help("Play with the trained model downloaded from google drive"), - ) - .arg( - Arg::with_name("wait") - .long("wait") - .takes_value(true) - .default_value("25") - .help("Waiting time in milliseconds between frames when playing"), - ) - .get_matches(); - - if !(matches.is_present("play") || matches.is_present("play-gdrive")) { - train(MAX_OPTS, MODEL_DIR)?; - } else { - let model_dir = if matches.is_present("play") { - let model_dir = matches - .value_of("play") - .expect("Failed to parse model directory"); - format!("{}{}", model_dir, "/best").to_owned() - } else { - let file_base = "sac_ant_20210324_ec2_smoothl1"; - let url = - "https://drive.google.com/uc?export=download&id=1XvFi2nJD5OhpTvs-Et3YREuoqy8c3Vkq"; - let model_dir = get_model_from_url(url, file_base)?; - info!("Download the model in {:?}", model_dir.as_ref().to_str()); - model_dir.as_ref().to_str().unwrap().to_string() - }; - - let wait = matches.value_of("wait").unwrap().parse().unwrap(); - eval(model_dir.as_str(), true, wait)?; - } - - Ok(()) -} diff --git a/border/examples/sac_ant_async.rs b/border/examples/sac_ant_async.rs deleted file mode 100644 index 60ce75b8..00000000 --- a/border/examples/sac_ant_async.rs +++ /dev/null @@ -1,323 +0,0 @@ -use anyhow::Result; -use border_async_trainer::{ - actor_stats_fmt, ActorManager as ActorManager_, ActorManagerConfig, - AsyncTrainer as AsyncTrainer_, AsyncTrainerConfig, -}; -use border_atari_env::{ - BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs, - BorderAtariObsRawFilter, -}; -use border_core::{ - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - DefaultEvaluator, Env as _, -}; -use border_derive::{Act, Obs, SubBatch}; -use border_py_gym_env::{ - util::{arrayd_to_tensor, tensor_to_arrayd}, - ArrayObsFilter, ContinuousActFilter, GymActFilter, GymContinuousAct, GymEnv, GymEnvConfig, - GymObs, GymObsFilter, -}; -use border_tch_agent::{ - mlp::{Mlp, Mlp2, MlpConfig}, - opt::OptimizerConfig, - sac::{ActorConfig, CriticConfig, EntCoefMode, Sac, SacConfig}, - util::CriticLoss, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg, ArgMatches}; -use crossbeam_channel::unbounded; -use ndarray::{ArrayD, IxDyn}; -use std::{ - convert::TryFrom, - default::Default, - sync::{Arc, Mutex}, -}; -use tch::Tensor; - -type PyObsDtype = f32; - -const DIM_OBS: i64 = 27; -const DIM_ACT: i64 = 8; -const LR_ACTOR: f64 = 3e-4; -const LR_CRITIC: f64 = 3e-4; -const BATCH_SIZE: usize = 256; -const N_TRANSITIONS_WARMUP: usize = 10_000; -const MAX_OPTS: usize = 3_000_000; -const EVAL_INTERVAL: usize = 10_000; -const RECORD_INTERVAL: usize = 10_000; -const SAVE_INTERVAL: usize = 500_000; -const REPLAY_BUFFER_CAPACITY: usize = 300_000; -const SYNC_INTERVAL: usize = 100; -const EVAL_EPISODES: usize = 5; -const N_CRITICS: usize = 2; -const TAU: f64 = 0.02; -const TARGET_ENTROPY: f64 = -(DIM_ACT as f64); -const LR_ENT_COEF: f64 = 3e-4; -const CRITIC_LOSS: CriticLoss = CriticLoss::SmoothL1; -const MODEL_DIR: &str = "./border/examples/model/sac_ant_async"; - -mod obs { - use super::*; - - #[derive(Clone, Debug)] - pub struct Obs(ArrayD); - - #[derive(Clone, SubBatch)] - pub struct ObsBatch(TensorSubBatch); - - impl border_core::Obs for Obs { - fn dummy(_n: usize) -> Self { - Self(ArrayD::zeros(IxDyn(&[0]))) - } - - fn len(&self) -> usize { - self.0.shape()[0] - } - } - - impl From> for Obs { - fn from(obs: ArrayD) -> Self { - Obs(obs) - } - } - - impl From for Tensor { - fn from(obs: Obs) -> Tensor { - Tensor::try_from(&obs.0).unwrap() - } - } - - impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -mod act { - use super::*; - - #[derive(Clone, Debug)] - pub struct Act(ArrayD); - - impl border_core::Act for Act {} - - impl From for ArrayD { - fn from(value: Act) -> Self { - value.0 - } - } - - impl From for Act { - fn from(t: Tensor) -> Self { - Self(tensor_to_arrayd(t, true)) - } - } - - // Required by Sac - impl From for Tensor { - fn from(value: Act) -> Self { - arrayd_to_tensor::<_, f32>(value.0, true) - } - } - - #[derive(SubBatch)] - pub struct ActBatch(TensorSubBatch); - - impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -use act::{Act, ActBatch}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayObsFilter; -type ActFilter = ContinuousActFilter; -type Env = GymEnv; -type EnvConfig = GymEnvConfig; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Agent = Sac; -type AgentConfig = SacConfig; -type ActorManager = ActorManager_; -type AsyncTrainer = AsyncTrainer_; -type Evaluator = DefaultEvaluator; - -mod config { - use super::*; - - pub fn env_config() -> GymEnvConfig { - GymEnvConfig::::default() - .name("Ant-v4".to_string()) - .obs_filter_config(ObsFilter::default_config()) - .act_filter_config(ActFilter::default_config()) - } - - pub fn agent_config(in_dim: i64, out_dim: i64) -> AgentConfig { - let actor_config = ActorConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) - .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![400, 300], out_dim, true)); - let critic_config = CriticConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![400, 300], 1, true)); - - SacConfig::default() - .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .actor_config(actor_config) - .critic_config(critic_config) - .tau(TAU) - .critic_loss(CRITIC_LOSS) - .n_critics(N_CRITICS) - .ent_coef_mode(EntCoefMode::Auto(TARGET_ENTROPY, LR_ENT_COEF)) - } - - pub fn async_trainer_config() -> AsyncTrainerConfig { - AsyncTrainerConfig { - model_dir: Some(MODEL_DIR.to_string()), - record_interval: RECORD_INTERVAL, - eval_interval: EVAL_INTERVAL, - max_train_steps: MAX_OPTS, - save_interval: SAVE_INTERVAL, - sync_interval: SYNC_INTERVAL, - } - } -} - -use config::{agent_config, async_trainer_config, env_config}; - -fn parse_args<'a>() -> ArgMatches<'a> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("dqn_atari_async") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("show-config") - .long("show-config") - .takes_value(false) - .help("Showing configuration loaded from files"), - ) - .arg( - Arg::with_name("n-actors") - .long("n-actors") - .takes_value(true) - .default_value("6") - .help("The number of actors"), - ) - .get_matches(); - - matches -} - -fn show_config( - env_config: &EnvConfig, - agent_config: &AgentConfig, - actor_man_config: &ActorManagerConfig, - trainer_config: &AsyncTrainerConfig, -) { - println!("Device: {:?}", tch::Device::cuda_if_available()); - println!("{}", serde_yaml::to_string(&env_config).unwrap()); - println!("{}", serde_yaml::to_string(&agent_config).unwrap()); - println!("{}", serde_yaml::to_string(&actor_man_config).unwrap()); - println!("{}", serde_yaml::to_string(&trainer_config).unwrap()); -} - -fn train(matches: ArgMatches) -> Result<()> { - // exploration parameters - let n_actors = matches - .value_of("n-actors") - .unwrap() - .parse::() - .unwrap(); - let env_config_train = env_config(); - let env_config_eval = env_config(); - let agent_configs = (0..n_actors) - .map(|ix| { - agent_config(DIM_OBS, DIM_ACT) - .seed(ix as i64) - .device(tch::Device::Cpu) - }) - .collect::>(); - let agent_config = agent_config(DIM_OBS, DIM_ACT).device(tch::Device::cuda_if_available()); - let replay_buffer_config = SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let step_proc_config = SimpleStepProcessorConfig::default(); - let actor_man_config = ActorManagerConfig::default(); - let async_trainer_config = async_trainer_config(); - - if matches.is_present("show-config") { - show_config( - &env_config_train, - &agent_config, - &actor_man_config, - &async_trainer_config, - ); - } else { - let mut recorder = TensorboardRecorder::new(MODEL_DIR); - let mut evaluator = Evaluator::new(&env_config_eval, 0, EVAL_EPISODES)?; - - // Shared flag to stop actor threads - let stop = Arc::new(Mutex::new(false)); - - // Creates channels - let (item_s, item_r) = unbounded(); // items pushed to replay buffer - let (model_s, model_r) = unbounded(); // model_info - - // guard for initialization of envs in multiple threads - let guard_init_env = Arc::new(Mutex::new(true)); - - // Actor manager and async trainer - let mut actors = ActorManager::build( - &actor_man_config, - &agent_configs, - &env_config_train, - &step_proc_config, - item_s, - model_r, - stop.clone(), - ); - let mut trainer = AsyncTrainer::build( - &async_trainer_config, - &agent_config, - &env_config_eval, - &replay_buffer_config, - item_r, - model_s, - stop.clone(), - ); - - // Set the number of threads - tch::set_num_threads(1); - - // Starts sampling and training - actors.run(guard_init_env.clone()); - let stats = trainer.train(&mut recorder, &mut evaluator, guard_init_env); - println!("Stats of async trainer"); - println!("{}", stats.fmt()); - - let stats = actors.stop_and_join(); - println!("Stats of generated samples in actors"); - println!("{}", actor_stats_fmt(&stats)); - } - - Ok(()) -} - -fn main() -> Result<()> { - let matches = parse_args(); - - train(matches)?; - - Ok(()) -} diff --git a/border/examples/sac_lunarlander_cont.rs b/border/examples/sac_lunarlander_cont.rs deleted file mode 100644 index a6374107..00000000 --- a/border/examples/sac_lunarlander_cont.rs +++ /dev/null @@ -1,275 +0,0 @@ -use anyhow::Result; -use border_core::{ - record::Record, - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - Agent, DefaultEvaluator, Evaluator as _, Policy, Trainer, TrainerConfig, -}; -use border_derive::SubBatch; -use border_py_gym_env::{ - util::{arrayd_to_tensor, tensor_to_arrayd}, - ArrayObsFilter, ContinuousActFilter, GymActFilter, GymEnv, GymEnvConfig, GymObsFilter, -}; -use border_tch_agent::{ - mlp::{Mlp, Mlp2, MlpConfig}, - opt::OptimizerConfig, - sac::{ActorConfig, CriticConfig, Sac, SacConfig}, - TensorSubBatch, -}; -use border_tensorboard::TensorboardRecorder; -use clap::{App, Arg}; -//use csv::WriterBuilder; -use ndarray::{ArrayD, IxDyn}; -use serde::Serialize; -use std::convert::TryFrom; -use tch::Tensor; - -const DIM_OBS: i64 = 8; -const DIM_ACT: i64 = 2; -const LR_ACTOR: f64 = 3e-4; -const LR_CRITIC: f64 = 3e-4; -const BATCH_SIZE: usize = 128; -const N_TRANSITIONS_WARMUP: usize = 1000; -const OPT_INTERVAL: usize = 1; -const MAX_OPTS: usize = 200_000; -const EVAL_INTERVAL: usize = 10_000; -const REPLAY_BUFFER_CAPACITY: usize = 100_000; -const N_EPISODES_PER_EVAL: usize = 5; -const MODEL_DIR: &str = "./border/examples/model/sac_lunarlander_cont"; - -type PyObsDtype = f32; - -mod obs { - use super::*; - - #[derive(Clone, Debug)] - pub struct Obs(ArrayD); - - impl border_core::Obs for Obs { - fn dummy(_n: usize) -> Self { - Self(ArrayD::zeros(IxDyn(&[0]))) - } - - fn len(&self) -> usize { - self.0.shape()[0] - } - } - - impl From> for Obs { - fn from(obs: ArrayD) -> Self { - Obs(obs) - } - } - - impl From for Tensor { - fn from(obs: Obs) -> Tensor { - Tensor::try_from(&obs.0).unwrap() - } - } - - #[derive(Clone, SubBatch)] - pub struct ObsBatch(TensorSubBatch); - - impl From for ObsBatch { - fn from(obs: Obs) -> Self { - let tensor = obs.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -mod act { - use super::*; - - #[derive(Clone, Debug)] - pub struct Act(ArrayD); - - impl border_core::Act for Act {} - - impl From for ArrayD { - fn from(value: Act) -> Self { - value.0 - } - } - - impl From for Act { - fn from(t: Tensor) -> Self { - Self(tensor_to_arrayd(t, true)) - } - } - - // Required by Sac - impl From for Tensor { - fn from(value: Act) -> Self { - arrayd_to_tensor::<_, f32>(value.0, true) - } - } - - #[derive(SubBatch)] - pub struct ActBatch(TensorSubBatch); - - impl From for ActBatch { - fn from(act: Act) -> Self { - let tensor = act.into(); - Self(TensorSubBatch::from_tensor(tensor)) - } - } -} - -use act::{Act, ActBatch}; -use obs::{Obs, ObsBatch}; - -type ObsFilter = ArrayObsFilter; -type ActFilter = ContinuousActFilter; -type Env = GymEnv; -type StepProc = SimpleStepProcessor; -type ReplayBuffer = SimpleReplayBuffer; -type Evaluator = DefaultEvaluator>; - -#[derive(Debug, Serialize)] -struct LunarlanderRecord { - episode: usize, - step: usize, - reward: f32, - obs: Vec, - act: Vec, -} - -impl TryFrom<&Record> for LunarlanderRecord { - type Error = anyhow::Error; - - fn try_from(record: &Record) -> Result { - Ok(Self { - episode: record.get_scalar("episode")? as _, - step: record.get_scalar("step")? as _, - reward: record.get_scalar("reward")?, - obs: record.get_array1("obs")?.to_vec(), - act: record.get_array1("act")?.to_vec(), - }) - } -} - -fn create_agent(in_dim: i64, out_dim: i64) -> Sac { - let device = tch::Device::cuda_if_available(); - let actor_config = ActorConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_ACTOR }) - .out_dim(out_dim) - .pi_config(MlpConfig::new(in_dim, vec![64, 64], out_dim, true)); - let critic_config = CriticConfig::default() - .opt_config(OptimizerConfig::Adam { lr: LR_CRITIC }) - .q_config(MlpConfig::new(in_dim + out_dim, vec![64, 64], 1, true)); - let sac_config = SacConfig::default() - .batch_size(BATCH_SIZE) - .min_transitions_warmup(N_TRANSITIONS_WARMUP) - .actor_config(actor_config) - .critic_config(critic_config) - .device(device); - Sac::build(sac_config) -} - -fn env_config() -> GymEnvConfig { - GymEnvConfig::::default() - .name("LunarLanderContinuous-v2".to_string()) - .obs_filter_config(ObsFilter::default_config()) - .act_filter_config(ActFilter::default_config()) -} - -fn train(max_opts: usize, model_dir: &str) -> Result<()> { - let mut trainer = { - let env_config = env_config(); - let step_proc_config = SimpleStepProcessorConfig {}; - let replay_buffer_config = - SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); - let config = TrainerConfig::default() - .max_opts(max_opts) - .opt_interval(OPT_INTERVAL) - .eval_interval(EVAL_INTERVAL) - .record_interval(EVAL_INTERVAL) - .save_interval(EVAL_INTERVAL) - .model_dir(model_dir); - let trainer = Trainer::::build( - config, - env_config, - step_proc_config, - replay_buffer_config, - ); - - trainer - }; - let mut agent = create_agent(DIM_OBS, DIM_ACT); - let mut recorder = TensorboardRecorder::new(model_dir); - let mut evaluator = Evaluator::new(&env_config(), 0, N_EPISODES_PER_EVAL)?; - - trainer.train(&mut agent, &mut recorder, &mut evaluator)?; - - Ok(()) -} - -fn eval(model_dir: &str, render: bool) -> Result<()> { - let env_config = { - let mut env_config = env_config(); - if render { - env_config = env_config - .render_mode(Some("human".to_string())) - .set_wait_in_millis(10); - } - env_config - }; - let mut agent = { - let mut agent = create_agent(DIM_OBS, DIM_ACT); - agent.load(model_dir)?; - agent.eval(); - agent - }; - // let mut recorder = BufferedRecorder::new(); - - let _ = Evaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); - - // // Vec<_> field in a struct does not support writing a header in csv crate, so disable it. - // let mut wtr = WriterBuilder::new() - // .has_headers(false) - // .from_writer(File::create(model_dir.to_string() + "/eval.csv")?); - // for record in recorder.iter() { - // wtr.serialize(LunarlanderRecord::try_from(record)?)?; - // } - - Ok(()) -} - -fn main() -> Result<()> { - env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init(); - tch::manual_seed(42); - - let matches = App::new("sac_lunarlander_cont") - .version("0.1.0") - .author("Taku Yoshioka ") - .arg( - Arg::with_name("train") - .long("train") - .takes_value(false) - .help("Do training only"), - ) - .arg( - Arg::with_name("eval") - .long("eval") - .takes_value(false) - .help("Do evaluation only"), - ) - .get_matches(); - - let do_train = (matches.is_present("train") && !matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - let do_eval = (!matches.is_present("train") && matches.is_present("eval")) - || (!matches.is_present("train") && !matches.is_present("eval")); - - if do_train { - train(MAX_OPTS, MODEL_DIR)?; - } - if do_eval { - eval(&(MODEL_DIR.to_owned() + "/best"), true)?; - } - - Ok(()) -} diff --git a/border/examples/test_async_trainer.rs b/border/examples/test_async_trainer.rs deleted file mode 100644 index 5b7ce326..00000000 --- a/border/examples/test_async_trainer.rs +++ /dev/null @@ -1,104 +0,0 @@ -use super::{ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel}; -use border_atari_env::util::test::*; -use border_core::{ - record::BufferedRecorder, - replay_buffer::{ - SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, - SimpleStepProcessorConfig, - }, - Env as _, -}; -use crossbeam_channel::unbounded; -use log::info; -use std::sync::{Arc, Mutex}; -use test_log::test; - -fn replay_buffer_config() -> SimpleReplayBufferConfig { - SimpleReplayBufferConfig::default() -} - -fn actor_man_config() -> ActorManagerConfig { - ActorManagerConfig::default() -} - -fn async_trainer_config() -> AsyncTrainerConfig { - AsyncTrainerConfig { - model_dir: Some("".to_string()), - record_interval: 5, - eval_interval: 105, - max_train_steps: 15, - save_interval: 5, - sync_interval: 5, - eval_episodes: 1, - } -} - -impl SyncModel for RandomAgent { - type ModelInfo = (); - - fn model_info(&self) -> (usize, Self::ModelInfo) { - info!("Returns the current model info"); - (self.n_opts_steps(), ()) - } - - fn sync_model(&mut self, _model_info: &Self::ModelInfo) { - info!("Sync model"); - } -} - -#[test] -fn test_async_trainer() { - type Agent = RandomAgent; - type StepProc = SimpleStepProcessor; - type ReplayBuffer = SimpleReplayBuffer; - type ActorManager_ = ActorManager; - type AsyncTrainer_ = AsyncTrainer; - - let env_config = env_config("pong".to_string()); - let env = Env::build(&env_config, 0).unwrap(); - let n_acts = env.get_num_actions_atari() as _; - let agent_config = RandomAgentConfig { n_acts }; - let step_proc_config = SimpleStepProcessorConfig::default(); - let replay_buffer_config = replay_buffer_config(); - let actor_man_config = actor_man_config(); - let async_trainer_config = async_trainer_config(); - let agent_configs = vec![agent_config.clone(); 2]; - - let mut recorder = BufferedRecorder::new(); - - // Shared flag to stop actor threads - let stop = Arc::new(Mutex::new(false)); - - // Pushed items into replay buffer - let (item_s, item_r) = unbounded(); - - // Synchronizing model - let (model_s, model_r) = unbounded(); - - // Prevents simlutaneous initialization of env - let guard_init_env = Arc::new(Mutex::new(true)); - - let mut actors = ActorManager_::build( - &actor_man_config, - &agent_configs, - &env_config, - &step_proc_config, - item_s, - model_r, - stop.clone(), - ); - let mut trainer = AsyncTrainer_::build( - &async_trainer_config, - &agent_config, - &env_config, - &replay_buffer_config, - item_r, - model_s, - stop, - ); - - actors.run(guard_init_env.clone()); - trainer.train(&mut recorder, guard_init_env); - - actors.stop_and_join(); -} diff --git a/border/src/lib.rs b/border/src/lib.rs index 264d2c99..5917b439 100644 --- a/border/src/lib.rs +++ b/border/src/lib.rs @@ -1,136 +1,34 @@ -//! Border is a reinforcement learning library. -//! -//! This crate is a collection of examples using the crates below. -//! -//! * [`border-core`](https://crates.io/crates/border-core) provides basic traits and functions -//! generic to environments and reinforcmenet learning (RL) agents. -//! * [`border-py-gym-env`](https://crates.io/crates/border-py-gym-env) is a wrapper of the -//! [Gym](https://gym.openai.com) environments written in Python, with the support of -//! [pybullet-gym](https://github.com/benelot/pybullet-gym) and -//! [atari](https://github.com/mgbellemare/Arcade-Learning-Environment). -//! * [`border-atari-env`](https://crates.io/crates/border-atari-env) is a wrapper of -//! [atari-env](https://crates.io/crates/atari-env), which is a part of -//! [gym-rs](https://crates.io/crates/gym-rs). -//! * [`border-tch-agent`](https://crates.io/crates/border-tch-agent) is a collection of RL agents -//! based on [tch](https://crates.io/crates/tch). Deep Q network (DQN), implicit quantile network -//! (IQN), and soft actor critic (SAC) are includes. -//! * [`border-async-trainer`](https://crates.io/crates/border-async-trainer) defines some traits and -//! functions for asynchronous training of RL agents by multiple actors, each of which runs -//! a sampling process of an agent and an environment in parallel. -//! -//! You can use a part of these crates for your purposes. -//! -//! # Environment -//! -//! [`border-core`](https://crates.io/crates/border-core) abstracts environments as [`Env`]. -//! [`Env`] has associated types [`Env::Obs`] and [`Env::Act`] for observation and action of -//! the envirnoment. [`Env::Config`] should be configrations of the concrete type. -//! -//! # Policy and agent -//! -//! In this crate, [`Policy`] is a controller for an environment implementing [`Env`] trait. -//! [`Agent`] trait abstracts a trainable [`Policy`] and has methods for save/load of -//! parameters, and its training. -//! -//! # Evaluation -//! -//! Structs that implements [`Evaluator`] trait can be used to run episodes with a given [`Env`] -//! and [`Policy`]. -//! The code might look like below. Here we use [`DefaultEvaluator`], a built-in implementation -//! of [`Evaluator`]. -//! -//! ```ignore -//! type E = TYPE_OF_ENV; -//! type P = TYPE_OF_POLICY; -//! -//! fn eval(model_dir: &str, render: bool) -> Result<()> { -//! let env_config: E::Config = { -//! let mut env_config = env_config() -//! .render_mode(Some("human".to_string())) -//! .set_wait_in_millis(10); -//! env_config -//! }; -//! let mut agent: P = { -//! let mut agent = create_agent(); -//! agent.load(model_dir)?; -//! agent.eval(); -//! agent -//! }; -//! -//! let _ = DefaultEvaluator::new(&env_config, 0, 5)?.evaluate(&mut agent); -//! } -//! ``` -//! -//! Users can customize the way the policy is evaluated by implementing a custom [`Evaluator`]. -//! -//! # Training -//! -//! You can train RL [`Agent`]s by using [`Trainer`] struct. -//! -//! ```ignore -//! fn train(max_opts: usize, model_dir: &str) -> Result<()> { -//! let mut trainer = { -//! let env_config = env_config(); // configration of the environment -//! let step_proc_config = SimpleStepProcessorConfig {}; -//! let replay_buffer_config = -//! SimpleReplayBufferConfig::default().capacity(REPLAY_BUFFER_CAPACITY); -//! let config = TrainerConfig::default() -//! .max_opts(max_opts); -//! // followed by methods to set training parameters -//! -//! trainer = Trainer::::build( -//! config, -//! env_config, -//! step_proc_config, -//! replay_buffer_config, -//! ) -//! }; -//! let mut agent = create_agent(); -//! let mut recorder = TensorboardRecorder::new(model_dir); -//! let mut evaluator = create_evaluator(&env_config())?; -//! -//! trainer.train(&mut agent, &mut recorder, &mut evaluator)?; -//! -//! Ok(()) -//! } -//! ``` -//! In the above code, [`SimpleStepProcessorConfig`] is configurations of -//! [`SimpleStepProcessor`], which implements [`StepProcessorBase`] trait. -//! [`StepProcessorBase`] abstracts the way how [`Step`] object is processed before pushed to -//! a replay buffer. Users can customize implementation of [`StepProcessorBase`] for their -//! purpose. For example, n-step TD samples or samples having Monte Carlo returns after the end -//! of episode can be computed with a statefull implementation of [`StepProcessorBase`]. -//! -//! It should be noted that a replay buffer is not a part of [`Agent`], but owned by -//! [`Trainer`]. In the above code, the configuration of a replay buffer is given to -//! [`Trainer`]. The design choise allows [`Agent`]s to separate sampling and optimization -//! processes. -//! -//! [`border-core`]: https://crates.io/crates/border-core -//! [`Env`]: border_core::Env -//! [`Env::Obs`]: border_core::Env::Obs -//! [`Env::Act`]: border_core::Env::Act -//! [`Env::Config`]: border_core::Env::Config -//! [`Policy`]: border_core::Policy -//! [`Recorder`]: border_core::record::Recorder -//! [`eval_with_recorder`]: border_core::util::eval_with_recorder -//! [`border_py_gym_env/examples/random_cartpols.rs`]: (https://github.com/taku-y/border/blob/982ef2d25a0ade93fb71cab3bb85e5062b6f769c/border-py-gym-env/examples/random_cartpole.rs) -//! [`Agent`]: border_core::Agent -//! [`StepProcessorBase`]: border_core::ReplayBufferBase -//! [`SimpleStepProcessor`]: border_core::replay_buffer::SimpleStepProcessor -//! [`SimpleStepProcessorConfig`]: border_core::replay_buffer::SimpleStepProcessorConfig -//! [`Step`]: border_core::Step -//! [`ReplayBufferBase`]: border_core::ReplayBufferBase -//! [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::Batch -//! [`StdBatchBase`]: border_core::StdBatchBase -//! [`ReplayBufferBase::Batch`]: border_core::ReplayBufferBase::Batch -//! [`Agent::opt()`]: border_core::Agent::opt -//! [`ExperienceBufferBase`]: border_core::ExperienceBufferBase -//! [`ExperienceBufferBase::PushedItem`]: border_core::ExperienceBufferBase::PushedItem -//! [`SimpleReplayBuffer`]: border_core::replay_buffer::SimpleReplayBuffer -//! [`Evaluator`]: border_core::Evaluator -//! [`DefaultEvaluator`]: border_core::DefaultEvaluator -//! [`Trainer`]: border_core::Trainer -//! [`Step`]: border_core::Step +//! A reinforcement learning library in Rust. +//! +//! Border consists of the following crates: +//! +//! * Core and utility +//! * [border-core](https://crates.io/crates/border-core) provides basic traits +//! and functions generic to environments and reinforcmenet learning (RL) agents. +//! * [border-tensorboard](https://crates.io/crates/border-tensorboard) has +//! `TensorboardRecorder` struct to write records which can be shown in Tensorboard. +//! It is based on [tensorboard-rs](https://crates.io/crates/tensorboard-rs). +//! * [border-mlflow-tracking](https://crates.io/crates/border-mlflow-tracking) +//! support MLflow tracking to log metrices during training via REST API. +//! * [border-async-trainer](https://crates.io/crates/border-async-trainer) defines +//! some traits and functions for asynchronous training of RL agents by multiple +//! actors, which runs sampling processes in parallel. In each sampling process, +//! an agent interacts with an environment to collect samples to be sent to a shared +//! replay buffer. +//! * [border](https://crates.io/crates/border) is just a collection of examples. +//! * Environment +//! * [border-py-gym-env](https://crates.io/crates/border-py-gym-env) is a wrapper of the +//! [Gymnasium](https://gymnasium.farama.org) environments written in Python. +//! * [border-atari-env](https://crates.io/crates/border-atari-env) is a wrapper of +//! [atari-env](https://crates.io/crates/atari-env), which is a part of +//! [gym-rs](https://crates.io/crates/gym-rs). +//! * Agent +//! * [border-tch-agent](https://crates.io/crates/border-tch-agent) includes RL agents +//! based on [tch](https://crates.io/crates/tch), including Deep Q network (DQN), +//! implicit quantile network (IQN), and soft actor critic (SAC). +//! * [border-candle-agent](https://crates.io/crates/border-candle-agent) includes RL +//! agents based on [candle](https://crates.io/crates/candle-core) +//! * [border-policy-no-backend](https://crates.io/crates/border-policy-no-backend) +//! includes a policy that is independent of any deep learning backend, such as Torch. pub mod util; diff --git a/border/src/util.rs b/border/src/util.rs index 35624f1a..dcb3d180 100644 --- a/border/src/util.rs +++ b/border/src/util.rs @@ -1,3 +1,3 @@ //! Utilities mod url; -pub use url::{get_model_from_url}; +pub use url::get_model_from_url; diff --git a/border/src/util/url.rs b/border/src/util/url.rs index b81575b5..18043311 100644 --- a/border/src/util/url.rs +++ b/border/src/util/url.rs @@ -26,7 +26,7 @@ pub fn get_model_from_url>( if path_zip.as_path().exists() { info!("Exists zip file {:?}, skips download", path_zip); path_dir.push(&file_base); - return Ok(path_dir) + return Ok(path_dir); } let mut zip_file = File::create(&path_zip.as_path()).context(format!( @@ -90,4 +90,4 @@ mod tests { Ok(()) } -} \ No newline at end of file +} diff --git a/border/tests/test_async_trainer.rs b/border/tests/test_async_trainer.rs new file mode 100644 index 00000000..52ad7d12 --- /dev/null +++ b/border/tests/test_async_trainer.rs @@ -0,0 +1,104 @@ +// use border_async_trainer::{ActorManagerConfig, AsyncTrainerConfig, SyncModel}; +// use border_atari_env::util::test::*; +// use border_core::{ +// replay_buffer::SimpleReplayBufferConfig, +// // record::BufferedRecorder, +// // replay_buffer::{ +// // SimpleReplayBuffer, SimpleReplayBufferConfig, SimpleStepProcessor, +// // SimpleStepProcessorConfig, +// // }, +// // Env as _, +// }; +// use log::info; +// // use std::sync::{Arc, Mutex}; +// // use test_log::test; + +// fn replay_buffer_config() -> SimpleReplayBufferConfig { +// SimpleReplayBufferConfig::default() +// } + +// fn actor_man_config() -> ActorManagerConfig { +// ActorManagerConfig::default() +// } + +// fn async_trainer_config() -> AsyncTrainerConfig { +// AsyncTrainerConfig { +// model_dir: Some("".to_string()), +// record_interval: 5, +// eval_interval: 105, +// max_train_steps: 15, +// save_interval: 5, +// sync_interval: 5, +// eval_episodes: 1, +// } +// } + +// impl SyncModel for RandomAgent { +// type ModelInfo = (); + +// fn model_info(&self) -> (usize, Self::ModelInfo) { +// info!("Returns the current model info"); +// (self.n_opts_steps(), ()) +// } + +// fn sync_model(&mut self, _model_info: &Self::ModelInfo) { +// info!("Sync model"); +// } +// } + +// #[test] +// fn test_async_trainer() { +// type Agent = RandomAgent; +// type StepProc = SimpleStepProcessor; +// type ReplayBuffer = SimpleReplayBuffer; +// type ActorManager_ = ActorManager; +// type AsyncTrainer_ = AsyncTrainer; + +// let env_config = env_config("pong".to_string()); +// let env = Env::build(&env_config, 0).unwrap(); +// let n_acts = env.get_num_actions_atari() as _; +// let agent_config = RandomAgentConfig { n_acts }; +// let step_proc_config = SimpleStepProcessorConfig::default(); +// let replay_buffer_config = replay_buffer_config(); +// let actor_man_config = actor_man_config(); +// let async_trainer_config = async_trainer_config(); +// let agent_configs = vec![agent_config.clone(); 2]; + +// let mut recorder = BufferedRecorder::new(); + +// // Shared flag to stop actor threads +// let stop = Arc::new(Mutex::new(false)); + +// // Pushed items into replay buffer +// let (item_s, item_r) = unbounded(); + +// // Synchronizing model +// let (model_s, model_r) = unbounded(); + +// // Prevents simlutaneous initialization of env +// let guard_init_env = Arc::new(Mutex::new(true)); + +// let mut actors = ActorManager_::build( +// &actor_man_config, +// &agent_configs, +// &env_config, +// &step_proc_config, +// item_s, +// model_r, +// stop.clone(), +// ); +// let mut trainer = AsyncTrainer_::build( +// &async_trainer_config, +// &agent_config, +// &env_config, +// &replay_buffer_config, +// item_r, +// model_s, +// stop, +// ); + +// actors.run(guard_init_env.clone()); +// trainer.train(&mut recorder, guard_init_env); + +// actors.stop_and_join(); +// } diff --git a/docker/aarch64/Dockerfile b/docker/aarch64/Dockerfile index 2a6ba1a8..9ffc784a 100644 --- a/docker/aarch64/Dockerfile +++ b/docker/aarch64/Dockerfile @@ -72,12 +72,16 @@ RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y RUN cd /root && python3 -m venv venv RUN source /root/venv/bin/activate && pip3 install --upgrade pip RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /root/venv/bin/activate && pip3 install torch==1.12.0 +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install mlflow==2.11.1 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 @@ -110,6 +114,7 @@ RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh diff --git a/docker/aarch64_doc/Dockerfile b/docker/aarch64_doc/Dockerfile deleted file mode 100644 index 42455746..00000000 --- a/docker/aarch64_doc/Dockerfile +++ /dev/null @@ -1,51 +0,0 @@ -FROM ubuntu:focal-20221130 - -ENV DEBIAN_FRONTEND noninteractive -RUN echo "Set disable_coredump false" >> /etc/sudo.conf -RUN apt-get update -q && \ - apt-get upgrade -yq && \ - apt-get install -yq wget curl git build-essential vim sudo libssl-dev -RUN apt install -y -q libclang-dev zip cmake llvm pkg-config libx11-dev libxkbcommon-dev - -# # swig -# RUN apt install -y swig - -# # python -# RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip - -# # headers required for building libtorch -# RUN apt install -y libgoogle-glog-dev libgflags-dev - -# # llvm, mesa for robosuite -# RUN apt install -y llvm libosmesa6-dev - -# # Used for Mujoco -# RUN apt install -y patchelf libglfw3 libglfw3-dev - -# Cleanup -RUN rm -rf /var/lib/apt/lists/* - -# COPY test_mujoco_py.py /test_mujoco_py.py -# RUN chmod 777 /test_mujoco_py.py - -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers - -# Use bash -RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh - -# User settings -USER ubuntu - -# rustup -RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y - -RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc - -USER root -RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh - -USER ubuntu -WORKDIR /home/ubuntu/border diff --git a/docker/aarch64_doc/README.md b/docker/aarch64_doc/README.md index 2e3eeb1e..6a9ef17c 100644 --- a/docker/aarch64_doc/README.md +++ b/docker/aarch64_doc/README.md @@ -1,39 +1,7 @@ -# Docker container for training - -This directory contains scripts to build and run a docker container for training. - -## Build - -The following command creates a container image locally, named `border_headless`. +Run the following command to create docs in border/doc. ```bash -cd $REPO/docker/aarch64_headless -sh build.sh +# in border/aarch64_doc +sh doc.sh ``` -# Build document - -The following commands builds the document and places it as `$REPO/doc`. - -## Run - -The following commands runs a program for training an agent. -The trained model will be saved in `$REPO/border/examples/model` directory, -which is mounted in the container. - -### DQN - -* Cartpole - - ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && cargo run --example dqn_cartpole --features='tch' -- --train" - ``` - - * Use a directory, not mounted on the host, as a cargo target directory, - making compile faster on Mac, where access to mounted directories is slow. - - ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && CARGO_TARGET_DIR=/home/ubuntu/target cargo run --example dqn_cartpole --features='tch' -- --train" - ``` diff --git a/docker/aarch64_doc/doc.sh b/docker/aarch64_doc/doc.sh index a5609c6b..c823279f 100644 --- a/docker/aarch64_doc/doc.sh +++ b/docker/aarch64_doc/doc.sh @@ -2,5 +2,5 @@ docker run -it --rm \ --name border_headless \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ - border_doc bash -l -c \ - "CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + border_headless bash -l -c \ + "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 LD_LIBRARY_PATH=$HOME/venv/lib/python3.10/site-packages/torch/lib:$LD_LIBRARY_PATH CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." diff --git a/docker/aarch64_headless/Dockerfile b/docker/aarch64_headless/Dockerfile index 5d99d528..c3fc15b0 100644 --- a/docker/aarch64_headless/Dockerfile +++ b/docker/aarch64_headless/Dockerfile @@ -1,13 +1,16 @@ -FROM ubuntu:focal-20221130 +FROM --platform=linux/aarch64 ubuntu:22.04 ENV DEBIAN_FRONTEND noninteractive RUN echo "Set disable_coredump false" >> /etc/sudo.conf RUN apt-get update -q && \ apt-get upgrade -yq && \ - apt-get install -yq wget curl git build-essential vim sudo libssl-dev - -# lsb-release locales bash-completion tzdata gosu && \ -# RUN rm -rf /var/lib/apt/lists/* + apt-get install -yq wget +RUN apt-get install -yq curl +RUN apt-get install -yq git +RUN apt-get install -yq build-essential +RUN apt-get install -yq vim +# RUN apt-get install -yq sudo +RUN apt-get install -yq libssl-dev # clang RUN apt install -y -q libclang-dev @@ -18,7 +21,7 @@ RUN apt update -y && \ DEBIAN_FRONTEND=noninteractive && \ apt install -y -q --no-install-recommends \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ - libsdl-dev libsdl-image1.2-dev + libsdl-image1.2-dev libsdl1.2-dev # zip RUN apt install -y zip @@ -27,7 +30,7 @@ RUN apt install -y zip RUN apt install -y swig # python -RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip +RUN apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip # cmake RUN apt install -y cmake @@ -44,32 +47,25 @@ RUN apt install -y patchelf libglfw3 libglfw3-dev # Cleanup RUN rm -rf /var/lib/apt/lists/* -# COPY test_mujoco_py.py /test_mujoco_py.py -# RUN chmod 777 /test_mujoco_py.py - -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers - # Use bash RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh -# User settings -USER ubuntu - # rustup RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # python -RUN cd /home/ubuntu && python3 -m venv venv -RUN source /home/ubuntu/venv/bin/activate && pip3 install --upgrade pip -RUN source /home/ubuntu/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /home/ubuntu/venv/bin/activate && pip3 install torch==1.12.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install ipython jupyterlab -RUN source /home/ubuntu/venv/bin/activate && pip3 install numpy==1.21.3 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN cd /root && python3 -m venv venv +RUN source /root/venv/bin/activate && pip3 install --upgrade pip +RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 +RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab +RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 +RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 +RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 @@ -79,20 +75,6 @@ RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1. # border RUN cd $HOME && mkdir -p .border/model -# Mujoco aarch64 binary -RUN cd $HOME && \ - mkdir .mujoco && \ - cd .mujoco && \ - wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-aarch64.tar.gz -RUN cd $HOME/.mujoco && \ - tar zxf mujoco-2.1.1-linux-aarch64.tar.gz && \ - mkdir -p mujoco210/bin && \ - ln -sf $PWD/mujoco-2.1.1/lib/libmujoco.so.2.1.1 $PWD/mujoco210/bin/libmujoco210.so && \ - ln -sf $PWD/mujoco-2.1.1/lib/libglewosmesa.so $PWD/mujoco210/bin/libglewosmesa.so && \ - ln -sf $PWD/mujoco-2.1.1/include/ $PWD/mujoco210/include && \ - ln -sf $PWD/mujoco-2.1.1/model/ $PWD/mujoco210/model -# RUN cp /*.py $HOME - # # PyBulletGym # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 # # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 @@ -106,22 +88,22 @@ RUN cd $HOME/.mujoco && \ # RUN sed -i 's/return state, sum(self.rewards), bool(done), {}/return state, sum(self.rewards), bool(done), bool(done), {}/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/roboschool/envs/locomotion/walker_base_env.py # RUN sed -i 's/id='\''AntPyBulletEnv-v0'\'',/id='\''AntPyBulletEnv-v0'\'', order_enforce=False,/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/__init__.py -# Env vars -# RUN echo 'export LIBTORCH=$HOME/.local/lib/python3.8/site-packages/torch' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc -# RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +# .bashrc +RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc +RUN echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc -ENV LIBTORCH_CXX11_ABI 0 -ENV LIBTORCH /home/ubuntu/venv/lib/python3.8/site-packages/torch -ENV LD_LIBRARY_PATH $LIBTORCH/lib -ENV PYTHONPATH /home/ubuntu/border/border-py-gym-env/examples:$PYTHONPATH +RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc +RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc +RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc -USER root RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh -USER ubuntu -WORKDIR /home/ubuntu/border +# USER root +# WORKDIR /home/ubuntu/border # ENV USER ubuntu # CMD ["/bin/bash", "-l", "-c"] diff --git a/docker/amd64/Dockerfile b/docker/amd64/Dockerfile index e0c9b114..fdbecf3b 100644 --- a/docker/amd64/Dockerfile +++ b/docker/amd64/Dockerfile @@ -1,4 +1,20 @@ -FROM dorowu/ubuntu-desktop-lxde-vnc:focal +FROM --platform=linux/amd64 ubuntu:22.04 + +# Adapted from https://qiita.com/takahashiakari/items/f096e5bcdfecf3d5ba90 +ENV DEBIAN_FRONTEND=noninteractive +RUN apt update -y && apt install --no-install-recommends --fix-missing -y xfce4 xfce4-goodies tigervnc-standalone-server novnc websockify sudo xterm init systemd snapd vim net-tools curl wget git tzdata +RUN apt update -y && apt install -y dbus-x11 x11-utils x11-xserver-utils x11-apps +RUN apt install software-properties-common -y +RUN add-apt-repository ppa:mozillateam/ppa -y +RUN echo 'Package: *' >> /etc/apt/preferences.d/mozilla-firefox +RUN echo 'Pin: release o=LP-PPA-mozillateam' >> /etc/apt/preferences.d/mozilla-firefox +RUN echo 'Pin-Priority: 1001' >> /etc/apt/preferences.d/mozilla-firefox +RUN echo 'Unattended-Upgrade::Allowed-Origins:: "LP-PPA-mozillateam:jammy";' | tee /etc/apt/apt.conf.d/51unattended-upgrades-firefox +RUN apt update -y && apt install -y firefox +RUN apt update -y && apt install -y xubuntu-icon-theme +RUN touch /root/.Xauthority +EXPOSE 5901 +EXPOSE 6080 ENV DEBIAN_FRONTEND noninteractive RUN echo "Set disable_coredump false" >> /etc/sudo.conf @@ -7,9 +23,6 @@ RUN apt-get update -q && \ apt-get upgrade -yq && \ apt-get install -yq wget curl git build-essential vim sudo libssl-dev -# lsb-release locales bash-completion tzdata gosu && \ -# RUN rm -rf /var/lib/apt/lists/* - # clang RUN apt install -y -q libclang-dev @@ -19,7 +32,7 @@ RUN apt update -y && \ DEBIAN_FRONTEND=noninteractive && \ apt install -y -q --no-install-recommends \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ - libsdl-dev libsdl-image1.2-dev + libsdl-image1.2-dev # zip RUN apt install -y zip @@ -28,7 +41,7 @@ RUN apt install -y zip RUN apt install -y swig # python -RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip +RUN apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip # cmake RUN apt install -y cmake @@ -53,33 +66,23 @@ RUN chmod 777 /test_dmc_viewer.py # Use bash RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers -COPY lxterminal.conf /lxterminal.conf -COPY desktop.conf /desktop.conf -RUN chmod 777 /*.conf - -# User desktop configs -USER ubuntu -RUN mkdir -p /home/ubuntu/.config/lxterminal && \ - cp /lxterminal.conf /home/ubuntu/.config/lxterminal/ -RUN mkdir -p /home/ubuntu/.config/lxsession/LXDE && \ - cp /desktop.conf /home/ubuntu/.config/lxsession/LXDE/ - # rustup RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # python -RUN cd /home/ubuntu && python -m venv venv -RUN source /home/ubuntu/venv/bin/activate && pip3 install --upgrade pip -RUN source /home/ubuntu/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /home/ubuntu/venv/bin/activate && pip3 install torch==1.12.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install ipython jupyterlab -RUN source /home/ubuntu/venv/bin/activate && pip3 install numpy==1.21.3 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN cd /root && python3 -m venv venv +RUN source /root/venv/bin/activate && pip3 install --upgrade pip +RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu --timeout 300 +RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab +RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 +RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 +RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install mlflow==2.11.1 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 @@ -89,29 +92,9 @@ RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1. # border RUN cd $HOME && mkdir -p .border/model -# Mujoco amd64 binary -RUN cd $HOME && \ - mkdir .mujoco && \ - cd .mujoco && \ - wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -RUN cd $HOME/.mujoco && \ - tar zxf mujoco-2.1.1-linux-x86_64.tar.gz && \ - mkdir -p mujoco210/bin && \ - ln -sf $PWD/mujoco-2.1.1/lib/libmujoco.so.2.1.1 $PWD/mujoco210/bin/libmujoco210.so && \ - ln -sf $PWD/mujoco-2.1.1/lib/libglewosmesa.so $PWD/mujoco210/bin/libglewosmesa.so && \ - ln -sf $PWD/mujoco-2.1.1/include/ $PWD/mujoco210/include && \ - ln -sf $PWD/mujoco-2.1.1/model/ $PWD/mujoco210/model -RUN cp /*.py $HOME - # # PyBulletGym # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 -# RUN source /home/ubuntu/venv/bin/activate && \ -# cd $HOME && \ -# git clone https://github.com/bulletphysics/bullet3.git && \ -# cd bullet3 && \ -# git checkout -b tmp 2c204c49e56ed15ec5fcfa71d199ab6d6570b3f5 && \ -# ./build_cmake_pybullet_double.sh # RUN cd $HOME && \ # git clone https://github.com/benelot/pybullet-gym.git && \ # cd pybullet-gym && \ @@ -123,7 +106,7 @@ RUN cp /*.py $HOME # RUN sed -i 's/id='\''AntPyBulletEnv-v0'\'',/id='\''AntPyBulletEnv-v0'\'', order_enforce=False,/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/__init__.py # .bashrc -RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.8/site-packages/torch' >> ~/.bashrc +RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc RUN echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc @@ -132,9 +115,8 @@ RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc -# RUN echo 'export PYTHONPATH=$HOME/bullet3/build_cmake/examples/pybullet:$PYTHONPATH' >> ~/.bashrc +# RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc -USER root RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh -ENV USER ubuntu +CMD bash -c "vncserver -localhost no -SecurityTypes None -geometry 1024x768 --I-KNOW-THIS-IS-INSECURE && openssl req -new -subj "/C=JP" -x509 -days 365 -nodes -out self.pem -keyout self.pem && websockify -D --web=/usr/share/novnc/ --cert=self.pem 6080 localhost:5901 && tail -f /dev/null" diff --git a/docker/amd64/build.sh b/docker/amd64/build.sh index 0eb76e0e..0936264c 100644 --- a/docker/amd64/build.sh +++ b/docker/amd64/build.sh @@ -1,2 +1,3 @@ #!/bin/bash docker build -t border . +#podman build -t border . diff --git a/docker/amd64/remove.sh b/docker/amd64/remove.sh index e7a325bc..3872196d 100644 --- a/docker/amd64/remove.sh +++ b/docker/amd64/remove.sh @@ -1 +1,2 @@ docker rm -f border +#podman rm -f border diff --git a/docker/amd64/run.sh b/docker/amd64/run.sh index ec60eac0..be2f43f3 100644 --- a/docker/amd64/run.sh +++ b/docker/amd64/run.sh @@ -1,13 +1,13 @@ #!/bin/bash -# nvidia-docker run -it --rm \ -# --env="DISPLAY" \ -# --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ -# --volume="/home/taku-y:/home/taku-y" \ -# --name my_pybullet my_pybullet bash +nvidia-docker run -it --rm \ + --env="DISPLAY" \ + --volume="/tmp/.X11-unix:/tmp/.X11-unix:rw" \ + --volume="/home/taku-y:/home/taku-y" \ + --name my_pybullet my_pybullet bash -docker run -td \ - --name border \ - -p 6080:80 \ - --shm-size=512m \ - --volume="$(pwd)/../..:/home/ubuntu/border" \ - border +# podman run -td \ +# --name border \ +# -p 6080:6080 \ +# --shm-size=512m \ +# --volume="$(pwd)/../..:/root/border" \ +# border diff --git a/docker/amd64_doc/Dockerfile b/docker/amd64_doc/Dockerfile new file mode 100644 index 00000000..c61e3ea0 --- /dev/null +++ b/docker/amd64_doc/Dockerfile @@ -0,0 +1,110 @@ +FROM --platform=linux/amd64 ubuntu:22.04 + +ENV DEBIAN_FRONTEND noninteractive +RUN echo "Set disable_coredump false" >> /etc/sudo.conf +RUN apt-get update -q && \ + apt-get upgrade -yq && \ + apt-get install -yq wget +RUN apt-get install -yq curl +RUN apt-get install -yq git +RUN apt-get install -yq build-essential +RUN apt-get install -yq vim +# RUN apt-get install -yq sudo +RUN apt-get install -yq libssl-dev + +# clang +RUN apt install -y -q libclang-dev + +# sdl +RUN apt update -y && \ + apt upgrade -y && \ + DEBIAN_FRONTEND=noninteractive && \ + apt install -y -q --no-install-recommends \ + libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ + libsdl-image1.2-dev libsdl1.2-dev + +# zip +RUN apt install -y zip + +# swig +RUN apt install -y swig + +# python +RUN apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip + +# cmake +RUN apt install -y cmake + +# headers required for building libtorch +RUN apt install -y libgoogle-glog-dev libgflags-dev + +# llvm, mesa for robosuite +RUN apt install -y llvm libosmesa6-dev + +# Used for Mujoco +RUN apt install -y patchelf libglfw3 libglfw3-dev + +# Cleanup +RUN rm -rf /var/lib/apt/lists/* + +# Use bash +RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh + +# rustup +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +# python +RUN cd /root && python3 -m venv venv +RUN source /root/venv/bin/activate && pip3 install --upgrade pip +RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cpu --timeout 300 +RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab +RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 +RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 +RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 +# RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 +# RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' +# RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 +# RUN source /home/ubuntu/venv/bin/activate && pip3 install pyrender==0.1.45 +# RUN source /home/ubuntu/venv/bin/activate && pip3 install dm2gym==0.2.0 + +# border +RUN cd $HOME && mkdir -p .border/model + +# # PyBulletGym +# RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 +# # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 +# RUN cd $HOME && \ +# git clone https://github.com/benelot/pybullet-gym.git && \ +# cd pybullet-gym && \ +# git checkout -b tmp bc68201c8101c4e30dde95f425647a0709ee2f29 && \ +# source /home/ubuntu/venv/bin/activate && \ +# pip install -e . +# # Tweaks for version incompatibility of gym and pybullet-gym +# RUN sed -i 's/return state, sum(self.rewards), bool(done), {}/return state, sum(self.rewards), bool(done), bool(done), {}/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/roboschool/envs/locomotion/walker_base_env.py +# RUN sed -i 's/id='\''AntPyBulletEnv-v0'\'',/id='\''AntPyBulletEnv-v0'\'', order_enforce=False,/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/__init__.py + +# .bashrc +RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc +RUN echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc +RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc +RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc +RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc +RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc + +RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh + +# USER root +# WORKDIR /home/ubuntu/border + +# ENV USER ubuntu +# CMD ["/bin/bash", "-l", "-c"] +# CMD source /home/ubuntu/.bashrc diff --git a/docker/amd64_doc/README.md b/docker/amd64_doc/README.md new file mode 100644 index 00000000..092a9229 --- /dev/null +++ b/docker/amd64_doc/README.md @@ -0,0 +1,17 @@ +# Docker container for building documents + +## Build a Docker image + +```bash +# cd $REPO/docker/aarch64_doc +sh build.sh +``` + +# Build document + +The following commands builds the document and places it as `$REPO/doc`. + +```bash +# cd $REPO/docker/aarch64_doc +sh doc.sh +``` diff --git a/docker/aarch64_doc/build.sh b/docker/amd64_doc/build.sh similarity index 57% rename from docker/aarch64_doc/build.sh rename to docker/amd64_doc/build.sh index 12e19a79..8040a4ff 100644 --- a/docker/aarch64_doc/build.sh +++ b/docker/amd64_doc/build.sh @@ -1,2 +1,3 @@ #!/bin/bash docker build -t border_doc . +#podman build -t border_doc . diff --git a/docker/amd64_doc/doc.sh b/docker/amd64_doc/doc.sh new file mode 100644 index 00000000..5f70fbf2 --- /dev/null +++ b/docker/amd64_doc/doc.sh @@ -0,0 +1,13 @@ +docker run -it --rm \ + --name border_doc \ + --shm-size=512m \ + --volume="$(pwd)/../..:/home/ubuntu/border" \ + border_headless bash -l -c \ + "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + +# podman run -it --rm \ +# --name border_doc \ +# --shm-size=512m \ +# --volume="$(pwd)/../..:/home/ubuntu/border" \ +# border_headless bash -l -c \ +# "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." diff --git a/docker/amd64_doc/run.sh b/docker/amd64_doc/run.sh new file mode 100644 index 00000000..0ab00ccc --- /dev/null +++ b/docker/amd64_doc/run.sh @@ -0,0 +1,5 @@ +docker run -it --rm \ + --name border_headless \ + --shm-size=512m \ + --volume="$(pwd)/../..:/home/ubuntu/border" \ + border_headless bash -l -c "$@" diff --git a/docker/amd64_headless/.gitignore b/docker/amd64_headless/.gitignore new file mode 100644 index 00000000..cbf82bc0 --- /dev/null +++ b/docker/amd64_headless/.gitignore @@ -0,0 +1 @@ +log_** diff --git a/docker/amd64_headless/Dockerfile b/docker/amd64_headless/Dockerfile index d460674d..49249a54 100644 --- a/docker/amd64_headless/Dockerfile +++ b/docker/amd64_headless/Dockerfile @@ -1,13 +1,16 @@ -FROM ubuntu:focal-20221130 +FROM --platform=linux/amd64 nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 ENV DEBIAN_FRONTEND noninteractive RUN echo "Set disable_coredump false" >> /etc/sudo.conf RUN apt-get update -q && \ apt-get upgrade -yq && \ - apt-get install -yq wget curl git build-essential vim sudo libssl-dev - -# lsb-release locales bash-completion tzdata gosu && \ -# RUN rm -rf /var/lib/apt/lists/* + apt-get install -yq wget +RUN apt-get install -yq curl +RUN apt-get install -yq git +RUN apt-get install -yq build-essential +RUN apt-get install -yq vim +# RUN apt-get install -yq sudo +RUN apt-get install -yq libssl-dev # clang RUN apt install -y -q libclang-dev @@ -18,7 +21,7 @@ RUN apt update -y && \ DEBIAN_FRONTEND=noninteractive && \ apt install -y -q --no-install-recommends \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ - libsdl-dev libsdl-image1.2-dev + libsdl-image1.2-dev libsdl1.2-dev # zip RUN apt install -y zip @@ -27,7 +30,7 @@ RUN apt install -y zip RUN apt install -y swig # python -RUN apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip +RUN apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip # cmake RUN apt install -y cmake @@ -44,32 +47,26 @@ RUN apt install -y patchelf libglfw3 libglfw3-dev # Cleanup RUN rm -rf /var/lib/apt/lists/* -# COPY test_mujoco_py.py /test_mujoco_py.py -# RUN chmod 777 /test_mujoco_py.py - -# Add user -RUN useradd --create-home --home-dir /home/ubuntu --shell /bin/bash --user-group --groups adm,sudo ubuntu && \ - echo ubuntu:ubuntu | chpasswd && \ - echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers - # Use bash RUN mv /bin/sh /bin/sh_tmp && ln -s /bin/bash /bin/sh -# User settings -USER ubuntu - # rustup RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # python -RUN cd /home/ubuntu && python3 -m venv venv -RUN source /home/ubuntu/venv/bin/activate && pip3 install --upgrade pip -RUN source /home/ubuntu/venv/bin/activate && pip3 install pyyaml typing-extensions -RUN source /home/ubuntu/venv/bin/activate && pip3 install torch==1.12.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install ipython jupyterlab -RUN source /home/ubuntu/venv/bin/activate && pip3 install numpy==1.21.3 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 -RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN cd /root && python3 -m venv venv +RUN source /root/venv/bin/activate && pip3 install --upgrade pip +RUN source /root/venv/bin/activate && pip3 install pyyaml typing-extensions +RUN source /root/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cu118 --timeout 300 +RUN source /root/venv/bin/activate && pip3 install ipython jupyterlab +RUN source /root/venv/bin/activate && pip3 install numpy==1.21.3 +RUN source /root/venv/bin/activate && pip3 install mujoco==2.3.7 +RUN source /root/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +RUN source /root/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +RUN source /root/venv/bin/activate && pip3 install tensorboard==2.16.2 +RUN source /root/venv/bin/activate && pip3 install mlflow==2.11.1 +RUN source /root/venv/bin/activate && pip3 install tabulate==0.9.0 +RUN source /root/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 # RUN source /home/ubuntu/venv/bin/activate && pip3 install robosuite==1.3.2 # RUN source /home/ubuntu/venv/bin/activate && pip3 install -U 'mujoco-py<2.2,>=2.1' # RUN source /home/ubuntu/venv/bin/activate && pip3 install dm-control==1.0.8 @@ -79,20 +76,6 @@ RUN source /home/ubuntu/venv/bin/activate && pip3 install gymnasium-robotics==1. # border RUN cd $HOME && mkdir -p .border/model -# Mujoco amd64 binary -RUN cd $HOME && \ - mkdir .mujoco && \ - cd .mujoco && \ - wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-x86_64.tar.gz -RUN cd $HOME/.mujoco && \ - tar zxf mujoco-2.1.1-linux-x86_64.tar.gz && \ - mkdir -p mujoco210/bin && \ - ln -sf $PWD/mujoco-2.1.1/lib/libmujoco.so.2.1.1 $PWD/mujoco210/bin/libmujoco210.so && \ - ln -sf $PWD/mujoco-2.1.1/lib/libglewosmesa.so $PWD/mujoco210/bin/libglewosmesa.so && \ - ln -sf $PWD/mujoco-2.1.1/include/ $PWD/mujoco210/include && \ - ln -sf $PWD/mujoco-2.1.1/model/ $PWD/mujoco210/model -# RUN cp /*.py $HOME - # # PyBulletGym # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==3.2.5 # # RUN source /home/ubuntu/venv/bin/activate && pip3 install pybullet==2.7.1 @@ -106,22 +89,22 @@ RUN cd $HOME/.mujoco && \ # RUN sed -i 's/return state, sum(self.rewards), bool(done), {}/return state, sum(self.rewards), bool(done), bool(done), {}/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/roboschool/envs/locomotion/walker_base_env.py # RUN sed -i 's/id='\''AntPyBulletEnv-v0'\'',/id='\''AntPyBulletEnv-v0'\'', order_enforce=False,/g' /home/ubuntu/pybullet-gym/pybulletgym/envs/__init__.py -# Env vars -# RUN echo 'export LIBTORCH=$HOME/.local/lib/python3.8/site-packages/torch' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc -# RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc -# RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +# .bashrc +RUN echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc +RUN echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc RUN echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc -ENV LIBTORCH_CXX11_ABI 0 -ENV LIBTORCH /home/ubuntu/venv/lib/python3.8/site-packages/torch -ENV LD_LIBRARY_PATH $LIBTORCH/lib -ENV PYTHONPATH /home/ubuntu/border/border-py-gym-env/examples:$PYTHONPATH +RUN echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc +RUN echo 'export CARGO_TARGET_DIR=$HOME/target' >> ~/.bashrc +RUN echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin' >> ~/.bashrc +RUN echo 'export MUJOCO_GL=glfw' >> ~/.bashrc +RUN echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +RUN echo 'export RUSTFLAGS="-C target-feature=+fp16"' >> ~/.bashrc -USER root RUN rm /bin/sh && mv /bin/sh_tmp /bin/sh -USER ubuntu -WORKDIR /home/ubuntu/border +# USER root +# WORKDIR /home/ubuntu/border # ENV USER ubuntu # CMD ["/bin/bash", "-l", "-c"] diff --git a/docker/amd64_headless/README.md b/docker/amd64_headless/README.md index 2e3eeb1e..c63c3617 100644 --- a/docker/amd64_headless/README.md +++ b/docker/amd64_headless/README.md @@ -2,38 +2,32 @@ This directory contains scripts to build and run a docker container for training. -## Build - -The following command creates a container image locally, named `border_headless`. +## Build a Docker image ```bash -cd $REPO/docker/aarch64_headless +# cd $REPO/docker/aarch64_headless sh build.sh ``` -# Build document - -The following commands builds the document and places it as `$REPO/doc`. - -## Run +## Run training The following commands runs a program for training an agent. -The trained model will be saved in `$REPO/border/examples/model` directory, -which is mounted in the container. +The trained model will be saved in `$REPO/border/examples/model` directory. ### DQN * Cartpole + Note that this command starts an MLflow server, accessible via a web browser at $IP:8080 without any authentication. + ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && cargo run --example dqn_cartpole --features='tch' -- --train" + # cd $REPO/docker/amd64_headless + sh dqn_cartpole.sh ``` - * Use a directory, not mounted on the host, as a cargo target directory, - making compile faster on Mac, where access to mounted directories is slow. +## Start MLflow server for checking logs - ```bash - cd $REPO/docker/aarch64_headless - sh run.sh "source /home/ubuntu/venv/bin/activate && CARGO_TARGET_DIR=/home/ubuntu/target cargo run --example dqn_cartpole --features='tch' -- --train" - ``` +```bash +# cd $REPO/docker/amd64_headless +sh mlflow.sh +``` diff --git a/docker/amd64_headless/build.sh b/docker/amd64_headless/build.sh index 860261b3..86de24a2 100644 --- a/docker/amd64_headless/build.sh +++ b/docker/amd64_headless/build.sh @@ -1,2 +1,3 @@ #!/bin/bash docker build -t border_headless . +#podman build -t border_headless . diff --git a/docker/amd64_headless/doc.sh b/docker/amd64_headless/doc.sh index ac4d6098..ec32f92f 100644 --- a/docker/amd64_headless/doc.sh +++ b/docker/amd64_headless/doc.sh @@ -3,4 +3,11 @@ docker run -it --rm \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ border_headless bash -l -c \ - "CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." + +# podman run -it --rm \ +# --name border_headless \ +# --shm-size=512m \ +# --volume="$(pwd)/../..:/home/ubuntu/border" \ +# border_headless bash -l -c \ +# "cd /home/ubuntu/border; source /root/venv/bin/activate; LIBTORCH_USE_PYTORCH=1 CARGO_TARGET_DIR=/home/ubuntu/target cargo doc --no-deps --document-private-items; cp -r /home/ubuntu/target/doc ." diff --git a/docker/amd64_headless/dqn_atari_async_tch.sh b/docker/amd64_headless/dqn_atari_async_tch.sh new file mode 100644 index 00000000..9f898b93 --- /dev/null +++ b/docker/amd64_headless/dqn_atari_async_tch.sh @@ -0,0 +1,17 @@ +rm -f log_dqn_atari_async_tch.txt +sh run_detach.sh " + source /root/venv/bin/activate && \ + pip3 install autorom && \ + mkdir $HOME/atari_rom && \ + AutoROM --install-dir $HOME/atari_rom --accept-license && \ + cd /home/ubuntu/border && \ + mlflow server --host 0.0.0.0 --port 8080 & \ + sleep 5 && \ + cd /home/ubuntu/border; \ + source /root/venv/bin/activate && \ + LIBTORCH_USE_PYTORCH=1 \ + LD_LIBRARY_PATH=/root/venv/lib/python3.10/site-packages/torch/lib \ + ATARI_ROM_DIR=$HOME/atari_rom \ + cargo run --release --example dqn_atari_async_tch --features=tch,border-async-trainer \ + -- ${1} --mlflow --n-actors 6 > \ + $HOME/border/docker/amd64_headless/log_dqn_atari_async_tch_${1}.txt 2>&1" diff --git a/docker/amd64_headless/dqn_atari_tch.sh b/docker/amd64_headless/dqn_atari_tch.sh new file mode 100644 index 00000000..c6320661 --- /dev/null +++ b/docker/amd64_headless/dqn_atari_tch.sh @@ -0,0 +1,17 @@ +rm -f log_dqn_atari_tch.txt +sh run_detach.sh " + source /root/venv/bin/activate && \ + pip3 install autorom && \ + mkdir $HOME/atari_rom && \ + AutoROM --install-dir $HOME/atari_rom --accept-license && \ + cd /home/ubuntu/border && \ + mlflow server --host 0.0.0.0 --port 8080 & \ + sleep 5 && \ + cd /home/ubuntu/border; \ + source /root/venv/bin/activate && \ + LIBTORCH_USE_PYTORCH=1 \ + LD_LIBRARY_PATH=/root/venv/lib/python3.10/site-packages/torch/lib \ + ATARI_ROM_DIR=$HOME/atari_rom \ + cargo run --release --example dqn_atari_tch --features=tch \ + -- ${1} --mlflow > \ + $HOME/border/docker/amd64_headless/log_dqn_atari_tch_${1}.txt 2>&1" diff --git a/docker/amd64_headless/dqn_cartpole.sh b/docker/amd64_headless/dqn_cartpole.sh new file mode 100644 index 00000000..dacf1726 --- /dev/null +++ b/docker/amd64_headless/dqn_cartpole.sh @@ -0,0 +1,12 @@ +sh run_detach.sh " + cd /home/ubuntu/border && \ + source /root/venv/bin/activate && \ + mlflow server --host 0.0.0.0 --port 8080 & \ + sleep 5 && \ + cd /home/ubuntu/border; \ + source /root/venv/bin/activate && \ + LIBTORCH_USE_PYTORCH=1 \ + PYTHONPATH=/home/ubuntu/border/border-py-gym-env/examples \ + cargo run --release --example dqn_cartpole --features=candle-core,cuda,cudnn \ + -- --train --mlflow > \ + $HOME/border/docker/amd64_headless/log_dqn_cartpole.txt 2>&1" diff --git a/docker/amd64_headless/mlflow.sh b/docker/amd64_headless/mlflow.sh new file mode 100644 index 00000000..2475ce09 --- /dev/null +++ b/docker/amd64_headless/mlflow.sh @@ -0,0 +1,5 @@ +sh run.sh " + cd /home/ubuntu/border && \ + source /root/venv/bin/activate && \ + mlflow server --host 0.0.0.0 --port 8080 +" \ No newline at end of file diff --git a/docker/amd64_headless/run.sh b/docker/amd64_headless/run.sh index 0ab00ccc..f8d56238 100644 --- a/docker/amd64_headless/run.sh +++ b/docker/amd64_headless/run.sh @@ -1,4 +1,6 @@ docker run -it --rm \ + --net host \ + --gpus all \ --name border_headless \ --shm-size=512m \ --volume="$(pwd)/../..:/home/ubuntu/border" \ diff --git a/docker/amd64_headless/run_detach.sh b/docker/amd64_headless/run_detach.sh new file mode 100644 index 00000000..962c93d7 --- /dev/null +++ b/docker/amd64_headless/run_detach.sh @@ -0,0 +1,7 @@ +docker run -dt --rm \ + --net host \ + --gpus all \ + --name border_headless \ + --shm-size=512m \ + --volume="$(pwd)/../..:/home/ubuntu/border" \ + border_headless bash -l -c "$@" diff --git a/docker/amd64_headless/sac_mujoco_async_tch.sh b/docker/amd64_headless/sac_mujoco_async_tch.sh new file mode 100644 index 00000000..5604dbe4 --- /dev/null +++ b/docker/amd64_headless/sac_mujoco_async_tch.sh @@ -0,0 +1,14 @@ +rm -f log_sac_mujoco_async_tch_${1}.txt +sh run_detach.sh " + source /root/venv/bin/activate && \ + cd /home/ubuntu/border && \ + mlflow server --host 0.0.0.0 --port 8080 & \ + sleep 5 && \ + cd /home/ubuntu/border; \ + source /root/venv/bin/activate && \ + LIBTORCH_USE_PYTORCH=1 \ + LD_LIBRARY_PATH=/root/venv/lib/python3.10/site-packages/torch/lib \ + PYTHONPATH=/home/ubuntu/border/border-py-gym-env/examples \ + cargo run --release --example sac_mujoco_async_tch --features=tch,border-async-trainer \ + -- ${1} --mlflow --n-actors 6 > \ + $HOME/border/docker/amd64_headless/log_sac_mujoco_async_tch_${1}.txt 2>&1" diff --git a/docker/amd64_headless/sac_mujoco_tch.sh b/docker/amd64_headless/sac_mujoco_tch.sh new file mode 100644 index 00000000..e09dbed8 --- /dev/null +++ b/docker/amd64_headless/sac_mujoco_tch.sh @@ -0,0 +1,14 @@ +rm -f log_sac_mujoco_tch_${1}.txt +sh run_detach.sh " + source /root/venv/bin/activate && \ + cd /home/ubuntu/border && \ + mlflow server --host 0.0.0.0 --port 8080 & \ + sleep 5 && \ + cd /home/ubuntu/border; \ + source /root/venv/bin/activate && \ + LIBTORCH_USE_PYTORCH=1 \ + LD_LIBRARY_PATH=/root/venv/lib/python3.10/site-packages/torch/lib \ + PYTHONPATH=/home/ubuntu/border/border-py-gym-env/examples \ + cargo run --release --example sac_mujoco_tch --features=tch \ + -- --env ${1} --mlflow > \ + $HOME/border/docker/amd64_headless/log_sac_mujoco_tch_${1}.txt 2>&1" diff --git a/ec2/install.sh b/ec2/install.sh new file mode 100644 index 00000000..06b3cdc4 --- /dev/null +++ b/ec2/install.sh @@ -0,0 +1,67 @@ +echo "\$nrconf{restart} = 'a';" | sudo tee /etc/needrestart/conf.d/50local.conf +$nrconf{restart} = 'a'; + +export DEBIAN_FRONTEND=noninteractive +sudo apt update -qy +sudo apt upgrade -qy +sudo apt install --no-install-recommends -y xfce4 xfce4-goodies tigervnc-standalone-server novnc websockify sudo xterm init systemd snapd vim net-tools curl wget git tzdata +sudo apt install -y dbus-x11 x11-utils x11-xserver-utils x11-apps +sudo apt install software-properties-common -y +sudo apt install -y build-essential libssl-dev swig cmake tmux htop libxkbcommon-dev + +# clang +sudo apt install -y -q libclang-dev + +# sdl +sudo DEBIAN_FRONTEND=noninteractive \ + apt install -y -q --no-install-recommends \ + libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ + libsdl-image-1.2-dev + +# python +# sudo apt install -y python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip +sudo apt install -y python3 python3-dev python3-distutils python3-venv python3-pip + +# headers required for building libtorch +sudo apt install -y libgoogle-glog-dev libgflags-dev + +# llvm, mesa for robosuite +sudo apt install -y llvm libosmesa6-dev + +# Used for Mujoco +sudo apt install -y patchelf libglfw3 libglfw3-dev + +# Cleanup +sudo rm -rf /var/lib/apt/lists/* + +# rustup +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +# python +cd $HOME && python3 -m venv venv +source $HOME/venv/bin/activate && pip3 install --upgrade pip +source $HOME/venv/bin/activate && pip3 install pyyaml typing-extensions +# source $HOME/venv/bin/activate && pip3 install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir +source $HOME/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir +# source $HOME/venv/bin/activate && pip3 install torch==2.0.1 +source $HOME/venv/bin/activate && pip3 install ipython jupyterlab +source $HOME/venv/bin/activate && pip3 install numpy==1.21.3 +source $HOME/venv/bin/activate && pip3 install mujoco==2.3.7 +source $HOME/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 +source $HOME/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 +source $HOME/venv/bin/activate && pip3 install tensorboard==2.16.2 +source $HOME/venv/bin/activate && pip3 install mlflow==2.11.1 +source $HOME/venv/bin/activate && pip3 install tabulate==0.9.0 +source $HOME/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 + +echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc +echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$LIBTORCH/lib' >> ~/.bashrc +echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc +echo 'export PATH=$PATH:$HOME/.local/bin:$PATH' >> ~/.bashrc +echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc +echo 'source "$HOME/.cargo/env"' >> ~/.bashrc +echo 'source $HOME/venv/bin/activate' >> ~/.bashrc +echo 'export ATARI_ROM_DIR=$HOME/atari_rom' >> ~/.bashrc +echo 'alias tml="tmux list-sessions"' >> ~/.bashrc +echo 'alias tma="tmux a -t"' >> ~/.bashrc +echo 'alias tms="tmux new -s"' >> ~/.bashrc diff --git a/gpusoroban/.gitignore b/gpusoroban/.gitignore new file mode 100644 index 00000000..02b51cfe --- /dev/null +++ b/gpusoroban/.gitignore @@ -0,0 +1 @@ +export**/** diff --git a/gpusoroban/README.md b/gpusoroban/README.md index de248a52..ef475d7e 100644 --- a/gpusoroban/README.md +++ b/gpusoroban/README.md @@ -13,10 +13,24 @@ cd $HOME/border RUST_LOG=info PYTHONPATH=./border-py-gym-env/examples cargo run --example random_cartpole ``` -## Copy trained model parameter file in remote to local +## Copy trained model parameter file from remote to local ```bash -scp -r -i ~/.ssh/mykey.txt -P 20122 user@localhost:/home/user/border/border/examples/atari/model/dqn_pong border/examples/atari/model +sh scp_results.sh +``` + +## Export and copy MLflow experiments + +Experiment logs will be copied to `$PWD/export`. + +```bash +sh export_and_scp_expr.sh +``` + +## Import MLflow experiments + +```bash +sh import_expr.sh ``` ## Install Atari ROM (optional) diff --git a/gpusoroban/export_and_scp_expr.sh b/gpusoroban/export_and_scp_expr.sh new file mode 100644 index 00000000..539b9455 --- /dev/null +++ b/gpusoroban/export_and_scp_expr.sh @@ -0,0 +1,14 @@ +# Export +ssh -i ~/.ssh/mykey.txt -p 20122 user@localhost 'mkdir ~/export' +ssh -i ~/.ssh/mykey.txt -p 20122 user@localhost \ + 'echo " +export MLFLOW_TRACKING_URI=http://localhost:8080 +/home/user/venv/bin/export-experiment --experiment Gym --output-dir /home/user/export/Gym +#/home/user/venv/bin/export-experiment --experiment Atari --output-dir /home/user/export/Atari +" > tmp.sh' + +ssh -i ~/.ssh/mykey.txt -p 20122 user@localhost 'bash tmp.sh' + +# Remote copy +rm -fr $PWD/export +scp -r -i ~/.ssh/mykey.txt -P 20122 user@localhost:/home/user/export $PWD diff --git a/gpusoroban/import_expr.sh b/gpusoroban/import_expr.sh new file mode 100644 index 00000000..0c1759a3 --- /dev/null +++ b/gpusoroban/import_expr.sh @@ -0,0 +1,4 @@ +export MLFLOW_TRACKING_URI=http://localhost:8080 + +import-experiment --experiment-name Gym --input-dir export/Gym +#import-experiment --experiment-name Atari --input-dir export/Atari diff --git a/gpusoroban/install.sh b/gpusoroban/install.sh index 95fa6178..baa158a0 100644 --- a/gpusoroban/install.sh +++ b/gpusoroban/install.sh @@ -36,18 +36,23 @@ curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y cd /home/user && python3 -m venv venv source /home/user/venv/bin/activate && pip3 install --upgrade pip source /home/user/venv/bin/activate && pip3 install pyyaml typing-extensions -source /home/user/venv/bin/activate && pip3 install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir +# source /home/user/venv/bin/activate && pip3 install torch==1.12.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 --no-cache-dir +source /home/user/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/cu121 --no-cache-dir +# source /home/user/venv/bin/activate && pip3 install torch==2.0.1 source /home/user/venv/bin/activate && pip3 install ipython jupyterlab source /home/user/venv/bin/activate && pip3 install numpy==1.21.3 source /home/user/venv/bin/activate && pip3 install mujoco==2.3.7 source /home/user/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0 source /home/user/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2 -source /home/user/venv/bin/activate && pip3 install tensorboard +source /home/user/venv/bin/activate && pip3 install tensorboard==2.16.2 +source /home/user/venv/bin/activate && pip3 install mlflow==2.11.1 +source /home/user/venv/bin/activate && pip3 install tabulate==0.9.0 +source /home/user/venv/bin/activate && pip3 install mlflow-export-import==1.2.0 echo 'export LIBTORCH=$HOME/venv/lib/python3.10/site-packages/torch' >> ~/.bashrc -echo 'export LD_LIBRARY_PATH=$LIBTORCH/lib' >> ~/.bashrc +echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$LIBTORCH/lib' >> ~/.bashrc echo 'export LIBTORCH_CXX11_ABI=0' >> ~/.bashrc -echo 'export PATH=$HOME/.local/bin:$PATH' >> ~/.bashrc +echo 'export PATH=$PATH:$HOME/.local/bin:$PATH' >> ~/.bashrc echo 'export PYTHONPATH=$HOME/border/border-py-gym-env/examples:$PYTHONPATH' >> ~/.bashrc echo 'source "$HOME/.cargo/env"' >> ~/.bashrc echo 'source $HOME/venv/bin/activate' >> ~/.bashrc diff --git a/gpusoroban/scp_results.sh b/gpusoroban/scp_results.sh new file mode 100644 index 00000000..b889c04b --- /dev/null +++ b/gpusoroban/scp_results.sh @@ -0,0 +1,18 @@ +REMOTE_BASE_DIR=/home/user/border +LOCAL_BASE_DIR=$PWD/.. + +function copy_best_result() { + SRC=$REMOTE_BASE_DIR/$1/best + DST=$LOCAL_BASE_DIR/$1 + echo ==================================== + echo Copy result to $DST + scp -r -i ~/.ssh/mykey.txt -P 20122 user@localhost:$SRC $DST +} + +# DQN Pong +copy_best_result border/examples/atari/model/candle/dqn_pong +copy_best_result border/examples/atari/model/tch/dqn_pong + +# SAC Ant +copy_best_result border/examples/ant/model/candle +copy_best_result border/examples/ant/model/tch diff --git a/singularity/abci/README.md b/singularity/abci/README.md index d7ccd744..28e0984d 100644 --- a/singularity/abci/README.md +++ b/singularity/abci/README.md @@ -2,14 +2,84 @@ This directory contains scripts to build and run a singularity container for training on [ABCI](https://abci.ai/). -## Build +## Preparation + +### Login to the interactive node + +```bash +# Login to the access server +ssh -i $ABCI_IDENTITY_FILE -L 10022:es:22 -l $ABCI_USER_NAME as.abci.ai +``` + +```bash +# Login to the interactive node +ssh -i $ABCI_IDENTITY_FILE -p 10022 -l $ABCI_USER_NAME localhost +``` + +### Build singularity container image + +Build the SIF image in an interactive node. ```bash +# in $HOME +git clone https://github.com/taku-y/border.git +``` + +```bash +cd border/singularity/abci sh build.sh ``` -## Run +### Install MLflow (optional) ```bash -qsub -g group [option] dqn_cartpole.sh +cd $HOME +load module python/3.10 +python3 -m venv venv +source venv/bin/activate +pip3 install mlflow ``` + +### Install AutoROM (optional) + +```bash +cd $HOME +source venv/bin/activate +pip3 install autorom +mkdir atari_rom +AutoROM --install-dir atari_rom +``` + +## Run training + +### Submit training job + +```bash +cd dqn_cartpole +qsub -g [group_id] dqn_cartpole.sh +``` + +## Open MLflow in an interactive note + +### Login to the compute node + +```bash +qrsh -g $ABCI_GROUP_NAME -l rt_F=1 +``` + +### Run MLflow server + +```bash +module load python/3.10 +source venv/bin/activate +cd border +mlflow server --host 0.0.0.0 --port 8080 +``` + +### Portforward + +```bash +ssh -N -L 8080:host_name:8080 -l $ABCI_USER_NAME -i $ABCI_IDENTITY_FILE -p 10022 localhost +``` + +Access `localhost:8080` in your browser to show MLflow UI. diff --git a/singularity/abci/border.def b/singularity/abci/border.def index afa5bea3..33f83a29 100644 --- a/singularity/abci/border.def +++ b/singularity/abci/border.def @@ -5,13 +5,20 @@ From: border_base.sif export CARGO_HOME=/opt/.cargo export RUSTUP_HOME=/opt/.rust export PATH=$PATH:$CARGO_HOME/bin + export PATH=/usr/local/cuda-11.8/bin:$PATH + export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64:$LD_LIBRARY_PATH + export LIBTORCH=/opt/venv/lib64/python3.10/site-packages/torch + export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH + export LIBTORCH_CX11_ABI=0 + export PYTHONPATH=/root/border/border-py-gym-env/examples %runscript cp -r /opt/.cargo /root export LIBTORCH_CXX11_ABI=0 export CARGO_HOME=/root/.cargo export CARGO_TARGET_DIR=/root/target - export LIBTORCH=/opt/venv/lib/python3.8/site-packages/torch + export LIBTORCH=/opt/venv/lib/python3.10/site-packages/torch export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$LIBTORCH/lib export PYTHONPATH=/root/border/border-py-gym-env/examples + export PYO3_PYTHON=python3.10 bash -c "cd /root/border; source /opt/venv/bin/activate; $@" diff --git a/singularity/abci/border_base.def b/singularity/abci/border_base.def index ac1b973c..3cce8d23 100644 --- a/singularity/abci/border_base.def +++ b/singularity/abci/border_base.def @@ -1,6 +1,5 @@ Bootstrap: docker -From: ubuntu:20.04 -# From: nvidia/cuda:12.0.0-devel-ubuntu20.04 +From: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 %environment export CARGO_HOME=/opt/.cargo @@ -9,10 +8,17 @@ From: ubuntu:20.04 %post export DEBIAN_FRONTEND=noninteractive - echo "Set disable_coredump false" >> /etc/sudo.conf apt-get update -q && \ - apt-get upgrade -yq && \ + apt-get upgrade -yq + + # python + apt install -y gnupg2 curl + apt install software-properties-common -y + add-apt-repository ppa:deadsnakes/ppa -y + apt install -y python3.10 python3.10-dev python3.10-distutils python3.10-venv python3-pip + apt-get install -yq wget curl git build-essential vim sudo libssl-dev + echo "Set disable_coredump false" >> /etc/sudo.conf # clang apt install -y -q libclang-dev @@ -20,11 +26,10 @@ From: ubuntu:20.04 # sdl apt install -y -q --no-install-recommends \ libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \ - libsdl-dev libsdl-image1.2-dev + libsdl-image1.2-dev - # zip - apt install -y zip swig python3.8 python3.8-dev python3.8-distutils python3.8-venv python3-pip \ - cmake libgoogle-glog-dev libgflags-dev + # zip, swig, python, cmake, misc + apt install -y zip swig cmake libgoogle-glog-dev libgflags-dev # llvm, mesa for robosuite apt install -y llvm libosmesa6-dev @@ -38,27 +43,19 @@ From: ubuntu:20.04 curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y # Python - python3 -m venv /opt/venv - bash -c "source /opt/venv/bin/activate && \ - pip3 install --upgrade pip && \ - pip3 install pyyaml typing-extensions torch==1.12.0 ipython jupyterlab numpy==1.21.3 \ - gym[box2d]==0.26.2 robosuite==1.3.2 pybullet==3.2.5 - " - # pip3 install -U 'mujoco-py<2.2,>=2.1' dm-control==1.0.9 pyrender==0.1.45 dm2gym==0.2.0 - # " - - # # Mujoco aarch64 binary - # cd /opt && \ - # mkdir .mujoco && \ - # cd .mujoco && \ - # wget https://github.com/deepmind/mujoco/releases/download/2.1.1/mujoco-2.1.1-linux-aarch64.tar.gz - # cd /opt/.mujoco && \ - # tar zxf mujoco-2.1.1-linux-aarch64.tar.gz && \ - # mkdir -p mujoco210/bin && \ - # ln -sf $PWD/mujoco-2.1.1/lib/libmujoco.so.2.1.1 $PWD/mujoco210/bin/libmujoco210.so && \ - # ln -sf $PWD/mujoco-2.1.1/lib/libglewosmesa.so $PWD/mujoco210/bin/libglewosmesa.so && \ - # ln -sf $PWD/mujoco-2.1.1/include/ $PWD/mujoco210/include && \ - # ln -sf $PWD/mujoco-2.1.1/model/ $PWD/mujoco210/model + python3.10 -m venv /opt/venv + bash -c "source /opt/venv/bin/activate && pip3 install --upgrade pip" + bash -c "source /opt/venv/bin/activate && pip3 install pyyaml typing-extensions" + bash -c "source /opt/venv/bin/activate && pip3 install torch==2.3.0 --index-url https://download.pytorch.org/whl/test/cu118" + bash -c "source /opt/venv/bin/activate && pip3 install ipython jupyterlab" + bash -c "source /opt/venv/bin/activate && pip3 install numpy==1.21.3" + bash -c "source /opt/venv/bin/activate && pip3 install mujoco==2.3.7" + bash -c "source /opt/venv/bin/activate && pip3 install gymnasium[box2d]==0.29.0" + bash -c "source /opt/venv/bin/activate && pip3 install gymnasium-robotics==1.2.2" + bash -c "source /opt/venv/bin/activate && pip3 install tensorboard==2.16.2" + bash -c "source /opt/venv/bin/activate && pip3 install mlflow==2.11.1" + bash -c "source /opt/venv/bin/activate && pip3 install tabulate==0.9.0" + bash -c "source /opt/venv/bin/activate && pip3 install mlflow-export-import==1.2.0" # PyBullet Gym cd /opt && \ diff --git a/singularity/abci/build.sh b/singularity/abci/build.sh index 06f7e161..b93bdd5d 100644 --- a/singularity/abci/build.sh +++ b/singularity/abci/build.sh @@ -1,2 +1,4 @@ -sudo singularity build --fakeroot border_base.sif border_base.def -sudo singularity build --fakeroot border.sif border.def +module load singularitypro +SINGULARITY_TMPDIR=$SGE_LOCALDIR +singularity build --fakeroot border_base.sif border_base.def +singularity build --fakeroot border.sif border.def diff --git a/singularity/abci/dqn_atari_tch/dqn_atari_tch.sh b/singularity/abci/dqn_atari_tch/dqn_atari_tch.sh new file mode 100644 index 00000000..d46de177 --- /dev/null +++ b/singularity/abci/dqn_atari_tch/dqn_atari_tch.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +#$-l rt_G.small=1 +#$-l h_rt=48:00:00 +#$-j y +#$-cwd + +source $HOME/.bashrc +PATH_TO_BORDER=$HOME/border +source /etc/profile.d/modules.sh +module load singularitypro +cd $PATH_TO_BORDER/singularity/abci +sh run.sh "mlflow server --host 127.0.0.1 --port 8080 & \ + sleep 5 && \ + ATARI_ROM_DIR=$HOME/atari_rom cargo run --release --example dqn_atari_tch --features=candle-tch -- $1 --mlflow" diff --git a/singularity/abci/dqn_cartpole/dqn_cartpole.sh b/singularity/abci/dqn_cartpole/dqn_cartpole.sh index 7a5cce4c..fdb421cf 100644 --- a/singularity/abci/dqn_cartpole/dqn_cartpole.sh +++ b/singularity/abci/dqn_cartpole/dqn_cartpole.sh @@ -5,7 +5,10 @@ #$-cwd source $HOME/.bashrc +PATH_TO_BORDER=$HOME/border source /etc/profile.d/modules.sh module load singularitypro cd $PATH_TO_BORDER/singularity/abci -sh run.sh "cargo run --release --example dqn_cartpole --features=tch" +sh run.sh "mlflow server --host 127.0.0.1 --port 8080 & \ + sleep 5 && \ + cargo run --release --example dqn_cartpole --features=candle-core,cuda,cudnn -- --train --mlflow" diff --git a/singularity/abci/dqn_cartpole_tch/dqn_cartpole_tch.sh b/singularity/abci/dqn_cartpole_tch/dqn_cartpole_tch.sh new file mode 100644 index 00000000..ea2e16c2 --- /dev/null +++ b/singularity/abci/dqn_cartpole_tch/dqn_cartpole_tch.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#$-l rt_G.small=1 +#$-j y +#$-cwd + +source $HOME/.bashrc +PATH_TO_BORDER=$HOME/border +source /etc/profile.d/modules.sh +module load singularitypro +cd $PATH_TO_BORDER/singularity/abci +sh run.sh "mlflow server --host 127.0.0.1 --port 8080 & \ + sleep 5 && \ + cargo run --release --example dqn_cartpole_tch --features=tch -- --train --mlflow" diff --git a/singularity/abci/run.sh b/singularity/abci/run.sh index 4db06bc7..cfd826b5 100644 --- a/singularity/abci/run.sh +++ b/singularity/abci/run.sh @@ -1 +1,3 @@ -singularity run --fakeroot border.sif "$1" +module load cuda/11.8 +module load cudnn/8.8.1 +SINGULARITY_TMPDIR=$SGE_LOCALDIR singularity run --nv --fakeroot border.sif "$1" diff --git a/singularity/abci/sac_mujoco/sac_mujoco.sh b/singularity/abci/sac_mujoco/sac_mujoco.sh new file mode 100644 index 00000000..af3afb95 --- /dev/null +++ b/singularity/abci/sac_mujoco/sac_mujoco.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +#$-l rt_G.small=1 +#$-l h_rt=24:00:00 +#$-j y +#$-cwd + +source $HOME/.bashrc +PATH_TO_BORDER=$HOME/border +source /etc/profile.d/modules.sh +module load singularitypro +cd $PATH_TO_BORDER/singularity/abci +sh run.sh "mlflow server --host 127.0.0.1 --port 8080 & \ + sleep 5 && \ + cargo run --release --example sac_mujoco --features=candle-core,cuda,cudnn -- --train --mlflow --env $1" diff --git a/singularity/abci/sac_mujoco_tch/sac_mujoco_tch.sh b/singularity/abci/sac_mujoco_tch/sac_mujoco_tch.sh new file mode 100644 index 00000000..bbfe82ad --- /dev/null +++ b/singularity/abci/sac_mujoco_tch/sac_mujoco_tch.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +#$-l rt_G.small=1 +#$-l h_rt=24:00:00 +#$-j y +#$-cwd + +source $HOME/.bashrc +PATH_TO_BORDER=$HOME/border +source /etc/profile.d/modules.sh +module load singularitypro +cd $PATH_TO_BORDER/singularity/abci +sh run.sh "mlflow server --host 127.0.0.1 --port 8080 & \ + sleep 5 && \ + cargo run --release --example sac_mujoco_tch --features=tch -- --train --mlflow --env $1"