Skip to content

Commit

Permalink
Move cron utils to their own package
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanseymour committed Dec 11, 2024
1 parent fd02ba6 commit ee2de6c
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 115 deletions.
1 change: 1 addition & 0 deletions temba/utils/crons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .decorator import cron_task # noqa
45 changes: 4 additions & 41 deletions temba/utils/crons.py → temba/utils/crons/decorator.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,16 @@
import logging
from functools import wraps

from celery import shared_task
from django_redis import get_redis_connection

from django.utils import timezone

from . import analytics, json
from .signals import post_cron_exec

logger = logging.getLogger(__name__)

# for tasks using a redis lock to prevent overlapping this is the default timeout for the lock
DEFAULT_TASK_LOCK_TIMEOUT = 900

STATS_EXPIRES = 60 * 60 * 48 # 2 days
STATS_KEY_BASE = "cron_stats"
STATS_LAST_START_KEY = f"{STATS_KEY_BASE}:last_start"
STATS_LAST_TIME_KEY = f"{STATS_KEY_BASE}:last_time"
STATS_LAST_RESULT_KEY = f"{STATS_KEY_BASE}:last_result"
STATS_CALL_COUNT_KEY = f"{STATS_KEY_BASE}:call_count"
STATS_TOTAL_TIME_KEY = f"{STATS_KEY_BASE}:total_time"
STATS_KEYS = (
STATS_LAST_START_KEY,
STATS_LAST_TIME_KEY,
STATS_LAST_RESULT_KEY,
STATS_CALL_COUNT_KEY,
STATS_TOTAL_TIME_KEY,
)


