Skip to content

Commit

Permalink
add profiling of categories
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Apr 19, 2023
1 parent 3b20991 commit 30aeabd
Showing 1 changed file with 178 additions and 24 deletions.
202 changes: 178 additions & 24 deletions triton-profiler/src/triton_profiler.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::cmp::max;
use std::collections::HashMap;
use std::fmt::Display;
use std::time::Duration;
use std::time::Instant;
Expand All @@ -20,6 +21,10 @@ struct Task {
depth: usize,
time: Duration,
task_type: TaskType,

/// The type of work the task is doing. Helps tracking time across specific tasks. For
/// example, if the task is building a Merkle tree, then the category could be "hash".
category: Option<String>,
}

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -109,12 +114,25 @@ impl TritonProfiler {
self.total_time.as_secs_f64(),
total_tracked_time,
);

// collect all categories and their total times
// todo: this can count the same category multiple times if it's nested
let mut category_times = HashMap::new();
for task in self.profile.iter() {
if let Some(category) = &task.category {
category_times
.entry(category.clone())
.and_modify(|e| *e += task.time)
.or_insert(task.time);
}
}

for (task_index, task) in self.profile.iter().enumerate() {
// compute this task's time relative to total duration
let relative_time = task.time.as_secs_f64() / total_tracked_time;
let relative_time = task.time.as_secs_f64() / self.total_time.as_secs_f64();
let weight = match task.task_type {
TaskType::AnyOtherIteration => Weight::LikeNothing,
_ => Weight::weigh(task.time.as_secs_f64() / total_tracked_time),
_ => Weight::weigh(task.time.as_secs_f64() / self.total_time.as_secs_f64()),
};

let is_last_sibling = !self.has_younger_sibling(task_index);
Expand All @@ -128,12 +146,19 @@ impl TritonProfiler {
}
ancestors.reverse();

let relative_category_time = task.category.clone().map(|category| {
let category_time = category_times.get(&category).unwrap();
task.time.as_secs_f64() / category_time.as_secs_f64()
});

report.push(TaskReport {
name: task.name.clone(),
parent_index: task.parent_index,
depth: task.depth,
time: task.time,
relative_time,
category: task.category.clone(),
relative_category_time,
is_last_sibling,
ancestors,
weight,
Expand Down Expand Up @@ -162,19 +187,20 @@ impl TritonProfiler {
tasks: report,
name: self.name.clone(),
total_time: self.total_time,
category_times,
cycle_count,
padded_height,
fri_domain_len,
}
}

pub fn start(&mut self, name: &str) {
pub fn start(&mut self, name: &str, category: Option<String>) {
if !self.ignoring() {
self.plain_start(name, TaskType::Generic);
self.plain_start(name, TaskType::Generic, category);
}
}

fn plain_start(&mut self, name: &str, task_type: TaskType) {
fn plain_start(&mut self, name: &str, task_type: TaskType, category: Option<String>) {
let parent_index = self.stack.last().map(|(u, _)| *u);
let now = self.timer.elapsed();

Expand All @@ -186,14 +212,15 @@ impl TritonProfiler {
depth: self.stack.len(),
time: now,
task_type,
category,
});

if std::env::var(GET_PROFILE_OUTPUT_AS_YOU_GO_ENV_VAR_NAME).is_ok() {
println!("start: {name}");
}
}

pub fn iteration_zero(&mut self, name: &str) {
pub fn iteration_zero(&mut self, name: &str, category: Option<String>) {
if self.ignoring() {
return;
}
Expand All @@ -208,7 +235,7 @@ impl TritonProfiler {

if top_type != TaskType::IterationZero && top_type != TaskType::AnyOtherIteration {
// start
self.plain_start("iteration 0", TaskType::IterationZero);
self.plain_start("iteration 0", TaskType::IterationZero, category);
return;
}

Expand All @@ -231,7 +258,11 @@ impl TritonProfiler {
self.plain_stop();

// start all other iterations
self.plain_start("all other iterations", TaskType::AnyOtherIteration);
self.plain_start(
"all other iterations",
TaskType::AnyOtherIteration,
category,
);
}

// top == *"all other iterations"
Expand Down Expand Up @@ -259,7 +290,8 @@ impl TritonProfiler {
if top == *"iteration 0" || top == *"all other iterations" {
assert!(
self.stack.len() >= 2,
"To close profiling of zeroth iteration, stack must be at least 2-high, but got stack of height {}.",
"To close profiling of zeroth iteration, stack must be at least 2-high, \
but got stack of height {}.",
self.stack.len(),
);
if self.stack[self.stack.len() - 2].1 == *name {
Expand All @@ -284,7 +316,8 @@ impl Profiler for TritonProfiler {
.to_str()
.expect("Directory must be valid unicode");
let name = format!("{dir}{benchmark_id}");
self.start(&name);
let category = None;
self.start(&name, category);
}

fn stop_profiling(&mut self, benchmark_id: &str, benchmark_dir: &std::path::Path) {
Expand Down Expand Up @@ -363,6 +396,8 @@ struct TaskReport {
depth: usize,
time: Duration,
relative_time: f64,
category: Option<String>,
relative_category_time: Option<f64>,
is_last_sibling: bool,
ancestors: Vec<usize>,
weight: Weight,
Expand All @@ -374,6 +409,7 @@ pub struct Report {
name: String,
tasks: Vec<TaskReport>,
total_time: Duration,
category_times: HashMap<String, Duration>,
cycle_count: Option<usize>,
padded_height: Option<usize>,
fri_domain_len: Option<usize>,
Expand All @@ -385,6 +421,7 @@ impl Report {
name: "".to_string(),
tasks: vec![],
total_time: Duration::ZERO,
category_times: HashMap::new(),
cycle_count: None,
padded_height: None,
fri_domain_len: None,
Expand All @@ -410,16 +447,27 @@ impl Display for Report {
.map(|t| t.name.width() + 2 * t.depth)
.max()
.expect("No tasks to generate report from.");
let max_category_name_width = self
.category_times
.keys()
.map(|k| k.width())
.max()
.unwrap_or(0);

let title = format!("### {}", self.name).bold();
let max_width = if max_name_width > title.width() {
max_name_width
} else {
title.width()
};
let max_width = max(max_name_width, title.width());
let total_time_string = Report::display_time_aligned(self.total_time).bold();
let separation = String::from_utf8(vec![b' '; max_width - title.width()]).unwrap();
writeln!(f, "{title}{separation} {total_time_string}")?;
let share_string = "Share".to_string().bold();
let category_string = match self.category_times.is_empty() {
true => "".to_string(),
false => "Category".to_string(),
}
.bold();
writeln!(
f,
"{title}{separation} {total_time_string} {share_string} {category_string}"
)?;

for task in self.tasks.iter() {
for ancestor_index in task.ancestors.iter() {
Expand Down Expand Up @@ -449,8 +497,7 @@ impl Display for Report {
let padding_length = max_width - task.name.len() - 2 * task.depth;
assert!(
padding_length < (1 << 60),
"max width: {}, width: {}",
max_name_width,
"max width: {max_name_width}, width: {}",
task.name.len(),
);
let task_name_colored = task.name.color(task.weight.color());
Expand All @@ -463,12 +510,49 @@ impl Display for Report {
let relative_time_string =
format!("{:>6}", format!("{:2.2}%", 100.0 * task.relative_time));
let relative_time_string_colored = relative_time_string.color(task.weight.color());
let category_name = task.category.as_deref().unwrap_or("");
let relative_category_time = task
.relative_category_time
.map(|t| format!("{:>6}", format!("{:2.2}%", 100.0 * t)))
.unwrap_or("".to_string());
let category_and_relative_time = match task.category.is_some() {
true => format!(
"({category_name:<max_category_name_width$} – {relative_category_time})"
),
false => "".to_string(),
};
f.write_fmt(format_args!(
"{task_name_colored}{padding} \
{task_time_colored}{relative_time_string_colored}\n"
{task_time_colored}{relative_time_string_colored} \
{category_and_relative_time}\n"
))?;
}

if !self.category_times.is_empty() {
writeln!(f)?;
let category_title = ("### Categories").to_string().bold();
writeln!(f, "{category_title}")?;
for (category, &category_time) in self.category_times.iter() {
let category_relative_time =
category_time.as_secs_f64() / self.total_time.as_secs_f64();
let category_color = Weight::weigh(category_relative_time).color();
let category_relative_time =
format!("{:>6}", format!("{:2.2}%", 100.0 * category_relative_time));
let category_time = Report::display_time_aligned(category_time);

let padding_length = max_category_name_width - category.width();
let padding = String::from_utf8(vec![b' '; padding_length]).unwrap();

let category = category.color(category_color);
let category_time = category_time.color(category_color);
let category_relative_time = category_relative_time.color(category_color);
writeln!(
f,
"{category}{padding} {category_time} {category_relative_time}"
)?;
}
}

if self.cycle_count.is_some()
|| self.padded_height.is_some()
|| self.fri_domain_len.is_some()
Expand Down Expand Up @@ -511,9 +595,14 @@ impl Display for Report {

#[macro_export]
macro_rules! prof_start {
($p: ident, $s : expr, $c : expr) => {
if let Some(profiler) = $p.as_mut() {
profiler.start($s, Some($c.to_string()));
}
};
($p: ident, $s : expr) => {
if let Some(profiler) = $p.as_mut() {
profiler.start($s);
profiler.start($s, None);
}
};
}
Expand All @@ -529,9 +618,14 @@ macro_rules! prof_stop {

#[macro_export]
macro_rules! prof_itr0 {
($p : ident, $s : expr ) => {
($p : ident, $s : expr, $c : expr) => {
if let Some(profiler) = $p.as_mut() {
profiler.iteration_zero($s);
profiler.iteration_zero($s, Some($c.to_string()));
}
};
($p : ident, $s : expr) => {
if let Some(profiler) = $p.as_mut() {
profiler.iteration_zero($s, None);
}
};
}
Expand All @@ -542,6 +636,7 @@ pub mod triton_profiler_tests {
use std::time::Duration;

use rand::rngs::ThreadRng;
use rand::Rng;
use rand::RngCore;

use super::*;
Expand All @@ -556,6 +651,11 @@ pub mod triton_profiler_tests {
.collect()
}

fn random_category(rng: &mut ThreadRng) -> String {
let options = vec!["setup", "compute", "drop", "cleanup"];
options[rng.next_u32() as usize % options.len()].to_string()
}

#[test]
fn test_sanity() {
let mut rng = rand::thread_rng();
Expand All @@ -581,8 +681,12 @@ pub mod triton_profiler_tests {

if pushable {
let name = random_task_name(&mut rng);
let category = match rng.gen() {
true => Some(random_category(&mut rng)),
false => None,
};
stack.push(name.clone());
profiler.start(&name);
profiler.start(&name, category);
}

sleep(Duration::from_micros(
Expand All @@ -597,7 +701,57 @@ pub mod triton_profiler_tests {

profiler.finish();
println!("{}", profiler.report(None, None, None));
}

#[test]
fn clk_freq() {
let mut profiler = Some(TritonProfiler::new("Clock Frequency Test"));
prof_start!(profiler, "clk_freq_test");
sleep(Duration::from_millis(3));
prof_stop!(profiler, "clk_freq_test");
let mut profiler = profiler.unwrap();
profiler.finish();
println!("{}", profiler.report(None, None, None));
println!("{}", profiler.report(Some(0), Some(0), Some(0)));
println!("{}", profiler.report(Some(5), Some(8), Some(13)));
println!("{}", profiler.report(Some(10), Some(12), Some(13)));
}

#[test]
fn macros() {
let mut profiler = TritonProfiler::new("Macro Test");
let mut profiler_ref = Some(&mut profiler);
let mut rng = rand::thread_rng();
let mut stack = vec![];
let steps = 100;

for step in 0..steps {
let steps_left = steps - step;
let pushable = rng.gen() && stack.len() + 1 < steps_left;
let poppable = ((!pushable && rng.gen())
|| (stack.len() == steps_left && steps_left > 0))
&& !stack.is_empty();

if pushable {
let name = random_task_name(&mut rng);
let category = random_category(&mut rng);
stack.push(name.clone());
match rng.gen() {
true => prof_start!(profiler_ref, &name, category),
false => prof_start!(profiler_ref, &name),
}
}

sleep(Duration::from_micros(
(rng.next_u64() % 10) * (rng.next_u64() % 10) * (rng.next_u64() % 10),
));

if poppable {
let name = stack.pop().unwrap();
prof_stop!(profiler_ref, &name);
}
}

profiler.finish();
println!("{}", profiler.report(None, None, None));
}
}

0 comments on commit 30aeabd

Please sign in to comment.