Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor PyTorch profiler 4/5 #6349

Merged
merged 98 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
aa89fa1
Refactor profilers
carmocca Mar 4, 2021
b66f1f6
Update PassThrough
carmocca Mar 5, 2021
d388d80
WIP - This is broken and will change
carmocca Mar 6, 2021
93aa073
Update pytorch_lightning/profiler/pytorch.py
carmocca Mar 9, 2021
0f8b6ef
resolve tests
tchaton Mar 18, 2021
d219ae1
Merge branch 'master' into update-pytorch-profiler
tchaton Mar 18, 2021
8a3d3cf
resolve tests
tchaton Mar 18, 2021
af865ff
find output
tchaton Mar 18, 2021
6147e6e
try something
tchaton Mar 19, 2021
5053b4f
update
tchaton Mar 19, 2021
3ae2fbf
add support for test and predict
tchaton Mar 19, 2021
1d7ac88
update
tchaton Mar 19, 2021
64f8a58
update
tchaton Mar 19, 2021
655ff46
use getattr
tchaton Mar 19, 2021
3f5c4d6
test
tchaton Mar 19, 2021
ced0d66
test
tchaton Mar 19, 2021
662f8de
update
tchaton Mar 19, 2021
753f63b
tests
tchaton Mar 19, 2021
e6a6917
update
tchaton Mar 19, 2021
a7f7f4c
update
tchaton Mar 19, 2021
c221d7d
update
tchaton Mar 19, 2021
e39f6d7
update
tchaton Mar 19, 2021
e485850
update
tchaton Mar 19, 2021
00d355a
remove file
tchaton Mar 19, 2021
61da390
update
tchaton Mar 19, 2021
002f137
Merge branch 'update-pytorch-profiler' of https://github.com/PyTorchL…
tchaton Mar 19, 2021
e565e8a
update
tchaton Mar 19, 2021
6f2258f
update
tchaton Mar 19, 2021
883b242
update
tchaton Mar 21, 2021
9b957bd
update
tchaton Mar 21, 2021
70cc20d
test
tchaton Mar 21, 2021
226e0be
update#
tchaton Mar 21, 2021
27348c7
update
tchaton Mar 21, 2021
853d36c
update tests
tchaton Mar 21, 2021
9004e92
update
tchaton Mar 21, 2021
9a87c52
Merge branch 'master' into update-pytorch-profiler
carmocca Mar 22, 2021
fbbb9a2
Refactor basic profilers
carmocca Mar 22, 2021
2a82e05
Fixes
carmocca Mar 22, 2021
ff125e2
Unused import
carmocca Mar 22, 2021
01a760e
Introduce setup
carmocca Mar 22, 2021
b31831e
Profile on all ranks. Print to stdout on 0
carmocca Mar 22, 2021
f8a8772
Introduce dirpath + filename
carmocca Mar 22, 2021
aa4b7dd
CHANGELOG
carmocca Mar 22, 2021
8e3034e
Add tests. Address comments
carmocca Mar 22, 2021
e4e0dd6
add `on_run_stage_setup`
tchaton Mar 22, 2021
d0fdbb9
add on_run_stage_setup function
tchaton Mar 22, 2021
1a16bb3
update
tchaton Mar 22, 2021
52fa69b
add test for RegisterRecordFunction
tchaton Mar 22, 2021
63b6988
update lightnng flow direction
tchaton Mar 22, 2021
1cf5a64
move variable to private
tchaton Mar 22, 2021
a05acdd
remove trace
tchaton Mar 22, 2021
af72dff
Merge branch 'master' into refactor-base-profilers
tchaton Mar 22, 2021
59c941b
Undo code that should be in 3/4
carmocca Mar 22, 2021
da0f310
Multi-stage multi-rank
carmocca Mar 22, 2021
59c1b4c
2/5 changes
carmocca Mar 22, 2021
dd1dce0
Pass stage in __del__
carmocca Mar 22, 2021
12d014b
Merge branch 'master' into refactor-base-profilers
carmocca Mar 22, 2021
097a426
Remove TODOs
carmocca Mar 22, 2021
4d529fa
Describe on_evaluation_end. Add tests
carmocca Mar 22, 2021
58dcd4e
Typo
carmocca Mar 22, 2021
c37162f
Address comments
carmocca Mar 22, 2021
4c5f1f3
deepcopy tests
carmocca Mar 22, 2021
5ed73fb
Advanced teardown
carmocca Mar 22, 2021
897f8e5
Fix teardown test
carmocca Mar 22, 2021
e42be2a
Fix tests
carmocca Mar 22, 2021
32c301c
Minor change
carmocca Mar 22, 2021
af0c8ad
Update CHANGELOG.md
carmocca Mar 22, 2021
29a73c5
Fix test
carmocca Mar 22, 2021
d6ede58
Merge branch 'refactor-base-profilers' into update-pytorch-profiler
carmocca Mar 22, 2021
0dc1e06
Quick fixes
carmocca Mar 22, 2021
cb756b8
Fix 6522
carmocca Mar 22, 2021
758b942
resolve ddp tests
tchaton Mar 23, 2021
fca4eb2
resolve tests
tchaton Mar 23, 2021
2919a39
resolve some tests
tchaton Mar 23, 2021
d7ca5fa
update tests
tchaton Mar 23, 2021
34b7991
Merge branch 'refactor-base-profilers' into update-pytorch-profiler
tchaton Mar 23, 2021
b3ffe64
resolve tests
tchaton Mar 23, 2021
e0d8308
Merge branch 'master' into update-pytorch-profiler
tchaton Mar 23, 2021
97c87d3
resolve tests
tchaton Mar 23, 2021
ae31f00
Missed fixes from 3/5
carmocca Mar 23, 2021
78970d5
Fixes
carmocca Mar 23, 2021
ba067fd
Broken refactor
carmocca Mar 23, 2021
f916013
Missed stage
carmocca Mar 23, 2021
9a1b1b6
Minor changes
carmocca Mar 23, 2021
04abd0a
resolve tests
tchaton Mar 23, 2021
ab1218e
Merge branch 'update-pytorch-profiler' of https://github.com/PyTorchL…
tchaton Mar 23, 2021
0230330
Update CHANGELOG
carmocca Mar 23, 2021
cba643b
resolve bug
tchaton Mar 23, 2021
85bfc7d
Merge branch 'update-pytorch-profiler' of https://github.com/PyTorchL…
tchaton Mar 23, 2021
af209f5
remove print
tchaton Mar 23, 2021
bb813bb
Typo
carmocca Mar 23, 2021
624a85f
Cleanup
carmocca Mar 23, 2021
1732ece
resolve ddp test
tchaton Mar 23, 2021
f0c1675
remove barrier
tchaton Mar 23, 2021
a13664d
Merge branch 'master' into update-pytorch-profiler
tchaton Mar 23, 2021
374f0da
update
tchaton Mar 23, 2021
6e948ca
Smaller model
carmocca Mar 23, 2021
e0d8623
add check for emit_nvtx
tchaton Mar 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

