-
Notifications
You must be signed in to change notification settings - Fork 427
/
test_profiler.py
220 lines (181 loc) · 7.7 KB
/
test_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
import os
import pathlib
from typing import Union
from unittest.mock import MagicMock
import pytest
import torch
from packaging import version
from torch.profiler.profiler import ProfilerAction as TorchProfilerAction
from composer.core import Engine, Event, State, Timestamp
from composer.loggers import Logger
from composer.profiler import Profiler, ProfilerAction, SystemProfiler, TorchProfiler, cyclic_schedule
from composer.profiler.utils import export_memory_timeline_html
@pytest.mark.parametrize('repeat', [1, 0])
def test_cyclic_schedule(dummy_state: State, repeat: int):
# tests that get_action works correctly given the state
skip_first = 1
wait = 2
warmup = 3
active = 4
schedule = cyclic_schedule(skip_first=1, wait=2, warmup=3, active=4, repeat=repeat)
assert schedule(dummy_state) == ProfilerAction.SKIP # skip first epoch
for _ in range(skip_first):
dummy_state.timestamp = dummy_state.timestamp.to_next_batch()
assert schedule(dummy_state) == ProfilerAction.SKIP
for _ in range(wait):
dummy_state.timestamp = dummy_state.timestamp.to_next_batch()
assert schedule(dummy_state) == ProfilerAction.WARMUP
for _ in range(warmup):
dummy_state.timestamp = dummy_state.timestamp.to_next_batch()
assert schedule(dummy_state) == ProfilerAction.ACTIVE
for _ in range(active + wait + warmup):
dummy_state.timestamp = dummy_state.timestamp.to_next_batch()
if repeat == 0:
assert schedule(dummy_state) == ProfilerAction.ACTIVE
else:
assert schedule(dummy_state) == ProfilerAction.SKIP
def test_profiler_init(minimal_state: State):
# Construct a profiler and assert that it created the correct callbacks from the arguments
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=cyclic_schedule(),
torch_prof_profile_memory=True,
torch_prof_memory_filename=None,
sys_prof_cpu=True,
)
profiler.bind_to_state(minimal_state)
assert any(isinstance(cb, TorchProfiler) for cb in minimal_state.callbacks)
assert any(isinstance(cb, SystemProfiler) for cb in minimal_state.callbacks)
def test_marker(dummy_state: State):
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=cyclic_schedule(),
torch_prof_memory_filename=None,
)
profiler.bind_to_state(dummy_state)
dummy_state.profiler = profiler
marker = profiler.marker(
'name',
actions=[ProfilerAction.SKIP, ProfilerAction.WARMUP, ProfilerAction.ACTIVE],
categories=['cat1'],
)
marker.start() # call #1
with pytest.raises(RuntimeError):
marker.start() # cannot call start twice without finishing
marker.finish() # call #2
with pytest.raises(RuntimeError):
marker.finish() # cannot call finish twice without a start before
with marker:
pass # call #3 and #4
@marker
def func_to_profile(foo: str):
assert foo == 'hi'
func_to_profile(foo='hi') # call 5 and 6
@marker()
def func_to_profile2(bar: int):
assert bar == 6
func_to_profile2(bar=6) # call 7 and 8
marker.instant()
assert mock_trace_handler.process_duration_event.call_count == 8
assert mock_trace_handler.process_instant_event.call_count == 1
@pytest.mark.parametrize('torch_prof_with_stack', [True, False])
@pytest.mark.parametrize('torch_prof_record_shapes', [True, False])
@pytest.mark.parametrize('torch_prof_profile_memory', [True, False])
@pytest.mark.parametrize('torch_prof_memory_filename', [None, 'test.html'])
def test_profiler_error_message(
torch_prof_with_stack: bool,
torch_prof_record_shapes: bool,
torch_prof_profile_memory: bool,
torch_prof_memory_filename: Union[None, str],
) -> None:
# Construct a profiler and assert that it triggers the ValueError if the arguments are invalid
if (
torch_prof_memory_filename is not None and
not (torch_prof_with_stack and torch_prof_record_shapes and torch_prof_profile_memory)
):
with pytest.raises(ValueError):
_ = Profiler(
trace_handlers=[MagicMock()],
schedule=cyclic_schedule(),
torch_prof_with_stack=torch_prof_with_stack,
torch_prof_record_shapes=torch_prof_record_shapes,
torch_prof_profile_memory=torch_prof_profile_memory,
torch_prof_memory_filename=torch_prof_memory_filename,
)
else:
_ = Profiler(
trace_handlers=[MagicMock()],
schedule=cyclic_schedule(),
torch_prof_with_stack=torch_prof_with_stack,
torch_prof_record_shapes=torch_prof_record_shapes,
torch_prof_profile_memory=torch_prof_profile_memory,
torch_prof_memory_filename=torch_prof_memory_filename,
)
@pytest.mark.gpu
def test_memory_timeline(tmp_path: pathlib.Path) -> None:
if version.parse(torch.__version__) <= version.parse('2.1.0.dev'):
# memory timeline is supported after PyTorch 2.1.0.
return
import torch.profiler._memory_profiler as _memory_profiler
model = torch.nn.Sequential(
torch.nn.Linear(1024, 1024, bias=True),
torch.nn.ReLU(),
torch.nn.Linear(1024, 1024, bias=False),
torch.nn.Softmax(dim=1),
).to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
x = torch.ones((1024, 1024), device='cuda')
targets = torch.ones((1024, 1024), device='cuda')
with torch.profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof:
y = model(x)
loss = torch.nn.functional.mse_loss(y, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
memory_profile = prof._memory_profile()
timeline = memory_profile.timeline
# this checks the default memory timeline event value (t == -1) for preexisting tensors
assert all((t == -1) if action == _memory_profiler.Action.PREEXISTING else (t > 0) for t, action, _, _ in timeline)
fig = export_memory_timeline_html(
prof,
os.path.join(tmp_path, 'test_memory_timeline.html'),
yxis_step_size=0.01,
return_fig=True,
)
assert fig is not None, 'export_memory_timeline_html should return a figure when return_fig=True'
_, end = fig.gca().get_ylim()
assert round(end, 2) == 0.06
def test_skip_first_after_resumption(minimal_state: State) -> None:
skip_first = 1
wait = 2
warmup = 3
active = 4
repeat = 1
schedule = cyclic_schedule(skip_first=skip_first, wait=wait, warmup=warmup, active=active, repeat=repeat)
mock_trace_handler = MagicMock()
profiler = Profiler(
trace_handlers=[mock_trace_handler],
schedule=schedule,
)
profiler.bind_to_state(minimal_state)
minimal_state.profiler = profiler
assert len(profiler._callbacks) >= 1
assert isinstance(profiler._callbacks[-1], TorchProfiler)
torch_profiler = profiler._callbacks[-1]
# Create torch.profiler.profile
logger = Logger(minimal_state)
engine = Engine(state=minimal_state, logger=logger)
engine.run_event(Event.INIT)
assert torch_profiler.profiler is not None
minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.RECORD
# Load checkpoint at batch 4
minimal_state.timestamp = Timestamp(batch_in_epoch=4)
engine.run_event(Event.BEFORE_LOAD)
engine.run_event(Event.AFTER_LOAD)
minimal_state.timestamp = Timestamp(batch_in_epoch=7)
assert torch_profiler.profiler.schedule(0) == TorchProfilerAction.WARMUP