Skip to content

Commit

Permalink
Fixed #36050 -- Added OuterRef support to CompositePrimaryKey.
Browse files Browse the repository at this point in the history
  • Loading branch information
csirmazbendeguz authored and sarahboyce committed Jan 10, 2025
1 parent 97ee8b8 commit 8bee7fa
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 9 deletions.
36 changes: 28 additions & 8 deletions django/db/models/fields/tuple_lookups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.core.exceptions import EmptyResultSet
from django.db.models import Field
from django.db.models.expressions import ColPairs, Func, Value
from django.db.models.expressions import ColPairs, Func, ResolvedOuterRef, Value
from django.db.models.lookups import (
Exact,
GreaterThan,
Expand Down Expand Up @@ -32,8 +32,11 @@ class TupleLookupMixin:
allows_composite_expressions = True

def get_prep_lookup(self):
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length()
if self.rhs_is_direct_value():
self.check_rhs_is_tuple_or_list()
self.check_rhs_length_equals_lhs_length()
else:
self.check_rhs_is_outer_ref()
return self.rhs

def check_rhs_is_tuple_or_list(self):
Expand All @@ -51,6 +54,15 @@ def check_rhs_length_equals_lhs_length(self):
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
)

def check_rhs_is_outer_ref(self):
if not isinstance(self.rhs, ResolvedOuterRef):
lhs_str = self.get_lhs_str()
rhs_cls = self.rhs.__class__.__name__
raise ValueError(
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
f"only supports OuterRef objects (received {rhs_cls!r})"
)

def get_lhs_str(self):
if isinstance(self.lhs, ColPairs):
return repr(self.lhs.field.name)
Expand All @@ -70,11 +82,19 @@ def process_lhs(self, compiler, connection, lhs=None):
return sql, params

def process_rhs(self, compiler, connection):
values = [
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, self.rhs)
]
return Tuple(*values).as_sql(compiler, connection)
if self.rhs_is_direct_value():
args = [
Value(val, output_field=col.output_field)
for col, val in zip(self.lhs, self.rhs)
]
return Tuple(*args).as_sql(compiler, connection)
else:
sql, params = compiler.compile(self.rhs)
if not isinstance(self.rhs, ColPairs):
raise ValueError(
"Composite field lookups only work with composite expressions."
)
return "(%s)" % sql, params


class TupleExact(TupleLookupMixin, Exact):
Expand Down
51 changes: 50 additions & 1 deletion tests/composite_pk/test_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db.models import F, TextField
from django.db.models import F, FilteredRelation, OuterRef, Q, Subquery, TextField
from django.db.models.functions import Cast
from django.db.models.lookups import Exact
from django.test import TestCase

from .models import Comment, Tenant, User
Expand Down Expand Up @@ -407,3 +408,51 @@ def test_cannot_cast_pk(self):
msg = "Cast does not support composite primary keys."
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.filter(text__gt=Cast(F("pk"), TextField())).count()

def test_outer_ref_pk(self):
subquery = Subquery(Comment.objects.filter(pk=OuterRef("pk")).values("id"))
tests = [
("", 5),
("__gt", 0),
("__gte", 5),
("__lt", 0),
("__lte", 5),
]
for lookup, expected_count in tests:
with self.subTest(f"id{lookup}"):
queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
self.assertEqual(queryset.count(), expected_count)

def test_non_outer_ref_subquery(self):
# If rhs is any non-OuterRef object with an as_sql() function.
pk = Exact(F("tenant_id"), 1)
msg = (
"'exact' subquery lookup of 'pk' only supports OuterRef objects "
"(received 'Exact')"
)
with self.assertRaisesMessage(ValueError, msg):
Comment.objects.filter(pk=pk)

def test_outer_ref_not_composite_pk(self):
subquery = Comment.objects.filter(pk=OuterRef("id")).values("id")
queryset = Comment.objects.filter(id=Subquery(subquery))

msg = "Composite field lookups only work with composite expressions."
with self.assertRaisesMessage(ValueError, msg):
self.assertEqual(queryset.count(), 5)

def test_outer_ref_in_filtered_relation(self):
msg = (
"This queryset contains a reference to an outer query and may only be used "
"in a subquery."
)
with self.assertRaisesMessage(ValueError, msg):
self.assertSequenceEqual(
Tenant.objects.annotate(
filtered_tokens=FilteredRelation(
"tokens",
condition=Q(tokens__pk__gte=OuterRef("tokens")),
)
).filter(filtered_tokens=(1, 1)),
[self.tenant_1],
)

0 comments on commit 8bee7fa

Please sign in to comment.