Skip to content

Commit

Permalink
count each selected buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
shu5620 committed Oct 1, 2023
1 parent 8aa4bb4 commit 307d4e1
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions border-async-trainer/src/async_trainer/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,18 +304,27 @@ where
info!("Send model info first in AsyncTrainer");
self.sync(&mut agent);

let mut ix_opt_buffer_cnt = vec![0; self.n_div_replaybuffer];

info!("Starts training loop");
loop {

let record = {
let time_tmp = SystemTime::now();
let buf = async_buffer.get_buffer_for_opt();
let (buf, ix_opt_buffer) = async_buffer.get_buffer_for_opt();
let mut buf = buf.lock().unwrap();
println!("time get_buffer_for_opt: {}", time_tmp.elapsed().unwrap().as_secs_f32());
// println!("time get_buffer_for_opt: {}", time_tmp.elapsed().unwrap().as_secs_f32());

let time_tmp = SystemTime::now();
let record = agent.opt(&mut buf);
println!("time opt: {}", time_tmp.elapsed().unwrap().as_secs_f32());
// println!("time opt: {}", time_tmp.elapsed().unwrap().as_secs_f32());

if record.is_some() {
ix_opt_buffer_cnt[ix_opt_buffer] += 1;
println!("ix_opt_buffer: {}", ix_opt_buffer);
println!("ix_opt_buffer_cnt: {:?}", ix_opt_buffer_cnt);
}

record
};

Expand Down Expand Up @@ -364,7 +373,7 @@ where
info!("Sends the trained model info to ActorManager");
self.sync(&agent);
}
println!("time others: {}", time_tmp.elapsed().unwrap().as_secs_f32());
// println!("time others: {}", time_tmp.elapsed().unwrap().as_secs_f32());
}
}
info!("Stopped training loop");
Expand Down Expand Up @@ -414,6 +423,8 @@ where

fn ix_push_buffer(&self) -> usize {
let ixs_free = self.ixs_free_buffer();
println!("ixs_free (in push): {:?}", ixs_free);

if ixs_free.len() >= 3 {
// If there are 3 or more free buffers, use one of the free buffers.
// This is to leave 2 or more choice when selecting a BUFFER in the OPT.
Expand All @@ -434,6 +445,8 @@ where

fn ix_opt_buffer(&self) -> usize {
let ixs_free = self.ixs_free_buffer();
println!("ixs_free (in opt): {:?}", ixs_free);

// println!(
// "ixs_free.len() in get_free_buffer_for_opt: {}",
// ixs_free.len()
Expand Down Expand Up @@ -507,12 +520,18 @@ where
samples_total: Arc<Mutex<usize>>,
receiver: Receiver<PushedItemMessage<R::PushedItem>>,
) {
let mut ix_push_buffer_cnt = vec![0; splitted_buffers.splitted_buffers.len()];

for msg in receiver.iter() {
let samples = msg.pushed_items.len();

*samples_total.lock().unwrap() += samples;

let ix_buffer = splitted_buffers.ix_push_buffer();
ix_push_buffer_cnt[ix_buffer] += 1;
println!("ix_push_buffer: {}", ix_buffer);
println!("ix_push_buffer_cnt: {:?}", ix_push_buffer_cnt);

// println!("ix_buffer: {}", ix_buffer);
let buf = splitted_buffers.splitted_buffers[ix_buffer].clone();
let (s, r) = unbounded();
Expand All @@ -538,12 +557,12 @@ where
}
}

fn get_buffer_for_opt(&self) -> Arc<Mutex<R>> {
fn get_buffer_for_opt(&self) -> (Arc<Mutex<R>>, usize) {
let time_tmp = SystemTime::now();
let ix_buffer = self.splitted_buffers.ix_opt_buffer();
println!("time ix_opt_buffer: {}", time_tmp.elapsed().unwrap().as_secs_f32());
// println!("time ix_opt_buffer: {}", time_tmp.elapsed().unwrap().as_secs_f32());

self.splitted_buffers.splitted_buffers[ix_buffer].clone()
(self.splitted_buffers.splitted_buffers[ix_buffer].clone(), ix_buffer)
}

fn samples_total(&self) -> usize {
Expand Down

0 comments on commit 307d4e1

Please sign in to comment.