Skip to content

Commit

Permalink
SAC cleanups (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Nov 24, 2023
1 parent 3073ffd commit 0f5a774
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 60 deletions.
2 changes: 1 addition & 1 deletion burn-rl/src/agent/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ pub use ppo::config::PPOTrainingConfig;
pub use ppo::model::{PPOModel, PPOOutput};
pub use sac::agent::SAC;
pub use sac::config::SACTrainingConfig;
pub use sac::model::{SACActor, SACCritic, SACTemperature};
pub use sac::model::{SACActor, SACCritic, SACNets, SACTemperature};
pub use sac::optimizer::SACOptimizer;
68 changes: 33 additions & 35 deletions burn-rl/src/agent/sac/agent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::agent::sac::model::SACTemperature;
use crate::agent::sac::model::{SACNets, SACTemperature};
use crate::agent::{SACActor, SACCritic, SACOptimizer, SACTrainingConfig};
use crate::base::agent::Agent;
use crate::base::environment::Environment;
Expand Down Expand Up @@ -85,12 +85,7 @@ impl<E: Environment, B: ADBackend, Actor: SACActor<B> + ADModule<B>> SAC<E, B, A
#[allow(clippy::too_many_arguments)]
pub fn train<const CAP: usize, Critic: SACCritic<B> + ADModule<B>>(
&mut self,
mut actor: Actor,
mut critic_1: Critic,
mut critic_1_target: Critic,
mut critic_2: Critic,
mut critic_2_target: Critic,
mut temperature: SACTemperature<B>,
mut nets: SACNets<B, Actor, Critic>,
memory: &Memory<E, B, CAP>,
optimizer: &mut SACOptimizer<
B,
Expand All @@ -101,24 +96,24 @@ impl<E: Environment, B: ADBackend, Actor: SACActor<B> + ADModule<B>> SAC<E, B, A
impl Optimizer<SACTemperature<B>, B> + Sized,
>,
config: &SACTrainingConfig,
) -> (Actor, Critic, Critic, Critic, Critic, SACTemperature<B>) {
) -> SACNets<B, Actor, Critic> {
let action_dim = <<E as Environment>::ActionType as Action>::size();
let sample_indices = sample_indices((0..memory.len()).collect(), config.batch_size);
let state_batch = get_batch(memory.states(), &sample_indices, ref_to_state_tensor);

let action_prob = actor.forward(state_batch.clone());
let action_prob = nets.actor.forward(state_batch.clone());
let log_prob = action_prob.clone().clamp_min(config.min_probability).log();
let q1 = critic_1.forward(state_batch.clone());
let q2 = critic_2.forward(state_batch.clone());
let q1 = nets.critic_1.forward(state_batch.clone());
let q2 = nets.critic_2.forward(state_batch.clone());
let q_min = elementwise_min(q1, q2);
let log_alpha = temperature.forward();
let log_alpha = nets.temperature.forward();
let alpha = log_alpha.clone().exp();
let actor_loss = (action_prob.clone() * (alpha.clone() * log_prob.clone() - q_min))
.sum_dim(1)
.mean();
actor = update_parameters(
nets.actor = update_parameters(
actor_loss,
actor,
nets.actor,
&mut optimizer.actor_optimizer,
config.learning_rate.into(),
);
Expand All @@ -127,9 +122,9 @@ impl<E: Environment, B: ADBackend, Actor: SACActor<B> + ADModule<B>> SAC<E, B, A
let temperature_loss = -(log_alpha.clone()
* (entropy.clone().sub_scalar(action_dim as ElemType)).detach())
.mean();
temperature = update_parameters(
nets.temperature = update_parameters(
temperature_loss,
temperature,
nets.temperature,
&mut optimizer.temperature_optimizer,
config.learning_rate.into(),
);
Expand All @@ -140,49 +135,52 @@ impl<E: Environment, B: ADBackend, Actor: SACActor<B> + ADModule<B>> SAC<E, B, A
let reward_batch = get_batch(memory.rewards(), &sample_indices, ref_to_reward_tensor);
let not_done_batch = get_batch(memory.dones(), &sample_indices, ref_to_not_done_tensor);

let action_prob = actor.clone().no_grad().forward(next_state_batch.clone());
let action_prob = nets
.actor
.clone()
.no_grad()
.forward(next_state_batch.clone());

let q1_target_next = critic_1_target
let q1_target_next = nets
.critic_1_target
.clone()
.no_grad()
.forward(next_state_batch.clone());
let q2_target_next = critic_2_target.clone().no_grad().forward(next_state_batch);
let q2_target_next = nets
.critic_2_target
.clone()
.no_grad()
.forward(next_state_batch);
let q_min_target_next = elementwise_min(q1_target_next, q2_target_next);
let q_next = action_prob.clone() * (q_min_target_next - alpha.clone() * entropy);
let q_target =
reward_batch + not_done_batch.mul_scalar(config.gamma) * q_next.sum_dim(1).no_grad();

let q1 = critic_1
let q1 = nets
.critic_1
.forward(state_batch.clone())
.gather(1, action_batch.clone());
let critic_1_loss = MSELoss::default().forward(q_target.clone(), q1, Reduction::Sum);
critic_1 = update_parameters(
nets.critic_1 = update_parameters(
critic_1_loss,
critic_1,
nets.critic_1,
&mut optimizer.critic_1_optimizer,
config.learning_rate.into(),
);

let q2 = critic_2.forward(state_batch).gather(1, action_batch);
let q2 = nets.critic_2.forward(state_batch).gather(1, action_batch);
let critic_2_loss = MSELoss::default().forward(q_target, q2, Reduction::Sum);
critic_2 = update_parameters(
nets.critic_2 = update_parameters(
critic_2_loss,
critic_2,
nets.critic_2,
&mut optimizer.critic_2_optimizer,
config.learning_rate.into(),
);

SACCritic::soft_update(&mut critic_1_target, &critic_1, config.tau);
SACCritic::soft_update(&mut critic_2_target, &critic_2, config.tau);
SACCritic::soft_update(&mut nets.critic_1_target, &nets.critic_1, config.tau);
SACCritic::soft_update(&mut nets.critic_2_target, &nets.critic_2, config.tau);

(
actor,
critic_1,
critic_1_target,
critic_2,
critic_2_target,
temperature,
)
nets
}

pub fn valid(&self, actor: Actor) -> SAC<E, B::InnerBackend, Actor::InnerModule>
Expand Down
23 changes: 23 additions & 0 deletions burn-rl/src/agent/sac/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,26 @@ impl<B: Backend> SACTemperature<B> {
self.temperature.val()
}
}

pub struct SACNets<B: Backend, Actor: SACActor<B>, Critic: SACCritic<B>> {
pub actor: Actor,
pub critic_1: Critic,
pub critic_1_target: Critic,

pub critic_2: Critic,
pub critic_2_target: Critic,
pub temperature: SACTemperature<B>,
}

impl<B: Backend, Actor: SACActor<B>, Critic: SACCritic<B>> SACNets<B, Actor, Critic> {
pub fn new(actor: Actor, critic_1: Critic, critic_2: Critic) -> Self {
Self {
actor,
critic_1: critic_1.clone(),
critic_1_target: critic_1,
critic_2: critic_2.clone(),
critic_2_target: critic_2,
temperature: SACTemperature::<B>::default(),
}
}
}
29 changes: 5 additions & 24 deletions examples/src/sac.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn::optim::AdamWConfig;
use burn::tensor::activation::{relu, softmax};
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::Tensor;
use burn_rl::agent::{SACActor, SACCritic, SACOptimizer, SACTemperature, SACTrainingConfig, SAC};
use burn_rl::agent::{SACActor, SACCritic, SACNets, SACOptimizer, SACTrainingConfig, SAC};
use burn_rl::base::{Action, Agent, ElemType, Environment, Memory, Model, State};

#[derive(Module, Debug)]
Expand Down Expand Up @@ -83,10 +83,8 @@ pub fn run<E: Environment, B: ADBackend>(num_episodes: usize, visualized: bool)

let mut actor = Actor::<B>::new(state_dim, DENSE_SIZE, action_dim);
let mut critic_1 = Critic::<B>::new(state_dim, DENSE_SIZE, action_dim);
let mut critic_1_target = critic_1.clone();
let mut critic_2 = Critic::<B>::new(state_dim, DENSE_SIZE, action_dim);
let mut critic_2_target = critic_2.clone();
let mut temperature = SACTemperature::<B>::default();
let mut nets = SACNets::<B, Actor<B>, Critic<B>>::new(actor, critic_1, critic_2);

let mut agent = MyAgent::default();

Expand Down Expand Up @@ -114,7 +112,7 @@ pub fn run<E: Environment, B: ADBackend>(num_episodes: usize, visualized: bool)
let mut state = env.state();

while !episode_done {
let action = MyAgent::<E, _>::react_with_model(&state, &actor);
let action = MyAgent::<E, _>::react_with_model(&state, &nets.actor);
let snapshot = env.step(action);

episode_reward +=
Expand All @@ -129,24 +127,7 @@ pub fn run<E: Environment, B: ADBackend>(num_episodes: usize, visualized: bool)
);

if config.batch_size < memory.len() {
(
actor,
critic_1,
critic_1_target,
critic_2,
critic_2_target,
temperature,
) = agent.train::<MEMORY_SIZE, _>(
actor,
critic_1,
critic_1_target,
critic_2,
critic_2_target,
temperature,
&memory,
&mut optimizer,
&config,
);
nets = agent.train::<MEMORY_SIZE, _>(nets, &memory, &mut optimizer, &config);
}

step += 1;
Expand All @@ -166,5 +147,5 @@ pub fn run<E: Environment, B: ADBackend>(num_episodes: usize, visualized: bool)
}
}

agent.valid(actor)
agent.valid(nets.actor)
}

0 comments on commit 0f5a774

Please sign in to comment.