Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StdLog reporter so we can visualize reported TBE stats from stderr #2386

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 131 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/runtime_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
# pyre-strict

import abc

import logging
from collections import deque
from dataclasses import dataclass
from typing import Optional
from types import TracebackType
from typing import Callable, Deque, Optional, Tuple, Type, TypeVar

import torch


class TBEStatsReporter(abc.ABC):
Expand Down Expand Up @@ -44,6 +48,27 @@ def report_duration(
...


class StdLogStatsReporter(TBEStatsReporter):
def __init__(self, report_interval: int) -> None:
assert report_interval > 0, "Report interval must be positive"
self.report_interval = report_interval

def should_report(self, iteration_step: int) -> bool:
return iteration_step % self.report_interval == 0

def report_duration(
self,
iteration_step: int,
event_name: str,
duration_ms: float,
embedding_id: str = "",
tbe_id: str = "",
) -> None:
logging.info(
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} ms"
)


@dataclass
class TBEStatsReporterConfig:
"""
Expand All @@ -61,3 +86,107 @@ def create_reporter(self) -> Optional[TBEStatsReporter]:
self.interval <= 0
), "Cannot specify interval without an actual implementation of reporter"
return None


@dataclass
class StdLogStatsReporterConfig(TBEStatsReporterConfig):
def create_reporter(self) -> Optional[TBEStatsReporter]:
if self.interval <= 0:
return None
return StdLogStatsReporter(report_interval=self.interval)


T = TypeVar("T")


class AsyncSeriesTimerRecordedContext:
"""
An easier way to use AsyncSeriesTimer. Example:
```
timer : AsyncSeriesTimer = ...
with timer.recording(ctx):
cuda_kernel1()
cuda_kernel2()
cuda_kernel3()
```
"""

def __init__(
self,
timer: "AsyncSeriesTimer",
context: T,
stream: Optional[torch.cuda.Stream] = None,
) -> None:
self._context = context
self._stream = stream
self._timer = timer

def __enter__(self) -> None:
self._timer.start(self._stream)

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self._timer.stop(self._context, self._stream)


class AsyncSeriesTimer:
"""
A wrapper class on top of torch.cuda.Event to measure the time between a
series of CUDA events. Once initiated, every start() and stop() call pair
will measure the timing between them on GPU. Caller cannot inititate another
recording if there's already one ongoing.

Reporting is asynchronous as the timing result is not ready immediately at
stop(). Instead, we do it in a lazy way -- we check the all unreported
events at every start or stop call.
"""

def __init__(self, report_functor: Callable[[T, float], None]) -> None:
self._events_queue: Deque[Tuple[torch.cuda.Event, torch.cuda.Event, T]] = (
deque()
)
self._active_start_event: Optional[torch.cuda.Event] = None
self._report_functor = report_functor

def start(self, stream: Optional[torch.cuda.Stream] = None) -> None:
assert self._active_start_event is None, "There's an active recording"
self._active_start_event = torch.cuda.Event(enable_timing=True)
self._active_start_event.record(stream)
self._lazy_report()

def stop(self, context: T, stream: Optional[torch.cuda.Stream] = None) -> None:
assert self._active_start_event is not None, "There's no active recording"
active_start_event: torch.cuda.Event = self._active_start_event

active_stop_event = torch.cuda.Event(enable_timing=True)
active_stop_event.record(stream)
self._events_queue.append((active_start_event, active_stop_event, context))
self._active_start_event = None
self._lazy_report()

def recording(
self, context: T, stream: Optional[torch.cuda.Stream] = None
) -> AsyncSeriesTimerRecordedContext:
return AsyncSeriesTimerRecordedContext(self, context, stream)

def _lazy_report(self) -> None:
# Since this is a series of timing event, the earlies recorded event
# finishes earliest. So we only need to check the leftmost stop event
# to decide if we need to report now.

while len(self._events_queue):
stop_event = self._events_queue[0][1]
if not stop_event.query():
# Even the earliest event hasn't completed in GPU. Don't do
# report.
return
start_event, stop_event, context = self._events_queue.popleft()
assert (
start_event.query()
), "Recording has start event later than stop event"
result = float(start_event.elapsed_time(stop_event))
self._report_functor(context, result)
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,25 @@

# pyre-ignore-all-errors[56]

import contextlib
import enum
import functools
import logging
import os
from dataclasses import dataclass, field
from itertools import accumulate
from math import log2
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import torch # usort:skip
from torch import nn, Tensor # usort:skip

import fbgemm_gpu.split_embedding_codegen_lookup_invokers as invokers
from fbgemm_gpu.runtime_monitor import TBEStatsReporter, TBEStatsReporterConfig
from fbgemm_gpu.runtime_monitor import (
AsyncSeriesTimer,
TBEStatsReporter,
TBEStatsReporterConfig,
)
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
Expand Down Expand Up @@ -453,6 +458,20 @@ def __init__( # noqa C901
stats_reporter_config.create_reporter() if stats_reporter_config else None
)

self.bwd_wait_prefetch_timer: Optional[AsyncSeriesTimer] = None
if self.stats_reporter:
# When stats_reporter is present, we set up async series timer to
# measure the GPU time per tracked event accordingly. Each of them
# is attached to custom callback report function to report collected
# duration with the corresponding event name.
self.bwd_wait_prefetch_timer = AsyncSeriesTimer(
functools.partial(
SplitTableBatchedEmbeddingBagsCodegen._report_wait_prefetch_time,
self,
event_name="bwd_wait_for_prefetch",
)
)

self.int8_emb_row_dim_offset: int = INT8_EMB_ROW_DIM_OFFSET

self.feature_table_map: List[int] = (
Expand Down Expand Up @@ -889,6 +908,18 @@ def get_table_wise_cache_miss(self) -> Tensor:

return self.table_wise_cache_miss

# The callback function for AsyncTimer to record duration to different event
def _report_wait_prefetch_time(
self,
it_step: int,
dur_ms: float,
event_name: str,
) -> None:
assert (
self.stats_reporter
), "We should not be here. AsyncTimer only happens with reporter present."
self.stats_reporter.report_duration(it_step, event_name, dur_ms)

def forward( # noqa: C901
self,
indices: Tensor,
Expand Down Expand Up @@ -1815,6 +1846,20 @@ def _apply_cache_state(
f"or {CacheAlgorithm.LFU}"
)

# pyre-ignore
def _recording_to_timer(
self, timer: Optional[AsyncSeriesTimer], **kwargs: Any
) -> Any:
if self.stats_reporter is not None and self.stats_reporter.should_report(
self.step
):
assert (
timer
), "We shouldn't be here, async timer must have been initiated if reporter is present."
return timer.recording(**kwargs)
# No-Op context manager
return contextlib.nullcontext()

def _sync_stream_post_backward(
self,
grad: Tensor,
Expand Down Expand Up @@ -1875,7 +1920,12 @@ def _update_cache_counter_and_locations(
if self.prefetch_stream is not None:
# need to wait for the prefetch of next batch,
# so that cache states are valid
torch.cuda.current_stream().wait_stream(self.prefetch_stream)
with self._recording_to_timer(
self.bwd_wait_prefetch_timer,
context=self.step,
stream=torch.cuda.current_stream(),
):
torch.cuda.current_stream().wait_stream(self.prefetch_stream)

torch.ops.fbgemm.lxu_cache_locking_counter_decrement(
self.lxu_cache_locking_counter,
Expand Down
Loading
Loading