Skip to content

Commit

Permalink
Fix deatch; smooth L1 loss; realistic model; cleanups (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Nov 17, 2023
1 parent dc79139 commit 7b23e07
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 60 deletions.
10 changes: 6 additions & 4 deletions src/agent/dqn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use crate::base::{Action, Memory, Model, State};
use crate::components::agent::Agent;
use crate::components::env::Environment;
use burn::nn::loss::{MSELoss, Reduction};
use crate::utils::SmoothL1Loss;
use burn::nn::loss::Reduction;
use burn::optim::{GradientsParams, Optimizer};
use burn::tensor::backend::ADBackend;
use burn::tensor::{ElementConversion, Tensor};
Expand Down Expand Up @@ -72,7 +73,7 @@ impl<E: Environment, B: ADBackend, M: Model<B>, const EVAL: bool> Agent for Dqn<
}

impl<E: Environment, B: ADBackend, M: Model<B>> Dqn<E, B, M, false> {
pub fn train<const SIZE: usize, T>(
pub fn train<const SAMPLE_SIZE: usize, const SIZE: usize, T>(
&mut self,
mut policy_net: M,
memory: &Memory<E, B, SIZE>,
Expand All @@ -86,15 +87,16 @@ impl<E: Environment, B: ADBackend, M: Model<B>> Dqn<E, B, M, false> {
let next_state_values = self
.target_net
.forward(sample.next_state_batch())
.max_dim(1);
.max_dim(1)
.detach();

let not_done_batch = sample.not_done_batch();
let reward_batch = sample.reward_batch();

let expected_state_action_values =
(next_state_values * not_done_batch).mul_scalar(GAMMA) + reward_batch;

let loss = MSELoss::new().forward(
let loss = SmoothL1Loss::default().forward(
state_action_values,
expected_state_action_values,
Reduction::Mean,
Expand Down
49 changes: 32 additions & 17 deletions src/base/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,53 @@ use rand::prelude::SliceRandom;
use ringbuffer::{ConstGenericRingBuffer, RingBuffer};
use std::marker::PhantomData;

#[allow(unused)]
pub struct Memory<E: Environment, B: Backend, const CAP: usize> {
state: ConstGenericRingBuffer<E::StateType, CAP>,
next_state: ConstGenericRingBuffer<E::StateType, CAP>,
action: ConstGenericRingBuffer<E::ActionType, CAP>,
state: ConstGenericRingBuffer<Tensor<B, 1>, CAP>,
next_state: ConstGenericRingBuffer<Tensor<B, 1>, CAP>,
action: ConstGenericRingBuffer<u32, CAP>,
reward: ConstGenericRingBuffer<ElemType, CAP>,
done: ConstGenericRingBuffer<bool, CAP>,
environment: PhantomData<E>,
backend: PhantomData<B>,
}

impl<E: Environment, B: Backend, const CAP: usize> Memory<E, B, CAP> {
#[allow(unused)]
pub fn new() -> Self {
impl<E: Environment, B: Backend, const CAP: usize> Default for Memory<E, B, CAP> {
fn default() -> Self {
Self {
state: ConstGenericRingBuffer::new(),
next_state: ConstGenericRingBuffer::new(),
action: ConstGenericRingBuffer::new(),
reward: ConstGenericRingBuffer::new(),
done: ConstGenericRingBuffer::new(),
environment: PhantomData,
backend: PhantomData,
}
}
}

#[allow(unused)]
impl<E: Environment, B: Backend, const CAP: usize> Memory<E, B, CAP> {
pub fn push(
&mut self,
state: E::StateType,
next_state: E::StateType,
action: E::ActionType,
reward: ElemType,
done: bool,
) {
self.state.push(state.data());
self.next_state.push(next_state.data());
self.action.push(action.into());
self.reward.push(reward);
self.done.push(done);
}

pub fn push_tensor(
&mut self,
state: Tensor<B, 1>,
next_state: Tensor<B, 1>,
action: u32,
reward: ElemType,
done: bool,
) {
self.state.push(state);
self.next_state.push(next_state);
Expand All @@ -45,16 +61,15 @@ impl<E: Environment, B: Backend, const CAP: usize> Memory<E, B, CAP> {
self.done.push(done);
}

#[allow(unused)]
pub fn sample<const SIZE: usize>(&self) -> Memory<E, B, SIZE> {
let mut rng = rand::thread_rng();
let mut indices: Vec<usize> = (0..self.len()).collect();
indices.shuffle(&mut rng);
let mut memory = Memory::<E, B, SIZE>::new();
let mut memory = Memory::<E, B, SIZE>::default();
for index in indices.iter().take(SIZE).copied() {
memory.push(
self.state[index],
self.next_state[index],
memory.push_tensor(
self.state[index].clone(),
self.next_state[index].clone(),
self.action[index],
self.reward[index],
self.done[index],
Expand All @@ -71,14 +86,14 @@ impl<E: Environment, B: Backend, const CAP: usize> Memory<E, B, CAP> {
.reshape([data.len() as i32, -1])
}
pub fn next_state_batch(&self) -> Tensor<B, 2> {
Self::stack(&self.state, |state| state.data())
Self::stack(&self.state, |state| state.clone())
}
pub fn state_batch(&self) -> Tensor<B, 2> {
Self::stack(&self.state, |state| state.data())
Self::stack(&self.state, |state| state.clone())
}
pub fn action_batch(&self) -> Tensor<B, 2, Int> {
Self::stack(&self.action, |action| {
Tensor::<B, 1, Int>::from_ints([(*action).into() as i32])
Tensor::<B, 1, Int>::from_ints([*action as i32])
})
}
pub fn reward_batch(&self) -> Tensor<B, 2> {
Expand Down Expand Up @@ -179,7 +194,7 @@ mod tests {

#[test]
fn test_memory() {
let mut memory = Memory::<TestEnv, TestBackend, 16>::new();
let mut memory = Memory::<TestEnv, TestBackend, 16>::default();
for i in 0..20 {
memory.push(
TestState {
Expand Down
97 changes: 58 additions & 39 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ mod agent;
mod base;
mod components;
mod env;
mod utils;

use crate::agent::Dqn;
use crate::base::{Action, ElemType, Memory, Model, State};
use crate::components::agent::Agent;
use crate::components::env::Environment;
use crate::env::cart_pole::CartPole;
use burn::backend::ndarray::NdArrayBackend;
use burn::grad_clipping::GradientClippingConfig;
use burn::module::{Module, Param};
use burn::nn::{Linear, LinearConfig};
use burn::optim::AdamConfig;
Expand All @@ -23,13 +26,15 @@ type MyEnv = CartPole;
pub struct DQNModel<B: Backend> {
linear_0: Linear<B>,
linear_1: Linear<B>,
linear_2: Linear<B>,
}

impl<B: ADBackend> DQNModel<B> {
pub fn new(input_size: usize, dense_size: usize, output_size: usize) -> Self {
Self {
linear_0: LinearConfig::new(input_size, dense_size).init(),
linear_1: LinearConfig::new(dense_size, output_size).init(),
linear_1: LinearConfig::new(dense_size, dense_size).init(),
linear_2: LinearConfig::new(dense_size, output_size).init(),
}
}

Expand All @@ -40,40 +45,42 @@ impl<B: ADBackend> DQNModel<B> {
) -> Param<Tensor<B, N>> {
let other_weight = that.val();
let self_weight = this.val();
let new_weight = self_weight * tau + other_weight * (1.0 - tau);
let new_weight = self_weight * (1.0 - tau) + other_weight * tau;

Param::from(new_weight.no_grad())
}
fn soft_update_linear(this: &mut Linear<B>, that: &Linear<B>, tau: f64) {
this.weight = Self::soft_update_tensor(&this.weight, &that.weight, tau);
if let (Some(self_bias), Some(other_bias)) = (&mut this.bias, &that.bias) {
this.bias = Some(Self::soft_update_tensor(self_bias, other_bias, tau));
}
}
}

impl<B: ADBackend> Model<B> for DQNModel<B> {
fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let layer_0_output = relu(self.linear_0.forward(input));
let layer_1_output = relu(self.linear_1.forward(layer_0_output));

relu(self.linear_1.forward(layer_0_output))
relu(self.linear_2.forward(layer_1_output))
}

fn soft_update(&mut self, other: &Self, tau: f64) {
self.linear_0.weight =
Self::soft_update_tensor(&self.linear_0.weight, &other.linear_0.weight, tau);
if let (Some(self_bias), Some(other_bias)) = (&mut self.linear_0.bias, &other.linear_0.bias)
{
self.linear_0.bias = Some(Self::soft_update_tensor(self_bias, other_bias, tau));
}
self.linear_1.weight =
Self::soft_update_tensor(&self.linear_1.weight, &other.linear_1.weight, tau);
if let (Some(self_bias), Some(other_bias)) = (&mut self.linear_1.bias, &other.linear_1.bias)
{
self.linear_1.bias = Some(Self::soft_update_tensor(self_bias, other_bias, tau));
}
Self::soft_update_linear(&mut self.linear_0, &other.linear_0, tau);
Self::soft_update_linear(&mut self.linear_1, &other.linear_1, tau);
Self::soft_update_linear(&mut self.linear_2, &other.linear_2, tau);
}
}

const MEMORY_SIZE: usize = 2048;
const BATCH_SIZE: usize = 128;

pub fn main() {
let num_episodes = 512_usize;
let eps_decay = 1000.0;
let eps_start = 0.9;
let eps_end = 0.05;
let dense_size = 16;
let dense_size = 96_usize;
let reward_ewma_decay = 0.95;

let mut env = MyEnv::new(false);
Expand All @@ -83,22 +90,26 @@ pub fn main() {
<<MyEnv as Environment>::ActionType as Action>::size(),
);

let mut agent = agent::Dqn::<MyEnv, DQNBackend, DQNModel<DQNBackend>, false>::new(model);
let mut state = env.state();
let mut step = 0;
let mut agent = Dqn::<MyEnv, DQNBackend, DQNModel<DQNBackend>, false>::new(model);

let mut step = 0_usize;

let mut memory = Memory::<MyEnv, DQNBackend, 256>::new();
let mut optimizer = AdamConfig::new().init::<DQNBackend, DQNModel<DQNBackend>>();
let mut memory = Memory::<MyEnv, DQNBackend, MEMORY_SIZE>::default();
let mut optimizer = AdamConfig::new()
.with_grad_clipping(Some(GradientClippingConfig::Value(100.0)))
.init::<DQNBackend, DQNModel<DQNBackend>>();
let mut policy_net = agent.model().clone();
let mut ewma_reward = 0.0;
for episode in 0..1024 {
let mut done = false;

for episode in 0..num_episodes {
let mut episode_done = false;
let mut episode_duration = 0;
while !done {
let mut state = env.state();
while !episode_done {
let eps_threshold =
eps_end + (eps_start - eps_end) * f64::exp(-(step as f64) / eps_decay);
let action = agent.react_with_exploration(&policy_net, state, eps_threshold);
let mut snapshot = env.step(action);
let snapshot = env.step(action);

memory.push(
state,
Expand All @@ -107,27 +118,35 @@ pub fn main() {
snapshot.reward(),
snapshot.done(),
);
if step > memory.len() {
policy_net = agent.train(policy_net, &memory, &mut optimizer);
if BATCH_SIZE < memory.len() {
policy_net =
agent.train::<BATCH_SIZE, MEMORY_SIZE, _>(policy_net, &memory, &mut optimizer);
}
if snapshot.done() {
snapshot = env.reset();
done = true;
}
state = *snapshot.state();

step += 1;
episode_duration += 1;

if snapshot.done() {
env.reset();
episode_done = true;
ewma_reward = (1.0 - reward_ewma_decay) * episode_duration as f64
+ reward_ewma_decay * ewma_reward;
println!(
"Episode {}: step = {}, EWMA episode duration = {:.4}, epsilon threshold = {:.4}",
episode, step, ewma_reward, eps_threshold
);
} else {
state = *snapshot.state();
}
}
ewma_reward =
(1.0 - reward_ewma_decay) * episode_duration as f64 + reward_ewma_decay * ewma_reward;
println!(
"Episode: {}, step: {}, EWMA episode duration: {}",
episode, step, ewma_reward
);
}

demo_model(agent.to_eval());
}

fn demo_model(agent: Dqn<MyEnv, DQNBackend, DQNModel<DQNBackend>, true>) {
let mut env = MyEnv::new(true);
let agent = agent.to_eval();
let mut state = env.state();
let mut done = false;
while !done {
let action = agent.react(&state);
Expand Down
3 changes: 3 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod smooth_l1;

pub use smooth_l1::SmoothL1Loss;
Loading

0 comments on commit 7b23e07

Please sign in to comment.