-
Notifications
You must be signed in to change notification settings - Fork 7
/
__init__.py
158 lines (122 loc) · 4.78 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from weakref import WeakValueDictionary
from django.core import checks
from django.db import models
from django.db.models.fields import related_descriptors
if TYPE_CHECKING: # pragma: no cover
class DescriptorBase:
field: models.Field
def is_cached(self, instance: models.Model) -> bool: ...
def __get__(
self,
instance: models.Model | None,
instance_type: type[models.Model] | None = None,
) -> Any: ...
else:
DescriptorBase = object
class DescriptorMixin(DescriptorBase):
def _field_name(self) -> str:
return self.field.name
def _is_cached(self, instance: models.Model) -> bool:
return self.is_cached(instance)
def _should_prefetch(self, instance: models.Model | None) -> bool:
return (
instance is not None # getattr on the class passes None to the descriptor
and not self._is_cached(instance) # already loaded
and len(getattr(instance, "_peers", [])) >= 2 # no peers no prefetch
)
def __get__(
self,
instance: models.Model | None,
instance_type: type[models.Model] | None = None,
) -> Any:
if instance is not None and self._should_prefetch(instance):
prefetch = models.query.Prefetch(self._field_name())
peers = [p for p in instance._peers.values() if not self._is_cached(p)]
models.query.prefetch_related_objects(peers, prefetch)
return super().__get__(instance, instance_type)
class ForwardDescriptorMixin(DescriptorMixin):
def _should_prefetch(self, instance: models.Model | None) -> bool:
return super()._should_prefetch(
instance
) and None not in self.field.get_local_related_value(
instance
) # field is null
class ForwardManyToOneDescriptor(
ForwardDescriptorMixin, related_descriptors.ForwardManyToOneDescriptor
):
pass
class ForwardOneToOneDescriptor(
ForwardDescriptorMixin, related_descriptors.ForwardOneToOneDescriptor
):
pass
class ReverseOneToOneDescriptor(
DescriptorMixin, related_descriptors.ReverseOneToOneDescriptor
):
def _is_cached(self, instance: models.Model) -> bool:
return self.related.is_cached(instance)
def _field_name(self) -> str:
return self.related.get_accessor_name()
class ForeignKey(models.ForeignKey):
forward_related_accessor_class = ForwardManyToOneDescriptor
class OneToOneField(models.OneToOneField):
forward_related_accessor_class = ForwardOneToOneDescriptor
related_accessor_class = ReverseOneToOneDescriptor
class QuerySet(models.QuerySet):
def _fetch_all(self) -> None:
set_peers = self._result_cache is None
super()._fetch_all()
# ModelIterable tests for query sets returning model instances vs
# values or value lists etc
if (
set_peers
and issubclass(self._iterable_class, models.query.ModelIterable)
and len(self._result_cache) >= 2
):
peers = WeakValueDictionary((id(o), o) for o in self._result_cache)
for peer in peers.values():
peer._peers = peers
Manager = models.Manager.from_queryset(QuerySet)
class Model(models.Model):
class Meta:
abstract = True
base_manager_name = "prefetch_manager"
objects = Manager()
prefetch_manager = Manager()
def __getstate__(self) -> dict[str, Any]:
# drop the peers info when pickling etc
res = super().__getstate__()
if "_peers" not in res: # pragma: no cover
return res
res = dict(res)
del res["_peers"]
return res
@classmethod
def check(cls, **kwargs: Any) -> list[checks.Error]:
errors: list[checks.Error] = super().check(**kwargs)
errors.extend(cls._check_meta_inheritance())
return errors
@classmethod
def _check_meta_inheritance(cls) -> list[checks.Error]:
errors = []
base_manager_name = cls._meta.base_manager_name
if base_manager_name != "prefetch_manager":
errors.append(
checks.Error(
id="auto_prefetch.E001",
obj=cls,
msg=(
f"{cls.__name__} inherits from auto_prefetch.Model"
+ " but its base_manager_name is not"
+ " 'prefetch_manager'"
),
hint=(
f"The base_manager_name is instead {base_manager_name!r}."
+ " Check the Meta class inherits from"
+ " auto_prefetch.Model.Meta."
),
)
)
return errors