Skip to content

Commit

Permalink
Add mountain car environment wrapper (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
yunjhongwu authored Nov 18, 2023
1 parent 4ef4173 commit 24dc96e
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 22 deletions.
3 changes: 1 addition & 2 deletions burn-rl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@ publish = false
[dependencies]
rand = "0.8.5"
burn = { version = "0.10.0", features = ["ndarray", "autodiff"] }
ordered-float = "4.1.1"
gym-rs = "0.3.0"
ringbuffer = "0.15.0"
ringbuffer = "0.15.0"
5 changes: 4 additions & 1 deletion burn-rl/src/base/action.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use rand::{thread_rng, Rng};
use std::fmt::Debug;

pub trait Action: Debug + Copy + Clone + From<u32> + Into<u32> {
fn random() -> Self;
fn random() -> Self {
(thread_rng().gen_range(0..Self::size()) as u32).into()
}
fn enumerate() -> Vec<Self>;

fn size() -> usize {
Expand Down
1 change: 1 addition & 0 deletions burn-rl/src/base/environment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::base::{Action, Snapshot, State};
pub trait Environment {
type StateType: State;
type ActionType: Action;
const MAX_STEPS: usize = usize::MAX;

fn state(&self) -> Self::StateType;
fn reset(&mut self) -> Snapshot<Self::StateType>;
Expand Down
19 changes: 5 additions & 14 deletions burn-rl/src/environment/cart_pole.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use burn::tensor::Tensor;
use gym_rs::core::Env;
use gym_rs::envs::classical_control::cartpole::{CartPoleEnv, CartPoleObservation};
use gym_rs::utils::renderer::RenderMode;
use rand::random;
use std::fmt::Debug;

type StateData = [ElemType; 4];
#[derive(Debug, Copy, Clone)]
pub struct CartPoleState {
data: [ElemType; 4],
data: StateData,
}

impl From<CartPoleObservation> for CartPoleState {
Expand All @@ -29,7 +29,7 @@ impl From<CartPoleObservation> for CartPoleState {
}

impl State for CartPoleState {
type Data = [ElemType; 4];
type Data = StateData;
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
}
Expand Down Expand Up @@ -65,20 +65,9 @@ impl From<CartPoleAction> for u32 {
}

impl Action for CartPoleAction {
fn random() -> Self {
if random::<ElemType>() < 0.5 {
Self::Left
} else {
Self::Right
}
}

fn enumerate() -> Vec<Self> {
vec![Self::Left, Self::Right]
}
fn size() -> usize {
2
}
}

#[derive(Debug)]
Expand All @@ -87,6 +76,7 @@ pub struct CartPole {
}

impl CartPole {
#[allow(unused)]
pub fn new(visualized: bool) -> Self {
Self {
gym_env: CartPoleEnv::new(if visualized {
Expand All @@ -101,6 +91,7 @@ impl CartPole {
impl Environment for CartPole {
type StateType = CartPoleState;
type ActionType = CartPoleAction;
const MAX_STEPS: usize = 200;

fn state(&self) -> Self::StateType {
self.gym_env.state.into()
Expand Down
2 changes: 2 additions & 0 deletions burn-rl/src/environment/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod cart_pole;
mod mountain_car;

pub use cart_pole::CartPole;
pub use mountain_car::MountainCar;
115 changes: 115 additions & 0 deletions burn-rl/src/environment/mountain_car.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use crate::base::environment::Environment;
use crate::base::{Action, State};
use crate::base::{ElemType, Snapshot};
use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
use gym_rs::core::Env;
use gym_rs::envs::classical_control::mountain_car::{MountainCarEnv, MountainCarObservation};
use gym_rs::utils::renderer::RenderMode;
use std::fmt::Debug;

type StateData = [ElemType; 2];
#[derive(Debug, Copy, Clone)]
pub struct MountainCarState {
data: StateData,
}

impl From<MountainCarObservation> for MountainCarState {
fn from(observation: MountainCarObservation) -> Self {
let vec = Vec::<f64>::from(observation);
Self {
data: [vec[0] as ElemType, vec[1] as ElemType],
}
}
}

impl State for MountainCarState {
type Data = StateData;
fn to_tensor<B: Backend>(&self) -> Tensor<B, 1> {
Tensor::<B, 1>::from_floats(self.data)
}

fn size() -> usize {
2
}
}

#[derive(Debug, Copy, Clone)]
pub enum MountainCarAction {
AccelerateToLeft,
NotAccelerate,
AccelerateToRight,
}

impl From<u32> for MountainCarAction {
fn from(value: u32) -> Self {
match value {
0 => Self::AccelerateToLeft,
1 => Self::NotAccelerate,
2 => Self::AccelerateToRight,
_ => panic!("Invalid action"),
}
}
}

impl From<MountainCarAction> for u32 {
fn from(action: MountainCarAction) -> Self {
match action {
MountainCarAction::AccelerateToLeft => 0,
MountainCarAction::NotAccelerate => 1,
MountainCarAction::AccelerateToRight => 2,
}
}
}

impl Action for MountainCarAction {
fn enumerate() -> Vec<Self> {
vec![
Self::AccelerateToLeft,
Self::NotAccelerate,
Self::AccelerateToRight,
]
}
}

#[derive(Debug)]
pub struct MountainCar {
gym_env: MountainCarEnv,
}

impl MountainCar {
#[allow(unused)]
pub fn new(visualized: bool) -> Self {
Self {
gym_env: MountainCarEnv::new(if visualized {
RenderMode::Human
} else {
RenderMode::None
}),
}
}
}

impl Environment for MountainCar {
type StateType = MountainCarState;
type ActionType = MountainCarAction;
const MAX_STEPS: usize = 200;

fn state(&self) -> Self::StateType {
self.gym_env.state.into()
}

fn reset(&mut self) -> Snapshot<Self::StateType> {
self.gym_env.reset(None, false, None);
Snapshot::new(self.gym_env.state.into(), 0.0, false)
}

fn step(&mut self, action: MountainCarAction) -> Snapshot<MountainCarState> {
let action_reward = self.gym_env.step(action as usize);
Snapshot::new(
action_reward.observation.into(),
*action_reward.reward as ElemType,
action_reward.done,
)
}
}
13 changes: 8 additions & 5 deletions examples/src/dqn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl<B: Backend> Model<B> for DQNModel<B> {
const MEMORY_SIZE: usize = 4096;
const BATCH_SIZE: usize = 128;

fn demo_model(agent: impl Agent<CartPole>) {
fn demo_model(agent: impl Agent<MyEnv>) {
let mut env = MyEnv::new(true);
let mut state = env.state();
let mut done = false;
Expand Down Expand Up @@ -108,7 +108,8 @@ pub fn run() {

for episode in 0..num_episodes {
let mut episode_done = false;
let mut episode_duration = 0;
let mut episode_reward = 0.0;
let mut episode_duration = 0_usize;
let mut state = env.state();

while !episode_done {
Expand All @@ -117,6 +118,8 @@ pub fn run() {
let action = MyAgent::react_with_exploration(&policy_net, state, eps_threshold);
let snapshot = env.step(action);

episode_reward += snapshot.reward();

memory.push(
state,
*snapshot.state(),
Expand All @@ -132,13 +135,13 @@ pub fn run() {
step += 1;
episode_duration += 1;

if snapshot.done() || episode_duration >= 500 {
if snapshot.done() || episode_duration >= MyEnv::MAX_STEPS {
env.reset();
episode_done = true;

println!(
"{{\"episode\": {}, \"duration\": {:.4}}}",
episode, episode_duration
"{{\"episode\": {}, \"reward\": {:.4}}}, \"duration\": {}",
episode, episode_reward, episode_duration
);
} else {
state = *snapshot.state();
Expand Down

0 comments on commit 24dc96e

Please sign in to comment.