Skip to content

Commit

Permalink
StdLog reporter so we can visualize reported TBE stats from stderr
Browse files Browse the repository at this point in the history
Summary: As titled. This is the most basical implementation so we can check reported TBE stats from standard logging

Differential Revision: D52861475
  • Loading branch information
levythu authored and facebook-github-bot committed Mar 5, 2024
1 parent fa8fc05 commit 5b9f639
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
30 changes: 30 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/runtime_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import abc
import logging
from collections import deque
from dataclasses import dataclass
from types import TracebackType
Expand Down Expand Up @@ -47,6 +48,27 @@ def report_duration(
...


class StdLogStatsReporter(TBEStatsReporter):
def __init__(self, report_interval: int) -> None:
assert report_interval > 0, "Report interval must be positive"
self.report_interval = report_interval

def should_report(self, iteration_step: int) -> bool:
return iteration_step % self.report_interval == 0

def report_duration(
self,
iteration_step: int,
event_name: str,
duration_ms: float,
embedding_id: str = "",
tbe_id: str = "",
) -> None:
logging.info(
f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} ms"
)


@dataclass
class TBEStatsReporterConfig:
"""
Expand All @@ -66,6 +88,14 @@ def create_reporter(self) -> Optional[TBEStatsReporter]:
return None


@dataclass
class StdLogStatsReporterConfig(TBEStatsReporterConfig):
def create_reporter(self) -> Optional[TBEStatsReporter]:
if self.interval <= 0:
return None
return StdLogStatsReporter(report_interval=self.interval)


T = TypeVar("T")


Expand Down
32 changes: 30 additions & 2 deletions fbgemm_gpu/test/runtime_monitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@


import unittest
from typing import List, Tuple
from typing import cast, List, Tuple

import fbgemm_gpu
import torch
from fbgemm_gpu.runtime_monitor import AsyncSeriesTimer
from fbgemm_gpu.runtime_monitor import (
AsyncSeriesTimer,
StdLogStatsReporter,
StdLogStatsReporterConfig,
TBEStatsReporterConfig,
)

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
Expand Down Expand Up @@ -116,3 +121,26 @@ def test_async_series_timer_recording_multi_stream_load(self) -> None:
self.assertGreaterEqual(timer.outputs[0][1], 0)
self.assertGreaterEqual(timer.outputs[1][1], 0)
self.assertGreaterEqual(timer.outputs[2][1], 0)

def test_noop_reporter(self) -> None:
# This config can create a None reporter, because interval is non-positive
config = TBEStatsReporterConfig()
self.assertIsNone(config.create_reporter())

# This config cannot, because it provides a positive interval without
# giving any actual implementation
config = TBEStatsReporterConfig(interval=100)
with self.assertRaises(AssertionError):
config.create_reporter()

def test_stdlog_reporter(self) -> None:
config = StdLogStatsReporterConfig(interval=0)
self.assertIsNone(config.create_reporter())

config = StdLogStatsReporterConfig(interval=100)
reporter = config.create_reporter()
self.assertIsInstance(reporter, StdLogStatsReporter)
r = cast(StdLogStatsReporter, reporter)
self.assertTrue(r.should_report(500))
self.assertFalse(r.should_report(101))
r.report_duration(iteration_step=500, event_name="test_event", duration_ms=404)

0 comments on commit 5b9f639

Please sign in to comment.