From c0ed927daa5dd4e2efce8458a2f18b371e485db5 Mon Sep 17 00:00:00 2001 From: Scott Kyle Date: Wed, 22 Jul 2015 15:51:14 -0700 Subject: [PATCH] Fix bug where order is wrong after adding objects This problem was caused by using the count of the objects in the intermediary join table as the starting sort value. This was insufficient because deleting objects would not update the sort values. So imagine with a fresh database, you added 5 objects then deleted 3. The max sort value would be 4 (the 5th object would use the count prior to adding, hence 4), but when you add a new object, its sort value would be 2 (current count of intermediary objects), NOT 5! This fixes that by querying for the maximum sort value in use by the source object's intermediary objects and using that as the starting sort value. Since bulk_create() was introduced in Django 1.4 and there is a dummy atomic() implementation (that actually was not being used), there is no need to use a different implementation to create intermediary objects on Django 1.5. Also, this ensures the right database is being used for the transaction, which was not the case before. --- sortedm2m/fields.py | 69 ++++++++++++++------------------------------- 1 file changed, 21 insertions(+), 48 deletions(-) diff --git a/sortedm2m/fields.py b/sortedm2m/fields.py index 8c4c1c3..90d1ea1 100644 --- a/sortedm2m/fields.py +++ b/sortedm2m/fields.py @@ -12,7 +12,6 @@ from django.utils.functional import curry from .compat import get_foreignkey_field_kwargs -from .compat import get_model from .compat import get_model_name from .forms import SortedMultipleChoiceField @@ -93,7 +92,7 @@ def _add_items(self, source_field_name, target_field_name, *objs): # *objs - objects to add. Either object instances, or primary keys of object instances. # If there aren't any objects, there is nothing to do. - from django.db.models import Model + from django.db.models import Max, Model if objs: # Django uses a set here, we need to use a list to keep the # correct ordering. @@ -122,7 +121,8 @@ def _add_items(self, source_field_name, target_field_name, *objs): new_ids.append(obj) db = router.db_for_write(self.through, instance=self.instance) - vals = (self.through._default_manager.using(db) + manager = self.through._default_manager.using(db) + vals = (manager .values_list(target_field_name, flat=True) .filter(**{ source_field_name: self._fk_val, @@ -144,27 +144,23 @@ def _add_items(self, source_field_name, target_field_name, *objs): signals.m2m_changed.send(sender=rel.through, action='pre_add', instance=self.instance, reverse=self.reverse, model=self.model, pk_set=new_ids_set, using=db) + # Add the ones that aren't there already - sort_field_name = self.through._sort_field_name - sort_field = self.through._meta.get_field_by_name(sort_field_name)[0] - if django.VERSION < (1, 6): - for obj_id in new_ids: - self.through._default_manager.using(db).create(**{ - '%s_id' % source_field_name: self._fk_val, # Django 1.5 compatibility - '%s_id' % target_field_name: obj_id, - sort_field_name: sort_field.get_default(), + with atomic(using=db): + fk_val = self._fk_val + source_queryset = manager.filter(**{'%s_id' % source_field_name: fk_val}) + sort_field_name = self.through._sort_field_name + sort_value_max = source_queryset.aggregate(max=Max(sort_field_name))['max'] or 0 + + manager.bulk_create([ + self.through(**{ + '%s_id' % source_field_name: fk_val, + '%s_id' % target_field_name: pk, + sort_field_name: sort_value_max + i + 1, }) - else: - with transaction.atomic(): - sort_field_default = sort_field.get_default() - self.through._default_manager.using(db).bulk_create([ - self.through(**{ - '%s_id' % source_field_name: self._fk_val, - '%s_id' % target_field_name: v, - sort_field_name: sort_field_default + i, - }) - for i, v in enumerate(new_ids) - ]) + for i, pk in enumerate(new_ids) + ]) + if self.reverse or source_field_name == self.source_field_name: # Don't send the signal when we are inserting the # duplicate data row for symmetrical reverse entries. @@ -195,9 +191,8 @@ class SortedManyToManyField(ManyToManyField): ''' def __init__(self, to, sorted=True, **kwargs): self.sorted = sorted - self.sort_value_field_name = kwargs.pop( - 'sort_value_field_name', - SORT_VALUE_FIELD_NAME) + self.sort_value_field_name = kwargs.pop('sort_value_field_name', SORT_VALUE_FIELD_NAME) + super(SortedManyToManyField, self).__init__(to, **kwargs) if self.sorted: self.help_text = kwargs.get('help_text', None) @@ -316,31 +311,9 @@ def get_rel_to_model_and_object_name(self, klass): to_object_name = to_model._meta.object_name return to_model, to_object_name - def get_intermediate_model_sort_value_field_default(self, klass): - def default_sort_value(name): - model = get_model(klass._meta.app_label, name) - # Django 1.5 support. - if django.VERSION < (1, 6): - return model._default_manager.count() - else: - from django.db.utils import ProgrammingError, OperationalError - try: - # We need to catch if the model is not yet migrated in the - # database. The default function is still called in this case while - # running the migration. So we mock the return value of 0. - with transaction.atomic(): - return model._default_manager.count() - except (ProgrammingError, OperationalError): - return 0 - - name = self.get_intermediate_model_name(klass) - default_sort_value = curry(default_sort_value, name) - return default_sort_value - def get_intermediate_model_sort_value_field(self, klass): - default_sort_value = self.get_intermediate_model_sort_value_field_default(klass) field_name = self.sort_value_field_name - field = models.IntegerField(default=default_sort_value) + field = models.IntegerField(default=0) return field_name, field def get_intermediate_model_from_field(self, klass):