Skip to content

Commit

Permalink
[nsys-jax] Add ratio of hidden communication time to total communicat…
Browse files Browse the repository at this point in the history
…ion time (#1241)
  • Loading branch information
sfvaroglu authored Jan 16, 2025
1 parent 7036e87 commit eb6d0d2
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 22 deletions.
162 changes: 140 additions & 22 deletions .github/container/nsys_jax/nsys_jax/analyses/communication.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
#!/usr/bin/env python
import argparse
import csv
from collections import defaultdict

from nsys_jax import (
align_profiler_data_timestamps,
apply_warmup_heuristics,
ensure_compiled_protos_are_importable,
load_profiler_data,
)
from math import sqrt
from prettytable import PrettyTable
import pathlib
from uncertainties import ufloat # type: ignore


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()
# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)
assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)
def process_communication_data(steady_state):
collective_types = set()
summary_data = defaultdict(dict)
for (collective, message_size), df in steady_state.communication.groupby(
Expand All @@ -52,7 +35,10 @@ def main():
summary_data[message_size][collective] = ufloat(
bandwidth.mean(), bandwidth.std() / sqrt(len(bandwidth))
)
collective_types = sorted(collective_types)
return sorted(collective_types), summary_data


def print_bandwidth_table(collective_types, summary_data):
collective_widths = {
collective: max(
len(collective),
Expand Down Expand Up @@ -96,5 +82,137 @@ def format_bandwidth(data, collective):
)


def process_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None, None

collective_types = set()
summary_data = defaultdict(dict)
for collective, df in steady_state.communication.groupby(["Collective"]):
collective_types.add(collective)
mean_dur_hidden_ms_to_total_ms = (
df["ProjDurHiddenMs"] / (df["ProjDurMs"] + df["ProjDurHiddenMs"])
).mean()
summary_data[collective] = mean_dur_hidden_ms_to_total_ms
return collective_types, summary_data


def print_hidden_ms_to_total_ms_table(
collective_types, summary_data, overall_hidden_ms_to_total_ms
):
table = PrettyTable()
table.field_names = ["Collective", "Mean HiddenToTotalMs"]

for collective in collective_types:
mean_value = summary_data[collective]
table.add_row([collective[0], mean_value])

print(table)
print("Overall HiddenMs to TotalMs:", overall_hidden_ms_to_total_ms)


def calculate_overall_hidden_ms_to_total_ms(steady_state):
if steady_state.communication["ProjDurHiddenMs"].sum() == 0:
return None

overall_hidden_ms_to_total_ms = (
steady_state.communication["ProjDurHiddenMs"].sum()
/ (
steady_state.communication["ProjDurMs"]
+ steady_state.communication["ProjDurHiddenMs"]
).sum()
)
return overall_hidden_ms_to_total_ms


def write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
output_file,
):
with open(output_file, "w", newline="") as csvfile:
writer = csv.writer(csvfile)

# Write bandwidth table
writer.writerow(["Bandwidth Table"])
writer.writerow(["Size [B]"] + list(collective_types))
for message_size in sorted(bandwidth_summary.keys()):
row = [message_size]
for collective in collective_types:
if collective in bandwidth_summary[message_size]:
row.append(f"{bandwidth_summary[message_size][collective]:S}")
else:
row.append("-")
writer.writerow(row)

writer.writerow([]) # Empty row for separation

# Write hidden to total table if data is available
if hidden_to_total_summary is not None:
writer.writerow(["HiddenMs to TotalMs Table"])
writer.writerow(["Collective", "Mean HiddenToTotalMs"])
for collective in hidden_to_total_summary:
writer.writerow([collective[0], hidden_to_total_summary[collective]])

writer.writerow([]) # Empty row for separation

if overall_hidden_ms_to_total_ms is not None:
writer.writerow(
["Overall HiddenMs to TotalMs", overall_hidden_ms_to_total_ms]
)


def main():
parser = argparse.ArgumentParser(
description="Summarise communication in an nsys-jax report"
)
parser.add_argument("prefix", type=pathlib.Path)
args = parser.parse_args()

# Make sure that the .proto files under protos/ have been compiled to .py, and
# that those generated .py files are importable.
ensure_compiled_protos_are_importable(prefix=args.prefix)
# Load the profiler data; the compilation part is needed for the warmup heuristics
all_data = load_profiler_data(args.prefix, frames={"communication", "compile"})
# Align timestamps
all_data, alignment_metadata = align_profiler_data_timestamps(all_data)
# TODO: make this pretty
# print(alignment_metadata)
# Partition the profile data into initialisation and steady-state running
_, steady_state = apply_warmup_heuristics(all_data)

assert len(steady_state.communication), (
"Communication summary was requested but no steady-state communication was "
"identified."
)

collective_types, bandwidth_summary = process_communication_data(steady_state)
print_bandwidth_table(collective_types, bandwidth_summary)

hidden_to_total_collective_types, hidden_to_total_summary = (
process_hidden_ms_to_total_ms(steady_state)
)
if hidden_to_total_summary is not None:
overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms(
steady_state
)
print_hidden_ms_to_total_ms_table(
hidden_to_total_collective_types,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
)

# Write all tables to a single CSV file
write_to_csv(
collective_types,
bandwidth_summary,
hidden_to_total_summary,
overall_hidden_ms_to_total_ms,
"communication_summary.csv",
)


if __name__ == "__main__":
main()
Empty file modified .github/container/nsys_jax/nsys_jax/analysis.py
100644 → 100755
Empty file.
1 change: 1 addition & 0 deletions .github/container/nsys_jax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ dependencies = [
"pyarrow",
"requests", # for install-protoc
"uncertainties", # communication analysis recipe
"prettytable",
]
requires-python = ">= 3.10"

Expand Down

0 comments on commit eb6d0d2

Please sign in to comment.