From 8ea956c4067f166642336a32c2c478a55dc14c56 Mon Sep 17 00:00:00 2001 From: Anders <6058745+ddabble@users.noreply.github.com> Date: Sun, 8 Sep 2024 01:33:12 +0200 Subject: [PATCH] Added a disable_history() context manager This replaces setting the `skip_history_when_saving` attribute on a model instance, which has been deprecated (and replaced with `disable_history()` or the utils mentioned below where possible), as well as having been removed from the docs and tests. (`HistoricalRecords.post_delete()` does not generate deprecation warnings on it, since support for `skip_history_when_saving` while deleting objects is part of the next, unreleased version.) The context manager was inspired by https://github.com/jazzband/django-simple-history/issues/642#issuecomment-922110914 Also added a couple useful utils related to `disable_history()`: `is_history_disabled()` and a `DisableHistoryInfo` dataclass - see their docstrings in `utils.py`. --- CHANGES.rst | 6 + docs/disabling_history.rst | 48 +- simple_history/manager.py | 9 +- simple_history/models.py | 51 +- simple_history/tests/models.py | 5 +- .../tests/tests/test_deprecation.py | 27 +- simple_history/tests/tests/test_manager.py | 43 ++ simple_history/tests/tests/test_models.py | 43 +- simple_history/tests/tests/test_utils.py | 495 +++++++++++++++++- simple_history/tests/tests/utils.py | 28 +- simple_history/utils.py | 165 +++++- 11 files changed, 815 insertions(+), 105 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 0b567c17..5c3ed6da 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,9 @@ Unreleased updating an object (gh-1262) - Added ``delete_without_historical_record()`` to all history-tracked model objects, which complements ``save_without_historical_record()`` (gh-1387) +- Added a ``disable_history()`` context manager, which disables history record creation + while it's active; see usage in the docs under "Disable Creating Historical Records" + (gh-1387) **Breaking changes:** @@ -23,6 +26,9 @@ Unreleased - Deprecated the undocumented ``HistoricalRecords.thread`` - use ``HistoricalRecords.context`` instead. The former attribute will be removed in version 3.10 (gh-1387) +- Deprecated ``skip_history_when_saving`` in favor of the newly added + ``disable_history()`` context manager. The former attribute will be removed in + version 4.0 (gh-1387) **Fixes and improvements:** diff --git a/docs/disabling_history.rst b/docs/disabling_history.rst index 1e3c16f2..adafd9df 100644 --- a/docs/disabling_history.rst +++ b/docs/disabling_history.rst @@ -8,34 +8,40 @@ These methods are automatically added to a model when registering it for history (i.e. defining a ``HistoricalRecords`` manager on the model), and can be called instead of ``save()`` and ``delete()``, respectively. -Setting the ``skip_history_when_saving`` attribute --------------------------------------------------- +Using the ``disable_history()`` context manager +----------------------------------------------- -If you want to save or delete model objects without triggering the creation of any -historical records, you can do the following: +``disable_history()`` has three ways of being called: -.. code-block:: python - - poll.skip_history_when_saving = True - # It applies both when saving... - poll.save() - # ...and when deleting - poll.delete() - # We recommend deleting the attribute afterward - del poll.skip_history_when_saving +#. With no arguments: This will disable all historical record creation + (as if the ``SIMPLE_HISTORY_ENABLED`` setting was set to ``False``; see below) + within the context manager's ``with`` block. +#. With ``only_for_model``: Only disable history creation for the provided model type. +#. With ``instance_predicate``: Only disable history creation for model instances passing + this predicate. -This also works when creating an object, but only when calling ``save()``: +See some examples below: .. code-block:: python - # Note that `Poll.objects.create()` is not called - poll = Poll(question="Why?") - poll.skip_history_when_saving = True - poll.save() - del poll.skip_history_when_saving + from simple_history.utils import disable_history + + # No historical records are created + with disable_history(): + User.objects.create(...) + Poll.objects.create(...) + + # A historical record is only created for the poll + with disable_history(only_for_model=User): + User.objects.create(...) + Poll.objects.create(...) -.. note:: - Historical records will always be created when calling the ``create()`` manager method. + # A historical record is created for the second poll, but not for the first poll + # (remember to check the instance type in the passed function if you expect + # historical records of more than one model to be created inside the `with` block) + with disable_history(instance_predicate=lambda poll: "ignore" in poll.question): + Poll.objects.create(question="ignore this") + Poll.objects.create(question="what's up?") The ``SIMPLE_HISTORY_ENABLED`` setting -------------------------------------- diff --git a/simple_history/manager.py b/simple_history/manager.py index 03f53034..66b2cb08 100644 --- a/simple_history/manager.py +++ b/simple_history/manager.py @@ -1,4 +1,3 @@ -from django.conf import settings from django.db import models from django.db.models import Exists, OuterRef, Q, QuerySet from django.utils import timezone @@ -216,8 +215,9 @@ def bulk_history_create( If called by bulk_update_with_history, use the update boolean and save the history_type accordingly. """ - if not getattr(settings, "SIMPLE_HISTORY_ENABLED", True): - return + info = utils.DisableHistoryInfo.get() + if info.disabled_globally: + return [] history_type = "+" if update: @@ -225,6 +225,9 @@ def bulk_history_create( historical_instances = [] for instance in objs: + if info.disabled_for(instance): + continue + history_user = getattr( instance, "_history_user", diff --git a/simple_history/models.py b/simple_history/models.py index 27ca59b8..d612b00b 100644 --- a/simple_history/models.py +++ b/simple_history/models.py @@ -7,6 +7,7 @@ from typing import ( TYPE_CHECKING, Any, + Callable, ClassVar, Dict, Iterable, @@ -195,18 +196,22 @@ def contribute_to_class(self, cls, name): warnings.warn(msg, UserWarning) def add_extra_methods(self, cls): + def get_instance_predicate( + self: models.Model, + ) -> Callable[[models.Model], bool]: + def predicate(instance: models.Model) -> bool: + return instance == self + + return predicate + def save_without_historical_record(self: models.Model, *args, **kwargs): """ Save the model instance without creating a historical record. Make sure you know what you're doing before using this method. """ - self.skip_history_when_saving = True - try: - ret = self.save(*args, **kwargs) - finally: - del self.skip_history_when_saving - return ret + with utils.disable_history(instance_predicate=get_instance_predicate(self)): + return self.save(*args, **kwargs) def delete_without_historical_record(self: models.Model, *args, **kwargs): """ @@ -214,12 +219,8 @@ def delete_without_historical_record(self: models.Model, *args, **kwargs): Make sure you know what you're doing before using this method. """ - self.skip_history_when_saving = True - try: - ret = self.delete(*args, **kwargs) - finally: - del self.skip_history_when_saving - return ret + with utils.disable_history(instance_predicate=get_instance_predicate(self)): + return self.delete(*args, **kwargs) cls.save_without_historical_record = save_without_historical_record cls.delete_without_historical_record = delete_without_historical_record @@ -682,18 +683,22 @@ def get_meta_options(self, model): def post_save( self, instance: models.Model, created: bool, using: str = None, **kwargs ): - if not getattr(settings, "SIMPLE_HISTORY_ENABLED", True): - return if hasattr(instance, "skip_history_when_saving"): + warnings.warn( + "Setting 'skip_history_when_saving' has been deprecated in favor of the" + " 'disable_history()' context manager." + " Support for the former attribute will be removed in version 4.0.", + DeprecationWarning, + ) + return + if utils.is_history_disabled(instance): return if not kwargs.get("raw", False): self.create_historical_record(instance, created and "+" or "~", using=using) def post_delete(self, instance: models.Model, using: str = None, **kwargs): - if not getattr(settings, "SIMPLE_HISTORY_ENABLED", True): - return - if hasattr(instance, "skip_history_when_saving"): + if utils.is_history_disabled(instance): return if self.cascade_delete_history: @@ -709,10 +714,16 @@ def get_change_reason_for_object(self, instance, history_type, using): """ return utils.get_change_reason_from_object(instance) - def m2m_changed(self, instance, action, attr, pk_set, reverse, **_): - if not getattr(settings, "SIMPLE_HISTORY_ENABLED", True): - return + def m2m_changed(self, instance: models.Model, action: str, **kwargs): if hasattr(instance, "skip_history_when_saving"): + warnings.warn( + "Setting 'skip_history_when_saving' has been deprecated in favor of the" + " 'disable_history()' context manager." + " Support for the former attribute will be removed in version 4.0.", + DeprecationWarning, + ) + return + if utils.is_history_disabled(instance): return if action in ("post_add", "post_remove", "post_clear"): diff --git a/simple_history/tests/models.py b/simple_history/tests/models.py index f35b5cf6..87e840c7 100644 --- a/simple_history/tests/models.py +++ b/simple_history/tests/models.py @@ -5,10 +5,9 @@ from django.conf import settings from django.db import models from django.db.models.deletion import CASCADE -from django.db.models.fields.related import ForeignKey from django.urls import reverse -from simple_history import register +from simple_history import register, utils from simple_history.manager import HistoricalQuerySet, HistoryManager from simple_history.models import HistoricalRecords, HistoricForeignKey @@ -355,7 +354,7 @@ class Person(models.Model): history = HistoricalRecords() def save(self, *args, **kwargs): - if hasattr(self, "skip_history_when_saving"): + if utils.DisableHistoryInfo.get().disabled_for(self): raise RuntimeError("error while saving") else: super().save(*args, **kwargs) diff --git a/simple_history/tests/tests/test_deprecation.py b/simple_history/tests/tests/test_deprecation.py index 1364f8d1..b3f327c0 100644 --- a/simple_history/tests/tests/test_deprecation.py +++ b/simple_history/tests/tests/test_deprecation.py @@ -1,11 +1,14 @@ -import unittest +from django.test import TestCase +from django.utils import timezone from simple_history import __version__ from simple_history.models import HistoricalRecords from simple_history.templatetags.simple_history_admin_list import display_list +from ..models import Place, PollWithManyToMany -class DeprecationWarningTest(unittest.TestCase): + +class DeprecationWarningTest(TestCase): def test__display_list__warns_deprecation(self): with self.assertWarns(DeprecationWarning): display_list({}) @@ -30,3 +33,23 @@ def test__HistoricalRecords_thread__warns_deprecation(self): # DEV: `_DeprecatedThreadDescriptor` and the `thread` attribute of # `HistoricalRecords` should be removed when 3.10 is released self.assertLess(__version__, "3.10") + + def test__skip_history_when_saving__warns_deprecation(self): + place = Place.objects.create(name="Here") + + poll = PollWithManyToMany(question="why?", pub_date=timezone.now()) + poll.skip_history_when_saving = True + with self.assertWarns(DeprecationWarning): + poll.save() + poll.question = "how?" + with self.assertWarns(DeprecationWarning): + poll.save() + with self.assertWarns(DeprecationWarning): + poll.places.add(place) + self.assertEqual(PollWithManyToMany.history.count(), 0) + self.assertEqual(poll.history.count(), 0) + + # DEV: The `if` statements checking for `skip_history_when_saving` (in the + # `post_save()` and `m2m_changed()` methods of `HistoricalRecords`) + # should be removed when 4.0 is released + self.assertLess(__version__, "4.0") diff --git a/simple_history/tests/tests/test_manager.py b/simple_history/tests/tests/test_manager.py index acb9e025..dbbefcd0 100644 --- a/simple_history/tests/tests/test_manager.py +++ b/simple_history/tests/tests/test_manager.py @@ -6,6 +6,7 @@ from django.test import TestCase, override_settings, skipUnlessDBFeature from simple_history.manager import SIMPLE_HISTORY_REVERSE_ATTR_NAME +from simple_history.utils import disable_history from ..models import Choice, Document, Poll, RankedDocument from .utils import HistoricalTestCase @@ -289,6 +290,23 @@ def test_simple_bulk_history_create_without_history_enabled(self): Poll.history.bulk_history_create(self.data) self.assertEqual(Poll.history.count(), 0) + def test_simple_bulk_history_create_with__disable_history(self): + with disable_history(): + Poll.history.bulk_history_create(self.data) + self.assertEqual(Poll.history.count(), 0) + + def test_simple_bulk_history_create_with__disable_history__only_for_model(self): + with disable_history(only_for_model=Poll): + Poll.history.bulk_history_create(self.data) + self.assertEqual(Poll.history.count(), 0) + + def test_simple_bulk_history_create_with__disable_history__instance_predicate(self): + with disable_history(instance_predicate=lambda poll: poll.id == 2): + Poll.history.bulk_history_create(self.data) + self.assertEqual(Poll.history.count(), 3) + historical_poll_ids = sorted(record.id for record in Poll.history.all()) + self.assertListEqual(historical_poll_ids, [1, 3, 4]) + def test_bulk_history_create_with_change_reason(self): for poll in self.data: poll._change_reason = "reason" @@ -412,6 +430,31 @@ def test_simple_bulk_history_create(self): self.assertEqual(created, []) self.assertEqual(Poll.history.count(), 4) + @override_settings(SIMPLE_HISTORY_ENABLED=False) + def test_simple_bulk_history_update_without_history_enabled(self): + Poll.history.bulk_history_create(self.data, update=True) + self.assertEqual(Poll.history.count(), 0) + + def test_simple_bulk_history_update_with__disable_history(self): + with disable_history(): + Poll.history.bulk_history_create(self.data, update=True) + self.assertEqual(Poll.history.count(), 0) + + def test_simple_bulk_history_update_with__disable_history__only_for_model(self): + with disable_history(only_for_model=Poll): + Poll.history.bulk_history_create(self.data, update=True) + self.assertEqual(Poll.history.count(), 0) + + def test_simple_bulk_history_update_with__disable_history__instance_predicate(self): + with disable_history(instance_predicate=lambda poll: poll.id == 2): + Poll.history.bulk_history_create(self.data, update=True) + self.assertEqual(Poll.history.count(), 3) + historical_poll_ids = sorted(record.id for record in Poll.history.all()) + self.assertListEqual(historical_poll_ids, [1, 3, 4]) + self.assertTrue( + all(record.history_type == "~" for record in Poll.history.all()) + ) + def test_bulk_history_create_with_change_reason(self): for poll in self.data: poll._change_reason = "reason" diff --git a/simple_history/tests/tests/test_models.py b/simple_history/tests/tests/test_models.py index c9daea9e..b9ffb8fa 100644 --- a/simple_history/tests/tests/test_models.py +++ b/simple_history/tests/tests/test_models.py @@ -318,7 +318,7 @@ def test__delete_without_historical_record__creates_no_records(self): ) @override_settings(SIMPLE_HISTORY_ENABLED=False) - def test_save_with_disabled_history(self): + def test_saving_without_history_enabled_creates_no_records(self): anthony = Person.objects.create(name="Anthony Gillard") anthony.name = "something else" anthony.save() @@ -330,7 +330,6 @@ def test_save_raises_exception(self): anthony = Person(name="Anthony Gillard") with self.assertRaises(RuntimeError): anthony.save_without_historical_record() - self.assertFalse(hasattr(anthony, "skip_history_when_saving")) self.assertEqual(Person.history.count(), 0) anthony.save() self.assertEqual(Person.history.count(), 1) @@ -2430,45 +2429,9 @@ def test_m2m_relation(self): self.assertEqual(self.poll.history.all()[0].places.count(), 0) self.assertEqual(poll_2.history.all()[0].places.count(), 2) - def test_skip_history_when_updating_an_object(self): - skip_poll = PollWithManyToMany.objects.create( - question="skip history?", pub_date=today - ) - self.assertEqual(skip_poll.history.all().count(), 1) - self.assertEqual(skip_poll.history.all()[0].places.count(), 0) - - skip_poll.skip_history_when_saving = True - - skip_poll.question = "huh?" - skip_poll.save() - skip_poll.places.add(self.place) - - self.assertEqual(skip_poll.history.all().count(), 1) - self.assertEqual(skip_poll.history.all()[0].places.count(), 0) - - del skip_poll.skip_history_when_saving - place_2 = Place.objects.create(name="Place 2") - - skip_poll.places.add(place_2) - - self.assertEqual(skip_poll.history.all().count(), 2) - self.assertEqual(skip_poll.history.all()[0].places.count(), 2) - - def test_skip_history_when_creating_an_object(self): - initial_poll_count = PollWithManyToMany.objects.count() - - skip_poll = PollWithManyToMany(question="skip history?", pub_date=today) - skip_poll.skip_history_when_saving = True - skip_poll.save() - skip_poll.places.add(self.place) - - self.assertEqual(skip_poll.history.count(), 0) - self.assertEqual(PollWithManyToMany.objects.count(), initial_poll_count + 1) - self.assertEqual(skip_poll.places.count(), 1) - @override_settings(SIMPLE_HISTORY_ENABLED=False) - def test_saving_with_disabled_history_doesnt_create_records(self): - # 1 from `setUp()` + def test_saving_without_history_enabled_creates_no_records(self): + # 1 record from `setUp()` self.assertEqual(PollWithManyToMany.history.count(), 1) poll = PollWithManyToMany.objects.create( diff --git a/simple_history/tests/tests/test_utils.py b/simple_history/tests/tests/test_utils.py index 002e74c9..9ff334fa 100644 --- a/simple_history/tests/tests/test_utils.py +++ b/simple_history/tests/tests/test_utils.py @@ -1,9 +1,11 @@ import unittest +from collections.abc import Callable from dataclasses import dataclass from datetime import datetime -from typing import Optional, Type +from enum import Enum, auto +from typing import Final, List, Optional, Type, Union from unittest import skipUnless -from unittest.mock import Mock, patch +from unittest.mock import ANY, Mock, patch import django from django.contrib.auth import get_user_model @@ -14,16 +16,20 @@ from simple_history.exceptions import AlternativeManagerError, NotHistoricalModelError from simple_history.manager import HistoryManager -from simple_history.models import HistoricalChanges +from simple_history.models import HistoricalChanges, HistoricalRecords from simple_history.utils import ( + DisableHistoryInfo, + _StoredDisableHistoryInfo, bulk_create_with_history, bulk_update_with_history, + disable_history, get_historical_records_of_instance, get_history_manager_for_model, get_history_model_for_model, get_m2m_field_name, get_m2m_reverse_field_name, get_pk_name, + is_history_disabled, update_change_reason, ) @@ -85,10 +91,461 @@ TrackedWithConcreteBase, Voter, ) +from .utils import HistoricalTestCase User = get_user_model() +class DisableHistoryTestCase(HistoricalTestCase): + """Tests related to the ``disable_history()`` context manager.""" + + def test_disable_history_info(self): + """Test that the various utilities for checking the current info on how + historical record creation is disabled, return the expected values. + This includes ``DisableHistoryInfo`` and ``is_history_disabled()``, as well as + the ``_StoredDisableHistoryInfo`` stored through ``HistoricalRecords.context``. + """ + + class DisableHistoryMode(Enum): + NOT_DISABLED = auto() + GLOBALLY = auto() + PREDICATE = auto() + + poll1 = Poll.objects.create(question="question?", pub_date=timezone.now()) + poll2 = Poll.objects.create(question="ignore this", pub_date=timezone.now()) + + def assert_disable_history_info( + mode: DisableHistoryMode, predicate_target: Union[Type[Poll], Poll] = None + ): + # Check the stored info + attr_name = _StoredDisableHistoryInfo.LOCAL_STORAGE_ATTR_NAME + info = getattr(HistoricalRecords.context, attr_name, None) + if mode is DisableHistoryMode.NOT_DISABLED: + self.assertIsNone(info) + elif mode is DisableHistoryMode.GLOBALLY: + self.assertEqual( + info, _StoredDisableHistoryInfo(instance_predicate=None) + ) + elif mode is DisableHistoryMode.PREDICATE: + self.assertEqual( + info, _StoredDisableHistoryInfo(instance_predicate=ANY) + ) + self.assertIsInstance(info.instance_predicate, Callable) + + # Check `DisableHistoryInfo` + info = DisableHistoryInfo.get() + self.assertEqual(info.not_disabled, mode == DisableHistoryMode.NOT_DISABLED) + self.assertEqual( + info.disabled_globally, mode == DisableHistoryMode.GLOBALLY + ) + self.assertEqual( + info.disabled_for(poll1), + predicate_target is Poll or predicate_target == poll1, + ) + self.assertEqual( + info.disabled_for(poll2), + predicate_target is Poll or predicate_target == poll2, + ) + + # Check `is_history_disabled()` + self.assertEqual(is_history_disabled(), mode == DisableHistoryMode.GLOBALLY) + self.assertEqual( + is_history_disabled(poll1), + mode == DisableHistoryMode.GLOBALLY + or predicate_target is Poll + or predicate_target == poll1, + ) + self.assertEqual( + is_history_disabled(poll2), + mode == DisableHistoryMode.GLOBALLY + or predicate_target is Poll + or predicate_target == poll2, + ) + + assert_disable_history_info(DisableHistoryMode.NOT_DISABLED) + + with disable_history(): + assert_disable_history_info(DisableHistoryMode.GLOBALLY) + + assert_disable_history_info(DisableHistoryMode.NOT_DISABLED) + + with disable_history(only_for_model=Poll): + assert_disable_history_info(DisableHistoryMode.PREDICATE, Poll) + + assert_disable_history_info(DisableHistoryMode.NOT_DISABLED) + + with disable_history(instance_predicate=lambda poll: "ignore" in poll.question): + assert_disable_history_info(DisableHistoryMode.PREDICATE, poll2) + + assert_disable_history_info(DisableHistoryMode.NOT_DISABLED) + + @staticmethod + def _test_disable_poll_history(**kwargs): + """Create, update and delete some ``Poll`` instances outside and inside + the context manager.""" + last_pk = 0 + + def manipulate_poll( + poll: Poll = None, *, create=False, update=False, delete=False + ) -> Poll: + if create: + nonlocal last_pk + last_pk += 1 + poll = Poll.objects.create( + pk=last_pk, question=f"qUESTION {last_pk}?", pub_date=timezone.now() + ) + if update: + poll.question = f"Question {poll.pk}!" + poll.save() + if delete: + poll.delete() + return poll + + poll1 = manipulate_poll(create=True, update=True, delete=True) # noqa: F841 + poll2 = manipulate_poll(create=True, update=True) + poll3 = manipulate_poll(create=True) + + with disable_history(**kwargs): + manipulate_poll(poll2, delete=True) + manipulate_poll(poll3, update=True) + poll4 = manipulate_poll(create=True, update=True, delete=True) # noqa: F841 + poll5 = manipulate_poll(create=True, update=True) + poll6 = manipulate_poll(create=True) + + manipulate_poll(poll5, delete=True) + manipulate_poll(poll6, update=True) + poll7 = manipulate_poll(create=True, update=True, delete=True) # noqa: F841 + + expected_poll_records_before_disable: Final = [ + {"id": 1, "question": "qUESTION 1?", "history_type": "+"}, + {"id": 1, "question": "Question 1!", "history_type": "~"}, + {"id": 1, "question": "Question 1!", "history_type": "-"}, + {"id": 2, "question": "qUESTION 2?", "history_type": "+"}, + {"id": 2, "question": "Question 2!", "history_type": "~"}, + {"id": 3, "question": "qUESTION 3?", "history_type": "+"}, + ] + expected_poll_records_during_disable: Final = [ + {"id": 2, "question": "Question 2!", "history_type": "-"}, + {"id": 3, "question": "Question 3!", "history_type": "~"}, + {"id": 4, "question": "qUESTION 4?", "history_type": "+"}, + {"id": 4, "question": "Question 4!", "history_type": "~"}, + {"id": 4, "question": "Question 4!", "history_type": "-"}, + {"id": 5, "question": "qUESTION 5?", "history_type": "+"}, + {"id": 5, "question": "Question 5!", "history_type": "~"}, + {"id": 6, "question": "qUESTION 6?", "history_type": "+"}, + ] + expected_poll_records_after_disable: Final = [ + {"id": 5, "question": "Question 5!", "history_type": "-"}, + {"id": 6, "question": "Question 6!", "history_type": "~"}, + {"id": 7, "question": "qUESTION 7?", "history_type": "+"}, + {"id": 7, "question": "Question 7!", "history_type": "~"}, + {"id": 7, "question": "Question 7!", "history_type": "-"}, + ] + + @staticmethod + def _test_disable_poll_with_m2m_history(**kwargs): + """Create some ``PollWithManyToMany`` instances and add, remove, set and clear + their ``Place`` relations outside and inside the context manager.""" + last_pk = 0 + place1 = Place.objects.create(pk=1, name="1") + place2 = Place.objects.create(pk=2, name="2") + + def manipulate_places( + poll=None, *, add=False, remove=False, set=False, clear=False + ) -> PollWithManyToMany: + if not poll: + nonlocal last_pk + last_pk += 1 + poll = PollWithManyToMany.objects.create( + pk=last_pk, question=f"{last_pk}?", pub_date=timezone.now() + ) + if add: + poll.places.add(place1) + if remove: + poll.places.remove(place1) + if set: + poll.places.set([place2]) + if clear: + poll.places.clear() + return poll + + poll1 = manipulate_places( # noqa: F841 + add=True, remove=True, set=True, clear=True + ) + poll2 = manipulate_places(add=True, remove=True, set=True) + poll3 = manipulate_places(add=True, remove=True) + poll4 = manipulate_places(add=True) + + with disable_history(**kwargs): + manipulate_places(poll2, clear=True) + manipulate_places(poll3, set=True, clear=True) + manipulate_places(poll4, remove=True, set=True, clear=True) + poll5 = manipulate_places( # noqa: F841 + add=True, remove=True, set=True, clear=True + ) + poll6 = manipulate_places(add=True, remove=True, set=True) + poll7 = manipulate_places(add=True, remove=True) + poll8 = manipulate_places(add=True) + + manipulate_places(poll6, clear=True) + manipulate_places(poll7, set=True, clear=True) + manipulate_places(poll8, remove=True, set=True, clear=True) + poll9 = manipulate_places( # noqa: F841 + add=True, remove=True, set=True, clear=True + ) + + expected_poll_with_m2m_records_before_disable: Final = [ + {"id": 1, "question": "1?", "history_type": "+", "places": []}, + {"id": 1, "question": "1?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 1, "question": "1?", "history_type": "~", "places": []}, + {"id": 1, "question": "1?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 1, "question": "1?", "history_type": "~", "places": []}, + {"id": 2, "question": "2?", "history_type": "+", "places": []}, + {"id": 2, "question": "2?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 2, "question": "2?", "history_type": "~", "places": []}, + {"id": 2, "question": "2?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 3, "question": "3?", "history_type": "+", "places": []}, + {"id": 3, "question": "3?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 3, "question": "3?", "history_type": "~", "places": []}, + {"id": 4, "question": "4?", "history_type": "+", "places": []}, + {"id": 4, "question": "4?", "history_type": "~", "places": [Place(pk=1)]}, + ] + expected_poll_with_m2m_records_during_disable: Final = [ + {"id": 2, "question": "2?", "history_type": "~", "places": []}, + {"id": 3, "question": "3?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 3, "question": "3?", "history_type": "~", "places": []}, + {"id": 4, "question": "4?", "history_type": "~", "places": []}, + {"id": 4, "question": "4?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 4, "question": "4?", "history_type": "~", "places": []}, + {"id": 5, "question": "5?", "history_type": "+", "places": []}, + {"id": 5, "question": "5?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 5, "question": "5?", "history_type": "~", "places": []}, + {"id": 5, "question": "5?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 5, "question": "5?", "history_type": "~", "places": []}, + {"id": 6, "question": "6?", "history_type": "+", "places": []}, + {"id": 6, "question": "6?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 6, "question": "6?", "history_type": "~", "places": []}, + {"id": 6, "question": "6?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 7, "question": "7?", "history_type": "+", "places": []}, + {"id": 7, "question": "7?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 7, "question": "7?", "history_type": "~", "places": []}, + {"id": 8, "question": "8?", "history_type": "+", "places": []}, + {"id": 8, "question": "8?", "history_type": "~", "places": [Place(pk=1)]}, + ] + expected_poll_with_m2m_records_after_disable: Final = [ + {"id": 6, "question": "6?", "history_type": "~", "places": []}, + {"id": 7, "question": "7?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 7, "question": "7?", "history_type": "~", "places": []}, + {"id": 8, "question": "8?", "history_type": "~", "places": []}, + {"id": 8, "question": "8?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 8, "question": "8?", "history_type": "~", "places": []}, + {"id": 9, "question": "9?", "history_type": "+", "places": []}, + {"id": 9, "question": "9?", "history_type": "~", "places": [Place(pk=1)]}, + {"id": 9, "question": "9?", "history_type": "~", "places": []}, + {"id": 9, "question": "9?", "history_type": "~", "places": [Place(pk=2)]}, + {"id": 9, "question": "9?", "history_type": "~", "places": []}, + ] + + def test__disable_history__with_no_args(self): + """Test that no historical records are created inside the context manager with + no arguments (i.e. history is globally disabled).""" + # Test with `Poll` instances + self._test_disable_poll_history() + expected_records = [ + *self.expected_poll_records_before_disable, + *self.expected_poll_records_after_disable, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + # Test with `PollWithManyToMany` instances + self._test_disable_poll_with_m2m_history() + expected_records = [ + *self.expected_poll_with_m2m_records_before_disable, + *self.expected_poll_with_m2m_records_after_disable, + ] + self.assert_all_records_of_model_equal(PollWithManyToMany, expected_records) + + def test__disable_history__with__only_for_model__poll(self): + """Test that no historical records are created for ``Poll`` instances inside + the context manager with ``only_for_model=Poll`` as argument.""" + # Test with `Poll` instances + self._test_disable_poll_history(only_for_model=Poll) + expected_records = [ + *self.expected_poll_records_before_disable, + *self.expected_poll_records_after_disable, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + # Test with `PollWithManyToMany` instances + self._test_disable_poll_with_m2m_history(only_for_model=Poll) + expected_records = [ + *self.expected_poll_with_m2m_records_before_disable, + *self.expected_poll_with_m2m_records_during_disable, + *self.expected_poll_with_m2m_records_after_disable, + ] + self.assert_all_records_of_model_equal(PollWithManyToMany, expected_records) + + def test__disable_history__with__only_for_model__poll_with_m2m(self): + """Test that no historical records are created for ``PollWithManyToMany`` + instances inside the context manager with ``only_for_model=PollWithManyToMany`` + as argument.""" + # Test with `Poll` instances + self._test_disable_poll_history(only_for_model=PollWithManyToMany) + expected_records = [ + *self.expected_poll_records_before_disable, + *self.expected_poll_records_during_disable, + *self.expected_poll_records_after_disable, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + # Test with `PollWithManyToMany` instances + self._test_disable_poll_with_m2m_history(only_for_model=PollWithManyToMany) + expected_records = [ + *self.expected_poll_with_m2m_records_before_disable, + *self.expected_poll_with_m2m_records_after_disable, + ] + self.assert_all_records_of_model_equal(PollWithManyToMany, expected_records) + + def test__disable_history__with__instance_predicate(self): + """Test that no historical records are created inside the context manager, for + model instances that match the provided ``instance_predicate`` argument.""" + # Test with `Poll` instances + self._test_disable_poll_history(instance_predicate=lambda poll: poll.pk == 4) + expected_records = [ + *self.expected_poll_records_before_disable, + *filter( + lambda poll_dict: poll_dict["id"] != 4, + self.expected_poll_records_during_disable, + ), + *self.expected_poll_records_after_disable, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + # Test with `PollWithManyToMany` instances + self._test_disable_poll_with_m2m_history( + instance_predicate=lambda poll: poll.pk == 5 + ) + expected_records = [ + *self.expected_poll_with_m2m_records_before_disable, + *filter( + lambda poll_dict: poll_dict["id"] != 5, + self.expected_poll_with_m2m_records_during_disable, + ), + *self.expected_poll_with_m2m_records_after_disable, + ] + self.assert_all_records_of_model_equal(PollWithManyToMany, expected_records) + + def test__disable_history__for_queryset_delete(self): + """Test that no historical records are created inside the context manager when + deleting objects using the ``delete()`` queryset method.""" + Poll.objects.create(pk=1, question="delete me", pub_date=timezone.now()) + Poll.objects.create(pk=2, question="keep me", pub_date=timezone.now()) + Poll.objects.create(pk=3, question="keep me", pub_date=timezone.now()) + Poll.objects.create(pk=4, question="delete me", pub_date=timezone.now()) + + with disable_history(): + Poll.objects.filter(question__startswith="delete").delete() + + expected_records = [ + {"id": 1, "question": "delete me", "history_type": "+"}, + {"id": 2, "question": "keep me", "history_type": "+"}, + {"id": 3, "question": "keep me", "history_type": "+"}, + {"id": 4, "question": "delete me", "history_type": "+"}, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + Poll.objects.all().delete() + expected_records += [ + # Django reverses the order before sending the `post_delete` signals + # while bulk-deleting + {"id": 3, "question": "keep me", "history_type": "-"}, + {"id": 2, "question": "keep me", "history_type": "-"}, + ] + self.assert_all_records_of_model_equal(Poll, expected_records) + + def test__disable_history__for_foreign_key_cascade_delete(self): + """Test that no historical records are created inside the context manager when + indirectly deleting objects through a foreign key relationship with + ``on_delete=CASCADE``.""" + poll1 = Poll.objects.create(pk=1, pub_date=timezone.now()) + poll2 = Poll.objects.create(pk=2, pub_date=timezone.now()) + Choice.objects.create(pk=11, poll=poll1, votes=0) + Choice.objects.create(pk=12, poll=poll1, votes=0) + Choice.objects.create(pk=21, poll=poll2, votes=0) + Choice.objects.create(pk=22, poll=poll2, votes=0) + + with disable_history(): + poll1.delete() + + expected_records = [ + {"id": 11, "poll_id": 1, "history_type": "+"}, + {"id": 12, "poll_id": 1, "history_type": "+"}, + {"id": 21, "poll_id": 2, "history_type": "+"}, + {"id": 22, "poll_id": 2, "history_type": "+"}, + ] + self.assert_all_records_of_model_equal(Choice, expected_records) + + poll2.delete() + expected_records += [ + # Django reverses the order before sending the `post_delete` signals + # while bulk-deleting + {"id": 22, "poll_id": 2, "history_type": "-"}, + {"id": 21, "poll_id": 2, "history_type": "-"}, + ] + self.assert_all_records_of_model_equal(Choice, expected_records) + + def assert_all_records_of_model_equal( + self, model: Type[Model], expected_records: List[dict] + ): + records = model.history.all() + self.assertEqual(len(records), len(expected_records)) + for record, expected_record in zip(reversed(records), expected_records): + with self.subTest(record=record, expected_record=expected_record): + self.assertRecordValues(record, model, expected_record) + + def test_providing_illegal_arguments_fails(self): + """Test that providing various illegal arguments and argument combinations + fails.""" + + def predicate(_): + return True + + # Providing both arguments should fail + with self.assertRaises(ValueError): + with disable_history(only_for_model=Poll, instance_predicate=predicate): + pass + # Providing the arguments individually should not fail + with disable_history(only_for_model=Poll): + pass + with disable_history(instance_predicate=predicate): + pass + + # Passing non-history-tracked models should fail + with self.assertRaises(NotHistoricalModelError): + with disable_history(only_for_model=Place): + pass + # Passing non-history-tracked model instances should fail + place = Place.objects.create() + with self.assertRaises(NotHistoricalModelError): + is_history_disabled(place) + + def test_nesting_fails(self): + """Test that nesting ``disable_history()`` contexts fails.""" + # Nesting (twice or more) should fail + with self.assertRaises(AssertionError): + with disable_history(): + with disable_history(): + pass + with self.assertRaises(AssertionError): + with disable_history(): + with disable_history(): + with disable_history(): + pass + # No nesting should not fail + with disable_history(): + pass + + class UpdateChangeReasonTestCase(TestCase): def test_update_change_reason_with_excluded_fields(self): poll = PollWithExcludeFields( @@ -412,8 +869,17 @@ def test_bulk_create_history(self): self.assertEqual(Poll.history.count(), 5) @override_settings(SIMPLE_HISTORY_ENABLED=False) - def test_bulk_create_history_with_disabled_setting(self): - bulk_create_with_history(self.data, Poll) + def test_bulk_create_history_without_history_enabled(self): + with self.assertNumQueries(1): + bulk_create_with_history(self.data, Poll) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.history.count(), 0) + + def test_bulk_create_history_with__disable_history(self): + with self.assertNumQueries(1): + with disable_history(only_for_model=Poll): + bulk_create_with_history(self.data, Poll) self.assertEqual(Poll.objects.count(), 5) self.assertEqual(Poll.history.count(), 0) @@ -667,13 +1133,20 @@ def test_bulk_update_history(self): @override_settings(SIMPLE_HISTORY_ENABLED=False) def test_bulk_update_history_without_history_enabled(self): + # 5 records from `setUp()` self.assertEqual(Poll.history.count(), 5) - # because setup called with enabled settings - bulk_update_with_history( - self.data, - Poll, - fields=["question"], - ) + bulk_update_with_history(self.data, Poll, fields=["question"]) + + self.assertEqual(Poll.objects.count(), 5) + self.assertEqual(Poll.objects.get(id=4).question, "Updated question") + self.assertEqual(Poll.history.count(), 5) + self.assertEqual(Poll.history.filter(history_type="~").count(), 0) + + def test_bulk_update_history_with__disable_history(self): + # 5 records from `setUp()` + self.assertEqual(Poll.history.count(), 5) + with disable_history(only_for_model=Poll): + bulk_update_with_history(self.data, Poll, fields=["question"]) self.assertEqual(Poll.objects.count(), 5) self.assertEqual(Poll.objects.get(id=4).question, "Updated question") diff --git a/simple_history/tests/tests/utils.py b/simple_history/tests/tests/utils.py index 3700b95f..314956bc 100644 --- a/simple_history/tests/tests/utils.py +++ b/simple_history/tests/tests/utils.py @@ -1,10 +1,12 @@ from enum import Enum -from typing import Type +from typing import List, Type from django.conf import settings -from django.db.models import Model +from django.db.models import Manager, Model from django.test import TestCase +from simple_history.utils import get_m2m_reverse_field_name + request_middleware = "simple_history.middleware.HistoryRequestMiddleware" OTHER_DB_NAME = "other" @@ -26,15 +28,33 @@ def assertRecordValues(self, record, klass: Type[Model], values_dict: dict): :param klass: The type of the history-tracked class of ``record``. :param values_dict: Field names of ``record`` mapped to their expected values. """ + values_dict_copy = values_dict.copy() for field_name, expected_value in values_dict.items(): - self.assertEqual(getattr(record, field_name), expected_value) + value = getattr(record, field_name) + if isinstance(value, Manager): + # Assuming that the value being a manager means that it's an M2M field + self._assert_m2m_record(record, field_name, expected_value) + # Remove the field, as `history_object` (used below) doesn't currently + # support historical M2M queryset values + values_dict_copy.pop(field_name) + else: + self.assertEqual(value, expected_value) history_object = record.history_object self.assertEqual(history_object.__class__, klass) - for field_name, expected_value in values_dict.items(): + for field_name, expected_value in values_dict_copy.items(): if field_name not in ("history_type", "history_change_reason"): self.assertEqual(getattr(history_object, field_name), expected_value) + def _assert_m2m_record(self, record, field_name: str, expected_value: List[Model]): + value = getattr(record, field_name) + field = record.instance_type._meta.get_field(field_name) + reverse_field_name = get_m2m_reverse_field_name(field) + self.assertListEqual( + [getattr(m2m_record, reverse_field_name) for m2m_record in value.all()], + expected_value, + ) + class TestDbRouter: def db_for_read(self, model, **hints): diff --git a/simple_history/utils.py b/simple_history/utils.py index fe272ff6..b43c2f9f 100644 --- a/simple_history/utils.py +++ b/simple_history/utils.py @@ -1,5 +1,10 @@ -from typing import TYPE_CHECKING, Optional, Type, Union +import sys +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, ClassVar, Iterator, Optional, Type, Union +from asgiref.local import Local +from django.conf import settings from django.db import transaction from django.db.models import Case, ForeignKey, ManyToManyField, Model, Q, When from django.forms.models import model_to_dict @@ -11,6 +16,164 @@ from .models import HistoricalChanges +@contextmanager +def disable_history( + *, + only_for_model: Type[Model] = None, + instance_predicate: Callable[[Model], bool] = None, +) -> Iterator[None]: + """ + Disable creating historical records while this context manager is active. + + Note: ``only_for_model`` and ``instance_predicate`` cannot both be provided. + + :param only_for_model: Only disable history creation for this model type. + :param instance_predicate: Only disable history creation for model instances passing + this predicate. + """ + assert ( # nosec + _StoredDisableHistoryInfo.get() is None + ), "Nesting 'disable_history()' contexts is undefined behavior" + + if only_for_model: + # Raise an error if it's not a history-tracked model + get_history_manager_for_model(only_for_model) + if instance_predicate: + raise ValueError( + "'only_for_model' and 'instance_predicate' cannot both be provided" + ) + else: + + def instance_predicate(instance: Model): + return isinstance(instance, only_for_model) + + info = _StoredDisableHistoryInfo(instance_predicate) + info.set() + try: + yield + finally: + info.delete() + + +@dataclass(frozen=True) +class _StoredDisableHistoryInfo: + """ + Data related to how historical record creation is disabled, stored in + ``HistoricalRecords.context`` through the ``disable_history()`` context manager. + """ + + LOCAL_STORAGE_ATTR_NAME: ClassVar = "disable_history_info" + + instance_predicate: Callable[[Model], bool] = None + + def set(self) -> None: + setattr(self._get_storage(), self.LOCAL_STORAGE_ATTR_NAME, self) + + @classmethod + def get(cls) -> Optional["_StoredDisableHistoryInfo"]: + """ + A return value of ``None`` means that the ``disable_history()`` context manager + is not active. + """ + return getattr(cls._get_storage(), cls.LOCAL_STORAGE_ATTR_NAME, None) + + @classmethod + def delete(cls) -> None: + delattr(cls._get_storage(), cls.LOCAL_STORAGE_ATTR_NAME) + + @staticmethod + def _get_storage() -> Local: + from .models import HistoricalRecords # Avoids circular importing + + return HistoricalRecords.context + + +@dataclass( + frozen=True, + # DEV: Replace this with just `kw_only=True` when the minimum required + # Python version is 3.10 + **({"kw_only": True} if sys.version_info >= (3, 10) else {}), +) +class DisableHistoryInfo: + """ + Provides info on *how* historical record creation is disabled. + + Create a new instance through ``get()`` for updated info. + (The ``__init__()`` method is intended for internal use.) + """ + + _disabled_globally: bool + _instance_predicate: Optional[Callable[[Model], bool]] + + @property + def not_disabled(self) -> bool: + """ + A value of ``True`` means that historical record creation is not disabled + in any way. + If ``False``, check ``disabled_globally`` and ``disabled_for()``. + """ + return not self._disabled_globally and not self._instance_predicate + + @property + def disabled_globally(self) -> bool: + """ + Whether historical record creation is disabled due to + the ``SIMPLE_HISTORY_ENABLED`` setting or the ``disable_history()`` context + manager being active. + """ + return self._disabled_globally + + def disabled_for(self, instance: Model) -> bool: + """ + Returns whether history record creation is disabled for the provided instance + specifically. + Remember to also check ``disabled_globally``! + """ + return bool(self._instance_predicate) and self._instance_predicate(instance) + + @classmethod + def get(cls) -> "DisableHistoryInfo": + """ + Returns an instance of this class. + + Note that this method must be called again every time you want updated info. + """ + stored_info = _StoredDisableHistoryInfo.get() + context_manager_active = bool(stored_info) + instance_predicate = ( + stored_info.instance_predicate if context_manager_active else None + ) + + disabled_globally = not getattr(settings, "SIMPLE_HISTORY_ENABLED", True) or ( + context_manager_active and not instance_predicate + ) + return cls( + _disabled_globally=disabled_globally, + _instance_predicate=instance_predicate, + ) + + +def is_history_disabled(instance: Model = None) -> bool: + """ + Returns whether creating historical records is disabled. + + :param instance: If *not* provided, will return whether history is disabled + globally. Otherwise, will return whether history is disabled for the provided + instance (either globally or due to the arguments passed to + the ``disable_history()`` context manager). + """ + if instance: + # Raise an error if it's not a history-tracked model instance + get_history_manager_for_model(instance) + + info = DisableHistoryInfo.get() + if info.disabled_globally: + return True + if instance and info.disabled_for(instance): + return True + return False + + def get_change_reason_from_object(obj: Model) -> Optional[str]: return getattr(obj, "_change_reason", None)