diff --git a/django/db/models/fields/tuple_lookups.py b/django/db/models/fields/tuple_lookups.py index 51432a666b5f..2f195f25b9cb 100644 --- a/django/db/models/fields/tuple_lookups.py +++ b/django/db/models/fields/tuple_lookups.py @@ -2,7 +2,7 @@ from django.core.exceptions import EmptyResultSet from django.db.models import Field -from django.db.models.expressions import ColPairs, Func, Value +from django.db.models.expressions import ColPairs, Func, ResolvedOuterRef, Value from django.db.models.lookups import ( Exact, GreaterThan, @@ -32,8 +32,11 @@ class TupleLookupMixin: allows_composite_expressions = True def get_prep_lookup(self): - self.check_rhs_is_tuple_or_list() - self.check_rhs_length_equals_lhs_length() + if self.rhs_is_direct_value(): + self.check_rhs_is_tuple_or_list() + self.check_rhs_length_equals_lhs_length() + else: + self.check_rhs_is_outer_ref() return self.rhs def check_rhs_is_tuple_or_list(self): @@ -51,6 +54,15 @@ def check_rhs_length_equals_lhs_length(self): f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements" ) + def check_rhs_is_outer_ref(self): + if not isinstance(self.rhs, ResolvedOuterRef): + lhs_str = self.get_lhs_str() + rhs_cls = self.rhs.__class__.__name__ + raise ValueError( + f"{self.lookup_name!r} subquery lookup of {lhs_str} " + f"only supports OuterRef objects (received {rhs_cls!r})" + ) + def get_lhs_str(self): if isinstance(self.lhs, ColPairs): return repr(self.lhs.field.name) @@ -70,11 +82,19 @@ def process_lhs(self, compiler, connection, lhs=None): return sql, params def process_rhs(self, compiler, connection): - values = [ - Value(val, output_field=col.output_field) - for col, val in zip(self.lhs, self.rhs) - ] - return Tuple(*values).as_sql(compiler, connection) + if self.rhs_is_direct_value(): + args = [ + Value(val, output_field=col.output_field) + for col, val in zip(self.lhs, self.rhs) + ] + return Tuple(*args).as_sql(compiler, connection) + else: + sql, params = compiler.compile(self.rhs) + if not isinstance(self.rhs, ColPairs): + raise ValueError( + "Composite field lookups only work with composite expressions." + ) + return "(%s)" % sql, params class TupleExact(TupleLookupMixin, Exact): diff --git a/tests/composite_pk/test_filter.py b/tests/composite_pk/test_filter.py index 864877483a7c..4edf94742369 100644 --- a/tests/composite_pk/test_filter.py +++ b/tests/composite_pk/test_filter.py @@ -1,5 +1,6 @@ -from django.db.models import F, TextField +from django.db.models import F, FilteredRelation, OuterRef, Q, Subquery, TextField from django.db.models.functions import Cast +from django.db.models.lookups import Exact from django.test import TestCase from .models import Comment, Tenant, User @@ -407,3 +408,51 @@ def test_cannot_cast_pk(self): msg = "Cast does not support composite primary keys." with self.assertRaisesMessage(ValueError, msg): Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count() + + def test_outer_ref_pk(self): + subquery = Subquery(Comment.objects.filter(pk=OuterRef("pk")).values("id")) + tests = [ + ("", 5), + ("__gt", 0), + ("__gte", 5), + ("__lt", 0), + ("__lte", 5), + ] + for lookup, expected_count in tests: + with self.subTest(f"id{lookup}"): + queryset = Comment.objects.filter(**{f"id{lookup}": subquery}) + self.assertEqual(queryset.count(), expected_count) + + def test_non_outer_ref_subquery(self): + # If rhs is any non-OuterRef object with an as_sql() function. + pk = Exact(F("tenant_id"), 1) + msg = ( + "'exact' subquery lookup of 'pk' only supports OuterRef objects " + "(received 'Exact')" + ) + with self.assertRaisesMessage(ValueError, msg): + Comment.objects.filter(pk=pk) + + def test_outer_ref_not_composite_pk(self): + subquery = Comment.objects.filter(pk=OuterRef("id")).values("id") + queryset = Comment.objects.filter(id=Subquery(subquery)) + + msg = "Composite field lookups only work with composite expressions." + with self.assertRaisesMessage(ValueError, msg): + self.assertEqual(queryset.count(), 5) + + def test_outer_ref_in_filtered_relation(self): + msg = ( + "This queryset contains a reference to an outer query and may only be used " + "in a subquery." + ) + with self.assertRaisesMessage(ValueError, msg): + self.assertSequenceEqual( + Tenant.objects.annotate( + filtered_tokens=FilteredRelation( + "tokens", + condition=Q(tokens__pk__gte=OuterRef("tokens")), + ) + ).filter(filtered_tokens=(1, 1)), + [self.tenant_1], + )