Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not strip discriminator property in oneOf generation. #268

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ data class ClassSettings(
NONE,
SUPER,
SUB,
ONE_OF,
}
}

Expand Down Expand Up @@ -102,11 +103,16 @@ object PropertyUtils {
property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey))
property.addValidationAnnotations(this, validationAnnotations)
}

ClassSettings.PolymorphyType.ONE_OF -> {
property.addAnnotation(JacksonMetadata.jacksonPropertyAnnotation(oasKey))
property.addValidationAnnotations(this, validationAnnotations)
}
}

if (isDiscriminatorFieldWithSingleKnownValue(classSettings, schemaName)) {
this as PropertyInfo.Field
if (classSettings.polymorphyType == ClassSettings.PolymorphyType.SUB) {
if (classSettings.polymorphyType in listOf(ClassSettings.PolymorphyType.SUB, ClassSettings.PolymorphyType.ONE_OF)) {
property.initializer(name)
property.addAnnotation(JacksonMetadata.jacksonParameterAnnotation(oasKey))
val constructorParameter: ParameterSpec.Builder = ParameterSpec.builder(name, wrappedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.cjbooms.fabrikt.model.PropertyInfo.Companion.topLevelProperties
import com.cjbooms.fabrikt.model.SchemaInfo
import com.cjbooms.fabrikt.model.SourceApi
import com.cjbooms.fabrikt.model.toEnclosingSchemaInfo
import com.cjbooms.fabrikt.util.KaizenParserExtensions.findOneOfSuperInterface
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getDiscriminatorForInLinedObjectUnderAllOf
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSchemaRefName
import com.cjbooms.fabrikt.util.KaizenParserExtensions.getSuperType
Expand Down Expand Up @@ -164,8 +165,11 @@ class JacksonModelGenerator(
.filterNot { it.schema.isSimpleType() }
.filterNot { it.schema.isOneOfPolymorphicTypes() }
.flatMap {
val properties = it.schema.topLevelProperties(HTTP_SETTINGS, it.schema)
if (properties.isNotEmpty() || it.typeInfo is KotlinTypeInfo.Enum) {
val properties = it.schema.topLevelProperties(HTTP_SETTINGS, api, it.schema)
if (properties.isNotEmpty() ||
it.typeInfo is KotlinTypeInfo.Enum ||
it.schema.findOneOfSuperInterface(schemas.map { it.schema }).isNotEmpty()
) {
val primaryModel = buildPrimaryModel(api, it, properties, schemas)
val inlinedModels = buildInLinedModels(properties, it.schema, it.schema.getDocumentUrl())
listOf(primaryModel) + inlinedModels
Expand Down Expand Up @@ -196,7 +200,7 @@ class JacksonModelGenerator(
schemaInfo.schema.discriminator,
allSchemas,
schemaInfo.schema.oneOfSchemas,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)

schemaInfo.schema.isPolymorphicSuperType() && schemaInfo.schema.isPolymorphicSubType(api) ->
Expand All @@ -207,7 +211,7 @@ class JacksonModelGenerator(
checkNotNull(schemaInfo.schema.getDiscriminatorForInLinedObjectUnderAllOf()),
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
allSchemas,
)

Expand All @@ -217,7 +221,7 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.discriminator,
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
allSchemas,
)

Expand All @@ -227,7 +231,7 @@ class JacksonModelGenerator(
properties,
schemaInfo.schema.getSuperType(api)!!.let { SchemaInfo(it.name, it) },
schemaInfo.schema.extensions,
findOneOfSuperInterface(allSchemas, schemaInfo, options),
schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)

schemaInfo.typeInfo is KotlinTypeInfo.Enum -> buildEnumClass(schemaInfo.typeInfo)
Expand All @@ -236,40 +240,12 @@ class JacksonModelGenerator(
schemaName = schemaName,
properties = properties,
extensions = schemaInfo.schema.extensions,
oneOfInterfaces = findOneOfSuperInterface(allSchemas, schemaInfo, options),
oneOfInterfaces = schemaInfo.schema.findOneOfSuperInterface(allSchemas.map { it.schema }),
)
}
}

private fun findOneOfSuperInterface(
allSchemas: List<SchemaInfo>,
schema: SchemaInfo,
options: Set<ModelCodeGenOptionType>,
): Set<SchemaInfo> {
if (SEALED_INTERFACES_FOR_ONE_OF !in options) {
return emptySet()
}
return allSchemas
.filter { it.schema.discriminator != null && it.schema.oneOfSchemas.isNotEmpty() }
.mapNotNull { info ->
info.schema.discriminator.mappings
.toList()
.find { (_, ref) ->
ref.endsWith("/${schema.name}")
}
?.let { (key, _) ->
Pair(key!!, info)
}
}
.map { (_, parent) ->
val field = parent.schema.discriminator.propertyName!!
if (!schema.schema.properties.containsKey(field)) {
throw IllegalArgumentException("schema $schema did not have discriminator property")
}
parent
}
.toSet()
}


private fun buildInLinedModels(
topLevelProperties: Collection<PropertyInfo>,
Expand All @@ -285,7 +261,7 @@ class JacksonModelGenerator(
if (it.isInherited) {
emptySet() // Rely on the parent definition
} else {
val props = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema)
val props = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema)
val currentModel = standardDataClass(
ModelNameRegistry.getOrRegister(it.schema, enclosingSchema.toEnclosingSchemaInfo()),
it.name,
Expand All @@ -308,7 +284,7 @@ class JacksonModelGenerator(
standardDataClass(
modelName = ModelNameRegistry.getOrRegister(it.schema, valueSuffix = it.schema.isInlinedTypedAdditionalProperties()),
schemaName = it.name,
properties = it.schema.topLevelProperties(HTTP_SETTINGS, enclosingSchema),
properties = it.schema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema),
extensions = it.schema.extensions,
oneOfInterfaces = emptySet(),
),
Expand Down Expand Up @@ -347,7 +323,7 @@ class JacksonModelGenerator(
?: enclosingSchema.toEnclosingSchemaInfo()
when {
items.isInlinedObjectDefinition() ->
items.topLevelProperties(HTTP_SETTINGS, enclosingSchema).let { props ->
items.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3, enclosingSchema).let { props ->
buildInLinedModels(
topLevelProperties = props,
enclosingSchema = enclosingSchema,
Expand Down Expand Up @@ -454,7 +430,7 @@ class JacksonModelGenerator(
standardDataClass(
modelName = ModelNameRegistry.getOrRegister(schema, valueSuffix = schema.isInlinedTypedAdditionalProperties()),
schemaName = schema.safeName(),
properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS),
properties = mapField.schema.additionalPropertiesSchema.topLevelProperties(HTTP_SETTINGS, sourceApi.openApi3),
extensions = mapField.schema.extensions,
oneOfInterfaces = emptySet(),
)
Expand All @@ -467,24 +443,10 @@ class JacksonModelGenerator(
schemaName: String,
properties: Collection<PropertyInfo>,
extensions: Map<String, Any>,
oneOfInterfaces: Set<SchemaInfo>,
oneOfInterfaces: Set<Schema>,
): TypeSpec {
val filteredProperties = if (oneOfInterfaces.size == 1) {
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the logic in question. I'm not sure of the intent here, but perhaps @pschichtel can weigh in before I revert it

Copy link
Contributor

@pschichtel pschichtel Mar 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the logic is:

if:

  • the super interface' model defines a discriminator field
  • the discriminator mappings map to the current model exactly once

then:

  • remove the discriminator field, since it must always be the same value given by the jackson type info field.

val oneOfInterface = oneOfInterfaces.first()
val discriminatorProp = oneOfInterface.schema.discriminator?.propertyName
val mappingCount =
oneOfInterface.schema.discriminator?.mappings?.values?.count { it.endsWith("/$modelName") }
if (discriminatorProp != null && mappingCount == 1) {
properties.filterNot { it.name == discriminatorProp }
} else {
properties
}
} else {
properties
}

val name = generatedType(packages.base, modelName)
val generateObject = properties.isNotEmpty() && filteredProperties.isEmpty()
val generateObject = properties.isEmpty()
val builder =
if (generateObject) {
TypeSpec.objectBuilder(name)
Expand All @@ -499,15 +461,23 @@ class JacksonModelGenerator(
.addCompanionObject()
for (oneOfInterface in oneOfInterfaces) {
classBuilder
.addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface.schema)))
.addSuperinterface(generatedType(packages.base, ModelNameRegistry.getOrRegister(oneOfInterface)))
}

if (!generateObject) {
filteredProperties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
)
if (oneOfInterfaces.size == 1) {
properties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.ONE_OF, extensions.hasJsonMergePatchExtension),
)
} else {
properties.addToClass(
schemaName = schemaName,
classBuilder = classBuilder,
classType = ClassSettings(ClassSettings.PolymorphyType.NONE, extensions.hasJsonMergePatchExtension),
)
}
}
return classBuilder.build()
}
Expand All @@ -519,7 +489,7 @@ class JacksonModelGenerator(
discriminator: Discriminator,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
): TypeSpec = with(FunSpec.constructorBuilder()) {
TypeSpec.classBuilder(generatedType(packages.base, modelName))
Expand Down Expand Up @@ -549,7 +519,7 @@ class JacksonModelGenerator(
discriminator: Discriminator,
allSchemas: List<SchemaInfo>,
members: List<Schema>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
): TypeSpec {
val interfaceBuilder = TypeSpec.interfaceBuilder(generatedType(packages.base, modelName))
.addModifiers(KModifier.SEALED)
Expand Down Expand Up @@ -589,7 +559,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSuperType(
Expand All @@ -609,7 +579,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
discriminator: Discriminator,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
allSchemas: List<SchemaInfo>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
Expand Down Expand Up @@ -661,7 +631,7 @@ class JacksonModelGenerator(
properties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
): TypeSpec = TypeSpec.classBuilder(generatedType(packages.base, modelName))
.buildPolymorphicSubType(schemaName, properties, superType, extensions, oneOfSuperInterfaces).build()

Expand All @@ -670,7 +640,7 @@ class JacksonModelGenerator(
allProperties: Collection<PropertyInfo>,
superType: SchemaInfo,
extensions: Map<String, Any>,
oneOfSuperInterfaces: Set<SchemaInfo>,
oneOfSuperInterfaces: Set<Schema>,
constructorBuilder: FunSpec.Builder = FunSpec.constructorBuilder(),
): TypeSpec.Builder {
this.addSerializableInterface()
Expand Down
22 changes: 13 additions & 9 deletions src/main/kotlin/com/cjbooms/fabrikt/model/PropertyInfo.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeName
import com.cjbooms.fabrikt.util.KaizenParserExtensions.safeType
import com.cjbooms.fabrikt.util.NormalisedString.camelCase
import com.cjbooms.fabrikt.util.NormalisedString.toEnumName
import com.reprezen.kaizen.oasparser.model3.OpenApi3
import com.reprezen.kaizen.oasparser.model3.Schema

sealed class PropertyInfo {
Expand All @@ -38,7 +39,7 @@ sealed class PropertyInfo {

val HTTP_SETTINGS = Settings()

fun Schema.topLevelProperties(settings: Settings, enclosingSchema: Schema? = null): Collection<PropertyInfo> {
fun Schema.topLevelProperties(settings: Settings, api: OpenApi3, enclosingSchema: Schema? = null): Collection<PropertyInfo> {
val results = mutableListOf<PropertyInfo>() +
allOfSchemas.flatMap {
it.topLevelProperties(
Expand All @@ -47,12 +48,13 @@ sealed class PropertyInfo {
enclosingSchema,
it
),
api,
this
)
} +
(if (oneOfSchemas.isEmpty()) emptyList() else listOf(OneOfAny(oneOfSchemas.first()))) +
anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), this) } +
getInLinedProperties(settings, enclosingSchema)
anyOfSchemas.flatMap { it.topLevelProperties(settings.copy(markAllOptional = true), api, this) } +
getInLinedProperties(settings, api, enclosingSchema)
return results.distinctBy { it.oasKey }
}

Expand All @@ -68,13 +70,14 @@ sealed class PropertyInfo {

private fun Schema.getInLinedProperties(
settings: Settings,
api: OpenApi3,
enclosingSchema: Schema? = null
): Collection<PropertyInfo> {
val mainProperties: List<PropertyInfo> = properties.map { property ->
when (property.value.safeType()) {
OasType.Array.type ->
ListField(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
property.key,
property.value,
settings.markAsInherited,
Expand All @@ -87,6 +90,7 @@ sealed class PropertyInfo {
if (property.value.isSimpleMapDefinition() || property.value.isSchemaLess())
MapField(
isRequired = isRequired(
api,
property,
settings.markReadWriteOnlyOptional,
settings.markAllOptional
Expand All @@ -99,7 +103,7 @@ sealed class PropertyInfo {
else if (property.value.isInlinedObjectDefinition())
ObjectInlinedField(
isRequired = isRequired(
property, settings.markReadWriteOnlyOptional, settings.markAllOptional
api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional
),
oasKey = property.key,
schema = property.value,
Expand All @@ -109,7 +113,7 @@ sealed class PropertyInfo {
)
else
ObjectRefField(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
property.key,
property.value,
settings.markAsInherited,
Expand All @@ -120,13 +124,13 @@ sealed class PropertyInfo {
null
} else {
Field(
isRequired(property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
isRequired(api, property, settings.markReadWriteOnlyOptional, settings.markAllOptional),
oasKey = property.key,
schema = property.value,
isInherited = settings.markAsInherited,
isPolymorphicDiscriminator = isDiscriminatorProperty(property),
isPolymorphicDiscriminator = isDiscriminatorProperty(api, property),
maybeDiscriminator = enclosingSchema?.let {
this.getKeyIfSingleDiscriminatorValue(property, it)
this.getKeyIfSingleDiscriminatorValue(api, property, it)
},
enclosingSchema = if (property.value.isInlinedEnumDefinition()) this else null
)
Expand Down
Loading
Loading