Skip to content

Commit

Permalink
Add support for custom get_user option
Browse files Browse the repository at this point in the history
  • Loading branch information
rwlogel committed Apr 15, 2018
1 parent 36efd80 commit 44e0056
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 21 deletions.
33 changes: 31 additions & 2 deletions docs/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,42 @@ referencing the ``changed_by`` field:
def _history_user(self, value):
self.changed_by = value
Admin integration requires that you use a ``_history_user.setter`` attribute with your custom ``_history_user`` property (see :ref:`admin_integration`).
Admin integration requires that you use a ``_history_user.setter`` attribute with
your custom ``_history_user`` property (see :ref:`admin_integration`).

Another option for identifying the change user is by providing a function via ``get_user``.
If provided it will be called everytime that the ``history_user`` needs to be
identified with the following key word arguments:

* ``history``: The current ``HistoricalRecords`` instance
* ``instance``: The current instance being modified
* ``request``: If using the middleware the current request object will be provided if they are authenticated.

This is very helpful when using ``register``:

.. code-block:: python
from django.db import models
from simple_history.models import HistoricalRecords
class Poll(models.Model):
question = models.CharField(max_length=200)
pub_date = models.DateTimeField('date published')
changed_by = models.ForeignKey('auth.User')
def get_poll_user(instance, **kwargs):
return instance.changed_by
register(Poll, get_user=get_poll_user)
Change User Model
------------------------------------

If you need to use a different user model then ``settings.AUTH_USER_MODEL``,
pass in the required model to ``user_model``. Doing this requires ``_history_user`` is provided.
pass in the required model to ``user_model``. Doing this requires ``_history_user``
or ``get_user`` is provided as detailed above.

.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion simple_history/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ def register(
records.manager_name = manager_name
records.table_name = table_name
records.module = app and ("%s.models" % app) or model.__module__
records.cls = model
records.add_extra_methods(model)
records.finalize(model)
models.registered_models[model._meta.db_table] = model
32 changes: 19 additions & 13 deletions simple_history/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,27 @@
registered_models = {}


def default_get_user(request, **kwargs):
try:
return request.user
except AttributeError:
return None


class HistoricalRecords(object):
thread = threading.local()

def __init__(self, verbose_name=None, bases=(models.Model,),
user_related_name='+', table_name=None, inherit=False,
history_id_field=None, user_model=None,
excluded_fields=None):
get_user=default_get_user, excluded_fields=None):
self.user_set_verbose_name = verbose_name
self.user_related_name = user_related_name
self.table_name = table_name
self.inherit = inherit
self.history_id_field = history_id_field
self.user_model = user_model
self.get_user = get_user
if excluded_fields is None:
excluded_fields = []
self.excluded_fields = excluded_fields
Expand Down Expand Up @@ -72,15 +80,11 @@ def save_without_historical_record(self, *args, **kwargs):

def finalize(self, sender, **kwargs):
inherited = False
try:
hint_class = self.cls
except AttributeError: # called via `register`
pass
else:
if hint_class is not sender: # set in concrete
inherited = (self.inherit and issubclass(sender, hint_class))
if not inherited:
return # set in abstract
if self.cls is not sender: # set in concrete
inherited = (self.inherit and issubclass(sender, self.cls))
if not inherited:
return # set in abstract

