Skip to content

Commit

Permalink
Merge pull request #94 from taku-y/object_safe
Browse files Browse the repository at this point in the history
Object safe
  • Loading branch information
taku-y authored Apr 26, 2024
2 parents a9d80f1 + 2524dd0 commit 04c6eb4
Show file tree
Hide file tree
Showing 25 changed files with 258 additions and 256 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

### Added

* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2).
* Add candle agent (`border-candle-agent`)
* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2)
* Add candle agent (`border-candle-agent`) (https://github.com/taku-y/border/issues/1)
* Split policy trait into two traits, one for sampling (`Policy`) and the other for configuration (`Configurable`) (https://github.com/taku-y/border/issues/12)

### Changed

Expand Down
6 changes: 3 additions & 3 deletions border-async-trainer/src/actor/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{ActorStat, PushedItemMessage, ReplayBufferProxy, ReplayBufferProxyConfig, SyncModel};
use border_core::{Agent, Env, ReplayBufferBase, Sampler, StepProcessor};
use border_core::{Agent, Configurable, Env, ReplayBufferBase, Sampler, StepProcessor};
use crossbeam_channel::Sender;
use log::info;
use std::{
Expand Down Expand Up @@ -32,7 +32,7 @@ use std::{
/// [`AsyncTrainer`]: crate::AsyncTrainer
pub struct Actor<A, E, P, R>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
P: StepProcessor<E>,
R: ReplayBufferBase<PushedItem = P::Output>,
Expand All @@ -53,7 +53,7 @@ where

impl<A, E, P, R> Actor<A, E, P, R>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
P: StepProcessor<E>,
R: ReplayBufferBase<PushedItem = P::Output>,
Expand Down
6 changes: 3 additions & 3 deletions border-async-trainer/src/actor_manager/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
Actor, ActorManagerConfig, ActorStat, PushedItemMessage, ReplayBufferProxyConfig, SyncModel,
};
use border_core::{Agent, Env, ReplayBufferBase, StepProcessor};
use border_core::{Agent, Configurable, Env, ReplayBufferBase, StepProcessor};
use crossbeam_channel::{bounded, /*unbounded,*/ Receiver, Sender};
use log::info;
use std::{
Expand All @@ -18,7 +18,7 @@ use std::{
/// * From the [`Actor`]s for pushing sample batch to the `LearnerManager`.
pub struct ActorManager<A, E, R, P>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
P: StepProcessor<E>,
R: ReplayBufferBase<PushedItem = P::Output>,
Expand Down Expand Up @@ -65,7 +65,7 @@ where

impl<A, E, R, P> ActorManager<A, E, R, P>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
P: StepProcessor<E>,
R: ReplayBufferBase<PushedItem = P::Output> + Send + 'static,
Expand Down
6 changes: 3 additions & 3 deletions border-async-trainer/src/async_trainer/base.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{AsyncTrainStat, AsyncTrainerConfig, PushedItemMessage, SyncModel};
use border_core::{
record::{Record, RecordValue::Scalar, Recorder},
Agent, Env, Evaluator, ReplayBufferBase,
Agent, Configurable, Env, Evaluator, ReplayBufferBase,
};
use crossbeam_channel::{Receiver, Sender};
use log::info;
Expand Down Expand Up @@ -54,7 +54,7 @@ use std::{
/// [`SyncModel::ModelInfo`]: crate::SyncModel::ModelInfo
pub struct AsyncTrainer<A, E, R>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
// R: ReplayBufferBase + Sync + Send + 'static,
R: ReplayBufferBase,
Expand Down Expand Up @@ -101,7 +101,7 @@ where

impl<A, E, R> AsyncTrainer<A, E, R>
where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
// R: ReplayBufferBase + Sync + Send + 'static,
R: ReplayBufferBase,
Expand Down
4 changes: 2 additions & 2 deletions border-async-trainer/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::{
actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel,
};
use border_core::{Agent, DefaultEvaluator, Env, ReplayBufferBase, StepProcessor};
use border_core::{Agent, Configurable, DefaultEvaluator, Env, ReplayBufferBase, StepProcessor};
use border_tensorboard::TensorboardRecorder;
use crossbeam_channel::unbounded;
use log::info;
Expand Down Expand Up @@ -40,7 +40,7 @@ pub fn train_async<A, E, R, S, P>(
actor_man_config: &ActorManagerConfig,
async_trainer_config: &AsyncTrainerConfig,
) where
A: Agent<E, R> + SyncModel,
A: Agent<E, R> + Configurable<E> + SyncModel,
E: Env,
R: ReplayBufferBase<PushedItem = S::Output> + Send + 'static,
S: StepProcessor<E>,
Expand Down
12 changes: 7 additions & 5 deletions border-atari-env/examples/random_pong.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use border_atari_env::{
BorderAtariAct, BorderAtariActRawFilter, BorderAtariEnv, BorderAtariEnvConfig, BorderAtariObs,
BorderAtariObsRawFilter,
};
use border_core::{DefaultEvaluator, Env as _, Evaluator, Policy};
use border_core::{Configurable, DefaultEvaluator, Env as _, Evaluator, Policy};

type Obs = BorderAtariObs;
type Act = BorderAtariAct;
Expand All @@ -22,17 +22,19 @@ struct RandomPolicy {
}

impl Policy<Env> for RandomPolicy {
fn sample(&mut self, _: &Obs) -> Act {
fastrand::u8(..self.n_acts as u8).into()
}
}

impl Configurable<Env> for RandomPolicy {
type Config = RandomPolicyConfig;

fn build(config: Self::Config) -> Self {
Self {
n_acts: config.n_acts,
}
}

fn sample(&mut self, _: &Obs) -> Act {
fastrand::u8(..self.n_acts as u8).into()
}
}

fn env_config(name: String) -> EnvConfig {
Expand Down
12 changes: 7 additions & 5 deletions border-atari-env/src/util/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::Result;
use border_core::{
record::Record,
replay_buffer::{SimpleReplayBuffer, SubBatch},
Agent as Agent_, Policy, ReplayBufferBase,
Agent as Agent_, Configurable, Policy, ReplayBufferBase,
};
use std::ptr::copy;

Expand Down Expand Up @@ -146,6 +146,12 @@ pub struct RandomAgent {
}

impl Policy<Env> for RandomAgent {
fn sample(&mut self, _: &Obs) -> Act {
fastrand::u8(..self.n_acts as u8).into()
}
}

impl Configurable<Env> for RandomAgent {
type Config = RandomAgentConfig;

fn build(config: Self::Config) -> Self {
Expand All @@ -155,10 +161,6 @@ impl Policy<Env> for RandomAgent {
train: true,
}
}

fn sample(&mut self, _: &Obs) -> Act {
fastrand::u8(..self.n_acts as u8).into()
}
}

impl<R: ReplayBufferBase> Agent_<Env, R> for RandomAgent {
Expand Down
86 changes: 41 additions & 45 deletions border-candle-agent/src/dqn/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
use anyhow::Result;
use border_core::{
record::{Record, RecordValue},
Agent, Env, Policy, ReplayBufferBase, StdBatchBase,
Agent, Configurable, Env, Policy, ReplayBufferBase, StdBatchBase,
};
use candle_core::{shape::D, DType, Device, Tensor};
use candle_nn::loss::mse;
Expand All @@ -20,15 +20,8 @@ use std::{fs, marker::PhantomData, path::Path};
/// DQN agent implemented with tch-rs.
pub struct Dqn<E, Q, R>
where
E: Env,
Q: SubModel1<Output = Tensor>,
R: ReplayBufferBase,
E::Obs: Into<Q::Input>,
E::Act: From<Q::Output>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
R::Batch: StdBatchBase,
<R::Batch as StdBatchBase>::ObsBatch: Into<Q::Input>,
<R::Batch as StdBatchBase>::ActBatch: Into<Tensor>,
{
pub(in crate::dqn) soft_update_interval: usize,
pub(in crate::dqn) soft_update_counter: usize,
Expand Down Expand Up @@ -58,8 +51,6 @@ where
E: Env,
Q: SubModel1<Output = Tensor>,
R: ReplayBufferBase,
E::Obs: Into<Q::Input>,
E::Act: From<Q::Output>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
R::Batch: StdBatchBase,
<R::Batch as StdBatchBase>::ObsBatch: Into<Q::Input>,
Expand Down Expand Up @@ -197,13 +188,49 @@ impl<E, Q, R> Policy<E> for Dqn<E, Q, R>
where
E: Env,
Q: SubModel1<Output = Tensor>,
R: ReplayBufferBase,
E::Obs: Into<Q::Input>,
E::Act: From<Q::Output>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
R::Batch: StdBatchBase,
<R::Batch as StdBatchBase>::ObsBatch: Into<Q::Input>,
<R::Batch as StdBatchBase>::ActBatch: Into<Tensor>,
{
/// In evaluation mode, take a random action with probability 0.01.
fn sample(&mut self, obs: &E::Obs) -> E::Act {
let a = self.qnet.forward(&obs.clone().into()).detach();
let a = if self.train {
self.n_samples_act += 1;
match &mut self.explorer {
DqnExplorer::Softmax(softmax) => softmax.action(&a, &mut self.rng),
DqnExplorer::EpsilonGreedy(egreedy) => {
if self.record_verbose_level >= 2 {
let (act, best) = egreedy.action_with_best(&a, &mut self.rng);
if best {
self.n_samples_best_act += 1;
}
act
} else {
egreedy.action(&a, &mut self.rng)
}
}
}
} else {
if self.rng.gen::<f32>() < 0.01 {
let n_actions = a.dims()[1] as i64;
let a: i64 = self.rng.gen_range(0..n_actions);
Tensor::try_from(vec![a]).unwrap()
} else {
a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap()
}
};
a.into()
}
}

impl<E, Q, R> Configurable<E> for Dqn<E, Q, R>
where
E: Env,
Q: SubModel1<Output = Tensor>,
E::Obs: Into<Q::Input>,
E::Act: From<Q::Output>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
type Config = DqnConfig<Q>;

Expand Down Expand Up @@ -241,37 +268,6 @@ where
rng: SmallRng::seed_from_u64(42),
}
}

/// In evaluation mode, take a random action with probability 0.01.
fn sample(&mut self, obs: &E::Obs) -> E::Act {
let a = self.qnet.forward(&obs.clone().into()).detach();
let a = if self.train {
self.n_samples_act += 1;
match &mut self.explorer {
DqnExplorer::Softmax(softmax) => softmax.action(&a, &mut self.rng),
DqnExplorer::EpsilonGreedy(egreedy) => {
if self.record_verbose_level >= 2 {
let (act, best) = egreedy.action_with_best(&a, &mut self.rng);
if best {
self.n_samples_best_act += 1;
}
act
} else {
egreedy.action(&a, &mut self.rng)
}
}
}
} else {
if self.rng.gen::<f32>() < 0.01 {
let n_actions = a.dims()[1] as i64;
let a: i64 = self.rng.gen_range(0..n_actions);
Tensor::try_from(vec![a]).unwrap()
} else {
a.argmax(D::Minus1).unwrap().to_dtype(DType::I64).unwrap()
}
};
a.into()
}
}

impl<E, Q, R> Agent<E, R> for Dqn<E, Q, R>
Expand Down
58 changes: 28 additions & 30 deletions border-candle-agent/src/sac/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
use anyhow::Result;
use border_core::{
record::{Record, RecordValue},
Agent, Env, Policy, ReplayBufferBase, StdBatchBase,
Agent, Configurable, Env, Policy, ReplayBufferBase, StdBatchBase,
};
use candle_core::{Device, Tensor, D};
use candle_nn::loss::mse;
Expand All @@ -27,18 +27,10 @@ fn normal_logp(x: &Tensor) -> Result<Tensor> {
/// Soft actor critic (SAC) agent.
pub struct Sac<E, Q, P, R>
where
E: Env,
Q: SubModel2<Output = ActionValue>,
P: SubModel1<Output = (ActMean, ActStd)>,
R: ReplayBufferBase,
E::Obs: Into<Q::Input1> + Into<P::Input>,
E::Act: Into<Q::Input2>,
Q::Input2: From<ActMean>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
R::Batch: StdBatchBase,
<R::Batch as StdBatchBase>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
<R::Batch as StdBatchBase>::ActBatch: Into<Q::Input2> + Into<Tensor>,
{
pub(super) qnets: Vec<Critic<Q>>,
pub(super) qnets_tgt: Vec<Critic<Q>>,
Expand Down Expand Up @@ -225,15 +217,37 @@ where
E: Env,
Q: SubModel2<Output = ActionValue>,
P: SubModel1<Output = (ActMean, ActStd)>,
R: ReplayBufferBase,
E::Obs: Into<Q::Input1> + Into<P::Input>,
E::Act: Into<Q::Input2> + From<Tensor>,
Q::Input2: From<ActMean>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
R::Batch: StdBatchBase,
<R::Batch as StdBatchBase>::ObsBatch: Into<Q::Input1> + Into<P::Input> + Clone,
<R::Batch as StdBatchBase>::ActBatch: Into<Q::Input2> + Into<Tensor>,
{
fn sample(&mut self, obs: &E::Obs) -> E::Act {
let obs = obs.clone().into();
let (mean, lstd) = self.pi.forward(&obs);
let std = lstd
.clamp(self.min_lstd, self.max_lstd)
.unwrap()
.exp()
.unwrap();
let act = if self.train {
((std * mean.randn_like(0., 1.).unwrap()).unwrap() + mean).unwrap()
} else {
mean
};
act.tanh().unwrap().into()
}
}

impl<E, Q, P, R> Configurable<E> for Sac<E, Q, P, R>
where
E: Env,
Q: SubModel2<Output = ActionValue>,
P: SubModel1<Output = (ActMean, ActStd)>,
E::Obs: Into<Q::Input1> + Into<P::Input>,
E::Act: Into<Q::Input2> + From<Tensor>,
Q::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
P::Config: DeserializeOwned + Serialize + OutDim + std::fmt::Debug + PartialEq + Clone,
{
type Config = SacConfig<Q, P>;

Expand Down Expand Up @@ -274,22 +288,6 @@ where
phantom: PhantomData,
}
}

fn sample(&mut self, obs: &E::Obs) -> E::Act {
let obs = obs.clone().into();
let (mean, lstd) = self.pi.forward(&obs);
let std = lstd
.clamp(self.min_lstd, self.max_lstd)
.unwrap()
.exp()
.unwrap();
let act = if self.train {
((std * mean.randn_like(0., 1.).unwrap()).unwrap() + mean).unwrap()
} else {
mean
};
act.tanh().unwrap().into()
}
}

impl<E, Q, P, R> Agent<E, R> for Sac<E, Q, P, R>
Expand Down
Loading

0 comments on commit 04c6eb4

Please sign in to comment.