diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 116097ac7..760bde99d 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -173,16 +173,37 @@ def get_class_fullname(klass: type) -> str: def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: """ Return the expression for the specific argument. + + It first checks the function definition for the presence of the argument in its parameters, returning the associated + expression for the argument if found. This handles the case of getting the expression of `null` in + + >>> def first_case(null: bool=False) -> Any: ... + >>> first_case(null=True) + + If not found, it is searched for in the names of actual arguments in the call expression, returning the expression + for the argument if found. This handles the case of getting the expression of `null` in + + >>> def second_case(*args, **kwargs) -> Any: ... + >>> second_case(null=True) + This helper should only be used with non-star arguments. """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - args = ctx.args[idx] - if len(args) != 1: - # Either an error or no value passed. - return None - return args[0] + # first check for named arg on function definition + if name in ctx.callee_arg_names: + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + # check for named arg in function call keyword args + for arg_group_idx, arg_group in enumerate(ctx.arg_names): + if name in arg_group: + arg_name_idx = arg_group.index(name) + return ctx.args[arg_group_idx][arg_name_idx] + + return None def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: diff --git a/tests/typecheck/fields/test_custom_fields.yml b/tests/typecheck/fields/test_custom_fields.yml new file mode 100644 index 000000000..e106fc149 --- /dev/null +++ b/tests/typecheck/fields/test_custom_fields.yml @@ -0,0 +1,44 @@ +- case: test_custom_model_fields_with_passthrough_constructor + main: | + from myapp.models import User + user = User() + reveal_type(user.id) # N: Revealed type is "builtins.int" + reveal_type(user.my_custom_field1) # N: Revealed type is "Union[builtins.int, None]" + reveal_type(user.my_custom_field2) # N: Revealed type is "builtins.int" + reveal_type(user.my_custom_field3) # N: Revealed type is "builtins.int" + reveal_type(user.my_custom_field4) # N: Revealed type is "Union[builtins.int, None]" + reveal_type(user.my_custom_field5) # N: Revealed type is "builtins.int" + reveal_type(user.my_custom_field6) # N: Revealed type is "builtins.int" + monkeypatch: true + installed_apps: + - myapp + out: | + myapp/models:15: error: "__init__" of "Field" gets multiple values for keyword argument "blank" [misc] + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from django.db.models import fields + + from typing import Any, TypeVar + + _ST = TypeVar("_ST", contravariant=True) + _GT = TypeVar("_GT", covariant=True) + + class MyIntegerField(fields.IntegerField[_ST, _GT]): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + class MyIntegerFieldBlank(fields.IntegerField[_ST, _GT]): + def __init__(self, *args: Any, blank: bool = True, **kwargs: Any) -> None: + super().__init__(*args, blank=blank, **kwargs) + + class User(models.Model): + id = models.AutoField(primary_key=True) + my_custom_field1 = MyIntegerField(null=True) + my_custom_field2 = MyIntegerField(null=False) + my_custom_field3 = MyIntegerField() + my_custom_field4 = MyIntegerFieldBlank(null=True) + my_custom_field5 = MyIntegerFieldBlank(null=False) + my_custom_field6 = MyIntegerFieldBlank() diff --git a/tests/typecheck/managers/querysets/test_custom_queryset.yml b/tests/typecheck/managers/querysets/test_custom_queryset.yml new file mode 100644 index 000000000..770661cb4 --- /dev/null +++ b/tests/typecheck/managers/querysets/test_custom_queryset.yml @@ -0,0 +1,42 @@ +- case: test_custom_queryset_with_passthrough_values_list + main: | + from typing import Any, TypeVar + from django.db.models.base import Model + from django.db.models.query import QuerySet + from myapp.models import MyUser + + _Model = TypeVar("_Model", bound=Model, covariant=True) + + class CustomQuerySet(QuerySet[_Model]): + def values_list(self, *args: Any, **kwargs: Any) -> QuerySet[_Model]: + return super().values_list(*args, **kwargs) + + qs = CustomQuerySet[MyUser](model=MyUser) + + # checking that the CustomQuerySet returns same types as MyUser's qs when using "flat" and "named" args which use + # "get_call_argument_by_name" helper function in plugin + reveal_type(MyUser.objects.values_list('name').get()) # N: Revealed type is "Tuple[builtins.str]" + reveal_type(qs.values_list('name').get()) # N: Revealed type is "Tuple[builtins.str]" + + reveal_type(MyUser.objects.values_list('name', flat=True).get()) # N: Revealed type is "builtins.str" + reveal_type(qs.values_list('name', flat=True).get()) # N: Revealed type is "builtins.str" + + reveal_type(MyUser.objects.values_list('name', named=True).get()) # N: Revealed type is "Tuple[builtins.str, fallback=main.Row]" + reveal_type(qs.values_list('name', named=True).get()) # N: Revealed type is "Tuple[builtins.str, fallback=main.Row1]" + + reveal_type(MyUser.objects.values_list('name', flat=True, named=True).get()) + reveal_type(qs.values_list('name', flat=True, named=True).get()) + out: | + main:25: error: 'flat' and 'named' can't be used together [misc] + main:25: note: Revealed type is "Any" + main:26: error: 'flat' and 'named' can't be used together [misc] + main:26: note: Revealed type is "Any" + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyUser(models.Model): + name = models.CharField(max_length=100)