Skip to content

Commit

Permalink
Tweaks on replay buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
taku-y committed Jul 14, 2023
1 parent 813d189 commit 82868dc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
4 changes: 2 additions & 2 deletions border-core/src/replay_buffer/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ where
}

fn batch(&mut self, size: usize) -> anyhow::Result<Self::Batch> {
let (ixs, weight) = if let Some(per_state) = &self.per_state {
let sum_tree = &per_state.sum_tree;
let (ixs, weight) = if let Some(per_state) = &mut self.per_state {
let sum_tree = &mut per_state.sum_tree;
let beta = per_state.iw_scheduler.beta();
let (ixs, weight) = sum_tree.sample(size, beta);
let ixs = ixs.iter().map(|&ix| ix as usize).collect();
Expand Down
41 changes: 23 additions & 18 deletions border-core/src/replay_buffer/base/sum_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//!
//! Code is adapted from <https://github.com/jaromiru/AI-blog/blob/master/SumTree.py> and
/// <https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py>
use rand::{rngs::StdRng, RngCore};
use segment_tree::{
ops::{MaxIgnoreNaN, MinIgnoreNaN},
SegmentPoint,
Expand All @@ -27,6 +28,7 @@ pub struct SumTree {
min_tree: SegmentPoint<f32, MinIgnoreNaN>,
max_tree: SegmentPoint<f32, MaxIgnoreNaN>,
normalize: WeightNormalizer,
rng: fastrand::Rng,
}

impl SumTree {
Expand All @@ -40,6 +42,7 @@ impl SumTree {
min_tree: SegmentPoint::build(vec![f32::MAX; capacity], MinIgnoreNaN),
max_tree: SegmentPoint::build(vec![1e-8f32; capacity], MaxIgnoreNaN),
normalize,
rng: fastrand::Rng::with_seed(0),
}
}

Expand Down Expand Up @@ -117,29 +120,13 @@ impl SumTree {
///
/// The weight is $w_i=\left(N^{-1}P(i)^{-1}\right)^{\beta}$
/// and it will be normalized by $max_i w_i$.
pub fn sample(&self, batch_size: usize, beta: f32) -> (Vec<i64>, Vec<f32>) {
pub fn sample(&mut self, batch_size: usize, beta: f32) -> (Vec<i64>, Vec<f32>) {
let p_sum = &self.total();
let ps = (0..batch_size)
.map(|_| p_sum * fastrand::f32())
.collect::<Vec<_>>();
let indices = ps.iter().map(|&p| self.get(p)).collect::<Vec<_>>();
// let indices = (0..batch_size)
// .map(|_| self.get(p_sum * fastrand::f32()))
// .collect::<Vec<_>>();

let n = self.n_samples as f32 / p_sum;
let ws = indices
.iter()
.map(|ix| self.tree[ix + self.capacity - 1])
.map(|p| (n * p).powf(-beta))
.collect::<Vec<_>>();

// normalizer within all samples
let w_max_inv = match self.normalize {
WeightNormalizer::All => (n * self.min_tree.query(0, self.n_samples)).powf(beta),
WeightNormalizer::Batch => 1f32 / ws.iter().fold(0.0 / 0.0, |m, v| v.max(m)),
};
let ws = ws.iter().map(|w| w * w_max_inv).collect::<Vec<f32>>();
let (ws, w_max_inv) = self.weights(&indices, beta);

// debug
// if self.n_samples % 100 == 0 || p_sum.is_nan() || w_max.is_nan() {
Expand Down Expand Up @@ -173,6 +160,24 @@ impl SumTree {
// println!("min = {}", self.min());
println!("total = {}", self.total());
}

fn weights(&self, ixs: &Vec<usize>, beta: f32) -> (Vec<f32>, f32) {
let n = self.n_samples as f32 / self.total();
let ws = ixs
.iter()
.map(|ix| self.tree[ix + self.capacity - 1])
.map(|p| (n * p).powf(-beta))
.collect::<Vec<_>>();

// normalizer within all samples
let w_max_inv = match self.normalize {
WeightNormalizer::All => (n * self.min_tree.query(0, self.n_samples)).powf(beta),
WeightNormalizer::Batch => 1f32 / ws.iter().fold(0.0 / 0.0, |m, v| v.max(m)),
};
let ws = ws.iter().map(|w| w * w_max_inv).collect::<Vec<f32>>();

(ws, w_max_inv)
}
}

#[cfg(test)]
Expand Down

0 comments on commit 82868dc

Please sign in to comment.