From f058e4f60204ac2eedce842e07cd8ced75856d2b Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Mon, 20 Sep 2021 13:04:27 +0100 Subject: [PATCH] Add type annotations to `synapse.metrics` --- changelog.d/10847.misc | 1 + mypy.ini | 3 + synapse/metrics/__init__.py | 93 +++++++++++------ synapse/metrics/_exposition.py | 26 ++--- synapse/metrics/background_process_metrics.py | 99 ++++++++++++++++--- synapse/metrics/jemalloc.py | 10 +- 6 files changed, 172 insertions(+), 60 deletions(-) create mode 100644 changelog.d/10847.misc diff --git a/changelog.d/10847.misc b/changelog.d/10847.misc new file mode 100644 index 000000000000..7933a38dca80 --- /dev/null +++ b/changelog.d/10847.misc @@ -0,0 +1 @@ +Add type annotations to `synapse.metrics`. diff --git a/mypy.ini b/mypy.ini index b21e1555ab7f..45cb1984e5c5 100644 --- a/mypy.ini +++ b/mypy.ini @@ -91,6 +91,9 @@ files = tests/util/test_itertools.py, tests/util/test_stream_change_cache.py +[mypy-synapse.metrics.*] +disallow_untyped_defs = True + [mypy-synapse.rest.*] disallow_untyped_defs = True diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index f237b8a2369e..a4e0b1678102 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -20,10 +20,22 @@ import platform import threading import time -from typing import Callable, Dict, Iterable, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, +) import attr -from prometheus_client import Counter, Gauge, Histogram +from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client.core import ( REGISTRY, CounterMetricFamily, @@ -32,6 +44,7 @@ ) from twisted.internet import reactor +from twisted.internet.base import ReactorBase import synapse from synapse.metrics._exposition import ( @@ -53,7 +66,7 @@ class RegistryProxy: @staticmethod - def collect(): + def collect() -> Iterable[Metric]: for metric in REGISTRY.collect(): if not metric.name.startswith("__"): yield metric @@ -69,7 +82,7 @@ class LaterGauge: # or dict mapping from a label tuple to a value caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]]) - def collect(self): + def collect(self) -> Iterable[Metric]: g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) @@ -88,10 +101,10 @@ def collect(self): yield g - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: self._register() - def _register(self): + def _register(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -100,6 +113,10 @@ def _register(self): all_gauges[self.name] = self +# Placeholder for `InFlightGauge._metrics_class`, which is created dynamically +_MetricsEntry = Any + + class InFlightGauge: """Tracks number of things (e.g. requests, Measure blocks, etc) in flight at any given time. @@ -110,14 +127,19 @@ class InFlightGauge: callbacks. Args: - name (str) - desc (str) - labels (list[str]) - sub_metrics (list[str]): A list of sub metrics that the callbacks - will update. + name + desc + labels + sub_metrics: A list of sub metrics that the callbacks will update. """ - def __init__(self, name, desc, labels, sub_metrics): + def __init__( + self, + name: str, + desc: str, + labels: Sequence[str], + sub_metrics: Sequence[str], + ): self.name = name self.desc = desc self.labels = labels @@ -130,14 +152,20 @@ def __init__(self, name, desc, labels, sub_metrics): ) # Counts number of in flight blocks for a given set of label values - self._registrations: Dict = {} + self._registrations: Dict[ + Tuple[str, ...], Set[Callable[[_MetricsEntry], None]] + ] = {} # Protects access to _registrations self._lock = threading.Lock() self._register_with_collector() - def register(self, key, callback): + def register( + self, + key: Tuple[str, ...], + callback: Callable[[_MetricsEntry], None], + ) -> None: """Registers that we've entered a new block with labels `key`. `callback` gets called each time the metrics are collected. The same @@ -153,13 +181,17 @@ def register(self, key, callback): with self._lock: self._registrations.setdefault(key, set()).add(callback) - def unregister(self, key, callback): + def unregister( + self, + key: Tuple[str, ...], + callback: Callable[[_MetricsEntry], None], + ) -> None: """Registers that we've exited a block with labels `key`.""" with self._lock: self._registrations.setdefault(key, set()).discard(callback) - def collect(self): + def collect(self) -> Iterable[Metric]: """Called by prometheus client when it reads metrics. Note: may be called by a separate thread. @@ -195,7 +227,7 @@ def collect(self): gauge.add_metric(key, getattr(metrics, name)) yield gauge - def _register_with_collector(self): + def _register_with_collector(self) -> None: if self.name in all_gauges.keys(): logger.warning("%s already registered, reregistering" % (self.name,)) REGISTRY.unregister(all_gauges.pop(self.name)) @@ -225,7 +257,7 @@ def __init__( name: str, documentation: str, buckets: Iterable[float], - registry=REGISTRY, + registry: CollectorRegistry = REGISTRY, ): """ Args: @@ -252,12 +284,12 @@ def __init__( registry.register(self) - def collect(self): + def collect(self) -> Iterable[Metric]: # Don't report metrics unless we've already collected some data if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]): + def update_data(self, values: Iterable[float]) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned @@ -299,7 +331,7 @@ def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFami class CPUMetrics: - def __init__(self): + def __init__(self) -> None: ticks_per_sec = 100 try: # Try and get the system config @@ -309,7 +341,7 @@ def __init__(self): self.ticks_per_sec = ticks_per_sec - def collect(self): + def collect(self) -> Iterable[Metric]: if not HAVE_PROC_SELF_STAT: return @@ -359,7 +391,7 @@ def collect(self): class GCCounts: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -377,7 +409,7 @@ def collect(self): class PyPyGCStats: - def collect(self): + def collect(self) -> Iterable[Metric]: # @stats is a pretty-printer object with __str__() returning a nice table, # plus some fields that contain data from that table. @@ -524,7 +556,7 @@ def collect(self): class ReactorLastSeenMetric: - def collect(self): + def collect(self) -> Iterable[Metric]: cm = GaugeMetricFamily( "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", @@ -543,9 +575,12 @@ def collect(self): _last_gc = [0.0, 0.0, 0.0] -def runUntilCurrentTimer(reactor, func): +F = TypeVar("F", bound=Callable[..., Any]) + + +def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F: @functools.wraps(func) - def f(*args, **kwargs): + def f(*args: Any, **kwargs: Any) -> Any: now = reactor.seconds() num_pending = 0 @@ -608,7 +643,7 @@ def f(*args, **kwargs): return ret - return f + return cast(F, f) try: @@ -636,5 +671,5 @@ def f(*args, **kwargs): "start_http_server", "LaterGauge", "InFlightGauge", - "BucketCollector", + "GaugeBucketCollector", ] diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index bb9bcb5592ed..5f83f99684e9 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -25,12 +25,14 @@ import threading from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Dict, List +from typing import Any, Dict, List, Type from urllib.parse import parse_qs, urlparse -from prometheus_client import REGISTRY +from prometheus_client import REGISTRY, CollectorRegistry +from prometheus_client.core import Sample from twisted.web.resource import Resource +from twisted.web.server import Request from synapse.util import caches @@ -41,7 +43,7 @@ MINUS_INF = float("-inf") -def floatToGoString(d): +def floatToGoString(d: Any) -> str: d = float(d) if d == INF: return "+Inf" @@ -60,7 +62,7 @@ def floatToGoString(d): return s -def sample_line(line, name): +def sample_line(line: Sample, name: str) -> str: if line.labels: labelstr = "{{{0}}}".format( ",".join( @@ -82,7 +84,7 @@ def sample_line(line, name): return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) -def generate_latest(registry, emit_help=False): +def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: # Trigger the cache metrics to be rescraped, which updates the common # metrics but do not produce metrics themselves @@ -187,7 +189,7 @@ class MetricsHandler(BaseHTTPRequestHandler): registry = REGISTRY - def do_GET(self): + def do_GET(self) -> None: registry = self.registry params = parse_qs(urlparse(self.path).query) @@ -207,11 +209,11 @@ def do_GET(self): self.end_headers() self.wfile.write(output) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: """Log nothing.""" @classmethod - def factory(cls, registry): + def factory(cls, registry: CollectorRegistry) -> Type: """Returns a dynamic MetricsHandler class tied to the passed registry. """ @@ -236,7 +238,9 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer): daemon_threads = True -def start_http_server(port, addr="", registry=REGISTRY): +def start_http_server( + port: int, addr: str = "", registry: CollectorRegistry = REGISTRY +) -> None: """Starts an HTTP server for prometheus metrics as a daemon thread""" CustomMetricsHandler = MetricsHandler.factory(registry) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) @@ -252,10 +256,10 @@ class MetricsResource(Resource): isLeaf = True - def __init__(self, registry=REGISTRY): + def __init__(self, registry: CollectorRegistry = REGISTRY): self.registry = registry - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) response = generate_latest(self.registry) request.setHeader(b"Content-Length", str(len(response))) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 3a14260752ed..c2ca89b25ee1 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -15,13 +15,33 @@ import logging import threading from functools import wraps -from typing import TYPE_CHECKING, Dict, Optional, Set, Union +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + Optional, + Set, + Type, + TypeVar, + Union, + cast, + overload, +) +from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge from twisted.internet import defer -from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.logging.context import ( + ContextResourceUsage, + LoggingContext, + PreserveLoggingContext, +) from synapse.logging.opentracing import ( SynapseTags, noop_context_manager, @@ -116,7 +136,7 @@ class _Collector: before they are returned. """ - def collect(self): + def collect(self) -> Iterable[Metric]: global _background_processes_active_since_last_scrape # We swap out the _background_processes set with an empty one so that @@ -144,12 +164,12 @@ def collect(self): class _BackgroundProcess: - def __init__(self, desc, ctx): + def __init__(self, desc: str, ctx: LoggingContext): self.desc = desc self._context = ctx - self._reported_stats = None + self._reported_stats: Optional[ContextResourceUsage] = None - def update_metrics(self): + def update_metrics(self) -> None: """Updates the metrics with values from this process.""" new_stats = self._context.get_resource_usage() if self._reported_stats is None: @@ -169,7 +189,41 @@ def update_metrics(self): ) -def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs): +R = TypeVar("R") + + +@overload +def run_as_background_process( # type: ignore[misc] + desc: str, + func: Callable[..., Awaitable[R]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> defer.Deferred[Optional[R]]: + ... + + +@overload +def run_as_background_process( + desc: str, + func: Callable[..., R], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> defer.Deferred[Optional[R]]: + ... + + +def run_as_background_process( + desc: str, + func: Union[ + Callable[..., Awaitable[R]], + Callable[..., R], + ], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> defer.Deferred[Optional[R]]: """Run the given function in its own logcontext, with resource metrics This should be used to wrap processes which are fired off to run in the @@ -189,11 +243,12 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar args: positional args for func kwargs: keyword args for func - Returns: Deferred which returns the result of func, but note that it does not - follow the synapse logcontext rules. + Returns: Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. """ - async def run(): + async def run() -> Optional[R]: with _bg_metrics_lock: count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 @@ -216,6 +271,7 @@ async def run(): "Background process '%s' threw an exception", desc, ) + return None finally: _background_process_in_flight_count.labels(desc).dec() @@ -225,19 +281,25 @@ async def run(): return defer.ensureDeferred(run()) -def wrap_as_background_process(desc): +F = TypeVar("F", bound=Callable[..., Any]) + + +def wrap_as_background_process(desc: str) -> Callable[[F], F]: """Decorator that wraps a function that gets called as a background process. Equivalent of calling the function with `run_as_background_process` """ - def wrap_as_background_process_inner(func): + # NB: Return type is incorrect and should be F with a Deferred[Optional[R]] return + def wrap_as_background_process_inner(func: F) -> F: @wraps(func) - def wrap_as_background_process_inner_2(*args, **kwargs): + def wrap_as_background_process_inner_2( + *args: Any, **kwargs: Any + ) -> defer.Deferred[Optional[R]]: return run_as_background_process(desc, func, *args, **kwargs) - return wrap_as_background_process_inner_2 + return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner @@ -265,7 +327,7 @@ def __init__(self, name: str, instance_id: Optional[Union[int, str]] = None): super().__init__("%s-%s" % (name, instance_id)) self._proc = _BackgroundProcess(name, self) - def start(self, rusage: "Optional[resource._RUsage]"): + def start(self, rusage: "Optional[resource._RUsage]") -> None: """Log context has started running (again).""" super().start(rusage) @@ -276,7 +338,12 @@ def start(self, rusage: "Optional[resource._RUsage]"): with _bg_metrics_lock: _background_processes_active_since_last_scrape.add(self._proc) - def __exit__(self, type, value, traceback) -> None: + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: """Log context has finished.""" super().__exit__(type, value, traceback) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 29ab6c0229df..98ed9c0829e8 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -16,14 +16,16 @@ import logging import os import re -from typing import Optional +from typing import Iterable, Optional + +from prometheus_client import Metric from synapse.metrics import REGISTRY, GaugeMetricFamily logger = logging.getLogger(__name__) -def _setup_jemalloc_stats(): +def _setup_jemalloc_stats() -> None: """Checks to see if jemalloc is loaded, and hooks up a collector to record statistics exposed by jemalloc. """ @@ -135,7 +137,7 @@ def _jemalloc_refresh_stats() -> None: class JemallocCollector: """Metrics for internal jemalloc stats.""" - def collect(self): + def collect(self) -> Iterable[Metric]: _jemalloc_refresh_stats() g = GaugeMetricFamily( @@ -185,7 +187,7 @@ def collect(self): logger.debug("Added jemalloc stats") -def setup_jemalloc_stats(): +def setup_jemalloc_stats() -> None: """Try to setup jemalloc stats, if jemalloc is loaded.""" try: