From 50ab9bdcede91c721327e6d5313809910d2ebd48 Mon Sep 17 00:00:00 2001 From: Levy Zhao Date: Tue, 5 Mar 2024 21:04:10 -0800 Subject: [PATCH] StdLog reporter so we can visualize reported TBE stats from stderr (#2386) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2386 As titled. This is the most basical implementation so we can check reported TBE stats from standard logging Reviewed By: sryap Differential Revision: D52861475 fbshipit-source-id: befe8e368c3d6ba8a97dbe04fde0ce5a89c946dc --- fbgemm_gpu/fbgemm_gpu/runtime_monitor.py | 30 ++++++++++++++++++ fbgemm_gpu/test/runtime_monitor_test.py | 39 ++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py index 1668f7ca4d..890967c458 100644 --- a/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py +++ b/fbgemm_gpu/fbgemm_gpu/runtime_monitor.py @@ -8,6 +8,7 @@ # pyre-strict import abc +import logging from collections import deque from dataclasses import dataclass from types import TracebackType @@ -47,6 +48,27 @@ def report_duration( ... +class StdLogStatsReporter(TBEStatsReporter): + def __init__(self, report_interval: int) -> None: + assert report_interval > 0, "Report interval must be positive" + self.report_interval = report_interval + + def should_report(self, iteration_step: int) -> bool: + return iteration_step % self.report_interval == 0 + + def report_duration( + self, + iteration_step: int, + event_name: str, + duration_ms: float, + embedding_id: str = "", + tbe_id: str = "", + ) -> None: + logging.info( + f"[Batch #{iteration_step}][TBE:{tbe_id}][Table:{embedding_id}] The event {event_name} took {duration_ms} ms" + ) + + @dataclass class TBEStatsReporterConfig: """ @@ -66,6 +88,14 @@ def create_reporter(self) -> Optional[TBEStatsReporter]: return None +@dataclass +class StdLogStatsReporterConfig(TBEStatsReporterConfig): + def create_reporter(self) -> Optional[TBEStatsReporter]: + if self.interval <= 0: + return None + return StdLogStatsReporter(report_interval=self.interval) + + T = TypeVar("T") diff --git a/fbgemm_gpu/test/runtime_monitor_test.py b/fbgemm_gpu/test/runtime_monitor_test.py index a39cb8bffc..bc73b09b0c 100644 --- a/fbgemm_gpu/test/runtime_monitor_test.py +++ b/fbgemm_gpu/test/runtime_monitor_test.py @@ -6,11 +6,16 @@ import unittest -from typing import List, Tuple +from typing import cast, List, Tuple import fbgemm_gpu import torch -from fbgemm_gpu.runtime_monitor import AsyncSeriesTimer +from fbgemm_gpu.runtime_monitor import ( + AsyncSeriesTimer, + StdLogStatsReporter, + StdLogStatsReporterConfig, + TBEStatsReporterConfig, +) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) @@ -133,3 +138,33 @@ def test_async_series_timer_recording_multi_stream_load(self) -> None: self.assertGreaterEqual(timer.outputs[0][1], 0) self.assertGreaterEqual(timer.outputs[1][1], 0) self.assertGreaterEqual(timer.outputs[2][1], 0) + + def test_noop_reporter(self) -> None: + """ + Test the base reporter config can only create noop-reporter (None) or + fail if interval is positive (which requires a real reporting) + """ + # This config can create a None reporter, because interval is non-positive + config = TBEStatsReporterConfig() + self.assertIsNone(config.create_reporter()) + + # This config cannot, because it provides a positive interval without + # giving any actual implementation + config = TBEStatsReporterConfig(interval=100) + with self.assertRaises(AssertionError): + config.create_reporter() + + def test_stdlog_reporter(self) -> None: + """ + Test std log reporter be created with different config and log as expected + """ + config = StdLogStatsReporterConfig(interval=0) + self.assertIsNone(config.create_reporter()) + + config = StdLogStatsReporterConfig(interval=100) + reporter = config.create_reporter() + self.assertIsInstance(reporter, StdLogStatsReporter) + r = cast(StdLogStatsReporter, reporter) + self.assertTrue(r.should_report(500)) + self.assertFalse(r.should_report(101)) + r.report_duration(iteration_step=500, event_name="test_event", duration_ms=404)