Skip to content

Commit

Permalink
Add metrics instrumentation celery (#1679)
Browse files Browse the repository at this point in the history
Co-authored-by: Shalev Roda <65566801+shalevr@users.noreply.github.com>
  • Loading branch information
Akochavi and shalevr authored Jun 18, 2023
1 parent 7804083 commit 1dd17ed
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased


### Added

- Make Flask request span attributes available for `start_span`.
Expand All @@ -16,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Instrument all httpx versions >= 0.18. ([#1748](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1748))
- Fix `Invalid type NoneType for attribute X (opentelemetry-instrumentation-aws-lambda)` error when some attributes do not exist
([#1780](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1780))
- Add metric instrumentation for celery
([#1679](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1679))

## Version 1.18.0/0.39b0 (2023-05-10)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def add(x, y):
"""

import logging
from timeit import default_timer
from typing import Collection, Iterable

from celery import signals # pylint: disable=no-name-in-module
Expand All @@ -69,6 +70,7 @@ def add(x, y):
from opentelemetry.instrumentation.celery.package import _instruments
from opentelemetry.instrumentation.celery.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.metrics import get_meter
from opentelemetry.propagate import extract, inject
from opentelemetry.propagators.textmap import Getter
from opentelemetry.semconv.trace import SpanAttributes
Expand Down Expand Up @@ -104,6 +106,11 @@ def keys(self, carrier):


class CeleryInstrumentor(BaseInstrumentor):
def __init__(self):
super().__init__()
self.metrics = None
self.task_id_to_start_time = {}

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

Expand All @@ -113,6 +120,11 @@ def _instrument(self, **kwargs):
# pylint: disable=attribute-defined-outside-init
self._tracer = trace.get_tracer(__name__, __version__, tracer_provider)

meter_provider = kwargs.get("meter_provider")
meter = get_meter(__name__, __version__, meter_provider)

self.create_celery_metrics(meter)

signals.task_prerun.connect(self._trace_prerun, weak=False)
signals.task_postrun.connect(self._trace_postrun, weak=False)
signals.before_task_publish.connect(
Expand All @@ -139,6 +151,7 @@ def _trace_prerun(self, *args, **kwargs):
if task is None or task_id is None:
return

self.update_task_duration_time(task_id)
request = task.request
tracectx = extract(request, getter=celery_getter) or None

Expand All @@ -153,8 +166,7 @@ def _trace_prerun(self, *args, **kwargs):
activation.__enter__() # pylint: disable=E1101
utils.attach_span(task, task_id, (span, activation))

@staticmethod
def _trace_postrun(*args, **kwargs):
def _trace_postrun(self, *args, **kwargs):
task = utils.retrieve_task(kwargs)
task_id = utils.retrieve_task_id(kwargs)

Expand All @@ -178,6 +190,9 @@ def _trace_postrun(*args, **kwargs):

activation.__exit__(None, None, None)
utils.detach_span(task, task_id)
self.update_task_duration_time(task_id)
labels = {"task": task.name, "worker": task.request.hostname}
self._record_histograms(task_id, labels)

def _trace_before_publish(self, *args, **kwargs):
task = utils.retrieve_task_from_sender(kwargs)
Expand Down Expand Up @@ -277,3 +292,30 @@ def _trace_retry(*args, **kwargs):
# Use `str(reason)` instead of `reason.message` in case we get
# something that isn't an `Exception`
span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason))

def update_task_duration_time(self, task_id):
cur_time = default_timer()
task_duration_time_until_now = (
cur_time - self.task_id_to_start_time[task_id]
if task_id in self.task_id_to_start_time
else cur_time
)
self.task_id_to_start_time[task_id] = task_duration_time_until_now

def _record_histograms(self, task_id, metric_attributes):
if task_id is None:
return

self.metrics["flower.task.runtime.seconds"].record(
self.task_id_to_start_time.get(task_id),
attributes=metric_attributes,
)

def create_celery_metrics(self, meter) -> None:
self.metrics = {
"flower.task.runtime.seconds": meter.create_histogram(
name="flower.task.runtime.seconds",
unit="seconds",
description="The time it took to run the task.",
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import threading
import time
from timeit import default_timer

from opentelemetry.instrumentation.celery import CeleryInstrumentor
from opentelemetry.test.test_base import TestBase

from .celery_test_tasks import app, task_add


class TestMetrics(TestBase):
def setUp(self):
super().setUp()
self._worker = app.Worker(
app=app, pool="solo", concurrency=1, hostname="celery@akochavi"
)
self._thread = threading.Thread(target=self._worker.start)
self._thread.daemon = True
self._thread.start()

def tearDown(self):
super().tearDown()
self._worker.stop()
self._thread.join()

def get_metrics(self):
result = task_add.delay(1, 2)

timeout = time.time() + 60 * 1 # 1 minutes from now
while not result.ready():
if time.time() > timeout:
break
time.sleep(0.05)
return self.get_sorted_metrics()

def test_basic_metric(self):
CeleryInstrumentor().instrument()
start_time = default_timer()
task_runtime_estimated = (default_timer() - start_time) * 1000

metrics = self.get_metrics()
CeleryInstrumentor().uninstrument()
self.assertEqual(len(metrics), 1)

task_runtime = metrics[0]
print(task_runtime)
self.assertEqual(task_runtime.name, "flower.task.runtime.seconds")
self.assert_metric_expected(
task_runtime,
[
self.create_histogram_data_point(
count=1,
sum_data_point=task_runtime_estimated,
max_data_point=task_runtime_estimated,
min_data_point=task_runtime_estimated,
attributes={
"task": "tests.celery_test_tasks.task_add",
"worker": "celery@akochavi",
},
)
],
est_value_delta=200,
)

def test_metric_uninstrument(self):
CeleryInstrumentor().instrument()
metrics = self.get_metrics()
self.assertEqual(len(metrics), 1)
CeleryInstrumentor().uninstrument()

metrics = self.get_metrics()
self.assertEqual(len(metrics), 1)

for metric in metrics:
for point in list(metric.data.data_points):
self.assertEqual(point.count, 1)

0 comments on commit 1dd17ed

Please sign in to comment.