class tqdm(_tqdm):
"""
Custom tqdm progressbar where we append 0 to floating points/strings to
prevent the progress bar from flickering
Custom tqdm progressbar where we append 0 to floating points/strings to prevent the progress bar from flickering
"""

@staticmethod
Expand Down
152 changes: 68 additions & 84 deletions pytorch_lightning/profiler/profilers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from contextlib import contextmanager
from typing import Optional, Union
from typing import Optional

import numpy as np

Expand All @@ -30,22 +30,8 @@
log = logging.getLogger(__name__)


class BaseProfiler(ABC):
"""
If you wish to write a custom profiler, you should inhereit from this class.
"""

def __init__(self, output_streams: Optional[Union[list, tuple]] = None):
"""
Args:
output_streams: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams
class AbstractProfiler(ABC):
"""Specification of a profiler."""

@abstractmethod
def start(self, action_name: str) -> None:
Expand All @@ -55,6 +41,38 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
"""Defines how to record the duration once an action is complete."""

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""


class BaseProfiler(AbstractProfiler, ABC):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
If you wish to write a custom profiler, you should inherit from this class.
"""

def __init__(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None) -> None:
self.output_fname = getattr(self, "output_fname", None)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.output_file = None
# the profiler can be used outside of lightning
# that's why we call `on_train_start` manually
self.on_train_start(local_rank=local_rank, log_dir=log_dir)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def on_train_start(self, local_rank: Optional[int] = None, log_dir: Optional[str] = None):
"""
This function is used by the Trainer to inject local_rank with `DDP`
and `TensorBoardLogger` log_dir in the profiler.
"""
self.local_rank = local_rank
self.log_dir = log_dir
self.prepare_file()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def prepare_file(self) -> None:
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")
self.write_streams = [self.output_file.write] if self.output_file else [log.info]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@contextmanager
def profile(self, action_name: str) -> None:
"""
Expand Down Expand Up @@ -90,13 +108,25 @@ def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
for write in self.write_streams:
write(self.summary())
if self.output_file:
self.output_file.flush()

@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""

def on_train_start(self, local_rank: Optional[int] = None):
self.local_rank = local_rank
def stats_to_str(self, stats: dict) -> str:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
output = ["Profiler Report"]
for action, value in stats.items():
header = f"Profile stats for: {action}"
if getattr(self, "local_rank", None) is not None:
header += f" rank: {self.local_rank}"
output.append(header)
output.append(value)
return os.linesep.join(output)

def __del__(self) -> None:
"""Close profiler's stream."""
try:
self.output_file.close()
except AttributeError:
pass
tchaton marked this conversation as resolved.
Show resolved Hide resolved


class PassThroughProfiler(BaseProfiler):
Expand All @@ -105,9 +135,6 @@ class PassThroughProfiler(BaseProfiler):
The Trainer uses this class by default.
"""

def __init__(self):
super().__init__(output_streams=None)

def start(self, action_name: str) -> None:
pass

Expand All @@ -124,7 +151,7 @@ class SimpleProfiler(BaseProfiler):
the mean duration of each action and the total time spent over the entire training run.
"""

def __init__(self, output_filename: Optional[str] = None, extended=True):
def __init__(self, output_filename: Optional[str] = None, extended: bool = True):
"""
Args:
output_filename: optionally save profile results to file instead of printing
Expand All @@ -135,19 +162,12 @@ def __init__(self, output_filename: Optional[str] = None, extended=True):
If you attempt to start an action which has already started, or
if you attempt to stop recording an action which was never started.
"""
self.output_fname = output_filename
self.current_actions = {}
self.recorded_durations = defaultdict(list)
self.extended = extended

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__()
self.start_time = time.monotonic()
super().__init__(output_streams=streaming_out)

def start(self, action_name: str) -> None:
if action_name in self.current_actions:
Expand All @@ -169,24 +189,25 @@ def make_report(self):
return report, total_duration

def summary(self) -> str:
output_string = "\n\nProfiler Report\n"
sep = os.linesep
output_string = f"Profiler Report{sep}"

if self.extended:

if len(self.recorded_durations) > 0:
max_key = np.max([len(k) for k in self.recorded_durations.keys()])

def log_row(action, mean, num_calls, total, per):
row = f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t|"
row = f"{sep}{action:<{max_key}s}\t| {mean:<15}\t|"
row += f"{num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
return row

output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(output_string)
output_string += f"{os.linesep}{'-' * output_string_len}"
output_string += f"{sep}{'-' * output_string_len}"
report, total_duration = self.make_report()
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
output_string += f"{os.linesep}{'-' * output_string_len}"
output_string += f"{sep}{'-' * output_string_len}"
for action, durations, duration_per in report:
output_string += log_row(
action,
Expand All @@ -198,27 +219,16 @@ def log_row(action, mean, num_calls, total, per):
else:

def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{sep}{action:<20s}\t| {mean:<15}\t| {total:<15}"

output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
output_string += f"{sep}{'-' * 65}"

for action, durations in self.recorded_durations.items():
output_string += log_row(action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}")
output_string += os.linesep
output_string += sep
return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()


class AdvancedProfiler(BaseProfiler):
"""
Expand All @@ -240,17 +250,10 @@ def __init__(self, output_filename: Optional[str] = None, line_count_restriction
ValueError:
If you attempt to stop recording an action which was never started.
"""
self.output_fname = output_filename
self.profiled_actions = {}
self.line_count_restriction = line_count_restriction

self.output_fname = output_filename
self.output_file = None
if self.output_fname:
fs = get_filesystem(self.output_fname)
self.output_file = fs.open(self.output_fname, "w")

streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
super().__init__()

def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
Expand All @@ -260,9 +263,7 @@ def start(self, action_name: str) -> None:
def stop(self, action_name: str) -> None:
pr = self.profiled_actions.get(action_name)
if pr is None:
raise ValueError( # pragma: no-cover
f"Attempting to stop recording an action ({action_name}) which was never started."
)
raise ValueError(f"Attempting to stop recording an action ({action_name}) which was never started.")
pr.disable()

def summary(self) -> str:
Expand All @@ -272,21 +273,4 @@ def summary(self) -> str:
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
recorded_stats[action_name] = s.getvalue()

# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"

return output_string

def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()

def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
return self.stats_to_str(recorded_stats)
Loading