diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index d9012d0e89352e..56771eb188e6c1 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1095,7 +1095,7 @@ def create_mkldnn_tensor(): stats = run_profiler(create_cuda_tensor) check_metrics( stats, - "cuda_memory_usage", + "device_memory_usage", allocs=[ "test_user_scope_alloc", "aten::to", @@ -1147,7 +1147,7 @@ def create_mkldnn_tensor(): deallocs=["[memory]"], ) if torch.cuda.is_available(): - check_metrics(stats, "cuda_memory_usage", deallocs=["[memory]"]) + check_metrics(stats, "device_memory_usage", deallocs=["[memory]"]) @unittest.skipIf( IS_JETSON, "Jetson has a guard against OOM since host and gpu memory are shared" diff --git a/test/test_autograd.py b/test/test_autograd.py index 95432aaa6a5865..5f2c4d28e4a83f 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -4628,11 +4628,11 @@ def test_profiler_function_event_avg(self): self.assertEqual(avg.count, 4) self.assertEqual(avg.cpu_time_total, 30) self.assertEqual(avg.self_cpu_time_total, 30) - self.assertEqual(avg.cuda_time_total, 0) + self.assertEqual(avg.device_time_total, 0) # average stats self.assertEqual(avg.cpu_time, 7.5) - self.assertEqual(avg.cuda_time_total, 0) + self.assertEqual(avg.device_time_total, 0) def test_profiler_shapes(self): print("") diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi index 7e503a8e90ea52..e6c4c3ec9d59ba 100644 --- a/torch/_C/_autograd.pyi +++ b/torch/_C/_autograd.pyi @@ -15,6 +15,7 @@ from ._profiler import ( class DeviceType(Enum): CPU = ... CUDA = ... + XPU = ... MKLDNN = ... OPENGL = ... OPENCL = ... diff --git a/torch/_C/_profiler.pyi b/torch/_C/_profiler.pyi index e1481dd9c1e2ab..d19e72f57322c4 100644 --- a/torch/_C/_profiler.pyi +++ b/torch/_C/_profiler.pyi @@ -39,6 +39,7 @@ class ActiveProfilerType(Enum): class ProfilerActivity(Enum): CPU = ... CUDA = ... + XPU = ... MTIA = ... PrivateUse1 = ... diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 53319bd2dd82b7..77e84d0829bbcc 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -136,7 +136,7 @@ def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float log.debug("profiling time breakdown") log.debug(actual_events.table(row_limit=-1)) - res = sum(event.cuda_time_total for event in actual_events) / 1000.0 / n_repeat + res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat log.debug("profiling results: %s ms", res) return res diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index ba020fb3cb8e1a..f233277b7e1294 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -7,7 +7,6 @@ import torch import torch.cuda -from torch._C import _get_privateuse1_backend_name from torch._C._profiler import _ExperimentalConfig from torch.autograd import ( @@ -112,8 +111,12 @@ class profile: Args: enabled (bool, optional): Setting this to False makes this context manager a no-op. - use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. - Adds approximately 4us of overhead to each tensor operation. + use_cuda (bool, optional): Enables timing of CUDA events as well + using the cudaEvent API. (will be deprecated) + + use_device (str, optional): Enables timing of device events. + Adds approximately 4us of overhead to each tensor operation when use cuda. + The valid devices options are 'cuda', 'xpu' and 'privateuseone'. record_shapes (bool, optional): If shapes recording is set, information about input dimensions will be collected. This allows one to see which @@ -161,9 +164,9 @@ class profile: .. warning: Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), - one cannot use the profiler with ``use_cuda = True`` to benchmark + one cannot use the profiler with ``use_device = 'cuda'`` to benchmark DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, - please use ``use_cuda = False`` or ``num_workers = 0``. + please use ``use_device = None`` or ``num_workers = 0``. Example: >>> # xdoctest: +SKIP @@ -207,9 +210,13 @@ def __init__( if not self.enabled: return self.use_cuda = use_cuda - self.use_device: Optional[str] = ( - use_device if use_device != "privateuseone" else None - ) + if self.use_cuda: + warn( + "The attribute `use_cuda` will be deprecated soon, please use ``use_device = 'cuda'`` instead." + ) + self.use_device: Optional[str] = "cuda" + else: + self.use_device = use_device self.function_events: Optional[EventList] = None self.entered = False self.record_shapes = record_shapes @@ -233,17 +240,19 @@ def __init__( use_kineto ), "Device-only events supported only with Kineto (use_kineto=True)" - if self.use_device == "cuda": - self.use_device = None - self.use_cuda = True - - if self.use_device and self.use_device != _get_privateuse1_backend_name(): - warn(f"{self.use_device} doesn't support profile.") + VALID_DEVICE_OPTIONS = ["cuda", "xpu", "privateuseone"] + if self.use_device not in VALID_DEVICE_OPTIONS: + warn(f"The {self.use_device} is not a valid device option.") self.use_device = None - if self.use_cuda and not torch.cuda.is_available(): + if self.use_device == "cuda" and not torch.cuda.is_available(): warn("CUDA is not available, disabling CUDA profiling") self.use_cuda = False + self.use_device = None + + if self.use_device == "xpu" and not torch.xpu.is_available(): + warn("XPU is not available, disabling XPU profiling") + self.use_device = None self.kineto_activities = set() if self.use_cpu: @@ -252,14 +261,18 @@ def __init__( self.kineto_activities.add(ProfilerActivity.MTIA) self.profiler_kind = ProfilerState.KINETO - if self.use_cuda: + if self.use_device == "cuda": if not use_kineto or ProfilerActivity.CUDA not in _supported_activities(): assert self.use_cpu, "Legacy CUDA profiling requires use_cpu=True" self.profiler_kind = ProfilerState.KINETO_GPU_FALLBACK else: self.kineto_activities.add(ProfilerActivity.CUDA) - - if self.use_device: + elif self.use_device == "xpu": + assert ( + use_kineto and ProfilerActivity.XPU in _supported_activities() + ), "Legacy XPU profiling is not supported. Requires use_kineto=True on XPU devices." + self.kineto_activities.add(ProfilerActivity.XPU) + elif self.use_device is not None and self.use_device != "privateuseone": if ( not use_kineto or ProfilerActivity.PrivateUse1 not in _supported_activities() @@ -315,8 +328,10 @@ def _start_trace(self): def __exit__(self, exc_type, exc_val, exc_tb): if not self.enabled: return - if self.use_cuda: + if self.use_device == "cuda": torch.cuda.synchronize() + elif self.use_device == "xpu": + torch.xpu.synchronize() t0 = perf_counter_ns() self.kineto_results = _disable_profiler() @@ -332,7 +347,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.function_events = EventList( parsed_results, - use_cuda=self.use_cuda, use_device=self.use_device, profile_memory=self.profile_memory, with_flops=self.with_flops, @@ -445,17 +459,11 @@ def _cpu_memory_usage(mem_record): else 0 ) - def _cuda_memory_usage(mem_record): + def _device_memory_usage(mem_record): return ( mem_record.nbytes() - if mem_record.device_type() in [DeviceType.CUDA, DeviceType.HIP] - else 0 - ) - - def _privateuse1_memory_usage(mem_record): - return ( - mem_record.nbytes() - if mem_record.device_type() in [DeviceType.PrivateUse1] + if mem_record.device_type() + in [DeviceType.CUDA, DeviceType.PrivateUse1, DeviceType.HIP] else 0 ) @@ -471,16 +479,14 @@ def _privateuse1_memory_usage(mem_record): abs_end_ns = kineto_event.start_ns() + kineto_event.duration_ns() cpu_memory_usage = 0 - cuda_memory_usage = 0 - privateuse1_memory_usage = 0 + device_memory_usage = 0 if kineto_event.device_type() == DeviceType.CPU: # find the corresponding memory allocation events for mem_record in mem_records_acc.in_interval( kineto_event.start_ns() / 1000, abs_end_ns / 1000 ): cpu_memory_usage += _cpu_memory_usage(mem_record[0]) - cuda_memory_usage += _cuda_memory_usage(mem_record[0]) - privateuse1_memory_usage += _privateuse1_memory_usage(mem_record[0]) + device_memory_usage += _device_memory_usage(mem_record[0]) mem_record[1] = True is_async = kineto_event.is_async() or ( @@ -505,8 +511,7 @@ def _privateuse1_memory_usage(mem_record): scope=kineto_event.scope(), use_device=self.use_device, cpu_memory_usage=cpu_memory_usage, - cuda_memory_usage=cuda_memory_usage, - privateuse1_memory_usage=privateuse1_memory_usage, + device_memory_usage=device_memory_usage, is_async=is_async, sequence_nr=kineto_event.sequence_nr(), device_type=kineto_event.device_type(), @@ -516,12 +521,12 @@ def _privateuse1_memory_usage(mem_record): ) max_evt_id = max(max_evt_id, fe.id) if fe.device_type == DeviceType.CPU and not fe.is_async: - if self.use_device: + if self.use_device == "privateuseone": privateuse1_time = kineto_event.privateuse1_elapsed_us() if privateuse1_time > 0: fe.append_kernel(fe.name, fe.device_index, privateuse1_time) fe.is_legacy = True - else: + elif self.use_device == "cuda": # Check if we have CUDA time as a fallback cuda_time = kineto_event.cuda_elapsed_us() if cuda_time > 0: @@ -534,7 +539,7 @@ def _privateuse1_memory_usage(mem_record): device_corr_map[corr_id] = [] device_corr_map[corr_id].append(fe) - # associate CUDA kernels and CUDA runtime (CPU) with CPU events + # associate device kernels and device runtime (CPU) with CPU events for fe in function_events: if ( fe.device_type == DeviceType.CPU @@ -549,7 +554,7 @@ def _privateuse1_memory_usage(mem_record): f_evt.time_range.end - f_evt.time_range.start, ) elif f_evt.device_type == DeviceType.CPU: - # make sure that 'thread' of a CPU Kineto (e.g. CUDA Runtime) event is associated + # make sure that 'thread' of a CPU Kineto (e.g. Device Runtime) event is associated # with the 'thread' of the corresponding linked PyTorch event to properly track # parents and children f_evt.thread = fe.thread @@ -569,8 +574,7 @@ def createFunctionEventForMemoryEvents(evt): scope=0, # RecordScope::FUNCTION use_device=self.use_device, cpu_memory_usage=_cpu_memory_usage(evt), - cuda_memory_usage=_cuda_memory_usage(evt), - privateuse1_memory_usage=_privateuse1_memory_usage(evt), + device_memory_usage=_device_memory_usage(evt), is_async=False, sequence_nr=-1, device_type=DeviceType.CPU, diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 32700ffb1cf317..c491f9797afbcd 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -93,7 +93,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): parsed_results = _parse_legacy_records(records) self.function_events = EventList( parsed_results, - use_cuda=self.use_cuda, + use_device="cuda" if self.use_cuda else None, profile_memory=self.profile_memory, with_flops=self.with_flops, ) @@ -251,7 +251,7 @@ def _get_record_key(record): ], scope=start.scope(), cpu_memory_usage=cpu_memory_usage, - cuda_memory_usage=cuda_memory_usage, + device_memory_usage=cuda_memory_usage, is_async=is_async, is_remote=is_remote_event, sequence_nr=start.sequence_nr(), @@ -287,7 +287,7 @@ def _get_record_key(record): end_us=0, stack=[], cpu_memory_usage=record.cpu_memory_usage(), - cuda_memory_usage=record.cuda_memory_usage(), + device_memory_usage=record.cuda_memory_usage(), is_legacy=True, ) functions.append(fe) diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 4db601ad7b0495..6d446d6ade2974 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -26,12 +26,10 @@ class EventList(list): """A list of Events (for pretty printing).""" def __init__(self, *args, **kwargs): - use_cuda = kwargs.pop("use_cuda", True) use_device = kwargs.pop("use_device", None) profile_memory = kwargs.pop("profile_memory", False) with_flops = kwargs.pop("with_flops", False) super().__init__(*args, **kwargs) - self._use_cuda = use_cuda self._use_device = use_device self._profile_memory = profile_memory self._tree_built = False @@ -181,14 +179,16 @@ def table( Args: sort_by (str, optional): Attribute used to sort entries. By default they are printed in the same order as they were registered. - Valid keys include: ``cpu_time``, ``cuda_time``, ``cpu_time_total``, - ``cuda_time_total``, ``cpu_memory_usage``, ``cuda_memory_usage``, - ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, ``count``. + Valid keys include: ``cpu_time``, ``cuda_time``, ``xpu_time``, + ``cpu_time_total``, ``cuda_time_total``, ``xpu_time_total``, + ``cpu_memory_usage``, ``cuda_memory_usage``, ``xpu_memory_usage``, + ``self_cpu_memory_usage``, ``self_cuda_memory_usage``, + ``self_xpu_memory_usage``, ``count``. top_level_events_only(bool, optional): Boolean flag to determine the selection of events to display. If true, the profiler will only display events at top level like top-level invocation of python `lstm`, python `add` or other functions, nested events like low-level - cpu/cuda ops events are omitted for profiler result readability. + cpu/cuda/xpu ops events are omitted for profiler result readability. Returns: A string containing the table. @@ -267,6 +267,7 @@ def supported_export_stacks_metrics(self): return [ "self_cpu_time_total", "self_cuda_time_total", + "self_xpu_time_total", "self_privateuse1_time_total", ] @@ -280,7 +281,12 @@ def export_stacks(self, path: str, metric: str): with open(path, "w") as f: for evt in self: if evt.stack and len(evt.stack) > 0: - metric_value = getattr(evt, metric) + metric_value = getattr( + evt, + metric.replace("cuda", "device") + .replace("xpu", "device") + .replace("privateuse1", "device"), + ) if int(metric_value) > 0: stack_str = "" for entry in reversed(evt.stack): @@ -325,7 +331,6 @@ def get_key(event, group_by_input_shapes, group_by_stack_n) -> Tuple[str, ...]: avg_list = EventList( stats.values(), - use_cuda=self._use_cuda, use_device=self._use_device, profile_memory=self._profile_memory, with_flops=self._with_flops, @@ -395,26 +400,23 @@ class FormattedTimesMixin: """ cpu_time_str = _attr_formatter("cpu_time") - cuda_time_str = _attr_formatter("cuda_time") - privateuse1_time_str = _attr_formatter("privateuse1_time") + device_time_str = _attr_formatter("device_time") cpu_time_total_str = _attr_formatter("cpu_time_total") - cuda_time_total_str = _attr_formatter("cuda_time_total") - privateuse1_time_total_str = _attr_formatter("privateuse1_time_total") + device_time_total_str = _attr_formatter("device_time_total") self_cpu_time_total_str = _attr_formatter("self_cpu_time_total") - self_cuda_time_total_str = _attr_formatter("self_cuda_time_total") - self_privateuse1_time_total_str = _attr_formatter("self_privateuse1_time_total") + self_device_time_total_str = _attr_formatter("self_device_time_total") @property def cpu_time(self): return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore[attr-defined] @property - def cuda_time(self): - return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore[attr-defined] + def device_time(self): + return 0.0 if self.count == 0 else 1.0 * self.device_time_total / self.count # type: ignore[attr-defined] @property - def privateuse1_time(self): - return 0.0 if self.count == 0 else 1.0 * self.privateuse1_time_total / self.count # type: ignore[attr-defined] + def cuda_time(self): # To be deprecated + return self.device_time class Interval: @@ -448,8 +450,7 @@ def __init__( scope=0, use_device=None, cpu_memory_usage=0, - cuda_memory_usage=0, - privateuse1_memory_usage=0, + device_memory_usage=0, is_async=False, is_remote=False, sequence_nr=-1, @@ -479,8 +480,7 @@ def __init__( self.scope: int = scope self.use_device: Optional[str] = use_device self.cpu_memory_usage: int = cpu_memory_usage - self.cuda_memory_usage: int = cuda_memory_usage - self.privateuse1_memory_usage: int = privateuse1_memory_usage + self.device_memory_usage: int = device_memory_usage self.is_async: bool = is_async self.is_remote: bool = is_remote self.sequence_nr: int = sequence_nr @@ -530,20 +530,23 @@ def self_cpu_memory_usage(self): ) @property - def self_cuda_memory_usage(self): + def self_device_memory_usage(self): if self.is_async or self.device_type != DeviceType.CPU: return 0 - return self.cuda_memory_usage - sum( - child.cuda_memory_usage for child in self.cpu_children + return self.device_memory_usage - sum( + child.device_memory_usage for child in self.cpu_children ) @property - def self_privateuse1_memory_usage(self): - if self.is_async or self.device_type != DeviceType.CPU: + def self_cuda_memory_usage(self): # To be deprecated + self.self_device_memory_usage + + @property + def cpu_time_total(self): + if self.device_type == DeviceType.CPU: + return self.time_range.elapsed_us() + else: return 0 - return self.privateuse1_memory_usage - sum( - child.privateuse1_memory_usage for child in self.cpu_children - ) @property def self_cpu_time_total(self): @@ -554,84 +557,50 @@ def self_cpu_time_total(self): ) @property - def cuda_time_total(self): - if self.is_async or self.use_device: + def device_time_total(self): + if self.is_async or not self.use_device: return 0 if self.device_type == DeviceType.CPU: if not self.is_legacy: # account for the kernels in the children ops return sum(kinfo.duration for kinfo in self.kernels) + sum( - ch.cuda_time_total for ch in self.cpu_children + ch.device_time_total for ch in self.cpu_children ) else: # each legacy cpu events has a single (fake) kernel return sum(kinfo.duration for kinfo in self.kernels) else: - assert self.device_type == DeviceType.CUDA + assert self.device_type in [DeviceType.CUDA, DeviceType.PrivateUse1] return self.time_range.elapsed_us() @property - def self_cuda_time_total(self): - if self.is_async or self.use_device: - return 0 - if self.device_type == DeviceType.CPU: - return self.cuda_time_total - sum( - child.cuda_time_total for child in self.cpu_children - ) - else: - assert self.device_type == DeviceType.CUDA - return self.cuda_time_total + def cuda_time_total(self): # To be deprecated + self.device_time_total @property - def cpu_time_total(self): - if self.device_type == DeviceType.CPU: - return self.time_range.elapsed_us() - else: - return 0 - - @property - def self_privateuse1_time_total(self): + def self_device_time_total(self): if self.is_async or not self.use_device: return 0 if self.device_type == DeviceType.CPU: - return self.privateuse1_time_total - sum( - child.privateuse1_time_total for child in self.cpu_children + return self.device_time_total - sum( + [child.device_time_total for child in self.cpu_children] ) else: - assert self.device_type == DeviceType.CUDA - return self.privateuse1_time_total + assert self.device_type in [DeviceType.CUDA, DeviceType.PrivateUse1] + return self.device_time_total @property - def privateuse1_time_total(self): - if self.is_async or not self.use_device: - return 0 - if self.device_type == DeviceType.CPU: - if not self.is_legacy: - # account for the kernels in the children ops - return sum(kinfo.duration for kinfo in self.kernels) + sum( - ch.privateuse1_time_total for ch in self.cpu_children - ) - else: - # each legacy cpu events has a single (fake) kernel - return sum(kinfo.duration for kinfo in self.kernels) - else: - assert self.device_type == DeviceType.PrivateUse1 - return self.time_range.elapsed_us() + def self_cuda_time_total(self): # To be deprecated + self.self_device_time_total @property def key(self): return self.name def __repr__(self): - device_name = "cuda" if not self.use_device else self.use_device - device_time = ( - self.cuda_time_str if not self.use_device else self.privateuse1_time_str - ) - device_memory_usage = ( - self.cuda_memory_usage - if not self.use_device - else self.privateuse1_memory_usage - ) + device_name = self.use_device + device_time = self.device_time_str + device_memory_usage = self.device_memory_usage return ( " 0 for event in events) - has_cuda_mem = any(event.self_cuda_memory_usage > 0 for event in events) - has_privateuse1_time = any( - event.self_privateuse1_time_total > 0 for event in events - ) - has_privateuse1_mem = any( - event.self_privateuse1_memory_usage > 0 for event in events - ) + has_device_time = any(event.self_device_time_total > 0 for event in events) + has_device_mem = any(event.self_device_memory_usage > 0 for event in events) use_device = events[0].use_device - if not use_device and (has_privateuse1_mem or has_privateuse1_time): - raise RuntimeError( - "use_device is None, but there is private device performance data." - ) + # Running on PrivateUse1 device with profiler but not enable + # ProfilerActivity.PrivateUse1 can also catch privateuse1 memory usage. + # Here only need to check has_privateuse1_time if not use_device. + if not use_device and has_device_time: + raise RuntimeError("use_device is None, but there is device performance data.") has_input_shapes = any( (event.input_shapes is not None and len(event.input_shapes) > 0) @@ -879,8 +825,16 @@ def _build_table( if sort_by is not None: events = EventList( - sorted(events, key=lambda evt: getattr(evt, sort_by), reverse=True), - use_cuda=has_cuda_time, + sorted( + events, + key=lambda evt: getattr( + evt, + sort_by.replace("cuda", "device") + .replace("xpu", "device") + .replace("privateuse1", "device"), + ), + reverse=True, + ), use_device=use_device, profile_memory=profile_memory, with_flops=with_flops, @@ -918,23 +872,14 @@ def _build_table( "CPU total", "CPU time avg", ] - if has_cuda_time: - headers.extend( - [ - "Self CUDA", - "Self CUDA %", - "CUDA total", - "CUDA time avg", - ] - ) - if has_privateuse1_time: - privateuse1 = use_device.upper() + device_name = use_device.upper() if use_device is not None else "None" + if has_device_time: headers.extend( [ - f"Self {privateuse1}", - f"Self {privateuse1} %", - f"{privateuse1} total", - f"{privateuse1} time avg", + f"Self {device_name}", + f"Self {device_name} %", + f"{device_name} total", + f"{device_name} time avg", ] ) if profile_memory: @@ -944,19 +889,11 @@ def _build_table( "Self CPU Mem", ] ) - if has_cuda_mem: + if has_device_mem: headers.extend( [ - "CUDA Mem", - "Self CUDA Mem", - ] - ) - if has_privateuse1_mem: - privateuse1 = use_device.upper() - headers.extend( - [ - f"{privateuse1} Mem", - f"Self {privateuse1} Mem", + f"{device_name} Mem", + f"Self {device_name} Mem", ] ) headers.append("# of Calls") @@ -1030,22 +967,16 @@ def append(s): result.append(s) result.append("\n") # Yes, newline after the end as well - sum_self_cpu_time_total = sum(event.self_cpu_time_total for event in events) - sum_self_cuda_time_total = 0 - sum_self_privateuse1_time_total = 0 + sum_self_cpu_time_total = 0 + sum_self_device_time_total = 0 for evt in events: - if evt.device_type == DeviceType.CPU: + sum_self_cpu_time_total += evt.self_cpu_time_total + if evt.device_type == DeviceType.CPU and evt.is_legacy: # in legacy profiler, kernel info is stored in cpu events - if evt.is_legacy: - if not use_device: - sum_self_cuda_time_total += evt.self_cuda_time_total - else: - sum_self_privateuse1_time_total += evt.self_privateuse1_time_total - elif evt.device_type == DeviceType.CUDA: + sum_self_device_time_total += evt.self_device_time_total + elif evt.device_type in [DeviceType.CUDA, DeviceType.PrivateUse1]: # in kineto profiler, there're events with the correct device type (e.g. CUDA) - sum_self_cuda_time_total += evt.self_cuda_time_total - elif evt.device_type == DeviceType.PrivateUse1: - sum_self_privateuse1_time_total += evt.self_privateuse1_time_total + sum_self_device_time_total += evt.self_device_time_total # Actual printing if header is not None: @@ -1090,28 +1021,16 @@ def trim_path(path, src_column_width): evt.cpu_time_total_str, # CPU total evt.cpu_time_str, # CPU time avg ] - if has_cuda_time: + if has_device_time: row_values.extend( [ - evt.self_cuda_time_total_str, - # CUDA time total % + evt.self_device_time_total_str, + # device time total % _format_time_share( - evt.self_cuda_time_total, sum_self_cuda_time_total + evt.self_device_time_total, sum_self_device_time_total ), - evt.cuda_time_total_str, - evt.cuda_time_str, # Cuda time avg - ] - ) - if has_privateuse1_time: - row_values.extend( - [ - evt.self_privateuse1_time_total_str, - # PrivateUse1 time total % - _format_time_share( - evt.self_privateuse1_time_total, sum_self_privateuse1_time_total - ), - evt.privateuse1_time_total_str, - evt.privateuse1_time_str, # PrivateUse1 time avg + evt.device_time_total_str, + evt.device_time_str, # device time avg ] ) if profile_memory: @@ -1123,22 +1042,13 @@ def trim_path(path, src_column_width): _format_memory(evt.self_cpu_memory_usage), ] ) - if has_cuda_mem: - row_values.extend( - [ - # CUDA Mem Total - _format_memory(evt.cuda_memory_usage), - # Self CUDA Mem Total - _format_memory(evt.self_cuda_memory_usage), - ] - ) - if has_privateuse1_mem: + if has_device_mem: row_values.extend( [ - # PrivateUse1 Mem Total - _format_memory(evt.privateuse1_memory_usage), - # Self PrivateUse1 Mem Total - _format_memory(evt.self_privateuse1_memory_usage), + # Device Mem Total + _format_memory(evt.device_memory_usage), + # Self Device Mem Total + _format_memory(evt.self_device_memory_usage), ] ) row_values.append( @@ -1174,10 +1084,9 @@ def trim_path(path, src_column_width): append(header_sep) append(f"Self CPU time total: {_format_time(sum_self_cpu_time_total)}") - if has_cuda_time: - append(f"Self CUDA time total: {_format_time(sum_self_cuda_time_total)}") - if has_privateuse1_time: + if has_device_time: append( - f"Self {use_device.upper()} time total: {_format_time(sum_self_privateuse1_time_total)}" + f"Self {use_device.upper() if use_device is not None else 'None'} " + f"time total: {_format_time(sum_self_device_time_total)}" ) return "".join(result) diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index 85f91bf8b28e52..41561c6f3e8927 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -342,6 +342,7 @@ c10::DeviceType deviceTypeFromActivity(libkineto::ActivityType activity_type) { case libkineto::ActivityType::USER_ANNOTATION: case libkineto::ActivityType::EXTERNAL_CORRELATION: case libkineto::ActivityType::CUDA_RUNTIME: + case libkineto::ActivityType::XPU_RUNTIME: case libkineto::ActivityType::CPU_INSTANT_EVENT: case libkineto::ActivityType::GLOW_RUNTIME: case libkineto::ActivityType::MTIA_RUNTIME: diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index bfc725700a760c..120b2acad2fa4c 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -13,7 +13,6 @@ import torch import torch.autograd.profiler as prof -from torch._C import _get_privateuse1_backend_name from torch._C._profiler import ( _add_execution_trace_observer, _disable_execution_trace_observer, @@ -72,8 +71,10 @@ class _KinetoProfile: Args: activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: - ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. - Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. record_shapes (bool): save information about operator's input shapes. profile_memory (bool): track tensor memory allocation/deallocation (see ``export_memory_timeline`` for more details). @@ -126,9 +127,13 @@ def __init__( self.profiler: Optional[prof.profile] = None self.mem_tl: Optional[MemoryProfileTimeline] = None self.use_device = None - privateuse1_backend = _get_privateuse1_backend_name() - if privateuse1_backend != "privateuseone": - self.use_device = privateuse1_backend + if ProfilerActivity.CUDA in self.activities: + self.use_device = "cuda" + elif ProfilerActivity.XPU in self.activities: + self.use_device = "xpu" + else: + self.use_device = "privateuseone" + # user-defined metadata to be amended to the trace self.preset_metadata: Dict[str, str] = dict() @@ -144,7 +149,7 @@ def prepare_trace(self): use_cuda=(ProfilerActivity.CUDA in self.activities), use_cpu=(ProfilerActivity.CPU in self.activities), use_mtia=(ProfilerActivity.MTIA in self.activities), - use_device=None, + use_device=self.use_device, record_shapes=self.record_shapes, with_flops=self.with_flops, profile_memory=self.profile_memory, @@ -444,8 +449,10 @@ class profile(_KinetoProfile): Args: activities (iterable): list of activity groups (CPU, CUDA) to use in profiling, supported values: - ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``. - Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA. + ``torch.profiler.ProfilerActivity.CPU``, ``torch.profiler.ProfilerActivity.CUDA``, + ``torch.profiler.ProfilerActivity.XPU``. + Default value: ProfilerActivity.CPU and (when available) ProfilerActivity.CUDA + or (when available) ProfilerActivity.XPU. schedule (Callable): callable that takes step (int) as a single parameter and returns ``ProfilerAction`` value that specifies the profiler action to perform at each step. on_trace_ready (Callable): callable that is called at each step when ``schedule`` diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 25495f0bf88804..9f1a8f8411e271 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -4606,22 +4606,22 @@ def get_name(event): function_events = p.function_events for event in function_events: if event.is_async: - self.assertEqual(0, event.cuda_time_total) + self.assertEqual(0, event.device_time_total) self.assertEqual([], event.kernels) - self.assertEqual(0, event.cuda_time) + self.assertEqual(0, event.device_time) else: if event.node_id == 1: continue self.assertTrue(event.node_id in [dst_cuda_0, dst_cuda_1]) if get_name(event) in EXPECTED_REMOTE_EVENTS: - self.assertGreater(event.cuda_time_total, 0) + self.assertGreater(event.device_time_total, 0) self.assertEqual(1, len(event.kernels)) kernel = event.kernels[0] if event.node_id == dst_cuda_0: self.assertEqual(kernel.device, 0) if event.node_id == dst_cuda_1: self.assertEqual(kernel.device, 1) - self.assertGreater(event.cuda_time, 0) + self.assertGreater(event.device_time, 0) # Validate that EXPECTED_REMOTE_EVENTS is a subset of remotely profiled # events.