Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support returning the correct values for the different QuerySet methods when using .values() and .values_list(). #33

Merged
merged 11 commits into from
Mar 10, 2019
Merged
4 changes: 2 additions & 2 deletions django-stubs/db/models/manager.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from django.db.models.query import QuerySet

_T = TypeVar("_T", bound=Model, covariant=True)

class BaseManager(QuerySet[_T]):
class BaseManager(QuerySet[_T, _T]):
creation_counter: int = ...
auto_created: bool = ...
use_in_migrations: bool = ...
Expand All @@ -21,7 +21,7 @@ class BaseManager(QuerySet[_T]):
def _get_queryset_methods(cls, queryset_class: type) -> Dict[str, Any]: ...
def contribute_to_class(self, model: Type[Model], name: str) -> None: ...
def db_manager(self, using: Optional[str] = ..., hints: Optional[Dict[str, Model]] = ...) -> Manager: ...
def get_queryset(self) -> QuerySet[_T]: ...
def get_queryset(self) -> QuerySet[_T, _T]: ...

class Manager(BaseManager[_T]): ...

Expand Down
89 changes: 48 additions & 41 deletions django-stubs/db/models/query.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ from typing import (
TypeVar,
Union,
overload,
Generic,
NamedTuple,
Collection,
)

from django.db.models.base import Model
Expand Down Expand Up @@ -46,7 +49,7 @@ class FlatValuesListIterable(BaseIterable):

_T = TypeVar("_T", bound=models.Model, covariant=True)