def cron_task(*task_args, **task_kwargs):
"""
Expand Down Expand Up @@ -58,32 +41,12 @@ def wrapper(*exec_args, **exec_kwargs):
with r.lock(lock_key, timeout=lock_timeout):
result = task_func(*exec_args, **exec_kwargs)
finally:
_record_cron_execution(r, task_name, start, end=timezone.now(), result=result)
post_cron_exec.send(
sender=cron_task, task_name=task_name, started=start, ended=timezone.now(), result=result
)

return result

return shared_task(*task_args, **task_kwargs)(wrapper)

return _cron_task


def _record_cron_execution(r, name: str, start, end, result):
pipe = r.pipeline()
pipe.hset(STATS_LAST_START_KEY, name, start.isoformat())
pipe.hset(STATS_LAST_TIME_KEY, name, str((end - start).total_seconds()))
pipe.hset(STATS_LAST_RESULT_KEY, name, json.dumps(result))
pipe.hincrby(STATS_CALL_COUNT_KEY, name, 1)
pipe.hincrbyfloat(STATS_TOTAL_TIME_KEY, name, (end - start).total_seconds())

for key in STATS_KEYS:
pipe.expire(key, STATS_EXPIRES)

pipe.execute()

analytics.gauges({f"temba.cron_{name}": (end - start).total_seconds()})


def clear_cron_stats():
r = get_redis_connection()
for key in STATS_KEYS:
r.delete(key)
3 changes: 3 additions & 0 deletions temba/utils/crons/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from django.dispatch import Signal

post_cron_exec = Signal()
63 changes: 63 additions & 0 deletions temba/utils/crons/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from unittest.mock import patch

from celery.app.task import Task

from temba.tests import TembaTest

from . import cron_task


class CronsTest(TembaTest):
@patch("redis.client.StrictRedis.lock")
@patch("redis.client.StrictRedis.get")
def test_cron_task(self, mock_redis_get, mock_redis_lock):
mock_redis_get.return_value = None
task_calls = []

@cron_task()
def test_task1(foo, bar):
task_calls.append("1-%d-%d" % (foo, bar))
return {"foo": 1}

@cron_task(name="task2", time_limit=100)
def test_task2(foo, bar):
task_calls.append("2-%d-%d" % (foo, bar))
return 1234

@cron_task(name="task3", time_limit=100, lock_timeout=55)
def test_task3(foo, bar):
task_calls.append("3-%d-%d" % (foo, bar))

self.assertIsInstance(test_task1, Task)
self.assertIsInstance(test_task2, Task)
self.assertEqual(test_task2.name, "task2")
self.assertEqual(test_task2.time_limit, 100)
self.assertIsInstance(test_task3, Task)
self.assertEqual(test_task3.name, "task3")
self.assertEqual(test_task3.time_limit, 100)

test_task1(11, 12)
test_task2(21, bar=22)
test_task3(foo=31, bar=32)

mock_redis_get.assert_any_call("celery-task-lock:test_task1")
mock_redis_get.assert_any_call("celery-task-lock:task2")
mock_redis_get.assert_any_call("celery-task-lock:task3")
mock_redis_lock.assert_any_call("celery-task-lock:test_task1", timeout=900)
mock_redis_lock.assert_any_call("celery-task-lock:task2", timeout=100)
mock_redis_lock.assert_any_call("celery-task-lock:task3", timeout=55)

self.assertEqual(task_calls, ["1-11-12", "2-21-22", "3-31-32"])

# simulate task being already running
mock_redis_get.reset_mock()
mock_redis_get.return_value = "xyz"
mock_redis_lock.reset_mock()

# try to run again
test_task1(13, 14)

# check that task is skipped
mock_redis_get.assert_called_once_with("celery-task-lock:test_task1")
self.assertEqual(mock_redis_lock.call_count, 0)
self.assertEqual(task_calls, ["1-11-12", "2-21-22", "3-31-32"])
1 change: 0 additions & 1 deletion temba/utils/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from . import signals # noqa
from .base import * # noqa
3 changes: 2 additions & 1 deletion temba/utils/management/commands/migrate_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.core.management import BaseCommand

from temba.utils import dynamo
from temba.utils.dynamo.signals import pre_create_table

TABLES = [
{
Expand Down Expand Up @@ -43,7 +44,7 @@ def _migrate_table(self, table: dict):
spec["TableName"] = real_name

# invoke pre-create signal to allow for table modifications
dynamo.signals.pre_create_table.send(self.__class__, spec=spec)
pre_create_table.send(self.__class__, spec=spec)

# ttl isn't actually part of the create call
ttlSpec = spec.pop("TimeToLiveSpecification", None)
Expand Down
72 changes: 0 additions & 72 deletions temba/utils/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,8 @@
from collections import OrderedDict
from datetime import date, timezone as tzone
from decimal import Decimal
from unittest.mock import patch
from zoneinfo import ZoneInfo

from celery.app.task import Task
from django_redis import get_redis_connection

from django import forms
from django.forms import ValidationError
from django.test import TestCase, override_settings
Expand All @@ -21,7 +17,6 @@

from . import countries, format_number, get_nested_key, languages, percentage, redact, set_nested_key, str_to_bool
from .checks import storage
from .crons import clear_cron_stats, cron_task
from .dates import date_range, datetime_to_str, datetime_to_timestamp, timestamp_to_datetime
from .fields import ExternalURLField, NameValidator
from .text import clean_string, generate_secret, generate_token, slugify_with, truncate, unsnakify
Expand Down Expand Up @@ -185,73 +180,6 @@ def test_encode_decode(self):
json.dumps(dict(foo=Exception("invalid")))


class CronsTest(TembaTest):
@patch("redis.client.StrictRedis.lock")
@patch("redis.client.StrictRedis.get")
def test_cron_task(self, mock_redis_get, mock_redis_lock):
clear_cron_stats()

mock_redis_get.return_value = None
task_calls = []

@cron_task()
def test_task1(foo, bar):
task_calls.append("1-%d-%d" % (foo, bar))
return {"foo": 1}

@cron_task(name="task2", time_limit=100)
def test_task2(foo, bar):
task_calls.append("2-%d-%d" % (foo, bar))
return 1234

@cron_task(name="task3", time_limit=100, lock_timeout=55)
def test_task3(foo, bar):
task_calls.append("3-%d-%d" % (foo, bar))

self.assertIsInstance(test_task1, Task)
self.assertIsInstance(test_task2, Task)
self.assertEqual(test_task2.name, "task2")
self.assertEqual(test_task2.time_limit, 100)
self.assertIsInstance(test_task3, Task)
self.assertEqual(test_task3.name, "task3")
self.assertEqual(test_task3.time_limit, 100)

test_task1(11, 12)
test_task2(21, bar=22)
test_task3(foo=31, bar=32)

mock_redis_get.assert_any_call("celery-task-lock:test_task1")
mock_redis_get.assert_any_call("celery-task-lock:task2")
mock_redis_get.assert_any_call("celery-task-lock:task3")
mock_redis_lock.assert_any_call("celery-task-lock:test_task1", timeout=900)
mock_redis_lock.assert_any_call("celery-task-lock:task2", timeout=100)
mock_redis_lock.assert_any_call("celery-task-lock:task3", timeout=55)

self.assertEqual(task_calls, ["1-11-12", "2-21-22", "3-31-32"])

r = get_redis_connection()
self.assertEqual({b"test_task1", b"task2", b"task3"}, set(r.hkeys("cron_stats:last_start")))
self.assertEqual({b"test_task1", b"task2", b"task3"}, set(r.hkeys("cron_stats:last_time")))
self.assertEqual(
{b"test_task1": b'{"foo": 1}', b"task2": b"1234", b"task3": b"null"}, r.hgetall("cron_stats:last_result")
)
self.assertEqual({b"test_task1": b"1", b"task2": b"1", b"task3": b"1"}, r.hgetall("cron_stats:call_count"))
self.assertEqual({b"test_task1", b"task2", b"task3"}, set(r.hkeys("cron_stats:total_time")))

# simulate task being already running
mock_redis_get.reset_mock()
mock_redis_get.return_value = "xyz"
mock_redis_lock.reset_mock()

# try to run again
test_task1(13, 14)

# check that task is skipped
mock_redis_get.assert_called_once_with("celery-task-lock:test_task1")
self.assertEqual(mock_redis_lock.call_count, 0)
self.assertEqual(task_calls, ["1-11-12", "2-21-22", "3-31-32"])


class MiddlewareTest(TembaTest):
def test_org(self):
index_url = reverse("public.public_index")
Expand Down

0 comments on commit ee2de6c

Please sign in to comment.