diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8866d3cb5e..71567f36ab 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.10 + rev: v0.6.3 hooks: - id: ruff-format exclude: ^tests/\w+/snapshots/ @@ -31,7 +31,7 @@ repos: args: ["--branch", "main"] - repo: https://github.com/adamchainz/blacken-docs - rev: 1.16.0 + rev: 1.18.0 hooks: - id: blacken-docs args: [--skip-errors] diff --git a/pyproject.toml b/pyproject.toml index 8f4e3d1b66..3ce106863a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -219,7 +219,7 @@ exclude = [ "dist", "node_modules", "venv", - "tests/codegen/snapshots" + "tests/*/snapshots" ] src = ["strawberry", "tests"] diff --git a/strawberry/dataloader.py b/strawberry/dataloader.py index 8a04628930..ce2d247b27 100644 --- a/strawberry/dataloader.py +++ b/strawberry/dataloader.py @@ -208,14 +208,11 @@ def prime_many(self, data: Mapping[K, T], force: bool = False) -> None: def should_create_new_batch(loader: DataLoader, batch: Batch) -> bool: - if ( + return bool( batch.dispatched or loader.max_batch_size and len(batch) >= loader.max_batch_size - ): - return True - - return False + ) def get_current_batch(loader: DataLoader) -> Batch: diff --git a/strawberry/types/union.py b/strawberry/types/union.py index f1e32fba68..f5d8c6210c 100644 --- a/strawberry/types/union.py +++ b/strawberry/types/union.py @@ -235,10 +235,7 @@ def is_valid_union_type(type_: object) -> bool: if isinstance(type_, StrawberryUnion): return True - if get_origin(type_) is Annotated: - return True - - return False + return get_origin(type_) is Annotated def union( diff --git a/strawberry/utils/typing.py b/strawberry/utils/typing.py index 82712f29ab..1b050bbf4b 100644 --- a/strawberry/utils/typing.py +++ b/strawberry/utils/typing.py @@ -84,7 +84,7 @@ def is_list(annotation: object) -> bool: """Returns True if annotation is a List.""" annotation_origin = getattr(annotation, "__origin__", None) - return annotation_origin == list + return annotation_origin is list def is_union(annotation: object) -> bool: @@ -313,7 +313,7 @@ def _get_namespace_from_ast( # can properly resolve it later type_name = args[0].strip(" '\"\n") for arg in args[1:]: - evaled_arg = eval(arg, globalns, localns) # noqa: PGH001, S307 + evaled_arg = eval(arg, globalns, localns) # noqa: S307 if isinstance(evaled_arg, StrawberryLazyReference): extra[type_name] = evaled_arg.resolve_forward_ref(ForwardRef(type_name)) diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index 2412e6b04c..6d87c690a7 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -64,10 +64,6 @@ class User: ... class Query: a: User = strawberry.field() - @strawberry.field - def a(self) -> User: - return User() - schema = strawberry.Schema(Query) expected = """ @@ -95,10 +91,6 @@ class User: ... class Query: a: User = strawberry.field() - @strawberry.field - def a(self) -> User: - return User() - schema = strawberry.Schema(Query) expected = """ diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index cc9b5a81e0..745e5e90c4 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -228,7 +228,7 @@ class User: assert len(definition.fields) == 1 assert definition.fields[0].python_name == "age" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == int + assert definition.fields[0].type is int def test_can_convert_pydantic_type_with_nested_data_to_strawberry(): diff --git a/tests/fields/test_arguments.py b/tests/fields/test_arguments.py index 90cc15d2ed..f8c911806b 100644 --- a/tests/fields/test_arguments.py +++ b/tests/fields/test_arguments.py @@ -421,7 +421,7 @@ def name( assert argument.python_name == "argument" assert argument.graphql_name is None - assert argument.type == str + assert argument.type is str assert argument.description == "This is a description" assert argument.type is str diff --git a/tests/fields/test_resolvers.py b/tests/fields/test_resolvers.py index 3f6be393a6..c5611e369c 100644 --- a/tests/fields/test_resolvers.py +++ b/tests/fields/test_resolvers.py @@ -35,7 +35,7 @@ class Query: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str assert definition.fields[0].base_resolver.wrapped_func == get_name @@ -53,7 +53,7 @@ def name(self) -> str: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str assert definition.fields[0].base_resolver(None) == Query().name() @@ -72,7 +72,7 @@ def name() -> str: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str assert definition.fields[0].base_resolver() == Query.name() assert Query.name() == "Name" @@ -96,7 +96,7 @@ def val(cls) -> str: assert definition.fields[0].python_name == "val" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str assert definition.fields[0].base_resolver() == Query.val() assert Query.val() == "thingy" @@ -310,13 +310,13 @@ class Query: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None assert definition.fields[0].python_name == "name" - assert definition.fields[0].type == str + assert definition.fields[0].type is str assert definition.fields[0].base_resolver.wrapped_func == get_name assert definition.fields[1].python_name == "name_2" assert definition.fields[1].graphql_name is None assert definition.fields[1].python_name == "name_2" - assert definition.fields[1].type == str + assert definition.fields[1].type is str assert definition.fields[1].base_resolver.wrapped_func == get_name diff --git a/tests/http/clients/base.py b/tests/http/clients/base.py index 0ea062c34c..daf1106359 100644 --- a/tests/http/clients/base.py +++ b/tests/http/clients/base.py @@ -290,9 +290,8 @@ async def __aiter__(self) -> AsyncGenerator[Message, None]: class DebuggableGraphQLTransportWSMixin: - @staticmethod def on_init(self) -> None: - """This method can be patched by unittests to get the instance of the + """This method can be patched by unit tests to get the instance of the transport handler when it is initialized. """ diff --git a/tests/objects/test_interfaces.py b/tests/objects/test_interfaces.py index f915cc647a..c7c90c41a2 100644 --- a/tests/objects/test_interfaces.py +++ b/tests/objects/test_interfaces.py @@ -38,7 +38,7 @@ class User(Node): assert definition.fields[1].python_name == "name" assert definition.fields[1].graphql_name is None - assert definition.fields[1].type == str + assert definition.fields[1].type is str assert definition.is_interface is False assert definition.interfaces == [Node.__strawberry_definition__] @@ -68,7 +68,7 @@ class Person(Node): assert definition.fields[1].python_name == "name" assert definition.fields[1].graphql_name is None - assert definition.fields[1].type == str + assert definition.fields[1].type is str assert definition.is_interface is False assert definition.interfaces == [Node.__strawberry_definition__] @@ -84,7 +84,7 @@ class Person(Node): assert definition.fields[1].python_name == "name" assert definition.fields[1].graphql_name is None - assert definition.fields[1].type == str + assert definition.fields[1].type is str assert definition.is_interface is False assert definition.interfaces == [Node.__strawberry_definition__] diff --git a/tests/schema/extensions/test_mask_errors.py b/tests/schema/extensions/test_mask_errors.py index affa8b0e80..6d3c406c40 100644 --- a/tests/schema/extensions/test_mask_errors.py +++ b/tests/schema/extensions/test_mask_errors.py @@ -45,9 +45,7 @@ def hidden_error(self) -> str: def should_mask_error(error: GraphQLError) -> bool: original_error = error.original_error - if original_error and isinstance(original_error, VisibleError): - return False - return True + return not (original_error and isinstance(original_error, VisibleError)) schema = strawberry.Schema( query=Query, extensions=[MaskErrors(should_mask_error=should_mask_error)] diff --git a/tests/schema/test_info.py b/tests/schema/test_info.py index 0c3c23dc3c..3869d0d61d 100644 --- a/tests/schema/test_info.py +++ b/tests/schema/test_info.py @@ -320,7 +320,7 @@ def field(self, info: strawberry.Info) -> return_type: def test_return_type_from_field(): def resolver(info): - assert info.return_type == int + assert info.return_type is int return 0 @strawberry.type diff --git a/tests/schema/test_permission.py b/tests/schema/test_permission.py index 0cec4cde60..23e54230d0 100644 --- a/tests/schema/test_permission.py +++ b/tests/schema/test_permission.py @@ -568,10 +568,7 @@ class IsAuthorized(BasePermission): def has_permission( self, source, info, **kwargs: typing.Any ) -> bool: # pragma: no cover - if kwargs["a_key"] == "secret": - return True - - return False + return kwargs["a_key"] == "secret" @strawberry.type class Query: diff --git a/tests/schema/test_private_field.py b/tests/schema/test_private_field.py index 554f6b59fa..fd14c839af 100644 --- a/tests/schema/test_private_field.py +++ b/tests/schema/test_private_field.py @@ -22,7 +22,7 @@ class Query: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str instance = Query(name="Luke", age=22) assert instance.name == "Luke" diff --git a/tests/tools/test_create_type.py b/tests/tools/test_create_type.py index 5f4f99518e..f483df8e49 100644 --- a/tests/tools/test_create_type.py +++ b/tests/tools/test_create_type.py @@ -25,7 +25,7 @@ def name() -> str: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str def test_create_type_extend_and_directives(): @@ -52,7 +52,7 @@ def name() -> str: assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str def test_create_input_type(): @@ -73,7 +73,7 @@ def test_create_input_type(): assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str def test_create_interface_type(): @@ -95,7 +95,7 @@ def test_create_interface_type(): assert definition.fields[0].python_name == "name" assert definition.fields[0].graphql_name is None - assert definition.fields[0].type == str + assert definition.fields[0].type is str def test_create_variable_type(): @@ -111,7 +111,7 @@ def get_name() -> str: assert definition.fields[0].python_name == "get_name" assert definition.fields[0].graphql_name == "name" - assert definition.fields[0].type == str + assert definition.fields[0].type is str def test_create_type_empty_list(): diff --git a/tests/types/test_annotation.py b/tests/types/test_annotation.py index 6ca9076642..969dddf871 100644 --- a/tests/types/test_annotation.py +++ b/tests/types/test_annotation.py @@ -64,7 +64,7 @@ def __eq__(self, other): def test_eq_on_non_annotation(): - assert StrawberryAnnotation(int) != int + assert StrawberryAnnotation(int) is not int assert StrawberryAnnotation(int) != 123 @@ -72,4 +72,4 @@ def test_set_anntation(): annotation = StrawberryAnnotation(int) annotation.annotation = str - assert annotation.annotation == str + assert annotation.annotation is str diff --git a/tests/types/test_argument_types.py b/tests/types/test_argument_types.py index 3cede4fdd3..6ee0aeb4b5 100644 --- a/tests/types/test_argument_types.py +++ b/tests/types/test_argument_types.py @@ -60,7 +60,7 @@ def get_name(id_: int) -> str: return "Lord Buckethead" argument = get_name.arguments[0] - assert argument.type == int + assert argument.type is int def test_object(): diff --git a/tests/types/test_object_types.py b/tests/types/test_object_types.py index 444d3160ad..ed31aa1423 100644 --- a/tests/types/test_object_types.py +++ b/tests/types/test_object_types.py @@ -63,7 +63,7 @@ class Fabric: field: StrawberryField = get_object_definition(Fabric).fields[0] - assert field.type == str + assert field.type is str def test_object():