if hasattr(sender._meta, 'simple_history_manager_attribute'):
raise exceptions.MultipleRegistrationsError(
'{}.{} registered multiple times for history tracking.'.format(
Expand Down Expand Up @@ -303,12 +307,14 @@ def get_history_user(self, instance):
try:
return instance._history_user
except AttributeError:
request = None
try:
if self.thread.request.user.is_authenticated:
return self.thread.request.user
return None
request = self.thread.request
except AttributeError:
return None
pass

return self.get_user(history=self, instance=instance, request=request)


def transform_field(field):
Expand Down
48 changes: 48 additions & 0 deletions simple_history/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ class InheritTracking4(TrackedAbstractBaseA):

class BucketMember(models.Model):
name = models.CharField(max_length=30)
user = models.OneToOneField(
User,
related_name="bucket_member",
on_delete=models.CASCADE
)


class BucketData(models.Model):
Expand All @@ -423,6 +428,49 @@ def _history_user(self):
return self.changed_by


def get_bucket_member1(instance, **kwargs):
try:
return instance.changed_by
except AttributeError:
return None


class BucketDataRegister1(models.Model):
changed_by = models.ForeignKey(
BucketMember,
on_delete=models.SET_NULL,
null=True, blank=True,
)


register(
BucketDataRegister1,
user_model=BucketMember,
get_user=get_bucket_member1
)


def get_bucket_member2(request, **kwargs):
try:
return request.user.bucket_member
except AttributeError:
return None


class BucketDataRegister2(models.Model):
data = models.CharField(max_length=30)

def get_absolute_url(self):
return reverse('bucket_data-detail', kwargs={'pk': self.pk})


register(
BucketDataRegister2,
user_model=BucketMember,
get_user=get_bucket_member2
)


class UUIDModel(models.Model):
id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
history = HistoricalRecords(
Expand Down
2 changes: 1 addition & 1 deletion simple_history/tests/tests/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_deleteting_user(self):
def test_deleteting_member(self):
"""Test deletes of a BucketMember doesn't cascade delete the history"""
self.login()
member = BucketMember.objects.create(name="member1")
member = BucketMember.objects.create(name="member1", user=self.user)
bucket_data = BucketData(changed_by=member)
bucket_data.save()

Expand Down
22 changes: 21 additions & 1 deletion simple_history/tests/tests/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from django.urls import reverse

from simple_history.tests.custom_user.models import CustomUser
from simple_history.tests.models import Poll
from simple_history.tests.models import (
BucketDataRegister2,
BucketMember,
Poll
)

overridden_settings = {
'MIDDLEWARE': (settings.MIDDLEWARE +
Expand Down Expand Up @@ -171,3 +175,19 @@ def test_user_is_not_set_on_delete_view_when_not_logged_in(self):

self.assertListEqual([ph.history_user_id for ph in poll_history],
[None, None])

@override_settings(**overridden_settings)
def test_bucket_member_is_set_on_create_view_when_logged_in(self):
self.client.force_login(self.user)
member1 = BucketMember.objects.create(name="member1", user=self.user)
data = {
'data': 'Test Data',
}
self.client.post(reverse('bucket_data-add'), data=data)
bucket_datas = BucketDataRegister2.objects.all()
self.assertEqual(bucket_datas.count(), 1)

history = bucket_datas.first().history.all()

self.assertListEqual([h.history_user_id for h in history],
[member1.id])
20 changes: 18 additions & 2 deletions simple_history/tests/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Book,
Bookcase,
BucketData,
BucketDataRegister1,
BucketMember,
Choice,
City,
Expand Down Expand Up @@ -373,8 +374,10 @@ def test_model_with_excluded_fields(self):
self.assertNotIn('pub_date', all_fields_names)

def test_user_model_override(self):
member1 = BucketMember.objects.create(name="member1")
member2 = BucketMember.objects.create(name="member2")
user1 = User.objects.create_user('user1', '1@example.com')
user2 = User.objects.create_user('user2', '1@example.com')
member1 = BucketMember.objects.create(name="member1", user=user1)
member2 = BucketMember.objects.create(name="member2", user=user2)
bucket_data = BucketData.objects.create(changed_by=member1)
bucket_data.changed_by = member2
bucket_data.save()
Expand All @@ -383,6 +386,19 @@ def test_user_model_override(self):
self.assertEqual([d.history_user for d in bucket_data.history.all()],
[None, member2, member1])

def test_user_model_override_registered(self):
user1 = User.objects.create_user('user1', '1@example.com')
user2 = User.objects.create_user('user2', '1@example.com')
member1 = BucketMember.objects.create(name="member1", user=user1)
member2 = BucketMember.objects.create(name="member2", user=user2)
bucket_data = BucketDataRegister1.objects.create(changed_by=member1)
bucket_data.changed_by = member2
bucket_data.save()
bucket_data.changed_by = None
bucket_data.save()
self.assertEqual([d.history_user for d in bucket_data.history.all()],
[None, member2, member1])

def test_uuid_history_id(self):
entry = UUIDModel.objects.create()

Expand Down
6 changes: 6 additions & 0 deletions simple_history/tests/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from django.contrib import admin

from simple_history.tests.view import (
BucketDataRegister2Create,
BucketDataRegister2Detail,
PollCreate,
PollDelete,
PollDetail,
Expand All @@ -17,6 +19,10 @@
urlpatterns = [
url(r'^admin/', admin.site.urls),
url(r'^other-admin/', other_admin.site.urls),
url(r'^bucket_data/add/$', BucketDataRegister2Create.as_view(),
name='bucket_data-add'),
url(r'^bucket_data/(?P<pk>[0-9]+)/$', BucketDataRegister2Detail.as_view(),
name='bucket_data-detail'),
url(r'^poll/add/$', PollCreate.as_view(), name='poll-add'),
url(r'^poll/(?P<pk>[0-9]+)/$', PollUpdate.as_view(), name='poll-update'),
url(r'^poll/(?P<pk>[0-9]+)/delete/$', PollDelete.as_view(),
Expand Down
12 changes: 11 additions & 1 deletion simple_history/tests/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
UpdateView
)

from simple_history.tests.models import Poll
from simple_history.tests.models import BucketDataRegister2, Poll


class PollCreate(CreateView):
Expand All @@ -33,3 +33,13 @@ class PollList(ListView):
class PollDetail(DetailView):
model = Poll
fields = ['question', 'pub_date']


class BucketDataRegister2Create(CreateView):
model = BucketDataRegister2
fields = ['data']


class BucketDataRegister2Detail(DetailView):
model = BucketDataRegister2
fields = ['data']

0 comments on commit 44e0056

Please sign in to comment.