diff --git a/triton-profiler/src/triton_profiler.rs b/triton-profiler/src/triton_profiler.rs index 34d08ce69..dee05b14a 100644 --- a/triton-profiler/src/triton_profiler.rs +++ b/triton-profiler/src/triton_profiler.rs @@ -1,4 +1,5 @@ use std::cmp::max; +use std::collections::HashMap; use std::fmt::Display; use std::time::Duration; use std::time::Instant; @@ -20,6 +21,10 @@ struct Task { depth: usize, time: Duration, task_type: TaskType, + + /// The type of work the task is doing. Helps tracking time across specific tasks. For + /// example, if the task is building a Merkle tree, then the category could be "hash". + category: Option, } #[derive(Clone, Debug, PartialEq, Eq)] @@ -109,12 +114,25 @@ impl TritonProfiler { self.total_time.as_secs_f64(), total_tracked_time, ); + + // collect all categories and their total times + // todo: this can count the same category multiple times if it's nested + let mut category_times = HashMap::new(); + for task in self.profile.iter() { + if let Some(category) = &task.category { + category_times + .entry(category.clone()) + .and_modify(|e| *e += task.time) + .or_insert(task.time); + } + } + for (task_index, task) in self.profile.iter().enumerate() { // compute this task's time relative to total duration - let relative_time = task.time.as_secs_f64() / total_tracked_time; + let relative_time = task.time.as_secs_f64() / self.total_time.as_secs_f64(); let weight = match task.task_type { TaskType::AnyOtherIteration => Weight::LikeNothing, - _ => Weight::weigh(task.time.as_secs_f64() / total_tracked_time), + _ => Weight::weigh(task.time.as_secs_f64() / self.total_time.as_secs_f64()), }; let is_last_sibling = !self.has_younger_sibling(task_index); @@ -128,12 +146,19 @@ impl TritonProfiler { } ancestors.reverse(); + let relative_category_time = task.category.clone().map(|category| { + let category_time = category_times.get(&category).unwrap(); + task.time.as_secs_f64() / category_time.as_secs_f64() + }); + report.push(TaskReport { name: task.name.clone(), parent_index: task.parent_index, depth: task.depth, time: task.time, relative_time, + category: task.category.clone(), + relative_category_time, is_last_sibling, ancestors, weight, @@ -162,19 +187,20 @@ impl TritonProfiler { tasks: report, name: self.name.clone(), total_time: self.total_time, + category_times, cycle_count, padded_height, fri_domain_len, } } - pub fn start(&mut self, name: &str) { + pub fn start(&mut self, name: &str, category: Option) { if !self.ignoring() { - self.plain_start(name, TaskType::Generic); + self.plain_start(name, TaskType::Generic, category); } } - fn plain_start(&mut self, name: &str, task_type: TaskType) { + fn plain_start(&mut self, name: &str, task_type: TaskType, category: Option) { let parent_index = self.stack.last().map(|(u, _)| *u); let now = self.timer.elapsed(); @@ -186,6 +212,7 @@ impl TritonProfiler { depth: self.stack.len(), time: now, task_type, + category, }); if std::env::var(GET_PROFILE_OUTPUT_AS_YOU_GO_ENV_VAR_NAME).is_ok() { @@ -193,7 +220,7 @@ impl TritonProfiler { } } - pub fn iteration_zero(&mut self, name: &str) { + pub fn iteration_zero(&mut self, name: &str, category: Option) { if self.ignoring() { return; } @@ -208,7 +235,7 @@ impl TritonProfiler { if top_type != TaskType::IterationZero && top_type != TaskType::AnyOtherIteration { // start - self.plain_start("iteration 0", TaskType::IterationZero); + self.plain_start("iteration 0", TaskType::IterationZero, category); return; } @@ -231,7 +258,11 @@ impl TritonProfiler { self.plain_stop(); // start all other iterations - self.plain_start("all other iterations", TaskType::AnyOtherIteration); + self.plain_start( + "all other iterations", + TaskType::AnyOtherIteration, + category, + ); } // top == *"all other iterations" @@ -259,7 +290,8 @@ impl TritonProfiler { if top == *"iteration 0" || top == *"all other iterations" { assert!( self.stack.len() >= 2, - "To close profiling of zeroth iteration, stack must be at least 2-high, but got stack of height {}.", + "To close profiling of zeroth iteration, stack must be at least 2-high, \ + but got stack of height {}.", self.stack.len(), ); if self.stack[self.stack.len() - 2].1 == *name { @@ -284,7 +316,8 @@ impl Profiler for TritonProfiler { .to_str() .expect("Directory must be valid unicode"); let name = format!("{dir}{benchmark_id}"); - self.start(&name); + let category = None; + self.start(&name, category); } fn stop_profiling(&mut self, benchmark_id: &str, benchmark_dir: &std::path::Path) { @@ -363,6 +396,8 @@ struct TaskReport { depth: usize, time: Duration, relative_time: f64, + category: Option, + relative_category_time: Option, is_last_sibling: bool, ancestors: Vec, weight: Weight, @@ -374,6 +409,7 @@ pub struct Report { name: String, tasks: Vec, total_time: Duration, + category_times: HashMap, cycle_count: Option, padded_height: Option, fri_domain_len: Option, @@ -385,6 +421,7 @@ impl Report { name: "".to_string(), tasks: vec![], total_time: Duration::ZERO, + category_times: HashMap::new(), cycle_count: None, padded_height: None, fri_domain_len: None, @@ -410,16 +447,27 @@ impl Display for Report { .map(|t| t.name.width() + 2 * t.depth) .max() .expect("No tasks to generate report from."); + let max_category_name_width = self + .category_times + .keys() + .map(|k| k.width()) + .max() + .unwrap_or(0); let title = format!("### {}", self.name).bold(); - let max_width = if max_name_width > title.width() { - max_name_width - } else { - title.width() - }; + let max_width = max(max_name_width, title.width()); let total_time_string = Report::display_time_aligned(self.total_time).bold(); let separation = String::from_utf8(vec![b' '; max_width - title.width()]).unwrap(); - writeln!(f, "{title}{separation} {total_time_string}")?; + let share_string = "Share".to_string().bold(); + let category_string = match self.category_times.is_empty() { + true => "".to_string(), + false => "Category".to_string(), + } + .bold(); + writeln!( + f, + "{title}{separation} {total_time_string} {share_string} {category_string}" + )?; for task in self.tasks.iter() { for ancestor_index in task.ancestors.iter() { @@ -449,8 +497,7 @@ impl Display for Report { let padding_length = max_width - task.name.len() - 2 * task.depth; assert!( padding_length < (1 << 60), - "max width: {}, width: {}", - max_name_width, + "max width: {max_name_width}, width: {}", task.name.len(), ); let task_name_colored = task.name.color(task.weight.color()); @@ -463,12 +510,49 @@ impl Display for Report { let relative_time_string = format!("{:>6}", format!("{:2.2}%", 100.0 * task.relative_time)); let relative_time_string_colored = relative_time_string.color(task.weight.color()); + let category_name = task.category.as_deref().unwrap_or(""); + let relative_category_time = task + .relative_category_time + .map(|t| format!("{:>6}", format!("{:2.2}%", 100.0 * t))) + .unwrap_or("".to_string()); + let category_and_relative_time = match task.category.is_some() { + true => format!( + "({category_name: "".to_string(), + }; f.write_fmt(format_args!( "{task_name_colored}{padding} \ - {task_time_colored}{relative_time_string_colored}\n" + {task_time_colored}{relative_time_string_colored} \ + {category_and_relative_time}\n" ))?; } + if !self.category_times.is_empty() { + writeln!(f)?; + let category_title = ("### Categories").to_string().bold(); + writeln!(f, "{category_title}")?; + for (category, &category_time) in self.category_times.iter() { + let category_relative_time = + category_time.as_secs_f64() / self.total_time.as_secs_f64(); + let category_color = Weight::weigh(category_relative_time).color(); + let category_relative_time = + format!("{:>6}", format!("{:2.2}%", 100.0 * category_relative_time)); + let category_time = Report::display_time_aligned(category_time); + + let padding_length = max_category_name_width - category.width(); + let padding = String::from_utf8(vec![b' '; padding_length]).unwrap(); + + let category = category.color(category_color); + let category_time = category_time.color(category_color); + let category_relative_time = category_relative_time.color(category_color); + writeln!( + f, + "{category}{padding} {category_time} {category_relative_time}" + )?; + } + } + if self.cycle_count.is_some() || self.padded_height.is_some() || self.fri_domain_len.is_some() @@ -511,9 +595,14 @@ impl Display for Report { #[macro_export] macro_rules! prof_start { + ($p: ident, $s : expr, $c : expr) => { + if let Some(profiler) = $p.as_mut() { + profiler.start($s, Some($c.to_string())); + } + }; ($p: ident, $s : expr) => { if let Some(profiler) = $p.as_mut() { - profiler.start($s); + profiler.start($s, None); } }; } @@ -529,9 +618,14 @@ macro_rules! prof_stop { #[macro_export] macro_rules! prof_itr0 { - ($p : ident, $s : expr ) => { + ($p : ident, $s : expr, $c : expr) => { if let Some(profiler) = $p.as_mut() { - profiler.iteration_zero($s); + profiler.iteration_zero($s, Some($c.to_string())); + } + }; + ($p : ident, $s : expr) => { + if let Some(profiler) = $p.as_mut() { + profiler.iteration_zero($s, None); } }; } @@ -542,6 +636,7 @@ pub mod triton_profiler_tests { use std::time::Duration; use rand::rngs::ThreadRng; + use rand::Rng; use rand::RngCore; use super::*; @@ -556,6 +651,11 @@ pub mod triton_profiler_tests { .collect() } + fn random_category(rng: &mut ThreadRng) -> String { + let options = vec!["setup", "compute", "drop", "cleanup"]; + options[rng.next_u32() as usize % options.len()].to_string() + } + #[test] fn test_sanity() { let mut rng = rand::thread_rng(); @@ -581,8 +681,12 @@ pub mod triton_profiler_tests { if pushable { let name = random_task_name(&mut rng); + let category = match rng.gen() { + true => Some(random_category(&mut rng)), + false => None, + }; stack.push(name.clone()); - profiler.start(&name); + profiler.start(&name, category); } sleep(Duration::from_micros( @@ -597,7 +701,57 @@ pub mod triton_profiler_tests { profiler.finish(); println!("{}", profiler.report(None, None, None)); + } + + #[test] + fn clk_freq() { + let mut profiler = Some(TritonProfiler::new("Clock Frequency Test")); + prof_start!(profiler, "clk_freq_test"); + sleep(Duration::from_millis(3)); + prof_stop!(profiler, "clk_freq_test"); + let mut profiler = profiler.unwrap(); + profiler.finish(); + println!("{}", profiler.report(None, None, None)); println!("{}", profiler.report(Some(0), Some(0), Some(0))); - println!("{}", profiler.report(Some(5), Some(8), Some(13))); + println!("{}", profiler.report(Some(10), Some(12), Some(13))); + } + + #[test] + fn macros() { + let mut profiler = TritonProfiler::new("Macro Test"); + let mut profiler_ref = Some(&mut profiler); + let mut rng = rand::thread_rng(); + let mut stack = vec![]; + let steps = 100; + + for step in 0..steps { + let steps_left = steps - step; + let pushable = rng.gen() && stack.len() + 1 < steps_left; + let poppable = ((!pushable && rng.gen()) + || (stack.len() == steps_left && steps_left > 0)) + && !stack.is_empty(); + + if pushable { + let name = random_task_name(&mut rng); + let category = random_category(&mut rng); + stack.push(name.clone()); + match rng.gen() { + true => prof_start!(profiler_ref, &name, category), + false => prof_start!(profiler_ref, &name), + } + } + + sleep(Duration::from_micros( + (rng.next_u64() % 10) * (rng.next_u64() % 10) * (rng.next_u64() % 10), + )); + + if poppable { + let name = stack.pop().unwrap(); + prof_stop!(profiler_ref, &name); + } + } + + profiler.finish(); + println!("{}", profiler.report(None, None, None)); } }