Skip to content

Commit

Permalink
Make discriminators work with multiple keys pointing to the same schema
Browse files Browse the repository at this point in the history
  • Loading branch information
ldej committed Mar 14, 2024
1 parent 9535a16 commit 0d8f5ee
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 7 deletions.
14 changes: 7 additions & 7 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def __apply_discriminator_type(
(pydantic_model.BaseModel, pydantic_model_v2.BaseModel),
):
continue # pragma: no cover
type_name = None
type_names = []
if mapping:
for name, path in mapping.items():
if (
Expand All @@ -765,10 +765,10 @@ def __apply_discriminator_type(
):
# TODO: support external reference
continue
type_name = name
type_names.append(name)
else:
type_name = discriminator_model.path.split('/')[-1]
if not type_name: # pragma: no cover
type_names = [discriminator_model.path.split('/')[-1]]
if not type_names: # pragma: no cover
raise RuntimeError(
f'Discriminator type is not found. {data_type.reference.path}'
)
Expand All @@ -780,7 +780,7 @@ def __apply_discriminator_type(
) != property_name:
continue
literals = discriminator_field.data_type.literals
if len(literals) == 1 and literals[0] == type_name:
if len(literals) == 1 and literals[0] == type_names[0] if type_names else None:
has_one_literal = True
continue
for (
Expand All @@ -789,7 +789,7 @@ def __apply_discriminator_type(
if field_data_type.reference: # pragma: no cover
field_data_type.remove_reference()
discriminator_field.data_type = self.data_type(
literals=[type_name]
literals=type_names
)
discriminator_field.data_type.parent = discriminator_field
discriminator_field.required = True
Expand All @@ -799,7 +799,7 @@ def __apply_discriminator_type(
discriminator_model.fields.append(
self.data_model_field_type(
name=property_name,
data_type=self.data_type(literals=[type_name]),
data_type=self.data_type(literals=type_names),
required=True,
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# generated by datamodel-codegen:
# filename: discriminator_enum_duplicate.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from enum import Enum
from typing import Literal, Optional, Union

from pydantic import BaseModel, Field


class Cat(BaseModel):
pet_type: Literal['cat'] = Field(..., title='Pet Type')
meows: int = Field(..., title='Meows')


class Dog(BaseModel):
pet_type: Literal['dog'] = Field(..., title='Pet Type')
barks: float = Field(..., title='Barks')


class PetType(Enum):
reptile = 'reptile'
lizard = 'lizard'


class Lizard(BaseModel):
pet_type: Literal['lizard', 'reptile'] = Field(..., title='Pet Type')
scales: bool = Field(..., title='Scales')


class Animal(BaseModel):
pet: Optional[Union[Cat, Dog, Lizard]] = Field(
None, discriminator='pet_type', title='Pet'
)
n: Optional[int] = Field(None, title='N')
64 changes: 64 additions & 0 deletions tests/data/openapi/discriminator_enum_duplicate.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Example from https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
openapi: 3.1.0
components:
schemas:
Cat:
properties:
pet_type:
const: "cat"
title: "Pet Type"
meows:
title: Meows
type: integer
required:
- pet_type
- meows
title: Cat
type: object
Dog:
properties:
pet_type:
const: "dog"
title: "Pet Type"
barks:
title: Barks
type: number
required:
- pet_type
- barks
title: Dog
type: object
Lizard:
properties:
pet_type:
enum:
- reptile
- lizard
title: Pet Type
type: string
scales:
title: Scales
type: boolean
required:
- pet_type
- scales
title: Lizard
type: object
Animal:
properties:
pet:
discriminator:
mapping:
cat: '#/components/schemas/Cat'
dog: '#/components/schemas/Dog'
lizard: '#/components/schemas/Lizard'
reptile: '#/components/schemas/Lizard'
propertyName: pet_type
oneOf:
- $ref: '#/components/schemas/Cat'
- $ref: '#/components/schemas/Dog'
- $ref: '#/components/schemas/Lizard'
title: Pet
'n':
title: 'N'
type: integer
33 changes: 33 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6638,3 +6638,36 @@ def test_main_openapi_discriminator_enum():
EXPECTED_MAIN_PATH / 'main_openapi_discriminator_enum' / 'output.py'
).read_text()
)


@freeze_time('2019-07-26')
@pytest.mark.skipif(
black.__version__.split('.')[0] == '19',
reason="Installed black doesn't support the old style",
)
def test_main_openapi_discriminator_enum_duplicate():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(OPEN_API_DATA_PATH / 'discriminator_enum_duplicate.yaml'),
'--output',
str(output_file),
'--target-python-version',
'3.10',
'--output-model-type',
'pydantic_v2.BaseModel',
'--input-file-type',
'openapi',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_MAIN_PATH
/ 'main_openapi_discriminator_enum_duplicate'
/ 'output.py'
).read_text()
)

0 comments on commit 0d8f5ee

Please sign in to comment.