diff --git a/drf_spectacular/contrib/pydantic.py b/drf_spectacular/contrib/pydantic.py index f03dda6a..395a8a9a 100644 --- a/drf_spectacular/contrib/pydantic.py +++ b/drf_spectacular/contrib/pydantic.py @@ -23,6 +23,7 @@ def get_name(self, auto_schema, direction): # of the entry model, we simply use the class name as string for object. This hack may # create false positive warnings, so turn it off. However, this may suppress correct # warnings involving the entry class. + # TODO suppression may be migrated to new ComponentIdentity system set_override(self.target, 'suppress_collision_warning', True) return self.target.__name__ diff --git a/drf_spectacular/contrib/rest_framework_dataclasses.py b/drf_spectacular/contrib/rest_framework_dataclasses.py index 760cdca5..95a8adee 100644 --- a/drf_spectacular/contrib/rest_framework_dataclasses.py +++ b/drf_spectacular/contrib/rest_framework_dataclasses.py @@ -1,6 +1,8 @@ +from typing import Any + from drf_spectacular.drainage import get_override, has_override from drf_spectacular.extensions import OpenApiSerializerExtension -from drf_spectacular.plumbing import get_doc +from drf_spectacular.plumbing import ComponentIdentity, get_doc from drf_spectacular.utils import Direction @@ -18,6 +20,9 @@ def get_name(self): return get_override(self.target.dataclass_definition.dataclass_type, 'component_name') return self.target.dataclass_definition.dataclass_type.__name__ + def get_identity(self, auto_schema, direction: Direction) -> Any: + return ComponentIdentity(self.target.dataclass_definition.dataclass_type) + def strip_library_doc(self, schema): """Strip the DataclassSerializer library documentation from the schema.""" from rest_framework_dataclasses.serializers import DataclassSerializer diff --git a/drf_spectacular/contrib/rest_polymorphic.py b/drf_spectacular/contrib/rest_polymorphic.py index 943e0799..1e725c52 100644 --- a/drf_spectacular/contrib/rest_polymorphic.py +++ b/drf_spectacular/contrib/rest_polymorphic.py @@ -1,7 +1,8 @@ from drf_spectacular.drainage import warn from drf_spectacular.extensions import OpenApiSerializerExtension from drf_spectacular.plumbing import ( - ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer, + ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type, + is_patched_serializer, ) from drf_spectacular.settings import spectacular_settings from drf_spectacular.types import OpenApiTypes @@ -25,7 +26,7 @@ def map_serializer(self, auto_schema, direction): component = ResolvedComponent( name=auto_schema._get_serializer_name(sub_serializer, direction), type=ResolvedComponent.SCHEMA, - object='virtual' + object=ComponentIdentity('virtual') ) typed_component = self.build_typed_component( auto_schema=auto_schema, diff --git a/drf_spectacular/extensions.py b/drf_spectacular/extensions.py index 052be3a0..1eae1c6d 100644 --- a/drf_spectacular/extensions.py +++ b/drf_spectacular/extensions.py @@ -68,6 +68,10 @@ def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[ """ return str for overriding default name extraction """ return None + def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any: + """ return anything to compare instances of target. Target will be used by default. """ + return None + def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType: """ override for customized serializer mapping """ return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True) diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index f4f14f3b..f53b6493 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -1478,12 +1478,13 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire and is_serializer(serializer) and (not is_list_serializer(serializer) or is_serializer(serializer.child)) ): - paginated_name = self.get_paginated_name(self._get_serializer_name(serializer, "response")) component = ResolvedComponent( - name=paginated_name, + name=self.get_paginated_name(self._get_serializer_name(serializer, 'response')), type=ResolvedComponent.SCHEMA, schema=paginator.get_paginated_response_schema(schema), - object=serializer.child if is_list_serializer(serializer) else serializer, + object=self.get_serializer_identity( + serializer.child if is_list_serializer(serializer) else serializer, 'response' + ) ) self.registry.register_on_missing(component) schema = component.ref @@ -1556,7 +1557,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> _ return result + def get_serializer_identity(self, serializer, direction: Direction) -> Any: + serializer_extension = OpenApiSerializerExtension.get_match(serializer) + if serializer_extension: + identity = serializer_extension.get_identity(self, direction) + if identity is not None: + return identity + + return serializer + def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str: + """ override this for custom behaviour """ return serializer.__class__.__name__ def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str: @@ -1612,7 +1623,7 @@ def resolve_serializer( component = ResolvedComponent( name=self._get_serializer_name(serializer, direction, bypass_extensions), type=ResolvedComponent.SCHEMA, - object=serializer, + object=self.get_serializer_identity(serializer, direction), ) if component in self.registry: return self.registry[component] # return component with schema diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 74753efa..df1aa03c 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -723,6 +723,17 @@ def ref(self) -> _SchemaType: return {'$ref': f'#/components/{self.type}/{self.name}'} +class ComponentIdentity: + """ A container class to make object/component comparison explicit """ + def __init__(self, obj): + self.obj = obj + + def __eq__(self, other): + if isinstance(other, ComponentIdentity): + return self.obj == other.obj + return self.obj == other + + class ComponentRegistry: def __init__(self) -> None: self._components: Dict[Tuple[str, str], ResolvedComponent] = {} @@ -746,17 +757,25 @@ def __contains__(self, component): query_obj = component.object registry_obj = self._components[component.key].object - query_class = query_obj if inspect.isclass(query_obj) else query_obj.__class__ - registry_class = query_obj if inspect.isclass(registry_obj) else registry_obj.__class__ + + if isinstance(query_obj, ComponentIdentity) or inspect.isclass(query_obj): + query_id = query_obj + else: + query_id = query_obj.__class__ + + if isinstance(registry_obj, ComponentIdentity) or inspect.isclass(registry_obj): + registry_id = registry_obj + else: + registry_id = registry_obj.__class__ suppress_collision_warning = ( - get_override(registry_class, 'suppress_collision_warning', False) - or get_override(query_class, 'suppress_collision_warning', False) + get_override(registry_id, 'suppress_collision_warning', False) + or get_override(query_id, 'suppress_collision_warning', False) ) - if query_class != registry_class and not suppress_collision_warning: + if query_id != registry_id and not suppress_collision_warning: warn( f'Encountered 2 components with identical names "{component.name}" and ' - f'different classes {query_class} and {registry_class}. This will very ' + f'different identities {query_id} and {registry_id}. This will very ' f'likely result in an incorrect schema. Try renaming one.' ) return True diff --git a/tests/contrib/test_rest_framework_dataclasses.py b/tests/contrib/test_rest_framework_dataclasses.py index 0062e122..d0dd8c91 100644 --- a/tests/contrib/test_rest_framework_dataclasses.py +++ b/tests/contrib/test_rest_framework_dataclasses.py @@ -90,3 +90,49 @@ def custom_name_via_serializer_decoration(request): generate_schema(None, patterns=urlpatterns), 'tests/contrib/test_rest_framework_dataclasses.yml' ) + + +@pytest.mark.contrib('rest_framework_dataclasses') +@pytest.mark.skipif(sys.version_info < (3, 7), reason='dataclass required by package') +def test_rest_framework_dataclasses_class_reuse(no_warnings): + from dataclasses import dataclass + + from rest_framework_dataclasses.serializers import DataclassSerializer + + @dataclass + class Person: + name: str + age: int + + @dataclass + class Party: + person: Person + num_persons: int + + class PartySerializer(DataclassSerializer[Party]): + class Meta: + dataclass = Party + + class PersonSerializer(DataclassSerializer[Person]): + class Meta: + dataclass = Person + + @extend_schema(responses=PartySerializer) + @api_view() + def party(request): + pass # pragma: no cover + + @extend_schema(responses=PersonSerializer) + @api_view() + def person(request): + pass # pragma: no cover + + urlpatterns = [ + path('party', person), + path('person', party), + ] + + schema = generate_schema(None, patterns=urlpatterns) + # just existence is enough to check since its about no_warnings + assert 'Person' in schema['components']['schemas'] + assert 'Party' in schema['components']['schemas'] diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 71c9539a..89c8ef44 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -49,7 +49,7 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet): generate_schema(None, patterns=router.urls) stderr = capsys.readouterr().err - assert 'Encountered 2 components with identical names "X" and different classes' in stderr + assert 'Encountered 2 components with identical names "X" and different identities' in stderr def test_owned_serializer_naming_override_with_ref_name_collision(warnings):