Skip to content

Commit

Permalink
parametrize component registry identity #1288
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed Sep 7, 2024
1 parent 2b4d5ab commit 8f2dfc5
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 14 deletions.
1 change: 1 addition & 0 deletions drf_spectacular/contrib/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_name(self, auto_schema, direction):
# of the entry model, we simply use the class name as string for object. This hack may
# create false positive warnings, so turn it off. However, this may suppress correct
# warnings involving the entry class.
# TODO suppression may be migrated to new ComponentIdentity system
set_override(self.target, 'suppress_collision_warning', True)
return self.target.__name__

Expand Down
7 changes: 6 additions & 1 deletion drf_spectacular/contrib/rest_framework_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

from drf_spectacular.drainage import get_override, has_override
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import get_doc
from drf_spectacular.plumbing import ComponentIdentity, get_doc
from drf_spectacular.utils import Direction


Expand All @@ -18,6 +20,9 @@ def get_name(self):
return get_override(self.target.dataclass_definition.dataclass_type, 'component_name')
return self.target.dataclass_definition.dataclass_type.__name__

def get_identity(self, auto_schema, direction: Direction) -> Any:
return ComponentIdentity(self.target.dataclass_definition.dataclass_type)

def strip_library_doc(self, schema):
"""Strip the DataclassSerializer library documentation from the schema."""
from rest_framework_dataclasses.serializers import DataclassSerializer
Expand Down
5 changes: 3 additions & 2 deletions drf_spectacular/contrib/rest_polymorphic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from drf_spectacular.drainage import warn
from drf_spectacular.extensions import OpenApiSerializerExtension
from drf_spectacular.plumbing import (
ResolvedComponent, build_basic_type, build_object_type, is_patched_serializer,
ComponentIdentity, ResolvedComponent, build_basic_type, build_object_type,
is_patched_serializer,
)
from drf_spectacular.settings import spectacular_settings
from drf_spectacular.types import OpenApiTypes
Expand All @@ -25,7 +26,7 @@ def map_serializer(self, auto_schema, direction):
component = ResolvedComponent(
name=auto_schema._get_serializer_name(sub_serializer, direction),
type=ResolvedComponent.SCHEMA,
object='virtual'
object=ComponentIdentity('virtual')
)
typed_component = self.build_typed_component(
auto_schema=auto_schema,
Expand Down
4 changes: 4 additions & 0 deletions drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[
""" return str for overriding default name extraction """
return None

def get_identity(self, auto_schema: 'AutoSchema', direction: Direction) -> Any:
""" return anything to compare instances of target. Target will be used by default. """
return None

def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction) -> _SchemaType:
""" override for customized serializer mapping """
return auto_schema._map_serializer(self.target_class, direction, bypass_extensions=True)
Expand Down
19 changes: 15 additions & 4 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,12 +1478,13 @@ def _get_response_for_code(self, serializer, status_code, media_types=None, dire
and is_serializer(serializer)
and (not is_list_serializer(serializer) or is_serializer(serializer.child))
):
paginated_name = self.get_paginated_name(self._get_serializer_name(serializer, "response"))
component = ResolvedComponent(
name=paginated_name,
name=self.get_paginated_name(self._get_serializer_name(serializer, 'response')),
type=ResolvedComponent.SCHEMA,
schema=paginator.get_paginated_response_schema(schema),
object=serializer.child if is_list_serializer(serializer) else serializer,
object=self.get_serializer_identity(
serializer.child if is_list_serializer(serializer) else serializer, 'response'
)
)
self.registry.register_on_missing(component)
schema = component.ref
Expand Down Expand Up @@ -1556,7 +1557,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> _

return result

def get_serializer_identity(self, serializer, direction: Direction) -> Any:
serializer_extension = OpenApiSerializerExtension.get_match(serializer)
if serializer_extension:
identity = serializer_extension.get_identity(self, direction)
if identity is not None:
return identity

return serializer

def get_serializer_name(self, serializer: serializers.Serializer, direction: Direction) -> str:
""" override this for custom behaviour """
return serializer.__class__.__name__

def _get_serializer_name(self, serializer, direction, bypass_extensions=False) -> str:
Expand Down Expand Up @@ -1612,7 +1623,7 @@ def resolve_serializer(
component = ResolvedComponent(
name=self._get_serializer_name(serializer, direction, bypass_extensions),
type=ResolvedComponent.SCHEMA,
object=serializer,
object=self.get_serializer_identity(serializer, direction),
)
if component in self.registry:
return self.registry[component] # return component with schema
Expand Down
31 changes: 25 additions & 6 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,17 @@ def ref(self) -> _SchemaType:
return {'$ref': f'#/components/{self.type}/{self.name}'}


class ComponentIdentity:
""" A container class to make object/component comparison explicit """
def __init__(self, obj):
self.obj = obj

def __eq__(self, other):
if isinstance(other, ComponentIdentity):
return self.obj == other.obj
return self.obj == other


class ComponentRegistry:
def __init__(self) -> None:
self._components: Dict[Tuple[str, str], ResolvedComponent] = {}
Expand All @@ -746,17 +757,25 @@ def __contains__(self, component):

query_obj = component.object
registry_obj = self._components[component.key].object
query_class = query_obj if inspect.isclass(query_obj) else query_obj.__class__
registry_class = query_obj if inspect.isclass(registry_obj) else registry_obj.__class__

if isinstance(query_obj, ComponentIdentity) or inspect.isclass(query_obj):
query_id = query_obj
else:
query_id = query_obj.__class__

if isinstance(registry_obj, ComponentIdentity) or inspect.isclass(registry_obj):
registry_id = registry_obj
else:
registry_id = registry_obj.__class__

suppress_collision_warning = (
get_override(registry_class, 'suppress_collision_warning', False)
or get_override(query_class, 'suppress_collision_warning', False)
get_override(registry_id, 'suppress_collision_warning', False)
or get_override(query_id, 'suppress_collision_warning', False)
)
if query_class != registry_class and not suppress_collision_warning:
if query_id != registry_id and not suppress_collision_warning:
warn(
f'Encountered 2 components with identical names "{component.name}" and '
f'different classes {query_class} and {registry_class}. This will very '
f'different identities {query_id} and {registry_id}. This will very '
f'likely result in an incorrect schema. Try renaming one.'
)
return True
Expand Down
46 changes: 46 additions & 0 deletions tests/contrib/test_rest_framework_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,49 @@ def custom_name_via_serializer_decoration(request):
generate_schema(None, patterns=urlpatterns),
'tests/contrib/test_rest_framework_dataclasses.yml'
)


@pytest.mark.contrib('rest_framework_dataclasses')
@pytest.mark.skipif(sys.version_info < (3, 7), reason='dataclass required by package')
def test_rest_framework_dataclasses_class_reuse(no_warnings):
from dataclasses import dataclass

from rest_framework_dataclasses.serializers import DataclassSerializer

@dataclass
class Person:
name: str
age: int

@dataclass
class Party:
person: Person
num_persons: int

class PartySerializer(DataclassSerializer[Party]):
class Meta:
dataclass = Party

class PersonSerializer(DataclassSerializer[Person]):
class Meta:
dataclass = Person

@extend_schema(responses=PartySerializer)
@api_view()
def party(request):
pass # pragma: no cover

@extend_schema(responses=PersonSerializer)
@api_view()
def person(request):
pass # pragma: no cover

urlpatterns = [
path('party', person),
path('person', party),
]

schema = generate_schema(None, patterns=urlpatterns)
# just existence is enough to check since its about no_warnings
assert 'Person' in schema['components']['schemas']
assert 'Party' in schema['components']['schemas']
2 changes: 1 addition & 1 deletion tests/test_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet):
generate_schema(None, patterns=router.urls)

stderr = capsys.readouterr().err
assert 'Encountered 2 components with identical names "X" and different classes' in stderr
assert 'Encountered 2 components with identical names "X" and different identities' in stderr


def test_owned_serializer_naming_override_with_ref_name_collision(warnings):
Expand Down

0 comments on commit 8f2dfc5

Please sign in to comment.