From 219d5492f9a17598dadf59ecb74f461abfeaa5bb Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 16:04:50 +0100 Subject: [PATCH 01/23] Add OrderedModel --- engine/apps/base/models/ordered_model.py | 148 +++++++ engine/apps/base/tests/test_ordered_model.py | 381 +++++++++++++++++++ 2 files changed, 529 insertions(+) create mode 100644 engine/apps/base/models/ordered_model.py create mode 100644 engine/apps/base/tests/test_ordered_model.py diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py new file mode 100644 index 0000000000..5bd01c9718 --- /dev/null +++ b/engine/apps/base/models/ordered_model.py @@ -0,0 +1,148 @@ +import logging +import random +import time +from functools import wraps + +from django.db import IntegrityError, OperationalError, connection, models, transaction + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +# TODO: comments +SQL_TO = """ +UPDATE `{db_table}` `t1` +JOIN `{db_table}` `t2` on `t2`.`id` = %(id)s +SET `t1`.`order` = IF(`t1`.`id` = `t2`.`id`, null, IF(`t1`.`order` < `t2`.`order`, `t1`.`order` + 1, `t1`.`order` - 1)) +WHERE {ordering_condition} +AND `t2`.`order` != %(order)s +AND `t1`.`order` >= IF(`t2`.`order` > %(order)s, %(order)s, `t2`.`order`) +AND `t1`.`order` <= IF(`t2`.`order` > %(order)s, `t2`.`order`, %(order)s) +ORDER BY IF(`t1`.`order` <= `t2`.`order`, `t1`.`order`, null) DESC, IF(`t1`.`order` >= `t2`.`order`, `t1`.`order`, null) ASC +""" + +SQL_SWAP = """ +UPDATE `{db_table}` `t1` +JOIN `{db_table}` `t2` on `t2`.`id` = %(id)s +SET `t1`.`order` = IF(`t1`.`id` = `t2`.`id`, null, `t2`.`order`) +WHERE {ordering_condition} +AND `t2`.`order` != %(order)s +AND (`t1`.`id` = `t2`.`id` OR `t1`.`order` = %(order)s) +ORDER BY IF(`t1`.`id` = `t2`.`id`, 0, 1) ASC +""" + + +def _retry(exc, max_attempts=15): + def _retry_with_params(f): + @wraps(f) + def wrapper(*args, **kwargs): + attempts = 0 + while attempts < max_attempts: + try: + return f(*args, **kwargs) + except exc: + logger.debug(f"IntegrityError occurred in {f.__qualname__}. Retrying...") + if attempts == max_attempts - 1: + raise + attempts += 1 + # double the sleep time each time and add some jitter + time.sleep(random.random()) + + return wrapper + + return _retry_with_params + + +class OrderedModel(models.Model): + """ + This class is intended to be used as a mixin for models that need to be ordered. + + Operations: + - create: TODO + - delete: TODO + - move to: TODO + - move to index: TODO + - swap: TODO + - get next: TODO + """ + + order: int = models.PositiveIntegerField(editable=False, db_index=True, null=True) + order_with_respect_to = [] + + class Meta: + abstract = True + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["order"], name="unique_order"), + ] + + def save(self, *args, **kwargs): + if self.order is None: + self._save_no_order_provided() + else: + if self.order < 0: + raise ValueError("Order must be a positive integer.") + super().save() + + @_retry(OperationalError) + def delete(self, using=None, keep_parents=False): + super().delete(using=using, keep_parents=keep_parents) + + @_retry((IntegrityError, OperationalError)) + def _save_no_order_provided(self): + max_order = self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] + + if max_order is None: + self.order = 0 + else: + self.order = max_order + 1 + + super().save() + + @_retry((IntegrityError, OperationalError)) + def to(self, order): + if order is None or order < 0: + raise ValueError("Order must be a positive integer.") + + sql = SQL_TO.format(db_table=self._meta.db_table, ordering_condition=self._ordering_condition_sql) + params = {"id": self.id, "order": order, **self._ordering_kwargs} + + with transaction.atomic(): + with connection.cursor() as cursor: + cursor.execute(sql, params) + self._meta.model.objects.filter(pk=self.pk).update(order=order) + + self.refresh_from_db() + + def to_index(self, index): + order = self._get_ordering_queryset().values_list("order", flat=True)[index] + self.to(order) + + @_retry((IntegrityError, OperationalError)) + def swap(self, order): + if order is None or order < 0: + raise ValueError("Order must be a positive integer.") + + sql = SQL_SWAP.format(db_table=self._meta.db_table, ordering_condition=self._ordering_condition_sql) + params = {"id": self.id, "order": order, **self._ordering_kwargs} + + with transaction.atomic(): + with connection.cursor() as cursor: + cursor.execute(sql, params) + self._meta.model.objects.filter(pk=self.pk).update(order=order) + + self.refresh_from_db() + + def next(self): + return self._get_ordering_queryset().filter(order__gt=self.order).first() + + @property + def _ordering_kwargs(self): + return {field: getattr(self, field) for field in self.order_with_respect_to} + + def _get_ordering_queryset(self): + return self._meta.model.objects.filter(**self._ordering_kwargs) + + @property + def _ordering_condition_sql(self): + ordering_parts = ["`t1`.`{0}` = %({0})s".format(field) for field in self.order_with_respect_to] + return " AND ".join(ordering_parts) diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py new file mode 100644 index 0000000000..303a3df27a --- /dev/null +++ b/engine/apps/base/tests/test_ordered_model.py @@ -0,0 +1,381 @@ +import random +import threading + +import pytest +from django.db import models + +from apps.base.models.ordered_model import OrderedModel + + +class TestOrderedModel(OrderedModel): + test_field = models.CharField(max_length=255) + extra_field = models.IntegerField(null=True, default=None) + order_with_respect_to = ["test_field"] + + class Meta: + app_label = "base" + ordering = ["order"] + constraints = [ + models.UniqueConstraint(fields=["test_field", "order"], name="unique_test_field_order"), + ] + + +def _get_ids(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("id", flat=True)) + + +def _get_orders(): + return list(TestOrderedModel.objects.filter(test_field="test").values_list("order", flat=True)) + + +def _orders_are_sequential(): + orders = _get_orders() + return orders == list(range(len(orders))) + + +@pytest.mark.django_db +def test_ordered_model_create(): + first = TestOrderedModel.objects.create(test_field="test") + second = TestOrderedModel.objects.create(test_field="test") + + assert first.order == 0 + assert second.order == 1 + + +@pytest.mark.django_db +def test_ordered_model_delete(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(3)] + + instances[1].delete() + assert instances[1].pk is None + assert _get_ids() == [instances[0].id, instances[2].id] + assert _get_orders() == [0, 2] + + +@pytest.mark.django_db +def test_ordered_model_to(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_to_index(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # move to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # move to the middle + instances[0].to_index(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([1, 2, 0, 3, 4]) + assert _orders_are_sequential() + + # move from the middle to the end + instances[0].to_index(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([1, 2, 3, 4, 0]) + assert _orders_are_sequential() + + # move from the end to the second position + instances[0].to_index(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([1, 0, 2, 3, 4]) + assert _orders_are_sequential() + + # move from the second position to the beginning + instances[0].to_index(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # don't move if the order is the same + for instance in instances: + instance.to_index(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + +@pytest.mark.django_db +def test_ordered_model_swap(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + + def _ids(indices): + return [instances[i].id for i in indices] + + # swap with last + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([4, 1, 2, 3, 0]) + assert _orders_are_sequential() + + # swap with first + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 1, 2, 3, 4]) + assert _orders_are_sequential() + + # swap with middle + instances[0].swap(2) + assert instances[0].order == 2 + assert _get_ids() == _ids([2, 1, 0, 3, 4]) + assert _orders_are_sequential() + + # swap from the middle to the end + instances[0].swap(4) + assert instances[0].order == 4 + assert _get_ids() == _ids([2, 1, 4, 3, 0]) + assert _orders_are_sequential() + + # swap from the end to the second position + instances[0].swap(1) + assert instances[0].order == 1 + assert _get_ids() == _ids([2, 0, 4, 3, 1]) + assert _orders_are_sequential() + + # swap from the second position to the beginning + instances[0].swap(0) + assert instances[0].order == 0 + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + # swap with itself + for instance in instances: + instance.refresh_from_db(fields=["order"]) + instance.swap(instance.order) + assert instance.order == instance.order + assert _get_ids() == _ids([0, 2, 4, 3, 1]) + assert _orders_are_sequential() + + +# Tests below are for checking that concurrent operations are performed correctly. +# They are skipped by default because they might take a lot of time to run. +# It could be useful to run them manually when making changes to the code, making sure +# that the changes don't break parallel operations. + + +@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_concurrent(): + LOOPS = 30 + THREADS = 10 + exceptions = [] + + def create(): + for loop in range(LOOPS): + try: + TestOrderedModel.objects.create(test_field="test") + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=create) for _ in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert TestOrderedModel.objects.count() == LOOPS * THREADS + assert _orders_are_sequential() + + +@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_to_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def to(idx): + try: + instance = instances[idx] + instance.to(positions[idx]) # swap with next + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=to, args=(idx,)) for idx in range(THREADS - 1)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random unique orders + random.seed(42) + unique_orders = list(range(THREADS)) + random.shuffle(unique_orders) + + def swap(idx): + try: + instance = instances[idx] + instance.swap(unique_orders[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not exceptions + assert _orders_are_sequential() + + # in case of unique orders, the final order is deterministic + assert list(TestOrderedModel.objects.order_by("id").values_list("order", flat=True)) == unique_orders + + +@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_swap_non_unique_orders_concurrent(): + THREADS = 300 + exceptions = [] + + TestOrderedModel.objects.all().delete() # clear table + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(THREADS)] + + # generate random non-unique orders + random.seed(42) + positions = [random.randint(0, THREADS - 1) for _ in range(THREADS)] + + def swap(idx): + try: + instance = instances[idx] + instance.swap(positions[idx]) + except Exception as e: + exceptions.append(e) + + threads = [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + # can only check that orders are still sequential and that there are no exceptions + # can't check the exact order because it changes depending on the order of execution + assert not exceptions + assert _orders_are_sequential() + + +@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.django_db(transaction=True) +def test_ordered_model_create_swap_and_delete_concurrent(): + """Check that create+swap, swap and delete operations are performed correctly when run concurrently.""" + + THREADS = 100 + exceptions = [] + + instances = [TestOrderedModel.objects.create(test_field="test", extra_field=idx) for idx in range(THREADS * 3)] + + def create_swap(idx): + try: + instance = TestOrderedModel.objects.create(test_field="test", extra_field=idx + 1000) + instance.swap(idx) + except Exception as e: + exceptions.append(("create_swap", e)) + + def swap(idx): + try: + instances[idx].swap(idx + 1) + except Exception as e: + exceptions.append(("swap", e)) + + def delete(idx): + try: + instances[idx].delete() + except Exception as e: + exceptions.append(("delete", e)) + + threads = [threading.Thread(target=create_swap, args=(idx,)) for idx in list(range(THREADS))] + threads += [threading.Thread(target=delete, args=(idx,)) for idx in range(THREADS)] + threads += [threading.Thread(target=swap, args=(idx,)) for idx in range(THREADS, THREADS * 2 - 1)] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + expected_extra_field_values = list(range(1000, 1000 + THREADS)) + expected_extra_field_values += [THREADS * 2 - 1] + list(range(THREADS, THREADS * 2 - 1)) + expected_extra_field_values += [instance.extra_field for instance in instances[THREADS * 2 : THREADS * 3]] + + assert not exceptions + assert _orders_are_sequential() + assert list(TestOrderedModel.objects.values_list("extra_field", flat=True)) == expected_extra_field_values From 362bab5d9bde859d9cf462b8e39398e007187184 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 16:15:51 +0100 Subject: [PATCH 02/23] Use OrderedModel in UserNotificationPolicy --- .../migrations/0004_auto_20230616_1510.py | 48 +++++++++++++++++++ .../base/models/user_notification_policy.py | 9 +++- 2 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 engine/apps/base/migrations/0004_auto_20230616_1510.py diff --git a/engine/apps/base/migrations/0004_auto_20230616_1510.py b/engine/apps/base/migrations/0004_auto_20230616_1510.py new file mode 100644 index 0000000000..4704fe2e2e --- /dev/null +++ b/engine/apps/base/migrations/0004_auto_20230616_1510.py @@ -0,0 +1,48 @@ +# Generated by Django 3.2.19 on 2023-06-16 15:10 + +from django.db import migrations, models +from django.db.models import Count + +from common.database import get_random_readonly_database_key_if_present_otherwise_default + + +def fix_duplicate_order_user_notification_policy(apps, schema_editor): + UserNotificationPolicy = apps.get_model('base', 'UserNotificationPolicy') + + # it should be safe to use a readonly database because duplicates are pretty infrequent + db = get_random_readonly_database_key_if_present_otherwise_default() + + # find all (user_id, important, order) tuples that have more than one entry (meaning duplicates) + items_with_duplicate_orders = UserNotificationPolicy.objects.using(db).values( + "user_id", "important", "order" + ).annotate(count=Count("order")).order_by().filter(count__gt=1) # use order_by() to reset any existing ordering + + # make sure we don't fix the same (user_id, important) pair more than once + values_to_fix = set((item["user_id"], item["important"]) for item in items_with_duplicate_orders) + + for user_id, important in values_to_fix: + policies = UserNotificationPolicy.objects.filter(user_id=user_id, important=important).order_by("order", "id") + # assign correct sequential order for each policy starting from 0 + for idx, policy in enumerate(policies): + policy.order = idx + UserNotificationPolicy.objects.bulk_update(policies, fields=["order"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('base', '0003_delete_organizationlogrecord'), + ] + + operations = [ + migrations.AlterField( + model_name='usernotificationpolicy', + name='order', + field=models.PositiveIntegerField(db_index=True, editable=False, null=True), + ), + migrations.RunPython(fix_duplicate_order_user_notification_policy, migrations.RunPython.noop), + migrations.AddConstraint( + model_name='usernotificationpolicy', + constraint=models.UniqueConstraint(fields=('user_id', 'important', 'order'), name='unique_user_notification_policy_order'), + ), + ] diff --git a/engine/apps/base/models/user_notification_policy.py b/engine/apps/base/models/user_notification_policy.py index b439391608..11a9e1b996 100644 --- a/engine/apps/base/models/user_notification_policy.py +++ b/engine/apps/base/models/user_notification_policy.py @@ -7,9 +7,9 @@ from django.core.validators import MinLengthValidator from django.db import models from django.db.models import Q, QuerySet -from ordered_model.models import OrderedModel from apps.base.messaging import get_messaging_backends +from apps.base.models.ordered_model import OrderedModel from apps.user_management.models import User from common.exceptions import UserNotificationPolicyCouldNotBeDeleted from common.public_primary_keys import generate_public_primary_key, increase_public_primary_key_length @@ -103,7 +103,7 @@ def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotifi class UserNotificationPolicy(OrderedModel): objects = UserNotificationPolicyQuerySet.as_manager() - order_with_respect_to = ("user", "important") + order_with_respect_to = ("user_id", "important") public_primary_key = models.CharField( max_length=20, @@ -145,6 +145,11 @@ class Step(models.IntegerChoices): class Meta: ordering = ("order",) + constraints = [ + models.UniqueConstraint( + fields=["user_id", "important", "order"], name="unique_user_notification_policy_order" + ) + ] def __str__(self): return f"{self.pk}: {self.short_verbal}" From 32dfeb91e0cbf3979e36e26b8f5caf33b6888953 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 16:46:41 +0100 Subject: [PATCH 03/23] Update internal API view and serializers --- .../serializers/user_notification_policy.py | 30 ++++--------------- .../tests/test_user_notification_policy.py | 25 ++++++---------- .../api/views/user_notification_policy.py | 7 ++++- 3 files changed, 20 insertions(+), 42 deletions(-) diff --git a/engine/apps/api/serializers/user_notification_policy.py b/engine/apps/api/serializers/user_notification_policy.py index 79eb845f8f..ba3e71724f 100644 --- a/engine/apps/api/serializers/user_notification_policy.py +++ b/engine/apps/api/serializers/user_notification_policy.py @@ -7,7 +7,7 @@ from apps.base.models.user_notification_policy import NotificationChannelAPIOptions from apps.user_management.models import User from common.api_helpers.custom_fields import OrganizationFilteredPrimaryKeyRelatedField -from common.api_helpers.exceptions import BadRequest, Forbidden +from common.api_helpers.exceptions import Forbidden from common.api_helpers.mixins import EagerLoadingMixin @@ -34,6 +34,7 @@ class UserNotificationPolicyBaseSerializer(EagerLoadingMixin, serializers.ModelS class Meta: model = UserNotificationPolicy fields = ["id", "step", "order", "notify_by", "wait_delay", "important", "user"] + read_only_fields = ["order"] def to_internal_value(self, data): if data.get("wait_delay", None): @@ -67,7 +68,6 @@ def _notify_by_to_representation(self, instance, result): class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): - prev_step = serializers.CharField(required=False, write_only=True, allow_null=True) user = OrganizationFilteredPrimaryKeyRelatedField( queryset=User.objects, required=False, @@ -80,36 +80,16 @@ class UserNotificationPolicySerializer(UserNotificationPolicyBaseSerializer): default=NotificationChannelAPIOptions.DEFAULT_NOTIFICATION_CHANNEL, ) - class Meta(UserNotificationPolicyBaseSerializer.Meta): - fields = [*UserNotificationPolicyBaseSerializer.Meta.fields, "prev_step"] - read_only_fields = ("order",) - def create(self, validated_data): - prev_step = validated_data.pop("prev_step", None) - - user = validated_data.get("user") + user = validated_data.get("user") or self.context["request"].user organization = self.context["request"].auth.organization - if not user: - user = self.context["request"].user - self_or_admin = user.self_or_admin(user_to_check=self.context["request"].user, organization=organization) if not self_or_admin: raise Forbidden() - if prev_step is not None: - try: - prev_step = UserNotificationPolicy.objects.get(public_primary_key=prev_step) - except UserNotificationPolicy.DoesNotExist: - raise BadRequest(detail="Prev step does not exist") - if prev_step.user != user or prev_step.important != validated_data.get("important", False): - raise BadRequest(detail="UserNotificationPolicy can be created only with the same user and importance") - instance = UserNotificationPolicy.objects.create(**validated_data) - instance.to(prev_step.order + 1) - return instance - else: - instance = UserNotificationPolicy.objects.create(**validated_data) - return instance + instance = UserNotificationPolicy.objects.create(**validated_data) + return instance class UserNotificationPolicyUpdateSerializer(UserNotificationPolicyBaseSerializer): diff --git a/engine/apps/api/tests/test_user_notification_policy.py b/engine/apps/api/tests/test_user_notification_policy.py index 3cda1f0d18..996775cc93 100644 --- a/engine/apps/api/tests/test_user_notification_policy.py +++ b/engine/apps/api/tests/test_user_notification_policy.py @@ -110,7 +110,7 @@ def test_user_cant_create_notification_policy_for_user( @pytest.mark.django_db -def test_create_notification_policy_from_step( +def test_create_notification_policy_order_is_ignored( user_notification_policy_internal_api_setup, make_user_auth_headers, ): @@ -121,7 +121,7 @@ def test_create_notification_policy_from_step( url = reverse("api-internal:notification_policy-list") data = { - "prev_step": wait_notification_step.public_primary_key, + "position": 2023, "step": UserNotificationPolicy.Step.NOTIFY, "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, "wait_delay": None, @@ -130,26 +130,19 @@ def test_create_notification_policy_from_step( } response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_201_CREATED - assert response.data["order"] == 1 + assert response.data["order"] == 2 @pytest.mark.django_db -def test_create_invalid_notification_policy(user_notification_policy_internal_api_setup, make_user_auth_headers): +def test_move_to_position_position_error(user_notification_policy_internal_api_setup, make_user_auth_headers): token, steps, users = user_notification_policy_internal_api_setup - wait_notification_step, _, _, _ = steps admin, _ = users + step = steps[0] client = APIClient() - url = reverse("api-internal:notification_policy-list") + url = reverse("api-internal:notification_policy-move-to-position", kwargs={"pk": step.public_primary_key}) - data = { - "prev_step": wait_notification_step.public_primary_key, - "step": UserNotificationPolicy.Step.NOTIFY, - "notify_by": UserNotificationPolicy.NotificationChannel.SLACK, - "wait_delay": None, - "important": True, - "user": admin.public_primary_key, - } - response = client.post(url, data, format="json", **make_user_auth_headers(admin, token)) + # position value only can be 0 or 1 for this test setup, because there are only 2 steps + response = client.put(f"{url}?position=2", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -221,7 +214,7 @@ def test_admin_can_move_user_step(user_notification_policy_internal_api_setup, m "api-internal:notification_policy-move-to-position", kwargs={"pk": second_user_step.public_primary_key} ) - response = client.put(f"{url}?position=1", content_type="application/json", **make_user_auth_headers(admin, token)) + response = client.put(f"{url}?position=0", content_type="application/json", **make_user_auth_headers(admin, token)) assert response.status_code == status.HTTP_200_OK diff --git a/engine/apps/api/views/user_notification_policy.py b/engine/apps/api/views/user_notification_policy.py index e05fc121bf..a7efb87fcc 100644 --- a/engine/apps/api/views/user_notification_policy.py +++ b/engine/apps/api/views/user_notification_policy.py @@ -142,7 +142,12 @@ def perform_destroy(self, instance): def move_to_position(self, request, pk): instance = self.get_object() position = get_move_to_position_param(request) - instance.to(position) + + try: + instance.to_index(position) + except IndexError: + raise BadRequest(detail="Invalid position") + return Response(status=status.HTTP_200_OK) @action(detail=False, methods=["get"]) From 21601aa3510ab1d64f071e6106a78814669402c9 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 17:40:04 +0100 Subject: [PATCH 04/23] Update public API --- engine/apps/base/models/ordered_model.py | 12 +++--- engine/apps/base/tests/test_ordered_model.py | 13 +++--- .../personal_notification_rules.py | 40 ++++++++++--------- 3 files changed, 33 insertions(+), 32 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index 5bd01c9718..d05c372ca7 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -89,13 +89,8 @@ def delete(self, using=None, keep_parents=False): @_retry((IntegrityError, OperationalError)) def _save_no_order_provided(self): - max_order = self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] - - if max_order is None: - self.order = 0 - else: - self.order = max_order + 1 - + max_order = self.max_order() + self.order = max_order + 1 if max_order is not None else 0 super().save() @_retry((IntegrityError, OperationalError)) @@ -135,6 +130,9 @@ def swap(self, order): def next(self): return self._get_ordering_queryset().filter(order__gt=self.order).first() + def max_order(self): + return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] + @property def _ordering_kwargs(self): return {field: getattr(self, field) for field in self.order_with_respect_to} diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py index 303a3df27a..61dc782e90 100644 --- a/engine/apps/base/tests/test_ordered_model.py +++ b/engine/apps/base/tests/test_ordered_model.py @@ -209,10 +209,11 @@ def _ids(indices): # Tests below are for checking that concurrent operations are performed correctly. # They are skipped by default because they might take a lot of time to run. # It could be useful to run them manually when making changes to the code, making sure -# that the changes don't break parallel operations. +# that the changes don't break parallel operations. To run the tests, set SKIP_CONCURRENT to False. +SKIP_CONCURRENT = True -@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") @pytest.mark.django_db(transaction=True) def test_ordered_model_create_concurrent(): LOOPS = 30 @@ -237,7 +238,7 @@ def create(): assert _orders_are_sequential() -@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") @pytest.mark.django_db(transaction=True) def test_ordered_model_to_concurrent(): THREADS = 300 @@ -268,7 +269,7 @@ def to(idx): assert _orders_are_sequential() -@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") @pytest.mark.django_db(transaction=True) def test_ordered_model_swap_concurrent(): THREADS = 300 @@ -302,7 +303,7 @@ def swap(idx): assert list(TestOrderedModel.objects.order_by("id").values_list("order", flat=True)) == unique_orders -@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") @pytest.mark.django_db(transaction=True) def test_ordered_model_swap_non_unique_orders_concurrent(): THREADS = 300 @@ -334,7 +335,7 @@ def swap(idx): assert _orders_are_sequential() -@pytest.mark.skip(reason="OrderedModel concurrent tests are skipped to speed up tests") +@pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") @pytest.mark.django_db(transaction=True) def test_ordered_model_create_swap_and_delete_concurrent(): """Check that create+swap, swap and delete operations are performed correctly when run concurrently.""" diff --git a/engine/apps/public_api/serializers/personal_notification_rules.py b/engine/apps/public_api/serializers/personal_notification_rules.py index 8d915da7aa..4fd3c82b9b 100644 --- a/engine/apps/public_api/serializers/personal_notification_rules.py +++ b/engine/apps/public_api/serializers/personal_notification_rules.py @@ -43,14 +43,13 @@ def create(self, validated_data): # that is why step key is used instead of type below if "wait_delay" in validated_data and validated_data["step"] != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) - user = validated_data.pop("user") + + instance = UserNotificationPolicy.objects.create(**validated_data, user=validated_data.pop("user")) + manual_order = validated_data.pop("manual_order") - if not manual_order: - order = validated_data.pop("order", None) - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) - self._change_position(order, instance) - else: - instance = UserNotificationPolicy.objects.create(**validated_data, user=user) + order = validated_data.pop("order", None) + if order is not None: + self._adjust_order(instance, manual_order, order) return instance @@ -117,14 +116,18 @@ def _type_to_step_and_notification_channel(cls, rule_type): raise exceptions.ValidationError({"type": "Invalid type"}) - def _change_position(self, order, instance): - if order is not None: - if order >= 0: - instance.to(order) - elif order == -1: - instance.bottom() - else: - raise BadRequest(detail="Invalid value for position field") + @staticmethod + def _adjust_order(instance, manual_order, order): + if order == -1: + order = instance.max_order() or 0 + + if order < 0: + raise BadRequest(detail="Invalid value for position field") + + if manual_order: + instance.swap(order) + else: + instance.to(order) class PersonalNotificationRuleUpdateSerializer(PersonalNotificationRuleSerializer): @@ -146,9 +149,8 @@ def update(self, instance, validated_data): raise exceptions.ValidationError({"duration": "Duration can't be set"}) manual_order = validated_data.pop("manual_order") - - if not manual_order: - order = validated_data.pop("order", None) - self._change_position(order, instance) + order = validated_data.pop("order", None) + if order is not None: + self._adjust_order(instance, manual_order, order) return super().update(instance, validated_data) From 08bc4a098bc817a0976482abefc6a840f6953250 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 17:44:07 +0100 Subject: [PATCH 05/23] delete excess if --- engine/apps/base/models/ordered_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index d05c372ca7..2614aab9f5 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -79,8 +79,6 @@ def save(self, *args, **kwargs): if self.order is None: self._save_no_order_provided() else: - if self.order < 0: - raise ValueError("Order must be a positive integer.") super().save() @_retry(OperationalError) From 230fa50b06df6b5b2357e42d3e381a9c08845ccc Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Fri, 16 Jun 2023 18:55:27 +0100 Subject: [PATCH 06/23] increase dev mysql max_connections --- docker-compose-developer.yml | 2 +- engine/apps/base/tests/test_ordered_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose-developer.yml b/docker-compose-developer.yml index 69ab2e3b9c..c44633e48c 100644 --- a/docker-compose-developer.yml +++ b/docker-compose-developer.yml @@ -208,7 +208,7 @@ services: container_name: mysql labels: *oncall-labels image: mysql:8.0.32 - command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci + command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --max_connections=1024 restart: always environment: MYSQL_ROOT_PASSWORD: empty diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py index 61dc782e90..ea3f570eb4 100644 --- a/engine/apps/base/tests/test_ordered_model.py +++ b/engine/apps/base/tests/test_ordered_model.py @@ -209,7 +209,7 @@ def _ids(indices): # Tests below are for checking that concurrent operations are performed correctly. # They are skipped by default because they might take a lot of time to run. # It could be useful to run them manually when making changes to the code, making sure -# that the changes don't break parallel operations. To run the tests, set SKIP_CONCURRENT to False. +# that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False. SKIP_CONCURRENT = True From 20ff1db3d126783ace519f19d537c30d40e91512 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 11:41:00 +0100 Subject: [PATCH 07/23] Add comments --- .../serializers/personal_notification_rules.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/engine/apps/public_api/serializers/personal_notification_rules.py b/engine/apps/public_api/serializers/personal_notification_rules.py index 4fd3c82b9b..91a119941f 100644 --- a/engine/apps/public_api/serializers/personal_notification_rules.py +++ b/engine/apps/public_api/serializers/personal_notification_rules.py @@ -44,10 +44,13 @@ def create(self, validated_data): if "wait_delay" in validated_data and validated_data["step"] != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) - instance = UserNotificationPolicy.objects.create(**validated_data, user=validated_data.pop("user")) - + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to create method. + # Policies are always created at the end of the list, and then moved to the desired position by _adjust_order. manual_order = validated_data.pop("manual_order") order = validated_data.pop("order", None) + + instance = UserNotificationPolicy.objects.create(**validated_data) + if order is not None: self._adjust_order(instance, manual_order, order) @@ -118,12 +121,20 @@ def _type_to_step_and_notification_channel(cls, rule_type): @staticmethod def _adjust_order(instance, manual_order, order): + # Passing order=-1 means that the policy should be moved to the end of the list. if order == -1: order = instance.max_order() or 0 + # Negative order is not allowed. if order < 0: raise BadRequest(detail="Invalid value for position field") + # manual_order=True is intended for use by Terraform provider only, and is not a documented feature. + # Orders are swapped instead of moved when using Terraform, because Terraform may issue concurrent requests + # to create / update / delete multiple policies. "Move to" operation is not deterministic in this case, and + # final order of policies may be different depending on the order in which requests are processed. On the other + # hand, the result of concurrent "swap" operations is deterministic and does not depend on the order in + # which requests are processed. if manual_order: instance.swap(order) else: @@ -148,6 +159,7 @@ def update(self, instance, validated_data): if "wait_delay" in validated_data and instance.step != UserNotificationPolicy.Step.WAIT: raise exceptions.ValidationError({"duration": "Duration can't be set"}) + # Remove "manual_order" and "order" fields from validated_data, so they are not passed to update method. manual_order = validated_data.pop("manual_order") order = validated_data.pop("order", None) if order is not None: From af50a5ed09cf9088d9d71b7bd7645b0d7bb90224 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 13:40:07 +0100 Subject: [PATCH 08/23] typing + pk name --- engine/apps/base/models/ordered_model.py | 99 +++++++++++++++--------- 1 file changed, 62 insertions(+), 37 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index 2614aab9f5..f616bad1b5 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -1,6 +1,7 @@ import logging import random import time +import typing from functools import wraps from django.db import IntegrityError, OperationalError, connection, models, transaction @@ -11,8 +12,8 @@ # TODO: comments SQL_TO = """ UPDATE `{db_table}` `t1` -JOIN `{db_table}` `t2` on `t2`.`id` = %(id)s -SET `t1`.`order` = IF(`t1`.`id` = `t2`.`id`, null, IF(`t1`.`order` < `t2`.`order`, `t1`.`order` + 1, `t1`.`order` - 1)) +JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s +SET `t1`.`order` = IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, null, IF(`t1`.`order` < `t2`.`order`, `t1`.`order` + 1, `t1`.`order` - 1)) WHERE {ordering_condition} AND `t2`.`order` != %(order)s AND `t1`.`order` >= IF(`t2`.`order` > %(order)s, %(order)s, `t2`.`order`) @@ -22,16 +23,16 @@ SQL_SWAP = """ UPDATE `{db_table}` `t1` -JOIN `{db_table}` `t2` on `t2`.`id` = %(id)s -SET `t1`.`order` = IF(`t1`.`id` = `t2`.`id`, null, `t2`.`order`) +JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s +SET `t1`.`order` = IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, null, `t2`.`order`) WHERE {ordering_condition} AND `t2`.`order` != %(order)s -AND (`t1`.`id` = `t2`.`id` OR `t1`.`order` = %(order)s) -ORDER BY IF(`t1`.`id` = `t2`.`id`, 0, 1) ASC +AND (`t1`.`{pk_name}` = `t2`.`{pk_name}` OR `t1`.`order` = %(order)s) +ORDER BY IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, 0, 1) ASC """ -def _retry(exc, max_attempts=15): +def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 15) -> typing.Callable: def _retry_with_params(f): @wraps(f) def wrapper(*args, **kwargs): @@ -55,18 +56,22 @@ def wrapper(*args, **kwargs): class OrderedModel(models.Model): """ This class is intended to be used as a mixin for models that need to be ordered. - - Operations: - - create: TODO - - delete: TODO - - move to: TODO - - move to index: TODO - - swap: TODO - - get next: TODO + It's similar to django-ordered-model: https://github.com/django-ordered-model/django-ordered-model. + The key difference of this implementation is that it allows orders to be unique at the database level and + is designed to work correctly under concurrent load. + + Example usage: + class Step(OrderedModel): + user = models.ForeignKey(User, on_delete=models.CASCADE) + order_with_respect_to = ["user_id"] # steps are ordered per user + + class Meta: + ordering = ["order"] # to make queryset ordering correct and consistent + unique_together = ["user_id", "order"] # orders are unique per user at the database level """ - order: int = models.PositiveIntegerField(editable=False, db_index=True, null=True) - order_with_respect_to = [] + order = models.PositiveIntegerField(editable=False, db_index=True, null=True) + order_with_respect_to: list[str] = [] class Meta: abstract = True @@ -75,70 +80,90 @@ class Meta: models.UniqueConstraint(fields=["order"], name="unique_order"), ] - def save(self, *args, **kwargs): + def save(self, *args, **kwargs) -> None: if self.order is None: self._save_no_order_provided() else: super().save() @_retry(OperationalError) - def delete(self, using=None, keep_parents=False): - super().delete(using=using, keep_parents=keep_parents) + def delete(self, *args, **kwargs) -> None: + super().delete(*args, **kwargs) @_retry((IntegrityError, OperationalError)) - def _save_no_order_provided(self): + def _save_no_order_provided(self) -> None: + """ + TODO: how this works and why it's needed + """ + max_order = self.max_order() self.order = max_order + 1 if max_order is not None else 0 super().save() @_retry((IntegrityError, OperationalError)) - def to(self, order): + def to(self, order: int) -> None: + """ + TODO: how this works and why it's needed + """ + if order is None or order < 0: raise ValueError("Order must be a positive integer.") - sql = SQL_TO.format(db_table=self._meta.db_table, ordering_condition=self._ordering_condition_sql) - params = {"id": self.id, "order": order, **self._ordering_kwargs} + sql = SQL_TO.format( + db_table=self._meta.db_table, pk_name=self._meta.pk.name, ordering_condition=self._ordering_condition_sql + ) + params = {"pk": self.pk, "order": order, **self._ordering_params} with transaction.atomic(): with connection.cursor() as cursor: cursor.execute(sql, params) self._meta.model.objects.filter(pk=self.pk).update(order=order) - self.refresh_from_db() + self.refresh_from_db(fields=["order"]) - def to_index(self, index): + def to_index(self, index: int) -> None: + """ + Might be prone to race conditions. + """ order = self._get_ordering_queryset().values_list("order", flat=True)[index] self.to(order) @_retry((IntegrityError, OperationalError)) - def swap(self, order): + def swap(self, order: int) -> None: + """ + TODO: how this works and why it's needed + """ + if order is None or order < 0: raise ValueError("Order must be a positive integer.") - sql = SQL_SWAP.format(db_table=self._meta.db_table, ordering_condition=self._ordering_condition_sql) - params = {"id": self.id, "order": order, **self._ordering_kwargs} + sql = SQL_SWAP.format( + db_table=self._meta.db_table, pk_name=self._meta.pk.name, ordering_condition=self._ordering_condition_sql + ) + params = {"pk": self.pk, "order": order, **self._ordering_params} with transaction.atomic(): with connection.cursor() as cursor: cursor.execute(sql, params) self._meta.model.objects.filter(pk=self.pk).update(order=order) - self.refresh_from_db() + self.refresh_from_db(fields=["order"]) - def next(self): + def next(self) -> models.Model | None: return self._get_ordering_queryset().filter(order__gt=self.order).first() - def max_order(self): + def max_order(self) -> int | None: return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] + def _get_ordering_queryset(self) -> models.QuerySet: + return self._meta.model.objects.filter(**self._ordering_params) + @property - def _ordering_kwargs(self): + def _ordering_params(self) -> dict[str, typing.Any]: return {field: getattr(self, field) for field in self.order_with_respect_to} - def _get_ordering_queryset(self): - return self._meta.model.objects.filter(**self._ordering_kwargs) - @property - def _ordering_condition_sql(self): + def _ordering_condition_sql(self) -> str: + # This doesn't insert actual values into the query, but rather uses placeholders to avoid SQL injections. ordering_parts = ["`t1`.`{0}` = %({0})s".format(field) for field in self.order_with_respect_to] return " AND ".join(ordering_parts) From 0e989a56fdd122d726e4cc5ad98eeb3f8ef1a4f2 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 16:11:37 +0100 Subject: [PATCH 09/23] comment --- engine/apps/base/models/ordered_model.py | 26 +++++++++--------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index f616bad1b5..04f6ef073a 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -9,7 +9,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# TODO: comments +# Update object's order to NULL and shift other objects' orders accordingly in a single SQL query. SQL_TO = """ UPDATE `{db_table}` `t1` JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s @@ -21,6 +21,7 @@ ORDER BY IF(`t1`.`order` <= `t2`.`order`, `t1`.`order`, null) DESC, IF(`t1`.`order` >= `t2`.`order`, `t1`.`order`, null) ASC """ +# Update object's order to NULL and set the other object's order to specified value in a single SQL query. SQL_SWAP = """ UPDATE `{db_table}` `t1` JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s @@ -60,6 +61,12 @@ class OrderedModel(models.Model): The key difference of this implementation is that it allows orders to be unique at the database level and is designed to work correctly under concurrent load. + Notable differences compared to django-ordered-model: + - order can be unique at the database level; + - order can temporarily be set to NULL while performing moving operations; + - instance.delete() only deletes the instance, and doesn't shift other instances' orders; + - some methods are not implemented because they're not used in the codebase; + Example usage: class Step(OrderedModel): user = models.ForeignKey(User, on_delete=models.CASCADE) @@ -68,6 +75,8 @@ class Step(OrderedModel): class Meta: ordering = ["order"] # to make queryset ordering correct and consistent unique_together = ["user_id", "order"] # orders are unique per user at the database level + + It's possible for orders to be non-sequential, e.g. order sequence [100, 150, 400] is totally possible and valid. """ order = models.PositiveIntegerField(editable=False, db_index=True, null=True) @@ -92,20 +101,12 @@ def delete(self, *args, **kwargs) -> None: @_retry((IntegrityError, OperationalError)) def _save_no_order_provided(self) -> None: - """ - TODO: how this works and why it's needed - """ - max_order = self.max_order() self.order = max_order + 1 if max_order is not None else 0 super().save() @_retry((IntegrityError, OperationalError)) def to(self, order: int) -> None: - """ - TODO: how this works and why it's needed - """ - if order is None or order < 0: raise ValueError("Order must be a positive integer.") @@ -122,18 +123,11 @@ def to(self, order: int) -> None: self.refresh_from_db(fields=["order"]) def to_index(self, index: int) -> None: - """ - Might be prone to race conditions. - """ order = self._get_ordering_queryset().values_list("order", flat=True)[index] self.to(order) @_retry((IntegrityError, OperationalError)) def swap(self, order: int) -> None: - """ - TODO: how this works and why it's needed - """ - if order is None or order < 0: raise ValueError("Order must be a positive integer.") From 1c7aab9504d7b830e195b91dc32ced1323434f67 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 16:16:16 +0100 Subject: [PATCH 10/23] comment --- engine/apps/base/models/ordered_model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index 04f6ef073a..c83a539910 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -1,4 +1,3 @@ -import logging import random import time import typing @@ -6,9 +5,6 @@ from django.db import IntegrityError, OperationalError, connection, models, transaction -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - # Update object's order to NULL and shift other objects' orders accordingly in a single SQL query. SQL_TO = """ UPDATE `{db_table}` `t1` @@ -42,11 +38,9 @@ def wrapper(*args, **kwargs): try: return f(*args, **kwargs) except exc: - logger.debug(f"IntegrityError occurred in {f.__qualname__}. Retrying...") if attempts == max_attempts - 1: raise attempts += 1 - # double the sleep time each time and add some jitter time.sleep(random.random()) return wrapper From b93cf703169ab11e524820d21c14709343630627 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 16:47:08 +0100 Subject: [PATCH 11/23] --no-migrations --- engine/tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/tox.ini b/engine/tox.ini index 7cabc843d8..330b6eb9d6 100644 --- a/engine/tox.ini +++ b/engine/tox.ini @@ -9,6 +9,6 @@ banned-modules = [pytest] # https://pytest-django.readthedocs.io/en/latest/configuring_django.html#order-of-choosing-settings # https://pytest-django.readthedocs.io/en/latest/database.html -addopts = --color=yes --showlocals +addopts = --no-migrations --color=yes --showlocals # https://pytest-django.readthedocs.io/en/latest/faq.html#my-tests-are-not-being-found-why python_files = tests.py test_*.py *_tests.py From 637b979d00659548cb9c28fe0736d6328045c104 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Mon, 19 Jun 2023 17:08:30 +0100 Subject: [PATCH 12/23] use connection.ops.quote_name --- engine/apps/base/models/ordered_model.py | 54 +++++++++++++----------- 1 file changed, 30 insertions(+), 24 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index c83a539910..6b5969d356 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -7,25 +7,25 @@ # Update object's order to NULL and shift other objects' orders accordingly in a single SQL query. SQL_TO = """ -UPDATE `{db_table}` `t1` -JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s -SET `t1`.`order` = IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, null, IF(`t1`.`order` < `t2`.`order`, `t1`.`order` + 1, `t1`.`order` - 1)) +UPDATE {db_table} {t1} +JOIN {db_table} {t2} ON {t2}.{pk_name} = %(pk)s +SET {t1}.{order} = IF({t1}.{pk_name} = {t2}.{pk_name}, null, IF({t1}.{order} < {t2}.{order}, {t1}.{order} + 1, {t1}.{order} - 1)) WHERE {ordering_condition} -AND `t2`.`order` != %(order)s -AND `t1`.`order` >= IF(`t2`.`order` > %(order)s, %(order)s, `t2`.`order`) -AND `t1`.`order` <= IF(`t2`.`order` > %(order)s, `t2`.`order`, %(order)s) -ORDER BY IF(`t1`.`order` <= `t2`.`order`, `t1`.`order`, null) DESC, IF(`t1`.`order` >= `t2`.`order`, `t1`.`order`, null) ASC +AND {t2}.{order} != %(order)s +AND {t1}.{order} >= IF({t2}.{order} > %(order)s, %(order)s, {t2}.{order}) +AND {t1}.{order} <= IF({t2}.{order} > %(order)s, {t2}.{order}, %(order)s) +ORDER BY IF({t1}.{order} <= {t2}.{order}, {t1}.{order}, null) DESC, IF({t1}.{order} >= {t2}.{order}, {t1}.{order}, null) ASC """ # Update object's order to NULL and set the other object's order to specified value in a single SQL query. SQL_SWAP = """ -UPDATE `{db_table}` `t1` -JOIN `{db_table}` `t2` ON `t2`.`{pk_name}` = %(pk)s -SET `t1`.`order` = IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, null, `t2`.`order`) +UPDATE {db_table} {t1} +JOIN {db_table} {t2} ON {t2}.{pk_name} = %(pk)s +SET {t1}.{order} = IF({t1}.{pk_name} = {t2}.{pk_name}, null, {t2}.{order}) WHERE {ordering_condition} -AND `t2`.`order` != %(order)s -AND (`t1`.`{pk_name}` = `t2`.`{pk_name}` OR `t1`.`order` = %(order)s) -ORDER BY IF(`t1`.`{pk_name}` = `t2`.`{pk_name}`, 0, 1) ASC +AND {t2}.{order} != %(order)s +AND ({t1}.{pk_name} = {t2}.{pk_name} OR {t1}.{order} = %(order)s) +ORDER BY IF({t1}.{pk_name} = {t2}.{pk_name}, 0, 1) ASC """ @@ -104,9 +104,7 @@ def to(self, order: int) -> None: if order is None or order < 0: raise ValueError("Order must be a positive integer.") - sql = SQL_TO.format( - db_table=self._meta.db_table, pk_name=self._meta.pk.name, ordering_condition=self._ordering_condition_sql - ) + sql = self._format_sql(SQL_TO) params = {"pk": self.pk, "order": order, **self._ordering_params} with transaction.atomic(): @@ -125,9 +123,7 @@ def swap(self, order: int) -> None: if order is None or order < 0: raise ValueError("Order must be a positive integer.") - sql = SQL_SWAP.format( - db_table=self._meta.db_table, pk_name=self._meta.pk.name, ordering_condition=self._ordering_condition_sql - ) + sql = self._format_sql(SQL_SWAP) params = {"pk": self.pk, "order": order, **self._ordering_params} with transaction.atomic(): @@ -150,8 +146,18 @@ def _get_ordering_queryset(self) -> models.QuerySet: def _ordering_params(self) -> dict[str, typing.Any]: return {field: getattr(self, field) for field in self.order_with_respect_to} - @property - def _ordering_condition_sql(self) -> str: - # This doesn't insert actual values into the query, but rather uses placeholders to avoid SQL injections. - ordering_parts = ["`t1`.`{0}` = %({0})s".format(field) for field in self.order_with_respect_to] - return " AND ".join(ordering_parts) + def _format_sql(self, sql): + ordering_parts = [ + "{t1}.{field} = %({field})s".format(t1=connection.ops.quote_name("t1"), field=field) + for field in self.order_with_respect_to + ] + ordering_condition = " AND ".join(ordering_parts) + + return sql.format( + t1=connection.ops.quote_name("t1"), + t2=connection.ops.quote_name("t2"), + order=connection.ops.quote_name("order"), + db_table=connection.ops.quote_name(self._meta.db_table), + pk_name=connection.ops.quote_name(self._meta.pk.name), + ordering_condition=ordering_condition, + ) From e8cf3d362d399de6a8118dfd852a9dcf67a950e0 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 14:45:25 +0100 Subject: [PATCH 13/23] remove raw SQL --- engine/apps/base/models/ordered_model.py | 242 ++++++++++++++++------- 1 file changed, 170 insertions(+), 72 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index 6b5969d356..3d0dc46e0e 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -3,33 +3,15 @@ import typing from functools import wraps -from django.db import IntegrityError, OperationalError, connection, models, transaction - -# Update object's order to NULL and shift other objects' orders accordingly in a single SQL query. -SQL_TO = """ -UPDATE {db_table} {t1} -JOIN {db_table} {t2} ON {t2}.{pk_name} = %(pk)s -SET {t1}.{order} = IF({t1}.{pk_name} = {t2}.{pk_name}, null, IF({t1}.{order} < {t2}.{order}, {t1}.{order} + 1, {t1}.{order} - 1)) -WHERE {ordering_condition} -AND {t2}.{order} != %(order)s -AND {t1}.{order} >= IF({t2}.{order} > %(order)s, %(order)s, {t2}.{order}) -AND {t1}.{order} <= IF({t2}.{order} > %(order)s, {t2}.{order}, %(order)s) -ORDER BY IF({t1}.{order} <= {t2}.{order}, {t1}.{order}, null) DESC, IF({t1}.{order} >= {t2}.{order}, {t1}.{order}, null) ASC -""" - -# Update object's order to NULL and set the other object's order to specified value in a single SQL query. -SQL_SWAP = """ -UPDATE {db_table} {t1} -JOIN {db_table} {t2} ON {t2}.{pk_name} = %(pk)s -SET {t1}.{order} = IF({t1}.{pk_name} = {t2}.{pk_name}, null, {t2}.{order}) -WHERE {ordering_condition} -AND {t2}.{order} != %(order)s -AND ({t1}.{pk_name} = {t2}.{pk_name} OR {t1}.{order} = %(order)s) -ORDER BY IF({t1}.{pk_name} = {t2}.{pk_name}, 0, 1) ASC -""" +from django.db import IntegrityError, OperationalError, models, transaction +from django.db.models import Case, F, Value, When def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 15) -> typing.Callable: + """ + A utility decorator for retrying a function on a given exception(s) up to max_attempts times. + """ + def _retry_with_params(f): @wraps(f) def wrapper(*args, **kwargs): @@ -48,6 +30,9 @@ def wrapper(*args, **kwargs): return _retry_with_params +Self = typing.TypeVar("Self", bound="OrderedModel") + + class OrderedModel(models.Model): """ This class is intended to be used as a mixin for models that need to be ordered. @@ -89,75 +74,188 @@ def save(self, *args, **kwargs) -> None: else: super().save() - @_retry(OperationalError) + @_retry(OperationalError) # retry on deadlock def delete(self, *args, **kwargs) -> None: super().delete(*args, **kwargs) @_retry((IntegrityError, OperationalError)) def _save_no_order_provided(self) -> None: + """ + Save self to DB without an order provided (e.g on creation). + Order is set to the next available order, or 0 if there are no other instances. + Example: + a = OrderedModel.objects.create() + b = OrderedModel.objects.create() + c = OrderedModel.objects.create(order=10) + d = OrderedModel.objects.create() + + assert (a.order, b.order, c.order, d.order) == (0, 1, 10, 11) + """ max_order = self.max_order() self.order = max_order + 1 if max_order is not None else 0 super().save() - @_retry((IntegrityError, OperationalError)) + @_retry(OperationalError) # retry on deadlock def to(self, order: int) -> None: - if order is None or order < 0: - raise ValueError("Order must be a positive integer.") - - sql = self._format_sql(SQL_TO) - params = {"pk": self.pk, "order": order, **self._ordering_params} - + """ + Move self to a given order, adjusting other instances' orders if necessary. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + + a.to(3) # move the first element to the last order + assert (a.order, b.order, c.order) == (3, 1, 2) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(order) with transaction.atomic(): - with connection.cursor() as cursor: - cursor.execute(sql, params) - self._meta.model.objects.filter(pk=self.pk).update(order=order) - - self.refresh_from_db(fields=["order"]) + instances = self._lock_ordering_queryset() + self._move_instances_to_order(instances, order) + @_retry(OperationalError) # retry on deadlock def to_index(self, index: int) -> None: - order = self._get_ordering_queryset().values_list("order", flat=True)[index] - self.to(order) - - @_retry((IntegrityError, OperationalError)) + """ + Move self to a given index, adjusting other instances' orders if necessary. + Similar with to(), but accepts an index instead of an order. + This might be handy as orders might be non-sequential, but most clients assume that they are sequential. + + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=5) + c = OrderedModel(order=10) + + a.to_index(2) # move the first element to the second index (where c is) + assert (a.order, b.order, c.order) == (10, 4, 9) # [a, b, c] -> [b, c, a] + """ + self._validate_positive_integer(index) + with transaction.atomic(): + instances = self._lock_ordering_queryset() + order = instances[index].order # get order of the instance at the given index + self._move_instances_to_order(instances, order) + + def _move_instances_to_order(self, instances: list[Self], order: int) -> None: + """ + Helper method for moving self to a given order, adjusting other instances' orders if necessary. + Must be called within a transaction that locks the ordering queryset. + """ + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Figure out instances that need to be moved. + if self.order < order: + instances_to_move = [ + instance + for instance in instances + if instance.order is not None and self.order < instance.order <= order + ] + else: + instances_to_move = [ + instance + for instance in instances + if instance.order is not None and order <= instance.order < self.order + ] + + # Temporarily set self.order to NULL and update other instances' orders in a single SQL command. + if instances_to_move: + order_by = "order" if self.order < order else "-order" + order_delta = -1 if self.order < order else 1 + self._manager.filter(pk__in=[self.pk] + [instance.pk for instance in instances_to_move]).order_by( + order_by + ).update( + order=Case( + When(pk=self.pk, then=Value(None)), + default=F("order") + order_delta, + ) + ) + + # Update self.order from NULL to the correct value. + self.order = order + self.save(update_fields=["order"]) + + @_retry(OperationalError) # retry on deadlock def swap(self, order: int) -> None: - if order is None or order < 0: - raise ValueError("Order must be a positive integer.") - - sql = self._format_sql(SQL_SWAP) - params = {"pk": self.pk, "order": order, **self._ordering_params} - + """ + Swap self with an instance at a given order. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + c = OrderedModel(order=3) + d = OrderedModel(order=4) + + a.swap(4) # swap the first element with the last element + assert (a.order, b.order, c.order, d.order) == (4, 2, 3, 1) # [a, b, c, d] -> [d, b, c, a] + """ + self._validate_positive_integer(order) with transaction.atomic(): - with connection.cursor() as cursor: - cursor.execute(sql, params) - self._meta.model.objects.filter(pk=self.pk).update(order=order) - - self.refresh_from_db(fields=["order"]) - - def next(self) -> models.Model | None: + instances = self._lock_ordering_queryset() + + # Get the up-to-date instance from the database, because it might have been updated by another transaction. + try: + _self = next(instance for instance in instances if instance.pk == self.pk) + self.order = _self.order + assert self.order is not None + except StopIteration: + raise self.DoesNotExist() + + # If the order is already correct, do nothing. + if self.order == order: + return + + # Get the instance to swap with. + try: + other = next(instance for instance in instances if instance.order == order) + except StopIteration: + other = None + + # Temporarily set self.order to NULL and update the other instance's order in a single SQL command. + if other: + order_by = "order" if self.order < order else "-order" + self._manager.filter(pk__in=[self.pk, other.pk]).order_by(order_by).update( + order=Case( + When(pk=self.pk, then=Value(None)), + default=Value(self.order), + ) + ) + + # Update self.order from NULL to the correct value. + self.order = order + self.save(update_fields=["order"]) + + def next(self) -> Self | None: return self._get_ordering_queryset().filter(order__gt=self.order).first() def max_order(self) -> int | None: return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] - def _get_ordering_queryset(self) -> models.QuerySet: - return self._meta.model.objects.filter(**self._ordering_params) + @staticmethod + def _validate_positive_integer(value: int | None) -> None: + if value is None or not isinstance(value, int) or value < 0: + raise ValueError("Value must be a positive integer.") + + def _get_ordering_queryset(self) -> models.QuerySet[Self]: + return self._manager.filter(**self._ordering_params) + + def _lock_ordering_queryset(self) -> list[Self]: + """ + Locks the ordering queryset with SELECT FOR UPDATE and returns the queryset as a list. + This allows to prevent concurrent updates from different transactions. + """ + return list(self._get_ordering_queryset().select_for_update().only("pk", "order")) + + @property + def _manager(self): + return self._meta.default_manager @property def _ordering_params(self) -> dict[str, typing.Any]: return {field: getattr(self, field) for field in self.order_with_respect_to} - - def _format_sql(self, sql): - ordering_parts = [ - "{t1}.{field} = %({field})s".format(t1=connection.ops.quote_name("t1"), field=field) - for field in self.order_with_respect_to - ] - ordering_condition = " AND ".join(ordering_parts) - - return sql.format( - t1=connection.ops.quote_name("t1"), - t2=connection.ops.quote_name("t2"), - order=connection.ops.quote_name("order"), - db_table=connection.ops.quote_name(self._meta.db_table), - pk_name=connection.ops.quote_name(self._meta.pk.name), - ordering_condition=ordering_condition, - ) From b73ff95b8f25bb4d40efa4d5d8f1eb730d2b2365 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 14:57:16 +0100 Subject: [PATCH 14/23] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f1473404c..048518db8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Change mobile shift notifications title and subtitle by @imtoori ([#2288](https://github.com/grafana/oncall/pull/2288)) +## Fixed + +- Fix duplicate orders for user notification policies by @vadimkerr ([#2278](https://github.com/grafana/oncall/pull/2278)) + ## v1.2.45 (2023-06-19) ### Changed From 00fb5f24c916a7ebac1ef7c48aadc7f65d809b6d Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 15:00:44 +0100 Subject: [PATCH 15/23] ignore migration --- engine/apps/base/migrations/0004_auto_20230616_1510.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/engine/apps/base/migrations/0004_auto_20230616_1510.py b/engine/apps/base/migrations/0004_auto_20230616_1510.py index 4704fe2e2e..09878106b9 100644 --- a/engine/apps/base/migrations/0004_auto_20230616_1510.py +++ b/engine/apps/base/migrations/0004_auto_20230616_1510.py @@ -4,6 +4,7 @@ from django.db.models import Count from common.database import get_random_readonly_database_key_if_present_otherwise_default +import django_migration_linter as linter def fix_duplicate_order_user_notification_policy(apps, schema_editor): @@ -35,6 +36,7 @@ class Migration(migrations.Migration): ] operations = [ + linter.IgnoreMigration(), # adding a unique constraint after fixing duplicates should be fine migrations.AlterField( model_name='usernotificationpolicy', name='order', From c790132a3a51da9b71c3286529100dc634211112 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 15:12:59 +0100 Subject: [PATCH 16/23] create_default_policies_for_user --- .../apps/base/models/user_notification_policy.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/engine/apps/base/models/user_notification_policy.py b/engine/apps/base/models/user_notification_policy.py index 11a9e1b996..7646e421e0 100644 --- a/engine/apps/base/models/user_notification_policy.py +++ b/engine/apps/base/models/user_notification_policy.py @@ -6,7 +6,7 @@ from django.core.exceptions import ValidationError from django.core.validators import MinLengthValidator from django.db import models -from django.db.models import Q, QuerySet +from django.db.models import Q from apps.base.messaging import get_messaging_backends from apps.base.models.ordered_model import OrderedModel @@ -67,9 +67,11 @@ def validate_channel_choice(value): class UserNotificationPolicyQuerySet(models.QuerySet): - def create_default_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": - model = self.model + def create_default_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=False).exists(): + return + model = self.model policies_to_create = ( model( user=user, @@ -82,11 +84,12 @@ def create_default_policies_for_user(self, user: User) -> "QuerySet[UserNotifica ) super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=False) - def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotificationPolicy]": - model = self.model + def create_important_policies_for_user(self, user: User) -> None: + if user.notification_policies.filter(important=True).exists(): + return + model = self.model policies_to_create = ( model( user=user, @@ -98,7 +101,6 @@ def create_important_policies_for_user(self, user: User) -> "QuerySet[UserNotifi ) super().bulk_create(policies_to_create) - return user.notification_policies.filter(important=True) class UserNotificationPolicy(OrderedModel): From 5bca007c958d9d8577d1a5d636ccd92140c4a6ec Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 15:17:27 +0100 Subject: [PATCH 17/23] create_default_policies_for_user --- engine/apps/base/models/user_notification_policy.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/engine/apps/base/models/user_notification_policy.py b/engine/apps/base/models/user_notification_policy.py index 7646e421e0..7e93af6375 100644 --- a/engine/apps/base/models/user_notification_policy.py +++ b/engine/apps/base/models/user_notification_policy.py @@ -5,7 +5,7 @@ from django.conf import settings from django.core.exceptions import ValidationError from django.core.validators import MinLengthValidator -from django.db import models +from django.db import IntegrityError, models from django.db.models import Q from apps.base.messaging import get_messaging_backends @@ -83,7 +83,10 @@ def create_default_policies_for_user(self, user: User) -> None: model(user=user, step=model.Step.NOTIFY, notify_by=model.NotificationChannel.PHONE_CALL, order=2), ) - super().bulk_create(policies_to_create) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass def create_important_policies_for_user(self, user: User) -> None: if user.notification_policies.filter(important=True).exists(): @@ -100,7 +103,10 @@ def create_important_policies_for_user(self, user: User) -> None: ), ) - super().bulk_create(policies_to_create) + try: + super().bulk_create(policies_to_create) + except IntegrityError: + pass class UserNotificationPolicy(OrderedModel): From 93c55883cc4d7555997f631e4d9bac6e11f50dc7 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 15:25:44 +0100 Subject: [PATCH 18/23] test --- engine/apps/base/tests/test_ordered_model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py index ea3f570eb4..0e541ad25d 100644 --- a/engine/apps/base/tests/test_ordered_model.py +++ b/engine/apps/base/tests/test_ordered_model.py @@ -206,6 +206,25 @@ def _ids(indices): assert _orders_are_sequential() +@pytest.mark.django_db +def test_order_with_respect_to_isolation(): + instances = [TestOrderedModel.objects.create(test_field="test") for _ in range(5)] + other_instances = [TestOrderedModel.objects.create(test_field="test1") for _ in range(5)] + + assert [i.order for i in instances] == [0, 1, 2, 3, 4] + assert [i.order for i in other_instances] == [0, 1, 2, 3, 4] + + instances[0].to(8) + instances[1].swap(7) + + for idx, instance in enumerate(other_instances): + instance.refresh_from_db() + assert instance.order == idx + + with pytest.raises(IndexError): + instances[0].to_index(6) + + # Tests below are for checking that concurrent operations are performed correctly. # They are skipped by default because they might take a lot of time to run. # It could be useful to run them manually when making changes to the code, making sure From cb4d5bab64a2afa125b693d0bde2ec99f4ff7fb1 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 16:20:31 +0100 Subject: [PATCH 19/23] lock queryset --- engine/apps/base/models/ordered_model.py | 103 +++++++++++++---------- 1 file changed, 57 insertions(+), 46 deletions(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index 3d0dc46e0e..cfcf922ef8 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -4,7 +4,6 @@ from functools import wraps from django.db import IntegrityError, OperationalError, models, transaction -from django.db.models import Case, F, Value, When def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 15) -> typing.Callable: @@ -76,9 +75,12 @@ def save(self, *args, **kwargs) -> None: @_retry(OperationalError) # retry on deadlock def delete(self, *args, **kwargs) -> None: - super().delete(*args, **kwargs) + with transaction.atomic(): + # lock ordering queryset to prevent deleting instances that are used by other transactions + self._lock_ordering_queryset() + super().delete(*args, **kwargs) - @_retry((IntegrityError, OperationalError)) + @_retry((IntegrityError, OperationalError)) # retry on duplicate order or deadlock def _save_no_order_provided(self) -> None: """ Save self to DB without an order provided (e.g on creation). @@ -91,9 +93,11 @@ def _save_no_order_provided(self) -> None: assert (a.order, b.order, c.order, d.order) == (0, 1, 10, 11) """ - max_order = self.max_order() - self.order = max_order + 1 if max_order is not None else 0 - super().save() + with transaction.atomic(): + instances = self._lock_ordering_queryset() # lock ordering queryset to prevent reading inconsistent data + max_order = max(instance.order for instance in instances) if instances else -1 + self.order = max_order + 1 + super().save() @_retry(OperationalError) # retry on deadlock def to(self, order: int) -> None: @@ -151,36 +155,32 @@ def _move_instances_to_order(self, instances: list[Self], order: int) -> None: if self.order == order: return - # Figure out instances that need to be moved. + # Figure out instances that need to be moved and their new orders. + instances_to_move = [] if self.order < order: - instances_to_move = [ - instance - for instance in instances - if instance.order is not None and self.order < instance.order <= order - ] + for instance in instances: + if instance.order is not None and self.order < instance.order <= order: + instance.order -= 1 + instances_to_move.append(instance) else: - instances_to_move = [ - instance - for instance in instances - if instance.order is not None and order <= instance.order < self.order - ] - - # Temporarily set self.order to NULL and update other instances' orders in a single SQL command. - if instances_to_move: - order_by = "order" if self.order < order else "-order" - order_delta = -1 if self.order < order else 1 - self._manager.filter(pk__in=[self.pk] + [instance.pk for instance in instances_to_move]).order_by( - order_by - ).update( - order=Case( - When(pk=self.pk, then=Value(None)), - default=F("order") + order_delta, - ) - ) - - # Update self.order from NULL to the correct value. + for instance in instances: + if instance.order is not None and order <= instance.order < self.order: + instance.order += 1 + instances_to_move.append(instance) + + # If there's nothing to move, just update self.order and return. + if not instances_to_move: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + pks = [self.pk] + [instance.pk for instance in instances_to_move] + self._manager.filter(pk__in=pks).update(order=None) + + # Update orders to appropriate unique values. self.order = order - self.save(update_fields=["order"]) + self._manager.filter(pk__in=pks).bulk_update([self] + instances_to_move, fields=["order"]) @_retry(OperationalError) # retry on deadlock def swap(self, order: int) -> None: @@ -217,24 +217,35 @@ def swap(self, order: int) -> None: except StopIteration: other = None - # Temporarily set self.order to NULL and update the other instance's order in a single SQL command. - if other: - order_by = "order" if self.order < order else "-order" - self._manager.filter(pk__in=[self.pk, other.pk]).order_by(order_by).update( - order=Case( - When(pk=self.pk, then=Value(None)), - default=Value(self.order), - ) - ) - - # Update self.order from NULL to the correct value. - self.order = order - self.save(update_fields=["order"]) + # If there's no instance to swap with, just update self.order and return. + if not other: + self.order = order + self.save(update_fields=["order"]) + return + + # Temporarily set order values to NULL to avoid unique constraint violations. + self._manager.filter(pk__in=[self.pk, other.pk]).update(order=None) + + # Swap order values. + self.order, other.order = other.order, self.order + self._manager.filter(pk__in=[self.pk, other.pk]).bulk_update([self, other], fields=["order"]) def next(self) -> Self | None: + """ + Return the next instance in the ordering queryset, or None if there's no next instance. + Example: + a = OrderedModel(order=1) + b = OrderedModel(order=2) + + assert a.next() == b + assert b.next() is None + """ return self._get_ordering_queryset().filter(order__gt=self.order).first() def max_order(self) -> int | None: + """ + Return the maximum order value in the ordering queryset or None if there are no instances. + """ return self._get_ordering_queryset().aggregate(models.Max("order"))["order__max"] @staticmethod From 53a61c143ed71a03dc0f88d5204f3e333f5481b4 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 16:29:52 +0100 Subject: [PATCH 20/23] less retries --- engine/apps/base/models/ordered_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/apps/base/models/ordered_model.py b/engine/apps/base/models/ordered_model.py index cfcf922ef8..b286520de8 100644 --- a/engine/apps/base/models/ordered_model.py +++ b/engine/apps/base/models/ordered_model.py @@ -6,7 +6,7 @@ from django.db import IntegrityError, OperationalError, models, transaction -def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 15) -> typing.Callable: +def _retry(exc: typing.Type[Exception] | tuple[typing.Type[Exception], ...], max_attempts: int = 5) -> typing.Callable: """ A utility decorator for retrying a function on a given exception(s) up to max_attempts times. """ From 004c0952c7602f32f3df2a45d2e43c003e20d98f Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 16:56:18 +0100 Subject: [PATCH 21/23] _adjust_order --- .../serializers/personal_notification_rules.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/engine/apps/public_api/serializers/personal_notification_rules.py b/engine/apps/public_api/serializers/personal_notification_rules.py index 91a119941f..3d0e7df318 100644 --- a/engine/apps/public_api/serializers/personal_notification_rules.py +++ b/engine/apps/public_api/serializers/personal_notification_rules.py @@ -52,7 +52,7 @@ def create(self, validated_data): instance = UserNotificationPolicy.objects.create(**validated_data) if order is not None: - self._adjust_order(instance, manual_order, order) + self._adjust_order(instance, manual_order, order, created=True) return instance @@ -120,10 +120,16 @@ def _type_to_step_and_notification_channel(cls, rule_type): raise exceptions.ValidationError({"type": "Invalid type"}) @staticmethod - def _adjust_order(instance, manual_order, order): + def _adjust_order(instance, manual_order, order, created): # Passing order=-1 means that the policy should be moved to the end of the list. if order == -1: - order = instance.max_order() or 0 + if created: + # The policy was just created, so it is already at the end of the list. + return + + order = instance.max_order() + # max_order() can't be None here because at least one instance exists – the one we are moving. + assert order is not None # Negative order is not allowed. if order < 0: @@ -163,6 +169,6 @@ def update(self, instance, validated_data): manual_order = validated_data.pop("manual_order") order = validated_data.pop("order", None) if order is not None: - self._adjust_order(instance, manual_order, order) + self._adjust_order(instance, manual_order, order, created=False) return super().update(instance, validated_data) From 294a8e6fb2ab0cbc06cb54b388e4a903134f3058 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 17:05:57 +0100 Subject: [PATCH 22/23] more tests --- engine/apps/base/tests/test_ordered_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py index 0e541ad25d..63c8b819bd 100644 --- a/engine/apps/base/tests/test_ordered_model.py +++ b/engine/apps/base/tests/test_ordered_model.py @@ -214,6 +214,9 @@ def test_order_with_respect_to_isolation(): assert [i.order for i in instances] == [0, 1, 2, 3, 4] assert [i.order for i in other_instances] == [0, 1, 2, 3, 4] + assert instances[-1].next() is None + assert instances[-1].max_order() == 4 + instances[0].to(8) instances[1].swap(7) @@ -229,7 +232,7 @@ def test_order_with_respect_to_isolation(): # They are skipped by default because they might take a lot of time to run. # It could be useful to run them manually when making changes to the code, making sure # that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False. -SKIP_CONCURRENT = True +SKIP_CONCURRENT = False @pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests") From e9a73bbe57c40e23be60b0e6e19429093e122de8 Mon Sep 17 00:00:00 2001 From: Vadim Stepanov Date: Tue, 20 Jun 2023 17:06:13 +0100 Subject: [PATCH 23/23] more tests --- engine/apps/base/tests/test_ordered_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/apps/base/tests/test_ordered_model.py b/engine/apps/base/tests/test_ordered_model.py index 63c8b819bd..b312d23891 100644 --- a/engine/apps/base/tests/test_ordered_model.py +++ b/engine/apps/base/tests/test_ordered_model.py @@ -232,7 +232,7 @@ def test_order_with_respect_to_isolation(): # They are skipped by default because they might take a lot of time to run. # It could be useful to run them manually when making changes to the code, making sure # that the changes don't break concurrent operations. To run the tests, set SKIP_CONCURRENT to False. -SKIP_CONCURRENT = False +SKIP_CONCURRENT = True @pytest.mark.skipif(SKIP_CONCURRENT, reason="OrderedModel concurrent tests are skipped to speed up tests")