Skip to content

Commit

Permalink
Add the memory timeline profiling support through the PyTorch profile…
Browse files Browse the repository at this point in the history
…r. (#2771)

* v1

* fix issues

* add logs

* change names

* comment

* add device

* uncomment original trace

* add custome plot

* fix pyright

* Update composer/profiler/torch_profiler.py

Co-authored-by: Charles Tang <j316chuck@users.noreply.github.com>

* address comments

* fix code check

* fix formatting

* address comments

* add unit test

* fix check

* fix check

* fix check

* fix check

* fix print

* add test comment

* add test comment

---------

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
Co-authored-by: Charles Tang <j316chuck@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 12, 2023
1 parent f497e60 commit a7cad7c
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 26 deletions.
11 changes: 11 additions & 0 deletions composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
torch_prof_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_memory_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_memory_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_overwrite (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_use_gzip (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_record_shapes (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Expand All @@ -97,6 +100,9 @@ def __init__(
torch_prof_folder: str = '{run_name}/torch_traces',
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: str = 'rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_overwrite: bool = False,
torch_prof_use_gzip: bool = False,
torch_prof_record_shapes: bool = False,
Expand All @@ -116,6 +122,9 @@ def __init__(
if torch_prof_remote_file_name:
self.remote_filenames.append(torch_prof_remote_file_name)
_, _, torch_prof_remote_file_name = parse_uri(torch_prof_remote_file_name)
if torch_prof_memory_remote_file_name:
self.remote_filenames.append(torch_prof_memory_remote_file_name)
_, _, torch_prof_memory_remote_file_name = parse_uri(torch_prof_memory_remote_file_name)
for handler in self._trace_handlers:
if isinstance(handler, JSONTraceHandler):
if handler.remote_file_name:
Expand All @@ -139,6 +148,8 @@ def __init__(
TorchProfiler(filename=torch_prof_filename,
folder=torch_prof_folder,
remote_file_name=torch_prof_remote_file_name,
memory_filename=torch_prof_memory_filename,
memory_remote_file_name=torch_prof_memory_remote_file_name,
num_traces_to_keep=torch_prof_num_traces_to_keep,
overwrite=torch_prof_overwrite,
record_shapes=torch_prof_record_shapes,
Expand Down
145 changes: 119 additions & 26 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import textwrap
from typing import TYPE_CHECKING, Optional, OrderedDict

import torch.cuda
import torch.profiler
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core.callback import Callback
Expand Down Expand Up @@ -92,9 +94,9 @@ class TorchProfiler(Callback): # noqa: D101
Each rank (process) will save traces to::
awesome-training-run/torch_traces/ep1-ba42-rank0.json
awesome-training-run/torch_traces/ep1-ba42-rank1.json
awesome-training-run/torch_traces/ep1-ba42-rank2.json
awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.json
awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.json
awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.json
...
remote_file_name (str, optional): Format string for a Torch Profiler trace file's remote file name.
Expand All @@ -107,6 +109,43 @@ class TorchProfiler(Callback): # noqa: D101
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
memory_filename (str, optional): A format string describing how to name Torch Profiler memory trace files.
Defaults to ``'rank{{rank}}.{{batch}}.pt.trace.memory.html'``.
At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns
:attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved
approximately to ``{{folder.format(...)}}/{{memory_filename.format(...)}}``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')}
Consider the following scenario, where:
* The :attr:`~.State.run_name` is ``'awesome-training-run'``.
* The default ``trace_folder='{{run_name}}/torch_traces'`` is used.
* The default ``name='rank{{rank}}.{{batch}}.pt.trace.memory.html'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.
Each rank (process) will save traces to::
awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.memory.html
awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.memory.html
awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.memory.html
...
memory_remote_file_name (str, optional): Format string for a Torch Profiler memory trace file's remote file name.
Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pt.trace.memory.json'``.
Whenever a trace file is saved, it is also uploaded as a file according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing Torch Profiler traces. Defaults to False.
Expand Down Expand Up @@ -146,7 +185,10 @@ def __init__(
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}.{batch}.pt.trace.json',
remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
*,
memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html',
memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html',
memory_custom_plot: bool = True,
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
Expand All @@ -157,19 +199,34 @@ def __init__(
) -> None:
self.overwrite = overwrite
self.folder = folder
if use_gzip and not filename.endswith('.gz'):
filename += '.gz'

if use_gzip:
if not filename.endswith('.gz'):
filename += '.gz'
self.filename = filename
if use_gzip and remote_file_name is not None and not remote_file_name.endswith('.gz'):
remote_file_name += '.gz'

if use_gzip:
if remote_file_name is not None and not remote_file_name.endswith('.gz'):
remote_file_name += '.gz'
self.remote_file_name = remote_file_name

if memory_filename is not None:
assert memory_filename.endswith('.html'), f'memory_filename must end with .html, got {memory_filename}'
self.memory_filename = memory_filename

if memory_remote_file_name is not None:
assert memory_remote_file_name.endswith(
'.html'), f'memory_remote_file_name must end with .html, got {memory_remote_file_name}'
self.memory_remote_file_name = memory_remote_file_name

self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
self.with_flops = with_flops
self.num_traces_to_keep = num_traces_to_keep
self.saved_traces = OrderedDict()
self.profiler: Optional[torch.profiler.profile] = None
self.memory_custom_plot = memory_custom_plot

def init(self, state: State, logger: Logger) -> None:
if state.profiler is None:
Expand Down Expand Up @@ -203,27 +260,63 @@ def handler_fn(prof: torch.profiler.profiler.profile):

timestamp = state.timestamp

trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
)
trace_file_dirname = os.path.dirname(trace_file_name)
if trace_file_dirname:
os.makedirs(trace_file_dirname, exist_ok=True)
prof.export_chrome_trace(trace_file_name)
state.profiler.record_chrome_json_trace_file(trace_file_name)
if self.remote_file_name is not None:
trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
trace_remote_file_name = trace_remote_file_name.lstrip('/')
logger.upload_file(remote_file_name=trace_remote_file_name,
file_path=trace_file_name,
overwrite=self.overwrite)
log.info(f'PyTorch Chrome trace profiler enabled: {self.filename if self.filename else False}')
if self.filename is not None:
trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
)
trace_file_dirname = os.path.dirname(trace_file_name)
if trace_file_dirname:
os.makedirs(trace_file_dirname, exist_ok=True)
prof.export_chrome_trace(trace_file_name)
state.profiler.record_chrome_json_trace_file(trace_file_name)
if self.remote_file_name is not None:
trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
trace_remote_file_name = trace_remote_file_name.lstrip('/')
logger.upload_file(remote_file_name=trace_remote_file_name,
file_path=trace_file_name,
overwrite=self.overwrite)

log.info(
f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}')
if self.memory_filename is not None:
if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore
# memory timeline profiling is only supported in torch v2.1.0-rc1 or higher
memory_trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.memory_filename,
run_name=state.run_name,
timestamp=timestamp),
)
log.debug(f'Saving memory trace to {memory_trace_file_name}')
memory_trace_file_dirname = os.path.dirname(memory_trace_file_name)
if memory_trace_file_dirname:
os.makedirs(memory_trace_file_dirname, exist_ok=True)
if self.memory_custom_plot:
from composer.profiler.utils import export_memory_timeline_html
export_memory_timeline_html(prof, memory_trace_file_name,
torch.cuda.current_device()) # type: ignore
else:
prof.export_memory_timeline(memory_trace_file_name, torch.cuda.current_device()) # type: ignore
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/')
log.debug(
f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}')
logger.upload_file(remote_file_name=memory_trace_remote_file_name,
file_path=memory_trace_file_name,
overwrite=self.overwrite)
else:
log.warning('Memory timeline is supported after PyTorch 2.1.0. Skipping memory trace.')

if self.num_traces_to_keep >= 0:
while len(self.saved_traces) > self.num_traces_to_keep:

# self.saved_traces is an ordered dict, so the zeroth item will be the oldest checkpoint
timestamp, filepaths = next(iter(self.saved_traces.items()))
if dist.get_global_rank() < len(filepaths):
Expand Down
97 changes: 97 additions & 0 deletions composer/profiler/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utility functions for torch profiler."""

import importlib.util
import logging
from base64 import b64encode
from os import remove
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union

import numpy as np
import torch
import torch.cuda
from packaging import version
from torch.profiler.profiler import profile as TorchProfile

log = logging.getLogger(__name__)


def export_memory_timeline_html(prof: TorchProfile,
path: str,
device: Optional[str] = None,
figsize=(20, 12),
title=None,
yxis_step_size: float = 1.0,
return_fig: bool = False) -> Optional[Union[None, Any]]:
"""Exports a memory timeline to an HTML file. Similar to the PyTorch plotting function, but with adjusted axis tickers and grids."""
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
log.warning('export_memory_timeline_html failed because memory timeline is supported after PyTorch 2.1.0.')
return

from torch.profiler._memory_profiler import _CATEGORY_TO_COLORS, _CATEGORY_TO_INDEX, MemoryProfileTimeline

# Default to device 0, if unset. Fallback on cpu.
if device is None and prof.use_device and prof.use_device != 'cuda':
device = prof.use_device + ':0'

if device is None:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

# Construct the memory timeline plot data
mem_tl = MemoryProfileTimeline(prof._memory_profile())

# Check if user has matplotlib installed, return gracefully if not.
matplotlib_spec = importlib.util.find_spec('matplotlib')
if matplotlib_spec is None:
log.warning('export_memory_timeline_html failed because matplotlib was not found.')
return
import matplotlib.pyplot as plt

mt = mem_tl._coalesce_timeline(device)
times, sizes = np.array(mt[0]), np.array(mt[1])
stacked = np.cumsum(sizes, axis=1) / 1024**3
max_memory_allocated = torch.cuda.max_memory_allocated()
max_memory_reserved = torch.cuda.max_memory_reserved()

# Plot memory timeline as stacked data
fig = plt.figure(figsize=figsize, dpi=80)
axes = fig.gca()
for category, color in _CATEGORY_TO_COLORS.items():
i = _CATEGORY_TO_INDEX[category]
axes.fill_between(times / 1e3, stacked[:, i], stacked[:, i + 1], color=color, alpha=0.7)
fig.legend(['Unknown' if i is None else i.name for i in _CATEGORY_TO_COLORS])
axes.set_xlabel('Time (us)')
axes.set_ylabel('Memory (GB)')
_, end = axes.get_ylim()
axes.grid(True)
axes.set_yticks(np.arange(0, end, yxis_step_size))
title = '\n\n'.join(([title] if title else []) + [
f'Max memory allocated: {max_memory_allocated/(10**9):.2f} GB \n'
f'Max memory reserved: {max_memory_reserved/(10**9):.2f} GB'
])
axes.set_title(title)

if return_fig:
return fig

# Embed the memory timeline image into the HTML file
tmpfile = NamedTemporaryFile('wb', suffix='.png', delete=False)
tmpfile.close()
fig.savefig(tmpfile.name, format='png')

with open(tmpfile.name, 'rb') as tmp:
encoded = b64encode(tmp.read()).decode('utf-8')
html = f"""<html>
<head><meta charset="utf-8" /><title>GPU Memory Timeline HTML</title></head>
<body>
<img src='data:image/png;base64,{encoded}'>
</body>
</html>"""

with open(path, 'w') as f:
f.write(html)
log.debug('Memory timeline exported to', path, '.')
remove(tmpfile.name)
52 changes: 52 additions & 0 deletions tests/profiler/test_memory_timeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

import os
import pathlib

import pytest
import torch
from packaging import version

from composer.profiler.utils import export_memory_timeline_html


@pytest.mark.gpu
def test_memory_timeline(tmp_path: pathlib.Path) -> None:
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory timeline is supported after PyTorch 2.1.0.
return
import torch.profiler._memory_profiler as _memory_profiler

model = torch.nn.Sequential(
torch.nn.Linear(1024, 1024, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(1024, 1024, bias=False),
torch.nn.Softmax(dim=1),
).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

x = torch.ones((1024, 1024), device='cuda')
targets = torch.ones((1024, 1024), device='cuda')
with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof:
y = model(x)
loss = torch.nn.functional.mse_loss(y, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()

memory_profile = prof._memory_profile()
timeline = memory_profile.timeline

# this checks the default memory timeline event value (t == -1) for preexisting tensors
assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline)

fig = export_memory_timeline_html(
prof,
os.path.join(tmp_path, 'test_memory_timeline.html'),
yxis_step_size=0.01,
return_fig=True,
)
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06

0 comments on commit a7cad7c

Please sign in to comment.