diff --git a/modelcluster/queryset.py b/modelcluster/queryset.py index bd021ac..5125dae 100644 --- a/modelcluster/queryset.py +++ b/modelcluster/queryset.py @@ -3,9 +3,10 @@ import datetime import re +from django.core.exceptions import FieldDoesNotExist from django.db.models import Model, prefetch_related_objects -from modelcluster.utils import extract_field_value, get_model_field, sort_by_fields +from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields # Constructor for test functions that determine whether an object passes some boolean condition @@ -13,23 +14,44 @@ def test_exact(model, attribute_name, value): if isinstance(value, Model): if value.pk is None: # comparing against an unsaved model, so objects need to match by reference - return lambda obj: extract_field_value(obj, attribute_name) is value + def _test(obj): + try: + other_value = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False + return other_value is value + + return _test + else: # comparing against a saved model; objects need to match by type and ID. # Additionally, where model inheritance is involved, we need to treat it as a # positive match if one is a subclass of the other def _test(obj): - other_value = extract_field_value(obj, attribute_name) - if not (isinstance(value, other_value.__class__) or isinstance(other_value, value.__class__)): + try: + other_value = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: return False - return value.pk == other_value.pk + return value.pk == other_value.pk and ( + isinstance(value, other_value.__class__) + or isinstance(other_value, value.__class__) + ) + return _test else: field = get_model_field(model, attribute_name) # convert value to the correct python type for this field typed_value = field.to_python(value) + # just a plain Python value = do a normal equality check - return lambda obj: extract_field_value(obj, attribute_name) == typed_value + def _test(obj): + try: + other_value = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False + return other_value == typed_value + + return _test def test_iexact(model, attribute_name, match_value): @@ -37,15 +59,24 @@ def test_iexact(model, attribute_name, match_value): match_value = field.to_python(match_value) if match_value is None: - return lambda obj: getattr(obj, attribute_name) is None + + def _test(obj): + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False + return val is None else: match_value = match_value.upper() def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.upper() == match_value - return _test + return _test def test_contains(model, attribute_name, value): @@ -53,7 +84,10 @@ def test_contains(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and match_value in val return _test @@ -64,7 +98,10 @@ def test_icontains(model, attribute_name, value): match_value = field.to_python(value).upper() def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and match_value in val.upper() return _test @@ -75,7 +112,10 @@ def test_lt(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val < match_value return _test @@ -86,7 +126,10 @@ def test_lte(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val <= match_value return _test @@ -97,7 +140,10 @@ def test_gt(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val > match_value return _test @@ -108,7 +154,10 @@ def test_gte(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val >= match_value return _test @@ -117,7 +166,15 @@ def _test(obj): def test_in(model, attribute_name, value_list): field = get_model_field(model, attribute_name) match_values = set(field.to_python(val) for val in value_list) - return lambda obj: extract_field_value(obj, attribute_name) in match_values + + def _test(obj): + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False + return val in match_values + + return _test def test_startswith(model, attribute_name, value): @@ -125,7 +182,10 @@ def test_startswith(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.startswith(match_value) return _test @@ -136,7 +196,10 @@ def test_istartswith(model, attribute_name, value): match_value = field.to_python(value).upper() def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.upper().startswith(match_value) return _test @@ -147,7 +210,10 @@ def test_endswith(model, attribute_name, value): match_value = field.to_python(value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.endswith(match_value) return _test @@ -158,7 +224,10 @@ def test_iendswith(model, attribute_name, value): match_value = field.to_python(value).upper() def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.upper().endswith(match_value) return _test @@ -170,7 +239,10 @@ def test_range(model, attribute_name, range_val): end_val = field.to_python(range_val[1]) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return (val is not None and val >= start_val and val <= end_val) return _test @@ -178,7 +250,10 @@ def _test(obj): def test_date(model, attribute_name, match_value): def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False if isinstance(val, datetime.datetime): return val.date() == match_value else: @@ -191,7 +266,10 @@ def test_year(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.year == match_value return _test @@ -201,7 +279,10 @@ def test_month(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.month == match_value return _test @@ -211,7 +292,10 @@ def test_day(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.day == match_value return _test @@ -221,7 +305,10 @@ def test_week(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.isocalendar()[1] == match_value return _test @@ -231,7 +318,10 @@ def test_week_day(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.isoweekday() % 7 + 1 == match_value return _test @@ -241,7 +331,10 @@ def test_quarter(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and int((val.month - 1) / 3) + 1 == match_value return _test @@ -249,7 +342,10 @@ def _test(obj): def test_time(model, attribute_name, match_value): def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False if isinstance(val, datetime.datetime): return val.time() == match_value else: @@ -262,7 +358,10 @@ def test_hour(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.hour == match_value return _test @@ -272,7 +371,10 @@ def test_minute(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.minute == match_value return _test @@ -282,24 +384,37 @@ def test_second(model, attribute_name, match_value): match_value = int(match_value) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and val.second == match_value return _test def test_isnull(model, attribute_name, sense): - if sense: - return lambda obj: extract_field_value(obj, attribute_name) is None - else: - return lambda obj: extract_field_value(obj, attribute_name) is not None + def _test(obj): + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False + if sense: + return val is None + else: + return val is not None + + return _test def test_regex(model, attribute_name, regex_string): regex = re.compile(regex_string) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and regex.search(val) return _test @@ -309,7 +424,10 @@ def test_iregex(model, attribute_name, regex_string): regex = re.compile(regex_string, re.I) def _test(obj): - val = extract_field_value(obj, attribute_name) + try: + val = extract_field_value(obj, attribute_name) + except NullRelationshipValueEncountered: + return False return val is not None and regex.search(val) return _test @@ -350,8 +468,15 @@ def _test(obj): def _build_test_function_from_filter(model, key_clauses, val): # Translate a filter kwarg rule (e.g. foo__bar__exact=123) into a function which can # take a model instance and return a boolean indicating whether it passes the rule - if key_clauses[-1] in FILTER_EXPRESSION_TOKENS: - # the last clause indicates the type of test + try: + get_model_field(model, "__".join(key_clauses)) + except FieldDoesNotExist: + # it is safe to assume the last clause indicates the type of test + field_match_found = False + else: + field_match_found = True + + if not field_match_found and key_clauses[-1] in FILTER_EXPRESSION_TOKENS: constructor = FILTER_EXPRESSION_TOKENS[key_clauses.pop()] else: constructor = test_exact @@ -376,7 +501,7 @@ def __iter__(self): field_names = self.queryset.dict_fields or [field.name for field in self.queryset.model._meta.fields] for obj in self.queryset.results: yield { - field_name: extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True) + field_name: extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) for field_name in field_names } @@ -385,14 +510,14 @@ class ValuesListIterable(FakeQuerySetIterable): def __iter__(self): field_names = self.queryset.tuple_fields or [field.name for field in self.queryset.model._meta.fields] for obj in self.queryset.results: - yield tuple([extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True) for field_name in field_names]) + yield tuple([extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) for field_name in field_names]) class FlatValuesListIterable(FakeQuerySetIterable): def __iter__(self): field_name = self.queryset.tuple_fields[0] for obj in self.queryset.results: - yield extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True) + yield extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) class FakeQuerySet(object): diff --git a/modelcluster/utils.py b/modelcluster/utils.py index 243685f..80afb00 100644 --- a/modelcluster/utils.py +++ b/modelcluster/utils.py @@ -1,6 +1,6 @@ from functools import lru_cache from django.core.exceptions import FieldDoesNotExist -from django.db.models import ManyToManyField, ManyToManyRel +from django.db.models import ManyToManyField, ManyToManyRel, Model REL_DELIMETER = "__" @@ -9,6 +9,10 @@ class ManyToManyTraversalError(ValueError): pass +class NullRelationshipValueEncountered(Exception): + pass + + class TraversedRelationship: __slots__ = ['from_model', 'field'] @@ -62,9 +66,18 @@ def get_model_field(model, name): subject_model=subject_model, ) ) - if hasattr(field, "related_model"): + if getattr(field, "related_model", None): traversals.append(TraversedRelationship(subject_model, field)) subject_model = field.related_model + else: + raise FieldDoesNotExist( + "Failed attempting to traverse from {from_field} (a {from_field_type}) to '{to_field}'." + .format( + from_field=subject_model._meta.label + '.' + field.name, + from_field_type=type(field), + to_field=field_name, + ) + ) try: field = subject_model._meta.get_field(field_name) except FieldDoesNotExist: @@ -76,7 +89,7 @@ def get_model_field(model, name): return field -def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=False): +def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=False, suppress_nullrelationshipvalueencountered=False): """ Attempts to extract a field value from ``obj`` matching the ``key`` - which, can contain double-underscores (`'__'`) to indicate traversal of relationships @@ -89,12 +102,34 @@ def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=Fals By default, ``FieldDoesNotExist`` is raised if the key cannot be mapped to a model field. Call the function with ``suppress_fielddoesnotexist=True`` - to get ``None`` values instead. + to instead receive a ``None`` value when this occurs. + + By default, ``NullRelationshipValueEncountered`` is raised if a ``None`` + value is encountered while attempting to traverse relationships in order to + access further fields. Call the function with + ``suppress_nullrelationshipvalueencountered`` to instead receive a ``None`` + value when this occurs. """ source = obj - for attr in key.split(REL_DELIMETER): - if hasattr(source, attr): - value = getattr(source, attr) + latest_obj = obj + segments = key.split(REL_DELIMETER) + for i, segment in enumerate(segments, start=1): + if hasattr(source, segment): + value = getattr(source, segment) + if isinstance(value, Model): + latest_obj = value + if value is None and i < len(segments): + if suppress_nullrelationshipvalueencountered: + return None + raise NullRelationshipValueEncountered( + "'{key}' cannot be reached for {obj} because {model_class}.{field_name} " + "is null.".format( + key=key, + obj=repr(obj), + model_class=latest_obj._meta.label, + field_name=segment, + ) + ) source = value continue elif suppress_fielddoesnotexist: @@ -102,7 +137,7 @@ def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=Fals else: raise FieldDoesNotExist( "'{name}' is not a valid field name for {model}".format( - name=attr, model=type(source) + name=segment, model=type(source) ) ) if pk_only and hasattr(value, 'pk'): @@ -128,7 +163,7 @@ def sort_by_fields(items, fields): def get_sort_value(item): # Use a tuple of (v is not None, v) as the key, to ensure that None sorts before other values, # as comparing directly with None breaks on python3 - value = extract_field_value(item, key, pk_only=True, suppress_fielddoesnotexist=True) + value = extract_field_value(item, key, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) return (value is not None, value) # Sort items diff --git a/tests/migrations/0012_add_record_label.py b/tests/migrations/0012_add_record_label.py new file mode 100644 index 0000000..b37d2e0 --- /dev/null +++ b/tests/migrations/0012_add_record_label.py @@ -0,0 +1,42 @@ +# Generated by Django 4.2.9 on 2024-02-04 06:59 + +from django.db import migrations, models +import django.db.models.deletion +import modelcluster.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ("taggit", "0005_auto_20220424_2025"), + ("tests", "0011_add_room_features"), + ] + + operations = [ + migrations.CreateModel( + name="RecordLabel", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=200)), + ("range", models.SmallIntegerField(blank=True, default=5)), + ], + ), + migrations.AddField( + model_name="album", + name="label", + field=models.ForeignKey( + blank=True, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="tests.recordlabel", + ), + ), + ] diff --git a/tests/models.py b/tests/models.py index 76820bb..c973047 100644 --- a/tests/models.py +++ b/tests/models.py @@ -36,6 +36,7 @@ class Album(ClusterableModel): name = models.CharField(max_length=255) release_date = models.DateField(null=True, blank=True) sort_order = models.IntegerField(null=True, blank=True, editable=False) + label = models.ForeignKey("RecordLabel", blank=True, null=True, on_delete=models.SET_NULL) sort_order_field = 'sort_order' @@ -60,6 +61,14 @@ class Meta: ordering = ['sort_order'] +class RecordLabel(models.Model): + name = models.CharField(max_length=200) + range = models.SmallIntegerField(default=5, blank=True) + + def __str__(self): + return self.name + + class TaggedPlace(TaggedItemBase): content_object = ParentalKey('Place', related_name='tagged_items', on_delete=models.CASCADE) diff --git a/tests/tests/test_cluster.py b/tests/tests/test_cluster.py index 7abd3ad..536e7c8 100644 --- a/tests/tests/test_cluster.py +++ b/tests/tests/test_cluster.py @@ -11,8 +11,9 @@ from modelcluster.queryset import FakeQuerySet from modelcluster.utils import ManyToManyTraversalError -from tests.models import Band, BandMember, Chef, Feature, Place, Restaurant, SeafoodRestaurant, \ - Review, Album, Article, Author, Category, Person, Room, House, Log, Dish, MenuItem, Wine +from tests.models import Band, BandMember, Chef, Feature, Place, Restaurant, \ + Review, Album, Song, RecordLabel, Article, Author, Category, Person, \ + Room, House, Log, Dish, MenuItem, Wine class ClusterTest(TestCase): @@ -142,6 +143,29 @@ def test_can_create_cluster(self): # queries on beatles.members should now revert to SQL self.assertTrue(beatles.members.extra(where=["tests_bandmember.name='John Lennon'"]).exists()) + def test_filter_expression_token_clash_handling(self): + """ + This tests ensures that the field name 'range' should not be mistaken + for the 'range' from FILTER_EXPRESSION_TOKENS when used in filter() + or exclude(). + + Plus, extract_field_value() should not crash when encountering albums + without a 'label' value specified (they should be classed as automatic + test failures and excluded from the result). + """ + label = RecordLabel.objects.create(name="Parlophone", range=7) + beatles = Band( + name="The Beatles", + albums=[ + Album(name='Please Please Me', label=label, sort_order=1), + Album(name='With The Beatles', sort_order=2), + Album(name='A Hard Day\'s Night', sort_order=3), + ], + ) + + self.assertEqual(beatles.albums.filter(label__range=7).count(), 1) + self.assertEqual(beatles.albums.exclude(label__range=7).count(), 2) + def test_values_list(self): beatles = Band( name="The Beatles",