Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump the version of tch to 0.10.0 #92

Merged
merged 51 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
f6ee762
WIP: Add candle-agent
taku-y Sep 28, 2023
1f69ed2
Cargo fmt
taku-y Oct 16, 2023
d1c879a
Fix bug in MlpConfig
taku-y Oct 16, 2023
e6f0090
Tweak
taku-y Mar 12, 2024
c6821da
Update cargo files (#1)
taku-y Mar 15, 2024
4265180
Add dqn with candle (#1)
taku-y Mar 15, 2024
471678b
Tweaks (#1)
taku-y Mar 16, 2024
2b74b30
Tweaks
taku-y Mar 16, 2024
f267f67
Merge branch 'dev_0_0_7' into tmp
taku-y Mar 17, 2024
f1f83ab
Update ci.yml
taku-y Mar 17, 2024
8404d38
Tweaks for test
taku-y Mar 17, 2024
e9fd1e6
Bug fix, tweaks on dqn candle agent (#1)
taku-y Mar 20, 2024
e42977f
WIP: modify dqn_cartpole example (#1)
taku-y Mar 20, 2024
547fc91
Replace input argument with trait object (#4)
taku-y Mar 20, 2024
a0345d8
Buf fix (#1)
taku-y Mar 20, 2024
4017a0c
Support basic authentication (#3)
taku-y Mar 21, 2024
fd8496d
Bump the version of candle (#5)
taku-y Mar 21, 2024
47915a7
Add compile flag (#5)
taku-y Mar 21, 2024
f5464d7
MLflow export import (#3)
taku-y Mar 23, 2024
fd4e034
Support candle-core in atari env (#1)
taku-y Mar 24, 2024
5e35732
Support candle in derive macro (#1)
taku-y Mar 24, 2024
cc51187
Add cnn in candle agent (#1)
taku-y Mar 24, 2024
762cfc1
Rename script (#1)
taku-y Mar 24, 2024
e8ad6f1
Tweaks (#1)
taku-y Mar 24, 2024
62e238a
Change visibility of config fields (#1)
taku-y Mar 25, 2024
dea4040
Refactor atari example (#1)
taku-y Mar 25, 2024
490a37f
Move config files
taku-y Mar 25, 2024
aeaab7e
Remove deprecated parameter
taku-y Mar 25, 2024
258bb9d
Tweaks (#1)
taku-y Mar 25, 2024
94ea9d0
Update Cargo.toml (#1)
taku-y Mar 25, 2024
440b3ec
Modify code for dqn_atari (#1)
taku-y Mar 25, 2024
a9143b6
Default config for dqn-atari trainer (#1)
taku-y Mar 25, 2024
ec36d74
Default config for dqn-atari replay buffer (#1)
taku-y Mar 25, 2024
4d8a998
Default config for dqn-atari agent (#1)
taku-y Mar 30, 2024
b53642b
Add mlflow in installation script (#2, #3)
taku-y Mar 30, 2024
e8ec329
Support tag and creating experiment (#2)
taku-y Mar 30, 2024
aa4980b
Better logging (#3)
taku-y Mar 30, 2024
245e982
Support GPU (#1)
taku-y Mar 30, 2024
1729d15
Tweak
taku-y Mar 30, 2024
52b8dee
Update dqn_atari_tch
taku-y Mar 31, 2024
f43fbf6
Tweak install script
taku-y Mar 31, 2024
f1c572a
Bunp the version of tch (#6)
taku-y Mar 31, 2024
7a1c3cb
Tweaks (#1, #6)
taku-y Mar 31, 2024
850ff89
Tweaks for comparing RL code (#1, #3, #6)
taku-y Apr 2, 2024
461dff2
Support MSE loss in Dqn (#3, #6)
taku-y Apr 2, 2024
5e9959e
Support smooth l1 loss for candle agent (#1)
taku-y Apr 2, 2024
c28ad41
Support critic loss type for dqn_atari (#1)
taku-y Apr 2, 2024
294bf4b
Support AdamW in tch agent (#6)
taku-y Apr 2, 2024
06c572d
Tweak (#6)
taku-y Apr 2, 2024
0660e15
Tweaks
taku-y Apr 7, 2024
439199e
Update CHANGELOG
taku-y Apr 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]
rust: [1.68.2]
rust: [1.70.0]
python-version: [3.8]
steps:
- uses: actions/checkout@v2
Expand Down Expand Up @@ -76,10 +76,12 @@ jobs:
run: cargo test -p border-py-gym-env

- if: matrix.os == 'ubuntu-latest'
name: Test border
name: Test border examples
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends --fix-missing \
libsdl2-dev libsdl2-image-dev libsdl2-mixer-dev libsdl2-net-dev libsdl2-ttf-dev \
libsdl2-dev libsdl-image1.2-dev
cargo test -p border --features=tch
cargo test --example dqn_cartpole_tch --features=tch
cargo test --example iqn_cartpole --features=tch
cargo test --example sac_pendulum_tch --features=tch
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,5 @@ __pycache__
.vscode/**
doc/**

out/**
mlruns/**
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@

### Added

Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2).
* Support MLflow tracking (`border-mlflow-tracking`) (https://github.com/taku-y/border/issues/2).
* Add candle agent (`border-candle-agent`)

### Changed

* Take `self` in the signature of `push()` method of replay buffer (`border-core`)
* Fix a bug in `MlpConfig` (`border-tch-agent`)
* Bump the version of tch to 0.10.0 (`border-tch-agent`)

## v0.0.6 (2023-09-19)

Expand Down
15 changes: 9 additions & 6 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ version = "0.0.6"
edition = "2018"
rust-version = "1.70"
description = "Reinforcement learning library"
repository = "https://github.com/taku-y/border"
keywords = ["rl"]
repository = "https://github.com/laboroai/border"
keywords = ["Reinforcement learning"]
categories = ["science"]
license = "MIT OR Apache-2.0"

[workspace.dependencies]
clap = "2.33.3"
csv = "1.1.5"
fastrand = "1.4.0"
tch = "0.8.0"
tch = "0.10.0"
anyhow = "1.0.38"
crossbeam-channel = "0.5.1"
serde_yaml = "0.8.7"
Expand All @@ -42,11 +42,14 @@ env_logger = "0.8.2"
tempdir = "0.3.7"
num-traits = "0.2.14"
tensorboard-rs = "0.2.4"
pyo3 = { version = "=0.14.5", default-features=false }
pyo3 = { version = "=0.14.5", default-features = false }
ndarray = "0.15.1"
chrono = "0.4"
segment-tree = "2.0.0"
image = "0.23.14"
candle-core = "0.2.2"
candle-nn = "0.2.2"
candle-core = { version = "=0.4.1", feature = ["cuda"] }
candle-nn = "0.4.1"
rand = "0.8.5"
itertools = "0.12.1"
ordered-float = "4.2.0"
reqwest = { version = "0.11.26", features = ["json", "blocking"] }
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,6 @@ Crates | License
`border-py-gym-env` | MIT OR Apache-2.0
`border-atari-env` | GPL-2.0-or-later
`border-tch-agent` | MIT OR Apache-2.0
`border-candle-agent` | MIT OR Apache-2.0
`border-async-trainer`| MIT OR Apache-2.0
`border` | GPL-2.0-or-later
2 changes: 1 addition & 1 deletion border-async-trainer/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
mod base;
mod stat;
pub use base::Actor;
pub use stat::{ActorStat, actor_stats_fmt};
pub use stat::{actor_stats_fmt, ActorStat};
10 changes: 4 additions & 6 deletions border-async-trainer/src/actor/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,10 @@ where

// Stop sampling loop
if *self.stop.lock().unwrap() {
*self.stats.lock().unwrap() = Some(
ActorStat {
env_steps,
duration: time.elapsed().unwrap(),
}
);
*self.stats.lock().unwrap() = Some(ActorStat {
env_steps,
duration: time.elapsed().unwrap(),
});
break;
}
}
Expand Down
6 changes: 2 additions & 4 deletions border-async-trainer/src/actor_manager/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ pub struct ActorManagerConfig {
}

impl Default for ActorManagerConfig {
fn default() -> Self {
Self {
n_buffer: 100,
}
fn default() -> Self {
Self { n_buffer: 100 }
}
}
2 changes: 1 addition & 1 deletion border-async-trainer/src/async_trainer/config.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::{
fs::File,
io::{BufReader, Write},
Expand Down
2 changes: 1 addition & 1 deletion border-async-trainer/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ use thiserror::Error;
#[derive(Debug, Error)]
pub enum BorderAsyncTrainerError {
#[error("Error")]
SendMsgForPush
SendMsgForPush,
}
5 changes: 1 addition & 4 deletions border-async-trainer/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@
use crate::{
actor_stats_fmt, ActorManager, ActorManagerConfig, AsyncTrainer, AsyncTrainerConfig, SyncModel,
};
use border_core::{
Agent, DefaultEvaluator, Env, ReplayBufferBase,
StepProcessorBase,
};
use border_core::{Agent, DefaultEvaluator, Env, ReplayBufferBase, StepProcessorBase};
use border_tensorboard::TensorboardRecorder;
use crossbeam_channel::unbounded;
use log::info;
Expand Down
3 changes: 2 additions & 1 deletion border-atari-env/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dirs = { workspace = true }
border-core = { version = "0.0.6", path = "../border-core" }
image = { workspace = true }
tch = { workspace = true, optional = true }
candle-core = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
itertools = "0.10.1"
fastrand = { workspace = true }
Expand All @@ -42,7 +43,7 @@ default = [
"winit_input_helper",
"minifb",
"pixels",
"tch",
# "tch",
]
sdl = ["atari-env-sys/sdl"]

Expand Down
4 changes: 2 additions & 2 deletions border-atari-env/src/act.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl From<u8> for BorderAtariAct {
}
}

/// Converts `A` to [BorderAtariAct].
/// Converts `A` to [`BorderAtariAct`].
pub trait BorderAtariActFilter<A: Act> {
/// Configuration of the filter.
type Config: Clone + Default;
Expand All @@ -46,7 +46,7 @@ pub trait BorderAtariActFilter<A: Act> {
}

#[derive(Debug, Deserialize, Serialize)]
/// Configuration of [BorderAtariActRawFilter].
/// Configuration of [`BorderAtariActRawFilter`].
#[derive(Clone)]
pub struct BorderAtariActRawFilterConfig;

Expand Down
29 changes: 12 additions & 17 deletions border-atari-env/src/env.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
mod config;
mod window;
use super::BorderAtariAct;
use anyhow::Result;
use super::{BorderAtariActFilter, BorderAtariObsFilter};
use crate::atari_env::{AtariAction, AtariEnv, EmulatorConfig};
use border_core::{record::Record, Env, Info, Obs, Act, Step};
use anyhow::Result;
use border_core::{record::Record, Act, Env, Info, Obs, Step};
pub use config::BorderAtariEnvConfig;
use image::{
imageops::{/*grayscale,*/ resize, FilterType::Triangle},
ImageBuffer, /*Luma,*/ Rgb,
};
use std::{default::Default, marker::PhantomData};
use itertools::izip;
use std::ptr::copy;
use std::{default::Default, marker::PhantomData};
use window::AtariWindow;
#[cfg(feature = "atari-env-sys")]
use winit::{event_loop::ControlFlow, platform::run_return::EventLoopExtRunReturn};
use super::{BorderAtariObsFilter, BorderAtariActFilter};
use itertools::izip;

/// Empty struct.
pub struct NullInfo;
Expand Down Expand Up @@ -74,7 +74,7 @@ where
phantom: PhantomData<(O, A)>,
}

impl<O, A, OF, AF> BorderAtariEnv <O, A, OF, AF>
impl<O, A, OF, AF> BorderAtariEnv<O, A, OF, AF>
where
O: Obs,
A: Act,
Expand Down Expand Up @@ -169,9 +169,11 @@ where
let i1 = buf.iter().step_by(3);
let i2 = buf.iter().skip(1).step_by(3);
let i3 = buf.iter().skip(2).step_by(3);
izip![i1, i2, i3].map(|(&b, &g, &r)|
((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8
).collect::<Vec<_>>()
izip![i1, i2, i3]
.map(|(&b, &g, &r)| {
((0.299 * r as f32) + (0.587 * g as f32) + (0.114 * b as f32)) as u8
})
.collect::<Vec<_>>()
};
// let buf = {
// let img: ImageBuffer<Luma<u8>, _> = grayscale(&img);
Expand Down Expand Up @@ -339,14 +341,7 @@ where
let reward = self.clip_reward(reward); // in training
self.stack_frame(obs);
let (obs, _record) = self.obs_filter.filt(self.frames.clone().into());
let step = Step::new(
obs,
act_org,
reward,
is_done,
NullInfo,
Self::Obs::dummy(1),
);
let step = Step::new(obs, act_org, reward, is_done, NullInfo, Self::Obs::dummy(1));
let record = Record::empty();

if let Some(window) = self.window.as_mut() {
Expand Down
19 changes: 9 additions & 10 deletions border-atari-env/src/env/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
//!
//! If environment variable `ATARI_ROM_DIR` exists, it is used as the directory
//! from which ROM images of the Atari games is loaded.
use std::{env, default::Default};
use border_core::{Obs, Act};
use super::{BorderAtariObsFilter, BorderAtariActFilter};
use super::{BorderAtariActFilter, BorderAtariObsFilter};
use border_core::{Act, Obs};
use serde::{Deserialize, Serialize};
use std::{default::Default, env};

#[derive(Serialize, Deserialize, Debug)]
/// Configurations of [`BorderAtariEnv`](super::BorderAtariEnv).
Expand All @@ -16,12 +16,12 @@ where
OF: BorderAtariObsFilter<O>,
AF: BorderAtariActFilter<A>,
{
pub(super) rom_dir: String,
pub(super) name: String,
pub(super) obs_filter_config: OF::Config,
pub(super) act_filter_config: AF::Config,
pub(super) train: bool,
pub(super) render: bool,
pub rom_dir: String,
pub name: String,
pub obs_filter_config: OF::Config,
pub act_filter_config: AF::Config,
pub train: bool,
pub render: bool,
}

impl<O, A, OF, AF> Clone for BorderAtariEnvConfig<O, A, OF, AF>
Expand All @@ -43,7 +43,6 @@ where
}
}


impl<O, A, OF, AF> Default for BorderAtariEnvConfig<O, A, OF, AF>
where
O: Obs,
Expand Down
12 changes: 4 additions & 8 deletions border-atari-env/src/env/window.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use anyhow::Result;
use crate::atari_env::AtariEnv;
use anyhow::Result;
#[cfg(feature = "atari-env-sys")]
use {
pixels::{Pixels, SurfaceTexture},
winit::{
event_loop::EventLoop,
// platform::run_return::EventLoopExtRunReturn,
window::{Window, WindowBuilder},
}
},
};

pub(super) struct AtariWindow {
Expand All @@ -29,12 +29,8 @@ impl AtariWindow {
.with_inner_size(winit::dpi::LogicalSize::new(128.0, 128.0))
.build(&event_loop)?;
let surface_texture = SurfaceTexture::new(128, 128, &window);
let pixels = Pixels::new(
env.width() as u32,
env.height() as u32,
surface_texture,
)
.unwrap();
let pixels =
Pixels::new(env.width() as u32, env.height() as u32, surface_texture).unwrap();
// event_loop.run_return(move |_event, _, _control_flow| {});

Ok(Self {
Expand Down
10 changes: 5 additions & 5 deletions border-atari-env/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//! You need to place Atari Rom directories under the directory specified by environment variable
//! `ATARI_ROM_DIR`. An easy way to do this is to use [AutoROM](https://pypi.org/project/AutoROM/)
//! Python package.
//!
//!
//! ```bash
//! pip install autorom
//! mkdir $HOME/atari_rom
Expand All @@ -21,7 +21,7 @@
//! ```
//!
//! Here is an example of running Pong environment with a random policy.
//!
//!
//! ```no_run
//! use anyhow::Result;
//! use border_atari_env::{
Expand Down Expand Up @@ -87,10 +87,10 @@
//! ```
//! [`atari_env::AtariEnv::lives()`]: atari_env::AtariEnv::lives
mod act;
mod obs;
pub mod atari_env;
mod env;
mod obs;
pub mod util;
pub mod atari_env;
pub use act::{BorderAtariAct, BorderAtariActFilter, BorderAtariActRawFilter};
pub use obs::{BorderAtariObs, BorderAtariObsFilter, BorderAtariObsRawFilter};
pub use env::{BorderAtariEnv, BorderAtariEnvConfig};
pub use obs::{BorderAtariObs, BorderAtariObsFilter, BorderAtariObsRawFilter};
14 changes: 14 additions & 0 deletions border-atari-env/src/obs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
//! Instead, the scaling is applied in CNN model.
use anyhow::Result;
use border_core::{record::Record, Obs};
#[cfg(feature = "candle-core")]
use candle_core::{Device::Cpu, Tensor};
use serde::{Deserialize, Serialize};
use std::{default::Default, marker::PhantomData};
#[cfg(feature = "tch")]
Expand Down Expand Up @@ -53,6 +55,18 @@ impl From<BorderAtariObs> for Tensor {
}
}

#[cfg(feature = "candle-core")]
impl From<BorderAtariObs> for Tensor {
fn from(obs: BorderAtariObs) -> Tensor {
let tmp = obs.frames;
// Assumes the batch size is 1, implying non-vectorized environment
Tensor::from_vec(tmp, &[1 * 4 * 1 * 84 * 84], &Cpu)
.unwrap()
.reshape(&[1, 4, 1, 84, 84])
.unwrap()
}
}

/// Converts [`BorderAtariObs`] to `O` with an arbitrary processing.
pub trait BorderAtariObsFilter<O: Obs> {
/// Configuration of the filter.
Expand Down
2 changes: 1 addition & 1 deletion border-atari-env/src/util.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub mod test;
pub mod test;
2 changes: 1 addition & 1 deletion border-atari-env/src/util/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ impl<R: ReplayBufferBase> Agent_<Env, R> for RandomAgent {
if buffer.len() <= 100 {
None
} else {
// Do nothing
// Do nothing
self.n_opts_steps += 1;
Some(Record::empty())
}
Expand Down
Loading
Loading