class QuerySet(Iterable[_T], Sized):
class QuerySet(Generic[_T, _Row], Collection[_Row]):
query: Query
def __init__(
self,
Expand All @@ -58,32 +61,31 @@ class QuerySet(Iterable[_T], Sized):
@classmethod
def as_manager(cls) -> Manager[Any]: ...
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __iter__(self) -> Iterator[_Row]: ...
def __contains__(self, x: object) -> bool: ...
@overload
def __getitem__(self, i: int) -> _Row: ...
@overload
def __getitem__(self, s: slice) -> QuerySet[_T, _Row]: ...
def __bool__(self) -> bool: ...
def __class_getitem__(cls, item: Type[_T]):
pass
def __getstate__(self) -> Dict[str, Any]: ...
@overload
def __getitem__(self, k: int) -> _T: ...
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
def __and__(self, other: QuerySet) -> QuerySet: ...
def __or__(self, other: QuerySet) -> QuerySet: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_T]: ...
def iterator(self, chunk_size: int = ...) -> Iterator[_Row]: ...
def aggregate(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: ...
def get(self, *args: Any, **kwargs: Any) -> _T: ...
def get(self, *args: Any, **kwargs: Any) -> _Row: ...
def create(self, **kwargs: Any) -> _T: ...
def bulk_create(self, objs: Iterable[Model], batch_size: Optional[int] = ...) -> List[_T]: ...
def get_or_create(self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any) -> Tuple[_T, bool]: ...
def update_or_create(
self, defaults: Optional[MutableMapping[str, Any]] = ..., **kwargs: Any
) -> Tuple[_T, bool]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _T: ...
def first(self) -> Optional[_T]: ...
def last(self) -> Optional[_T]: ...
def earliest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def latest(self, *fields: Any, field_name: Optional[Any] = ...) -> _Row: ...
def first(self) -> Optional[_Row]: ...
def last(self) -> Optional[_Row]: ...
def in_bulk(self, id_list: Iterable[Any] = ..., *, field_name: str = ...) -> Dict[Any, _T]: ...
def delete(self) -> Tuple[int, Dict[str, int]]: ...
def update(self, **kwargs: Any) -> int: ...
Expand All @@ -93,31 +95,36 @@ class QuerySet(Iterable[_T], Sized):
def raw(
self, raw_query: str, params: Any = ..., translations: Optional[Dict[str, str]] = ..., using: None = ...
) -> RawQuerySet: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet: ...
def values_list(self, *fields: Union[str, Combinable], flat: bool = ..., named: bool = ...) -> QuerySet: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable], named: Literal[True]) -> NamedValuesListIterable: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable], flat: Literal[True]) -> FlatValuesListIterable: ...
# @overload
# def values_list(self, *fields: Union[str, Combinable]) -> ValuesListIterable: ...
def values(self, *fields: Union[str, Combinable], **expressions: Any) -> QuerySet[_T, Dict[str, Any]]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[True]
) -> QuerySet[_T, NamedTuple]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[True], named: Literal[False] = ...
) -> QuerySet[_T, Any]: ...
@overload
def values_list(
self, *fields: Union[str, Combinable], flat: Literal[False] = ..., named: Literal[False] = ...
) -> QuerySet[_T, Tuple]: ...
def dates(self, field_name: str, kind: str, order: str = ...) -> QuerySet: ...
def datetimes(self, field_name: str, kind: str, order: str = ..., tzinfo: None = ...) -> QuerySet: ...
def none(self) -> QuerySet[_T]: ...
def all(self) -> QuerySet[_T]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T]: ...
def none(self) -> QuerySet[_T, _Row]: ...
def all(self) -> QuerySet[_T, _Row]: ...
def filter(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def exclude(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def complex_filter(self, filter_obj: Any) -> QuerySet[_T, _Row]: ...
def count(self) -> int: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T]: ...
def union(self, *other_qs: Any, all: bool = ...) -> QuerySet[_T, _Row]: ...
def intersection(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
def difference(self, *other_qs: Any) -> QuerySet[_T, _Row]: ...
def select_for_update(self, nowait: bool = ..., skip_locked: bool = ..., of: Tuple = ...) -> QuerySet: ...
def select_related(self, *fields: Any) -> QuerySet[_T]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T]: ...
def select_related(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def prefetch_related(self, *lookups: Any) -> QuerySet[_T, _Row]: ...
def annotate(self, *args: Any, **kwargs: Any) -> QuerySet[_T, _Row]: ...
def order_by(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def distinct(self, *field_names: Any) -> QuerySet[_T, _Row]: ...
def extra(
self,
select: Optional[Dict[str, Any]] = ...,
Expand All @@ -126,11 +133,11 @@ class QuerySet(Iterable[_T], Sized):
tables: Optional[List[str]] = ...,
order_by: Optional[Sequence[str]] = ...,
select_params: Optional[Sequence[Any]] = ...,
) -> QuerySet[_T]: ...
def reverse(self) -> QuerySet[_T]: ...
def defer(self, *fields: Any) -> QuerySet[_T]: ...
def only(self, *fields: Any) -> QuerySet[_T]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T]: ...
) -> QuerySet[_T, _Row]: ...
def reverse(self) -> QuerySet[_T, _Row]: ...
def defer(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def only(self, *fields: Any) -> QuerySet[_T, _Row]: ...
def using(self, alias: Optional[str]) -> QuerySet[_T, _Row]: ...
@property
def ordered(self) -> bool: ...
@property
Expand Down Expand Up @@ -159,7 +166,7 @@ class RawQuerySet(Iterable[_T], Sized):
@overload
def __getitem__(self, k: str) -> Any: ...
@overload
def __getitem__(self, k: slice) -> QuerySet[_T]: ...
def __getitem__(self, k: slice) -> RawQuerySet[_T]: ...
@property
def columns(self) -> List[str]: ...
@property
Expand Down
4 changes: 2 additions & 2 deletions django-stubs/shortcuts.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ def redirect(

_T = TypeVar("_T", bound=Model)

def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> _T: ...
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T]], *args: Any, **kwargs: Any) -> List[_T]: ...
def get_object_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> _T: ...
def get_list_or_404(klass: Union[Type[_T], Manager[_T], QuerySet[_T, _T]], *args: Any, **kwargs: Any) -> List[_T]: ...
def resolve_url(to: Union[Callable, Model, str], *args: Any, **kwargs: Any) -> str: ...
24 changes: 6 additions & 18 deletions scripts/typecheck_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,10 @@
'Argument "is_dst" to "localize" of "BaseTzInfo" has incompatible type "None"; expected "bool"'
],
'aggregation': [
'Incompatible types in assignment (expression has type "QuerySet[Any]", variable has type "List[Any]")',
'"as_sql" undefined in superclass',
'Incompatible types in assignment (expression has type "FlatValuesListIterable", '
+ 'variable has type "ValuesListIterable")',
'Incompatible type for "contact" of "Book" (got "Optional[Author]", expected "Union[Author, Combinable]")',
'Incompatible type for "publisher" of "Book" (got "Optional[Publisher]", expected "Union[Publisher, Combinable]")'
],
'aggregation_regress': [
'Incompatible types in assignment (expression has type "List[str]", variable has type "QuerySet[Author]")',
'Incompatible types in assignment (expression has type "FlatValuesListIterable", variable has type "QuerySet[Any]")',
'Too few arguments for "count" of "Sequence"'
],
'apps': [
'Incompatible types in assignment (expression has type "str", target has type "type")',
'"Callable[[bool, bool], List[Type[Model]]]" has no attribute "cache_clear"'
Expand Down Expand Up @@ -152,9 +144,6 @@
'db_typecasts': [
'"object" has no attribute "__iter__"; maybe "__str__" or "__dir__"? (not iterable)'
],
'expressions': [
'Argument 1 to "Subquery" has incompatible type "Sequence[Dict[str, Any]]"; expected "QuerySet[Any]"'
],
'from_db_value': [
'has no attribute "vendor"'
],
Expand Down Expand Up @@ -192,9 +181,9 @@
],
'get_object_or_404': [
'Argument 1 to "get_object_or_404" has incompatible type "str"; '
+ 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>]]"',
+ 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>, <nothing>]]"',
'Argument 1 to "get_list_or_404" has incompatible type "List[Type[Article]]"; '
+ 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>]]"',
+ 'expected "Union[Type[<nothing>], Manager[<nothing>], QuerySet[<nothing>, <nothing>]]"',
'CustomClass'
],
'get_or_create': [
Expand All @@ -221,9 +210,6 @@
'many_to_one': [
'Incompatible type for "parent" of "Child" (got "None", expected "Union[Parent, Combinable]")'
],
'model_inheritance_regress': [
'Incompatible types in assignment (expression has type "List[Supplier]", variable has type "QuerySet[Supplier]")'
],
'model_meta': [
'"object" has no attribute "items"',
'"Field" has no attribute "many_to_many"'
Expand Down Expand Up @@ -296,7 +282,9 @@
],
'queries': [
'Incompatible types in assignment (expression has type "None", variable has type "str")',
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"'
'Invalid index type "Optional[str]" for "Dict[str, int]"; expected type "str"',
'No overload variant of "values_list" of "QuerySet" matches argument types "str", "bool", "bool"',
"note: "
],
'requests': [
'Incompatible types in assignment (expression has type "Dict[str, str]", variable has type "QueryDict")'
Expand All @@ -305,7 +293,7 @@
'Argument 1 to "TextIOWrapper" has incompatible type "HttpResponse"; expected "IO[bytes]"'
],
'prefetch_related': [
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room]")',
'Incompatible types in assignment (expression has type "List[Room]", variable has type "QuerySet[Room, Room]")',
'"None" has no attribute "__iter__"',
'has no attribute "read_by"'
],
Expand Down
75 changes: 75 additions & 0 deletions test-data/typecheck/queryset.test
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,79 @@ reveal_type(Blog.objects.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any,
reveal_type(Blog.objects.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(Blog.objects.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'


qs = Blog.objects.all()
reveal_type(qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]'
reveal_type(iter(qs)) # E: Revealed type is 'typing.Iterator[main.Blog*]'
reveal_type(qs.iterator()) # E: Revealed type is 'typing.Iterator[main.Blog*]'
reveal_type(qs.first()) # E: Revealed type is 'Union[main.Blog*, None]'
reveal_type(qs.last()) # E: Revealed type is 'Union[main.Blog*, None]'
reveal_type(qs.earliest()) # E: Revealed type is 'main.Blog*'
reveal_type(qs.latest()) # E: Revealed type is 'main.Blog*'
reveal_type(qs[0]) # E: Revealed type is 'main.Blog*'
reveal_type(qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, main.Blog*]'
reveal_type(qs.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(qs.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'


values_qs = Blog.objects.values()
reveal_type(values_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict[builtins.str, Any]]'
reveal_type(iter(values_qs)) # E: Revealed type is 'typing.Iterator[builtins.dict*[builtins.str, Any]]'
reveal_type(values_qs.iterator()) # E: Revealed type is 'typing.Iterator[builtins.dict*[builtins.str, Any]]'
reveal_type(values_qs.first()) # E: Revealed type is 'Union[builtins.dict*[builtins.str, Any], None]'
reveal_type(values_qs.last()) # E: Revealed type is 'Union[builtins.dict*[builtins.str, Any], None]'
reveal_type(values_qs.earliest()) # E: Revealed type is 'builtins.dict*[builtins.str, Any]'
reveal_type(values_qs.latest()) # E: Revealed type is 'builtins.dict*[builtins.str, Any]'
reveal_type(values_qs[0]) # E: Revealed type is 'builtins.dict*[builtins.str, Any]'
reveal_type(values_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.dict*[builtins.str, Any]]'
reveal_type(values_qs.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(values_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(values_qs.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'


values_list_qs = Blog.objects.values_list('id', 'slug')
reveal_type(values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple[Any]]'
reveal_type(iter(values_list_qs)) # E: Revealed type is 'typing.Iterator[builtins.tuple*[Any]]'
reveal_type(values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[builtins.tuple*[Any]]'
reveal_type(values_list_qs.first()) # E: Revealed type is 'Union[builtins.tuple*[Any], None]'
reveal_type(values_list_qs.last()) # E: Revealed type is 'Union[builtins.tuple*[Any], None]'
reveal_type(values_list_qs.earliest()) # E: Revealed type is 'builtins.tuple*[Any]'
reveal_type(values_list_qs.latest()) # E: Revealed type is 'builtins.tuple*[Any]'
reveal_type(values_list_qs[0]) # E: Revealed type is 'builtins.tuple*[Any]'
reveal_type(values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, builtins.tuple*[Any]]'
reveal_type(values_list_qs.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(values_list_qs.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'


flat_values_list_qs = Blog.objects.values_list('id', flat=True)
reveal_type(flat_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]'
reveal_type(iter(flat_values_list_qs)) # E: Revealed type is 'typing.Iterator[Any]'
reveal_type(flat_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[Any]'
reveal_type(flat_values_list_qs.first()) # E: Revealed type is 'Union[Any, None]'
reveal_type(flat_values_list_qs.last()) # E: Revealed type is 'Union[Any, None]'
reveal_type(flat_values_list_qs.earliest()) # E: Revealed type is 'Any'
reveal_type(flat_values_list_qs.latest()) # E: Revealed type is 'Any'
reveal_type(flat_values_list_qs[0]) # E: Revealed type is 'Any'
reveal_type(flat_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, Any]'
reveal_type(flat_values_list_qs.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(flat_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(flat_values_list_qs.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'


named_values_list_qs = Blog.objects.values_list('id', named=True)
reveal_type(named_values_list_qs) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple]'
reveal_type(iter(named_values_list_qs)) # E: Revealed type is 'typing.Iterator[typing.NamedTuple*]'
reveal_type(named_values_list_qs.iterator()) # E: Revealed type is 'typing.Iterator[typing.NamedTuple*]'
reveal_type(named_values_list_qs.first()) # E: Revealed type is 'Union[typing.NamedTuple*, None]'
reveal_type(named_values_list_qs.last()) # E: Revealed type is 'Union[typing.NamedTuple*, None]'
reveal_type(named_values_list_qs.earliest()) # E: Revealed type is 'typing.NamedTuple*'
reveal_type(named_values_list_qs.latest()) # E: Revealed type is 'typing.NamedTuple*'
reveal_type(named_values_list_qs[0]) # E: Revealed type is 'typing.NamedTuple*'
reveal_type(named_values_list_qs[:9]) # E: Revealed type is 'django.db.models.query.QuerySet[main.Blog*, typing.NamedTuple*]'
reveal_type(named_values_list_qs.in_bulk([1])) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(named_values_list_qs.in_bulk()) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'
reveal_type(named_values_list_qs.in_bulk(['beatles_blog'], field_name='slug')) # E: Revealed type is 'builtins.dict[Any, main.Blog*]'

[out]