Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change how Private is implemented to better support type checkers #1437

Merged
merged 3 commits into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

This release changes how `strawberry.Private` is implemented to
improve support for type checkers.
6 changes: 4 additions & 2 deletions strawberry/experimental/pydantic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, List, Type

from strawberry.experimental.pydantic.exceptions import UnregisteredTypeException
from strawberry.private import Private
from strawberry.private import is_private
from strawberry.utils.typing import (
get_list_annotation,
get_optional_annotation,
Expand Down Expand Up @@ -30,7 +30,9 @@ def get_strawberry_type_from_model(type_: Any):

def get_private_fields(cls: Type) -> List[dataclasses.Field]:
private_fields: List[dataclasses.Field] = []

for field in dataclasses.fields(cls):
if isinstance(field.type, Private):
if is_private(field.type):
private_fields.append(field)

return private_fields
15 changes: 0 additions & 15 deletions strawberry/ext/mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ def strawberry_field_hook(ctx: FunctionContext) -> Type:
return AnyType(TypeOfAny.special_form)


def private_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type:
type_name = ctx.type.args[0]
type_ = ctx.api.analyze_type(type_name)

return type_


def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface):
if "." in name:
return api.named_type_or_none(name) # type: ignore
Expand Down Expand Up @@ -638,9 +631,6 @@ def get_type_analyze_hook(self, fullname: str):
if self._is_strawberry_lazy_type(fullname):
return lazy_type_analyze_callback

if self._is_strawberry_private(fullname):
return private_type_analyze_callback

return None

def get_class_decorator_hook(
Expand Down Expand Up @@ -682,11 +672,6 @@ def _is_strawberry_enum(self, fullname: str) -> bool:
def _is_strawberry_lazy_type(self, fullname: str) -> bool:
return fullname == "strawberry.lazy_type.LazyType"

def _is_strawberry_private(self, fullname: str) -> bool:
return fullname == "strawberry.private.Private" or fullname.endswith(
"strawberry.Private"
)

def _is_strawberry_decorator(self, fullname: str) -> bool:
if any(
strawberry_decorator in fullname
Expand Down
46 changes: 25 additions & 21 deletions strawberry/private.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
class Private:
"""Represent a private field that won't be converted into a GraphQL field
from typing import TypeVar

Example:
from typing_extensions import Annotated, get_args, get_origin

>>> import strawberry
>>> @strawberry.type
... class User:
... name: str
... age: strawberry.Private[int]
"""

__slots__ = ("type",)
class StrawberryPrivate:
...

def __init__(self, type):
self.type = type

def __repr__(self):
if isinstance(self.type, type):
type_name = self.type.__name__
else:
# typing objects, e.g. List[int]
type_name = repr(self.type)
return f"strawberry.Private[{type_name}]"
T = TypeVar("T")

def __class_getitem__(cls, type):
return Private(type)
Private = Annotated[T, StrawberryPrivate()]
Private.__doc__ = """Represent a private field that won't be converted into a GraphQL field

Example:

>>> import strawberry
>>> @strawberry.type
... class User:
... name: str
... age: strawberry.Private[int]
"""


def is_private(type_: object) -> bool:
if get_origin(type_) is Annotated:
return any(
isinstance(argument, StrawberryPrivate) for argument in get_args(type_)
)

return False
6 changes: 3 additions & 3 deletions strawberry/types/type_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
PrivateStrawberryFieldError,
)
from strawberry.field import StrawberryField
from strawberry.private import Private
from strawberry.private import is_private

from ..arguments import UNSET

Expand Down Expand Up @@ -81,7 +81,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o

if isinstance(field, StrawberryField):
# Check that the field type is not Private
if isinstance(field.type, Private):
if is_private(field.type):
raise PrivateStrawberryFieldError(field.python_name, cls.__name__)

# Check that default is not set if a resolver is defined
Expand Down Expand Up @@ -125,7 +125,7 @@ class if one is not set by either using an explicit strawberry.field(name=...) o
# Create a StrawberryField for fields that didn't use strawberry.field
else:
# Only ignore Private fields that weren't defined using StrawberryFields
if isinstance(field.type, Private):
if is_private(field.type):
continue

field_type = field.type
Expand Down
50 changes: 50 additions & 0 deletions tests/pyright/test_private.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from .utils import Result, requires_pyright, run_pyright, skip_on_windows


pytestmark = [skip_on_windows, requires_pyright]


CODE = """
import strawberry
@strawberry.type
class User:
name: str
age: strawberry.Private[int]
patrick = User(name="Patrick", age=1)
User(n="Patrick")
reveal_type(patrick.name)
reveal_type(patrick.age)
"""


def test_pyright():
results = run_pyright(CODE)

assert results == [
Result(
type="error",
message='No parameter named "n" (reportGeneralTypeIssues)',
line=12,
column=6,
),
Result(
type="error",
message=(
"Arguments missing for parameters "
'"name", "age" (reportGeneralTypeIssues)'
),
line=12,
column=1,
),
Result(
type="info", message='Type of "patrick.name" is "str"', line=14, column=13
),
Result(
type="info", message='Type of "patrick.age" is "int"', line=15, column=13
),
]