Skip to content

Commit

Permalink
fix_path_input
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jan 17, 2024
1 parent 34bb7c2 commit ba90b6a
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions transformer_nuggets/utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def save_memory_snapshot(file_path: Path):
file_path: The path to the folder to save the snapshot to
will create the folder if it doesn't exist
"""
file_path.mkdir(parents=True, exist_ok=True)
if file_path.is_dir():
raise ValueError(f"{file_path} is a directory")

# make parent dir
file_path.parent.mkdir(parents=True, exist_ok=True)
torch.cuda.memory._record_memory_history()
try:
yield
Expand All @@ -116,8 +120,8 @@ def save_memory_snapshot(file_path: Path):
pass
if dist_avail and dist.is_initialized():
local_rank = dist.get_rank()
output_path = file_path / f"trace_plot_rank_{local_rank}.html"
output_path = file_path / f"_rank_{local_rank}.html"
else:
output_path = file_path / "trace_plot.html"
output_path = file_path.with_suffix(".html")
with open(output_path, "w") as f:
f.write(torch.cuda._memory_viz.trace_plot(s))

0 comments on commit ba90b6a

Please sign in to comment.