Skip to content

Commit

Permalink
Runnable training (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Nov 17, 2023
1 parent 9e650aa commit dc79139
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 132 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ edition = "2021"
publish = false

[dependencies]
rand = "0.8.4"
rand = "0.8.5"
burn = "0.10.0"
burn-autodiff = "0.10.0"
burn-ndarray = "0.10.0"
ndarray = "0.15.6"
ndarray-linalg = "0.16.0"
ordered-float = "4.1.1"
gym-rs = "0.3.0"
serde = "1.0.130"
serde = "1.0.192"
ringbuffer = "0.15.0"

[features]
default = ["burn/ndarray"]
109 changes: 86 additions & 23 deletions src/agent/dqn.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,39 @@
use crate::base::{Model, State};
use crate::base::{Action, Memory, Model, State};
use crate::components::agent::Agent;
use crate::components::env::Environment;
use burn::tensor::backend::Backend;
use burn::nn::loss::{MSELoss, Reduction};
use burn::optim::{GradientsParams, Optimizer};
use burn::tensor::backend::ADBackend;
use burn::tensor::{ElementConversion, Tensor};
use rand::random;
use std::marker::PhantomData;

pub struct Dqn<E: Environment, B: Backend, M: Model<B>> {
is_eval: bool,
model: M,
const GAMMA: f64 = 0.999;
const TAU: f64 = 0.005;
const LR: f64 = 0.001;

pub struct Dqn<E: Environment, B: ADBackend, M: Model<B>, const EVAL: bool> {
target_net: M,
state: PhantomData<E::StateType>,
action: PhantomData<E::ActionType>,
backend: PhantomData<B>,
}

impl<E: Environment, B: Backend, M: Model<B>> Dqn<E, B, M> {
pub(crate) fn new(model: M) -> Self {
impl<E: Environment, B: ADBackend, M: Model<B>, const EVAL: bool> Dqn<E, B, M, EVAL> {
pub fn new(model: M) -> Self {
Self {
is_eval: false,
model,
target_net: model,
state: PhantomData,
action: PhantomData,
backend: PhantomData,
}
}

fn convert(state: &E::StateType) -> Tensor<B, 2> {
fn convert_state_to_tensor(state: <Self as Agent>::StateType) -> Tensor<B, 2> {
state.data().unsqueeze()
}
}

impl<E: Environment, B: Backend, M: Model<B>> Agent for Dqn<E, B, M> {
type StateType = E::StateType;
type ActionType = E::ActionType;

fn react(&mut self, state: &Self::StateType) -> Self::ActionType {
let output = self.model.forward(Self::convert(state));
fn convert_tenor_to_action(output: Tensor<B, 2>) -> <Self as Agent>::ActionType {
unsafe {
output
.argmax(0)
Expand All @@ -46,15 +45,79 @@ impl<E: Environment, B: Backend, M: Model<B>> Agent for Dqn<E, B, M> {
}
}

fn collect(&mut self, _reward: f32, _done: bool) {
todo!()
pub fn react_with_exploration(
&self,
policy_net: &M,
state: <Self as Agent>::StateType,
eps_threshold: f64,
) -> <Self as Agent>::ActionType {
if random::<f64>() > eps_threshold {
Self::convert_tenor_to_action(policy_net.forward(Self::convert_state_to_tensor(state)))
} else {
Action::random()
}
}
}

fn reset(&mut self) {
todo!()
impl<E: Environment, B: ADBackend, M: Model<B>, const EVAL: bool> Agent for Dqn<E, B, M, EVAL> {
type StateType = E::StateType;
type ActionType = E::ActionType;

fn react(&self, state: &Self::StateType) -> Self::ActionType {
Self::convert_tenor_to_action(
self.target_net
.forward(Self::convert_state_to_tensor(*state)),
)
}
}

impl<E: Environment, B: ADBackend, M: Model<B>> Dqn<E, B, M, false> {
pub fn train<const SIZE: usize, T>(
&mut self,
mut policy_net: M,
memory: &Memory<E, B, SIZE>,
optimizer: &mut (impl Optimizer<M, B, Record = T> + Sized),
) -> M {
let sample = memory.sample::<SIZE>();

let state_action_values = policy_net
.forward(sample.state_batch())
.gather(1, sample.action_batch());
let next_state_values = self
.target_net
.forward(sample.next_state_batch())
.max_dim(1);

fn is_eval(&self) -> bool {
self.is_eval
let not_done_batch = sample.not_done_batch();
let reward_batch = sample.reward_batch();

let expected_state_action_values =
(next_state_values * not_done_batch).mul_scalar(GAMMA) + reward_batch;

let loss = MSELoss::new().forward(
state_action_values,
expected_state_action_values,
Reduction::Mean,
);

let gradients = loss.backward();
let gradient_params = GradientsParams::from_grads(gradients, &policy_net);

policy_net = optimizer.step(LR, policy_net, gradient_params);
self.target_net.soft_update(&policy_net, TAU);

policy_net
}

pub fn model(&self) -> &M {
&self.target_net
}
pub fn to_eval(&self) -> Dqn<E, B, M, true> {
Dqn::<E, B, M, true> {
target_net: self.target_net.clone(),
state: PhantomData,
action: PhantomData,
backend: PhantomData,
}
}
}
10 changes: 1 addition & 9 deletions src/agent/random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,7 @@ impl<E: Environment> Agent for Random<E> {
type StateType = E::StateType;
type ActionType = E::ActionType;

fn react(&mut self, _state: &Self::StateType) -> Self::ActionType {
fn react(&self, _state: &Self::StateType) -> Self::ActionType {
Self::ActionType::random()
}

fn collect(&mut self, _reward: f32, _done: bool) {}

fn reset(&mut self) {}

fn is_eval(&self) -> bool {
true
}
}
6 changes: 5 additions & 1 deletion src/base/action.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use std::fmt::Debug;

pub trait Action: Debug + Copy + Clone + Default + From<u32> {
pub trait Action: Debug + Copy + Clone + Default + From<u32> + Into<u32> {
fn random() -> Self;
fn enumerate() -> Vec<Self>;

fn size() -> usize {
Self::enumerate().len()
}
}
Loading

0 comments on commit dc79139

Please sign in to comment.