Skip to content

Commit

Permalink
[Profiler] Unify the device(CUDA, XPU, PrivateUse1) in torch profiler…
Browse files Browse the repository at this point in the history
… post processing (pytorch#123247)

This PR unifies the CUDA, XPU and PrivateUse1 in the torch profiler. Now CUDA, XPU and PrivateUse1 can together use string object `use_device` to distinguish each other and share one device path for calculating kineto time durations and memory statistics for post processing.

#suppress-api-compatibility-check

Co-authored-by: Aaron Enye Shi <enye.shi@gmail.com>
Pull Request resolved: pytorch#123247
Approved by: https://github.com/aaronenyeshi, https://github.com/gujinghui
  • Loading branch information
zejun-chen authored and pytorchmergebot committed Apr 19, 2024
1 parent 803a08f commit 768ce2c
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 260 deletions.
4 changes: 2 additions & 2 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def create_mkldnn_tensor():
stats = run_profiler(create_cuda_tensor)
check_metrics(
stats,
"cuda_memory_usage",
"device_memory_usage",
allocs=[
"test_user_scope_alloc",
"aten::to",
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def create_mkldnn_tensor():
deallocs=["[memory]"],
)
if torch.cuda.is_available():
check_metrics(stats, "cuda_memory_usage", deallocs=["[memory]"])
check_metrics(stats, "device_memory_usage", deallocs=["[memory]"])

@unittest.skipIf(
IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared"
Expand Down
4 changes: 2 additions & 2 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4628,11 +4628,11 @@ def test_profiler_function_event_avg(self):
self.assertEqual(avg.count, 4)
self.assertEqual(avg.cpu_time_total, 30)
self.assertEqual(avg.self_cpu_time_total, 30)
self.assertEqual(avg.cuda_time_total, 0)
self.assertEqual(avg.device_time_total, 0)

# average stats
self.assertEqual(avg.cpu_time, 7.5)
self.assertEqual(avg.cuda_time_total, 0)
self.assertEqual(avg.device_time_total, 0)

def test_profiler_shapes(self):
print("")
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_autograd.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from ._profiler import (
class DeviceType(Enum):
CPU = ...
CUDA = ...
XPU = ...
MKLDNN = ...
OPENGL = ...
OPENCL = ...
Expand Down
1 change: 1 addition & 0 deletions torch/_C/_profiler.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class ActiveProfilerType(Enum):
class ProfilerActivity(Enum):
CPU = ...
CUDA = ...
XPU = ...
MTIA = ...
PrivateUse1 = ...

Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float
log.debug("profiling time breakdown")
log.debug(actual_events.table(row_limit=-1))

res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat
res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
log.debug("profiling results: %s ms", res)
return res

Expand Down
86 changes: 45 additions & 41 deletions torch/autograd/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch

import torch.cuda
from torch._C import _get_privateuse1_backend_name
from torch._C._profiler import _ExperimentalConfig

from torch.autograd import (
Expand Down Expand Up @@ -112,8 +111,12 @@ class profile:
Args:
enabled (bool, optional): Setting this to False makes this context manager a no-op.
use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API.
Adds approximately 4us of overhead to each tensor operation.
use_cuda (bool, optional): Enables timing of CUDA events as well
using the cudaEvent API. (will be deprecated)
use_device (str, optional): Enables timing of device events.
Adds approximately 4us of overhead to each tensor operation when use cuda.
The valid devices options are 'cuda', 'xpu' and 'privateuseone'.
record_shapes (bool, optional): If shapes recording is set, information
about input dimensions will be collected. This allows one to see which
Expand Down Expand Up @@ -161,9 +164,9 @@ class profile:
.. warning:
Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_),
one cannot use the profiler with ``use_cuda = True`` to benchmark
one cannot use the profiler with ``use_device = 'cuda'`` to benchmark
DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading,
please use ``use_cuda = False`` or ``num_workers = 0``.
please use ``use_device = None`` or ``num_workers = 0``.
Example:
>>> # xdoctest: +SKIP
Expand Down Expand Up @@ -207,9 +210,13 @@ def __init__(
if not self.enabled:
return
self.use_cuda = use_cuda
self.use_device: Optional[str] = (
use_device if use_device != "privateuseone" else None
)
if self.use_cuda:
warn(
"The attribute `use_cuda` will be deprecated soon, please use ``use_device = 'cuda'`` instead."
)
self.use_device: Optional[str] = "cuda"
else:
self.use_device = use_device
self.function_events: Optional[EventList] = None
self.entered = False
self.record_shapes = record_shapes
Expand All @@ -233,17 +240,19 @@ def __init__(
use_kineto
), "Device-only events supported only with Kineto (use_kineto=True)"

if self.use_device == "cuda":
self.use_device = None
self.use_cuda = True

if self.use_device and self.use_device != _get_privateuse1_backend_name():
warn(f"{self.use_device} doesn't support profile.")
VALID_DEVICE_OPTIONS = ["cuda", "xpu", "privateuseone"]
if self.use_device not in VALID_DEVICE_OPTIONS:
warn(f"The {self.use_device} is not a valid device option.")
self.use_device = None

if self.use_cuda and not torch.cuda.is_available():
if self.use_device == "cuda" and not torch.cuda.is_available():
warn("CUDA is not available, disabling CUDA profiling")
self.use_cuda = False
self.use_device = None

if self.use_device == "xpu" and not torch.xpu.is_available():
warn("XPU is not available, disabling XPU profiling")
self.use_device = None

self.kineto_activities = set()
if self.use_cpu:
Expand All @@ -252,14 +261,18 @@ def __init__(
self.kineto_activities.add(ProfilerActivity.MTIA)

self.profiler_kind = ProfilerState.KINETO
if self.use_cuda:
if self.use_device == "cuda":
if not use_kineto or ProfilerActivity.CUDA not in _supported_activities():
assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True"
self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK
else:
self.kineto_activities.add(ProfilerActivity.CUDA)

if self.use_device:
elif self.use_device == "xpu":
assert (
use_kineto and ProfilerActivity.XPU in _supported_activities()
), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices."
self.kineto_activities.add(ProfilerActivity.XPU)
elif self.use_device is not None and self.use_device != "privateuseone":
if (
not use_kineto
or ProfilerActivity.PrivateUse1 not in _supported_activities()
Expand Down Expand Up @@ -315,8 +328,10 @@ def _start_trace(self):
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.enabled:
return
if self.use_cuda:
if self.use_device == "cuda":
torch.cuda.synchronize()
elif self.use_device == "xpu":
torch.xpu.synchronize()

t0 = perf_counter_ns()
self.kineto_results = _disable_profiler()
Expand All @@ -332,7 +347,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):

self.function_events = EventList(
parsed_results,
use_cuda=self.use_cuda,
use_device=self.use_device,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
Expand Down Expand Up @@ -445,17 +459,11 @@ def _cpu_memory_usage(mem_record):
else 0
)

def _cuda_memory_usage(mem_record):
def _device_memory_usage(mem_record):
return (
mem_record.nbytes()
if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP]
else 0
)

def _privateuse1_memory_usage(mem_record):
return (
mem_record.nbytes()
if mem_record.device_type() in [DeviceType.PrivateUse1]
if mem_record.device_type()
in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP]
else 0
)

Expand All @@ -471,16 +479,14 @@ def _privateuse1_memory_usage(mem_record):
abs_end_ns = kineto_event.start_ns() + kineto_event.duration_ns()

cpu_memory_usage = 0
cuda_memory_usage = 0
privateuse1_memory_usage = 0
device_memory_usage = 0
if kineto_event.device_type() == DeviceType.CPU:
# find the corresponding memory allocation events
for mem_record in mem_records_acc.in_interval(
kineto_event.start_ns() / 1000, abs_end_ns / 1000
):
cpu_memory_usage += _cpu_memory_usage(mem_record[0])
cuda_memory_usage += _cuda_memory_usage(mem_record[0])
privateuse1_memory_usage += _privateuse1_memory_usage(mem_record[0])
device_memory_usage += _device_memory_usage(mem_record[0])
mem_record[1] = True

is_async = kineto_event.is_async() or (
Expand All @@ -505,8 +511,7 @@ def _privateuse1_memory_usage(mem_record):
scope=kineto_event.scope(),
use_device=self.use_device,
cpu_memory_usage=cpu_memory_usage,
cuda_memory_usage=cuda_memory_usage,
privateuse1_memory_usage=privateuse1_memory_usage,
device_memory_usage=device_memory_usage,
is_async=is_async,
sequence_nr=kineto_event.sequence_nr(),
device_type=kineto_event.device_type(),
Expand All @@ -516,12 +521,12 @@ def _privateuse1_memory_usage(mem_record):
)
max_evt_id = max(max_evt_id, fe.id)
if fe.device_type == DeviceType.CPU and not fe.is_async:
if self.use_device:
if self.use_device == "privateuseone":
privateuse1_time = kineto_event.privateuse1_elapsed_us()
if privateuse1_time > 0:
fe.append_kernel(fe.name, fe.device_index, privateuse1_time)
fe.is_legacy = True
else:
elif self.use_device == "cuda":
# Check if we have CUDA time as a fallback
cuda_time = kineto_event.cuda_elapsed_us()
if cuda_time > 0:
Expand All @@ -534,7 +539,7 @@ def _privateuse1_memory_usage(mem_record):
device_corr_map[corr_id] = []
device_corr_map[corr_id].append(fe)

# associate CUDA kernels and CUDA runtime (CPU) with CPU events
# associate device kernels and device runtime (CPU) with CPU events
for fe in function_events:
if (
fe.device_type == DeviceType.CPU
Expand All @@ -549,7 +554,7 @@ def _privateuse1_memory_usage(mem_record):
f_evt.time_range.end - f_evt.time_range.start,
)
elif f_evt.device_type == DeviceType.CPU:
# make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated
# make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated
# with the 'thread' of the corresponding linked PyTorch event to properly track
# parents and children
f_evt.thread = fe.thread
Expand All @@ -569,8 +574,7 @@ def createFunctionEventForMemoryEvents(evt):
scope=0, # RecordScope::FUNCTION
use_device=self.use_device,
cpu_memory_usage=_cpu_memory_usage(evt),
cuda_memory_usage=_cuda_memory_usage(evt),
privateuse1_memory_usage=_privateuse1_memory_usage(evt),
device_memory_usage=_device_memory_usage(evt),
is_async=False,
sequence_nr=-1,
device_type=DeviceType.CPU,
Expand Down
6 changes: 3 additions & 3 deletions torch/autograd/profiler_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
parsed_results = _parse_legacy_records(records)
self.function_events = EventList(
parsed_results,
use_cuda=self.use_cuda,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
)
Expand Down Expand Up @@ -251,7 +251,7 @@ def _get_record_key(record):
],
scope=start.scope(),
cpu_memory_usage=cpu_memory_usage,
cuda_memory_usage=cuda_memory_usage,
device_memory_usage=cuda_memory_usage,
is_async=is_async,
is_remote=is_remote_event,
sequence_nr=start.sequence_nr(),
Expand Down Expand Up @@ -287,7 +287,7 @@ def _get_record_key(record):
end_us=0,
stack=[],
cpu_memory_usage=record.cpu_memory_usage(),
cuda_memory_usage=record.cuda_memory_usage(),
device_memory_usage=record.cuda_memory_usage(),
is_legacy=True,
)
functions.append(fe)
Expand Down
Loading

0 comments on commit 768ce2c

Please sign in to comment.