Skip to content

Commit

Permalink
Merge branch 'dev' into turn-on-system-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Dec 15, 2023
2 parents 42db526 + 45bb135 commit 7d4acaa
Show file tree
Hide file tree
Showing 23 changed files with 713 additions and 115 deletions.
2 changes: 1 addition & 1 deletion composer/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

"""The Composer Version."""

__version__ = '0.17.1'
__version__ = '0.17.2'
10 changes: 5 additions & 5 deletions composer/profiler/marker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Marker:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand All @@ -57,7 +57,7 @@ class Marker:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down Expand Up @@ -124,7 +124,7 @@ def start(self) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down Expand Up @@ -187,7 +187,7 @@ def instant(self) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand All @@ -213,7 +213,7 @@ def counter(self, values: Dict[str, Union[float, int]]) -> None:
.. testsetup::
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
.. doctest::
Expand Down
26 changes: 25 additions & 1 deletion composer/profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Profiler:
def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
if 'trace_handlers' not in kwargs:
kwargs['trace_handlers'] = []
kwargs['torch_prof_memory_filename'] = None
original_profiler_init(self, **kwargs)
Profiler.__init__ = new_profiler_init
Expand All @@ -62,6 +63,7 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
active=4,
repeat=1,
),
torch_prof_memory_filename=None,
)
trace_handlers (TraceHandler | Sequence[TraceHandler]): Trace handlers which record and
Expand All @@ -76,6 +78,9 @@ def new_profiler_init(self, dummy_ellipsis=None, **kwargs):
torch_prof_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_memory_filename (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_memory_remote_file_name (str, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Additionally supports full object store paths e.g: s3://bucket/path/to/file.
torch_prof_overwrite (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_use_gzip (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
torch_prof_record_shapes (bool, optional): See :class:`~composer.profiler.torch_profiler.TorchProfiler`.
Expand All @@ -97,6 +102,9 @@ def __init__(
torch_prof_folder: str = '{run_name}/torch_traces',
torch_prof_filename: str = 'rank{rank}.{batch}.pt.trace.json',
torch_prof_remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
torch_prof_memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.memory_trace.html',
torch_prof_overwrite: bool = False,
torch_prof_use_gzip: bool = False,
torch_prof_record_shapes: bool = False,
Expand All @@ -116,6 +124,9 @@ def __init__(
if torch_prof_remote_file_name:
self.remote_filenames.append(torch_prof_remote_file_name)
_, _, torch_prof_remote_file_name = parse_uri(torch_prof_remote_file_name)
if torch_prof_memory_remote_file_name:
self.remote_filenames.append(torch_prof_memory_remote_file_name)
_, _, torch_prof_memory_remote_file_name = parse_uri(torch_prof_memory_remote_file_name)
for handler in self._trace_handlers:
if isinstance(handler, JSONTraceHandler):
if handler.remote_file_name:
Expand All @@ -134,11 +145,24 @@ def __init__(
profile_net=sys_prof_net,
stats_thread_interval_seconds=sys_prof_stats_thread_interval_seconds))

if torch_prof_memory_filename is not None:
if not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory):
raise ValueError(
f'torch_prof_memory_filename is set. Generating the memory timeline graph requires all the three flags torch_prof_with_stack, torch_prof_record_shapes, and torch_prof_profile_memory to be true. Got torch_prof_with_stack={torch_prof_with_stack}, torch_prof_record_shapes={torch_prof_record_shapes}, torch_prof_profile_memory={torch_prof_profile_memory}'
)
log.info(
f'Memory profiling is enabled and uses {torch_prof_memory_filename} as the filename to generate the memory timeline graph. To disable the memory timeline graph generation, explicitly set torch_prof_memory_filename to None.'
)
else:
log.info(f'torch_prof_memory_filename is explicitly set to None. Memory timeline will not be be generated.')

if torch_prof_record_shapes or torch_prof_profile_memory or torch_prof_with_stack or torch_prof_with_flops:
self._callbacks.append(
TorchProfiler(filename=torch_prof_filename,
folder=torch_prof_folder,
remote_file_name=torch_prof_remote_file_name,
memory_filename=torch_prof_memory_filename,
memory_remote_file_name=torch_prof_memory_remote_file_name,
num_traces_to_keep=torch_prof_num_traces_to_keep,
overwrite=torch_prof_overwrite,
record_shapes=torch_prof_record_shapes,
Expand Down Expand Up @@ -219,7 +243,7 @@ def marker(
from composer.profiler import Profiler, cyclic_schedule
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[])
profiler = Profiler(schedule=cyclic_schedule(), trace_handlers=[], torch_prof_memory_filename=None)
profiler.bind_to_state(state)
state.profiler = profiler
Expand Down
140 changes: 114 additions & 26 deletions composer/profiler/torch_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
import textwrap
from typing import TYPE_CHECKING, Optional, OrderedDict

import torch.cuda
import torch.profiler
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction

from composer.core.callback import Callback
Expand Down Expand Up @@ -92,9 +94,9 @@ class TorchProfiler(Callback): # noqa: D101
Each rank (process) will save traces to::
awesome-training-run/torch_traces/ep1-ba42-rank0.json
awesome-training-run/torch_traces/ep1-ba42-rank1.json
awesome-training-run/torch_traces/ep1-ba42-rank2.json
awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.json
awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.json
awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.json
...
remote_file_name (str, optional): Format string for a Torch Profiler trace file's remote file name.
Expand All @@ -107,6 +109,43 @@ class TorchProfiler(Callback): # noqa: D101
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
memory_filename (str, optional): A format string describing how to name Torch Profiler memory trace files.
Defaults to ``'rank{{rank}}.{{batch}}.pt.trace.memory.html'``.
At the end of each batch where :meth:`~composer.profiler.Profiler.get_action` returns
:attr:`~composer.profiler._profiler_action.ProfilerAction.ACTIVE_AND_SAVE`, trace files are saved
approximately to ``{{folder.format(...)}}/{{memory_filename.format(...)}}``.
The following format variables are available:
{textwrap.indent(FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, prefix=' ')}
Consider the following scenario, where:
* The :attr:`~.State.run_name` is ``'awesome-training-run'``.
* The default ``trace_folder='{{run_name}}/torch_traces'`` is used.
* The default ``name='rank{{rank}}.{{batch}}.pt.trace.memory.html'`` is used.
* The current epoch count is ``1``.
* The current batch count is ``42``.
Each rank (process) will save traces to::
awesome-training-run/torch_traces/ep1-ba42-rank0.pt.trace.memory.html
awesome-training-run/torch_traces/ep1-ba42-rank1.pt.trace.memory.html
awesome-training-run/torch_traces/ep1-ba42-rank2.pt.trace.memory.html
...
memory_remote_file_name (str, optional): Format string for a Torch Profiler memory trace file's remote file name.
Defaults to ``'{{run_name}}/torch_traces/rank{{rank}}.{{batch}}.pt.trace.memory.json'``.
Whenever a trace file is saved, it is also uploaded as a file according to this format string.
The same format variables as for ``filename`` are available.
.. seealso:: :doc:`Uploading Files</trainer/file_uploading>` for notes for file uploading.
Leading slashes (``'/'``) will be stripped.
To disable uploading trace files, set this parameter to ``None``.
overwrite (bool, optional): Whether to override existing Torch Profiler traces. Defaults to False.
Expand Down Expand Up @@ -146,7 +185,9 @@ def __init__(
folder: str = '{run_name}/torch_traces',
filename: str = 'rank{rank}.{batch}.pt.trace.json',
remote_file_name: Optional[str] = '{run_name}/torch_traces/rank{rank}.{batch}.pt.trace.json',
*,
memory_filename: Optional[str] = 'rank{rank}.{batch}.pt.trace.memory.html',
memory_remote_file_name: Optional[
str] = '{run_name}/torch_memory_traces/rank{rank}.{batch}.pt.trace.memory.html',
overwrite: bool = False,
use_gzip: bool = False,
record_shapes: bool = False,
Expand All @@ -157,12 +198,26 @@ def __init__(
) -> None:
self.overwrite = overwrite
self.folder = folder
if use_gzip and not filename.endswith('.gz'):
filename += '.gz'

if use_gzip:
if not filename.endswith('.gz'):
filename += '.gz'
self.filename = filename
if use_gzip and remote_file_name is not None and not remote_file_name.endswith('.gz'):
remote_file_name += '.gz'

if use_gzip:
if remote_file_name is not None and not remote_file_name.endswith('.gz'):
remote_file_name += '.gz'
self.remote_file_name = remote_file_name

if memory_filename is not None:
assert memory_filename.endswith('.html'), f'memory_filename must end with .html, got {memory_filename}'
self.memory_filename = memory_filename

if memory_remote_file_name is not None:
assert memory_remote_file_name.endswith(
'.html'), f'memory_remote_file_name must end with .html, got {memory_remote_file_name}'
self.memory_remote_file_name = memory_remote_file_name

self.record_shapes = record_shapes
self.profile_memory = profile_memory
self.with_stack = with_stack
Expand Down Expand Up @@ -203,27 +258,60 @@ def handler_fn(prof: torch.profiler.profiler.profile):

timestamp = state.timestamp

trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
)
trace_file_dirname = os.path.dirname(trace_file_name)
if trace_file_dirname:
os.makedirs(trace_file_dirname, exist_ok=True)
prof.export_chrome_trace(trace_file_name)
state.profiler.record_chrome_json_trace_file(trace_file_name)
if self.remote_file_name is not None:
trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
trace_remote_file_name = trace_remote_file_name.lstrip('/')
logger.upload_file(remote_file_name=trace_remote_file_name,
file_path=trace_file_name,
overwrite=self.overwrite)
log.info(f'PyTorch Chrome trace profiler enabled: {self.filename if self.filename else False}')
if self.filename is not None:
trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=timestamp),
)
trace_file_dirname = os.path.dirname(trace_file_name)
if trace_file_dirname:
os.makedirs(trace_file_dirname, exist_ok=True)
prof.export_chrome_trace(trace_file_name)
state.profiler.record_chrome_json_trace_file(trace_file_name)
if self.remote_file_name is not None:
trace_remote_file_name = format_name_with_dist_and_time(self.remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
trace_remote_file_name = trace_remote_file_name.lstrip('/')
logger.upload_file(remote_file_name=trace_remote_file_name,
file_path=trace_file_name,
overwrite=self.overwrite)

log.info(
f'PyTorch memory timeline profiler enabled: {self.memory_filename if self.memory_filename else False}')
if self.memory_filename is not None:
if version.parse(torch.__version__) > version.parse('2.1.0.dev'): # type: ignore
# memory timeline profiling is only supported in torch v2.1.0-rc1 or higher
memory_trace_file_name = os.path.join(
folder_name,
format_name_with_dist_and_time(self.memory_filename,
run_name=state.run_name,
timestamp=timestamp),
)
log.debug(f'Saving memory trace to {memory_trace_file_name}')
memory_trace_file_dirname = os.path.dirname(memory_trace_file_name)
if memory_trace_file_dirname:
os.makedirs(memory_trace_file_dirname, exist_ok=True)
from composer.profiler.utils import export_memory_timeline_html
export_memory_timeline_html(prof, memory_trace_file_name,
torch.cuda.current_device()) # type: ignore
log.debug(f'Uploaded memory trace to {self.memory_remote_file_name}')
if self.memory_remote_file_name is not None:
memory_trace_remote_file_name = format_name_with_dist_and_time(self.memory_remote_file_name,
run_name=state.run_name,
timestamp=timestamp)
memory_trace_remote_file_name = memory_trace_remote_file_name.lstrip('/')
log.debug(
f'Uploading memory trace to {memory_trace_remote_file_name} from {memory_trace_file_name}')
logger.upload_file(remote_file_name=memory_trace_remote_file_name,
file_path=memory_trace_file_name,
overwrite=self.overwrite)
else:
log.warning('Memory timeline is supported after PyTorch 2.1.0. Skipping memory trace.')

if self.num_traces_to_keep >= 0:
while len(self.saved_traces) > self.num_traces_to_keep:

# self.saved_traces is an ordered dict, so the zeroth item will be the oldest checkpoint
timestamp, filepaths = next(iter(self.saved_traces.items()))
if dist.get_global_rank() < len(filepaths):
Expand Down
Loading

0 comments on commit 7d4acaa

Please sign in to comment.