Skip to content

Commit

Permalink
Add enum type inference based on choices values (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
mofr authored and axnsan12 committed Dec 7, 2018
1 parent f587785 commit f654465
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/drf_yasg/inspectors/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **

if isinstance(field, serializers.ChoiceField):
enum_type = openapi.TYPE_STRING
enum_values = list(field.choices.keys())

# for ModelSerializer, try to infer the type from the associated model field
serializer = get_parent_serializer(field)
Expand All @@ -596,8 +597,14 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
model_type = get_basic_type_info(model_field)
if model_type:
enum_type = model_type.get('type', enum_type)
else:
# Try to infer field type based on enum values
enum_value_types = {type(v) for v in enum_values}
if len(enum_value_types) == 1:
values_type = get_basic_type_info_from_hint(next(iter(enum_value_types)))
if values_type:
enum_type = values_type.get('type', enum_type)

enum_values = list(field.choices.keys())
if isinstance(field, serializers.MultipleChoiceField):
result = SwaggerType(
type=openapi.TYPE_ARRAY,
Expand Down
29 changes: 29 additions & 0 deletions tests/test_schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,32 @@ def action_post(self, request):
assert action_ops['post']['description'] == 'mapping docstring post'
assert action_ops['get']['description'] == 'mapping docstring get/delete'
assert action_ops['delete']['description'] == 'mapping docstring get/delete'


@pytest.mark.parametrize('choices, expected_type', [
(['A', 'B'], openapi.TYPE_STRING),
([123, 456], openapi.TYPE_INTEGER),
([1.2, 3.4], openapi.TYPE_NUMBER),
(['A', 456], openapi.TYPE_STRING)
])
def test_choice_field(choices, expected_type):
class DetailSerializer(serializers.Serializer):
detail = serializers.ChoiceField(choices)

class DetailViewSet(viewsets.ViewSet):
@swagger_auto_schema(responses={200: openapi.Response("OK", DetailSerializer)})
def retrieve(self, request, pk=None):
return Response({'detail': None})

router = routers.DefaultRouter()
router.register(r'details', DetailViewSet, base_name='details')

generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
patterns=router.urls
)

swagger = generator.get_schema(None, True)
property_schema = swagger['definitions']['Detail']['properties']['detail']

assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices)

0 comments on commit f654465

Please sign in to comment.