Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Memory #23

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions burn-rl/src/agent/dqn.rs → burn-rl/src/agent/dqn/agent.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
use crate::agent::{DQNMemory, DQNModel};
use crate::base::agent::Agent;
use crate::base::environment::Environment;
use crate::base::{Action, Memory, Model};
use crate::base::Action;
use crate::utils::{convert_state_to_tensor, convert_tenor_to_action};
use burn::module::ADModule;
use burn::nn::loss::{MSELoss, Reduction};
use burn::optim::{GradientsParams, Optimizer};
use burn::tensor::backend::{ADBackend, Backend};
use burn::tensor::Tensor;
use rand::random;
use std::marker::PhantomData;

const GAMMA: f64 = 0.999;
const TAU: f64 = 0.005;
const LR: f64 = 0.001;

pub trait DQNModel<B: Backend>: Model<B, Tensor<B, 2>, Tensor<B, 2>> {
fn soft_update(this: &mut Self, that: &Self, tau: f64);
}

pub struct DQN<E: Environment, B: Backend, M: DQNModel<B>> {
target_net: M,
state: PhantomData<E::StateType>,
Expand All @@ -29,7 +25,7 @@ impl<E: Environment, B: Backend, M: DQNModel<B>> Agent<E> for DQN<E, B, M> {
fn react(&self, state: &E::StateType) -> E::ActionType {
convert_tenor_to_action::<E::ActionType, B>(
self.target_net
.forward(convert_state_to_tensor::<E::StateType, B>(*state)),
.forward(convert_state_to_tensor::<E::StateType, B>(*state).unsqueeze()),
)
}
}
Expand All @@ -56,7 +52,7 @@ impl<E: Environment, B: ADBackend, M: DQNModel<B>> DQN<E, B, M> {
) -> E::ActionType {
if random::<f64>() > eps_threshold {
convert_tenor_to_action::<E::ActionType, B>(
policy_net.forward(convert_state_to_tensor::<E::StateType, B>(state)),
policy_net.forward(convert_state_to_tensor::<E::StateType, B>(state).unsqueeze()),
)
} else {
Action::random()
Expand All @@ -68,7 +64,7 @@ impl<E: Environment, B: ADBackend, M: DQNModel<B> + ADModule<B>> DQN<E, B, M> {
pub fn train<const BATCH_SIZE: usize>(
&mut self,
mut policy_net: M,
sample: Memory<E, B, BATCH_SIZE>,
sample: DQNMemory<E, B, BATCH_SIZE>,
optimizer: &mut (impl Optimizer<M, B> + Sized),
) -> M {
let state_action_values = policy_net
Expand Down
291 changes: 291 additions & 0 deletions burn-rl/src/agent/dqn/memory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
use crate::base::{Environment, Memory, Transition};
use crate::utils::{convert_state_to_tensor, stack};
use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor};
use ringbuffer::{ConstGenericRingBuffer, RingBuffer};
use std::marker::PhantomData;
use std::ops::Index;

#[derive(Debug)]
pub struct DQNTransition<E: Environment, B: Backend> {
state: E::StateType,
next_state: E::StateType,
action: E::ActionType,
reward: E::RewardType,
done: bool,
environment: PhantomData<E>,
backend: PhantomData<B>,
}

impl<E: Environment, B: Backend> Transition<E> for DQNTransition<E, B> {
fn state(&self) -> &E::StateType {
&self.state
}

fn action(&self) -> &E::ActionType {
&self.action
}

fn reward(&self) -> &E::RewardType {
&self.reward
}

fn next_state(&self) -> &E::StateType {
&self.next_state
}

fn is_done(&self) -> bool {
self.done
}
}

impl<E: Environment, B: Backend> Clone for DQNTransition<E, B> {
fn clone(&self) -> Self {
*self
}
}

impl<E: Environment, B: Backend> Copy for DQNTransition<E, B> {}

impl<E: Environment, B: Backend> DQNTransition<E, B> {
pub fn new(
state: E::StateType,
next_state: E::StateType,
action: E::ActionType,
reward: E::RewardType,
done: bool,
) -> Self {
Self {
state,
next_state,
action,
reward,
done,
environment: PhantomData,
backend: PhantomData,
}
}

fn state(&self) -> &E::StateType {
&self.state
}

fn next_state(&self) -> &E::StateType {
&self.next_state
}

fn action(&self) -> &E::ActionType {
&self.action
}

fn reward(&self) -> &E::RewardType {
&self.reward
}

fn done(&self) -> &bool {
&self.done
}
}

pub struct DQNMemory<E: Environment, B: Backend, const CAP: usize> {
transitions: ConstGenericRingBuffer<DQNTransition<E, B>, CAP>,
environment: PhantomData<E>,
backend: PhantomData<B>,
}

impl<E: Environment, B: Backend, const CAP: usize> Default for DQNMemory<E, B, CAP> {
fn default() -> Self {
Self {
transitions: ConstGenericRingBuffer::new(),
environment: PhantomData,
backend: PhantomData,
}
}
}

impl<E: Environment, B: Backend, const CAP: usize> Index<usize> for DQNMemory<E, B, CAP> {
type Output = DQNTransition<E, B>;

fn index(&self, index: usize) -> &Self::Output {
&self.transitions[index]
}
}

impl<E: Environment, B: Backend, const CAP: usize> Memory<E, B, CAP> for DQNMemory<E, B, CAP> {
type TransitionType = DQNTransition<E, B>;

fn get(&self, index: usize) -> &Self::TransitionType {
&self.transitions[index]
}

fn push(&mut self, transition: DQNTransition<E, B>) {
self.transitions.push(transition);
}

fn len(&self) -> usize {
self.transitions.len()
}

fn is_empty(&self) -> bool {
self.transitions.is_empty()
}
}

impl<E: Environment, B: Backend, const CAP: usize> DQNMemory<E, B, CAP> {
pub fn next_state_batch(&self) -> Tensor<B, 2> {
stack(&self.transitions, |transition| {
convert_state_to_tensor(*transition.next_state())
})
}

pub fn state_batch(&self) -> Tensor<B, 2> {
stack(&self.transitions, |transition| {
convert_state_to_tensor(*transition.state())
})
}

pub fn action_batch(&self) -> Tensor<B, 2, Int> {
stack(&self.transitions, |transition| {
let action_index: u32 = (*transition.action()).into();
Tensor::<B, 1, Int>::from_ints([action_index as i32])
})
}

pub fn reward_batch(&self) -> Tensor<B, 2> {
stack(&self.transitions, |transition| {
let reward: f32 = (*transition.reward()).into();
Tensor::from_floats([reward])
})
}
pub fn not_done_batch(&self) -> Tensor<B, 2> {
stack(&self.transitions, |transition| {
let not_done = if *transition.done() { 0.0 } else { 1.0 };
Tensor::from_floats([not_done])
})
}
}

#[cfg(test)]
mod tests {
use crate::agent::{DQNMemory, DQNTransition};
use crate::base::environment::Environment;
use crate::base::{sample_memory, Action, ElemType, Memory, Snapshot, State};
use burn::backend::NdArrayBackend;
use burn::tensor::backend::Backend;
use burn::tensor::{Shape, Tensor};
use std::fmt::Debug;

#[derive(Debug, Copy, Clone)]
struct TestAction {
data: i32,
}

impl From<u32> for TestAction {
fn from(value: u32) -> Self {
value.into()
}
}

impl From<TestAction> for u32 {
fn from(action: TestAction) -> Self {
action.data as u32
}
}

impl Action for TestAction {
fn random() -> Self {
Self { data: 1 }
}

fn enumerate() -> Vec<Self> {
vec![Self { data: 1 }]
}
}

#[derive(Debug, Copy, Clone)]
struct TestState {
data: [ElemType; 2],
}

impl State for TestState {
type Data = [ElemType; 2];

fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
}
fn size() -> usize {
1
}
}

#[derive(Debug)]
struct TestEnv {}

type TestBackend = NdArrayBackend<ElemType>;

impl Environment for TestEnv {
type StateType = TestState;
type ActionType = TestAction;
type RewardType = ElemType;

fn state(&self) -> Self::StateType {
todo!()
}

fn reset(&mut self) -> Snapshot<Self::StateType> {
todo!()
}

fn step(&mut self, _action: Self::ActionType) -> Snapshot<Self::StateType> {
todo!()
}
}

#[test]
fn test_memory() {
let mut memory = DQNMemory::<TestEnv, TestBackend, 16>::default();
for i in 0..20 {
memory.push(DQNTransition::<TestEnv, TestBackend>::new(
TestState {
data: [i as ElemType, (i * 2) as ElemType],
},
TestState {
data: [-i as ElemType, (i * 3) as ElemType],
},
TestAction { data: i },
0.1,
false,
));
}
let sample = sample_memory::<16, 5, _, _, _, DQNMemory<TestEnv, TestBackend, 5>>(&memory);
assert_eq!(sample.len(), 5);

let state_batch = sample.state_batch();
assert_eq!(state_batch.shape(), Shape::new([5, 2]));
let state_sample = state_batch
.select(0, Tensor::from_ints([0, 1]))
.to_data()
.value;
assert_eq!(state_sample[0] * 2.0, state_sample[1]);

let next_state_batch = sample.next_state_batch();
assert_eq!(next_state_batch.shape(), Shape::new([5, 2]));
let next_state_sample = next_state_batch
.select(0, Tensor::from_ints([0, 1]))
.to_data()
.value;
assert_eq!(next_state_sample[0] * -3.0, next_state_sample[1]);

let action_batch = sample.action_batch();
assert_eq!(action_batch.shape(), Shape::new([5, 1]));
assert_eq!(action_batch.to_data().value[0], state_sample[0] as i64);

let reward_batch = sample.reward_batch();
assert_eq!(reward_batch.shape(), Shape::new([5, 1]));
assert_eq!(reward_batch.to_data().value[0], 0.1);

let not_done_batch = sample.not_done_batch();
assert_eq!(not_done_batch.shape(), Shape::new([5, 1]));
assert_eq!(not_done_batch.to_data().value[0], 1.0);
}
}
3 changes: 3 additions & 0 deletions burn-rl/src/agent/dqn/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
pub mod agent;
pub mod memory;
pub mod model;
7 changes: 7 additions & 0 deletions burn-rl/src/agent/dqn/model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
use crate::base::Model;
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;

pub trait DQNModel<B: Backend>: Model<B, Tensor<B, 2>, Tensor<B, 2>> {
fn soft_update(this: &mut Self, that: &Self, tau: f64);
}
6 changes: 3 additions & 3 deletions burn-rl/src/agent/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod dqn;
mod random;

pub use dqn::{DQNModel, DQN};
pub use random::Random;
pub use dqn::agent::DQN;
pub use dqn::memory::{DQNMemory, DQNTransition};
pub use dqn::model::DQNModel;
12 changes: 0 additions & 12 deletions burn-rl/src/agent/random.rs

This file was deleted.

5 changes: 4 additions & 1 deletion burn-rl/src/base/environment.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::base::{Action, Snapshot, State};
use std::fmt::Debug;

pub trait Environment {
pub trait Environment: Debug {
type StateType: State;
type ActionType: Action;
type RewardType: Debug + Copy + Clone + Into<f32>;

const MAX_STEPS: usize = usize::MAX;

fn state(&self) -> Self::StateType;
Expand Down
Loading
Loading