Skip to content

Commit

Permalink
Merge pull request #1906 from blast-hardcheese/resolve-1642
Browse files Browse the repository at this point in the history
Distinct protocol parameters
  • Loading branch information
blast-hardcheese authored Dec 26, 2023
2 parents 9e64878 + 58f4178 commit 10eb849
Show file tree
Hide file tree
Showing 4 changed files with 363 additions and 223 deletions.
5 changes: 5 additions & 0 deletions modules/sample/src/main/resources/issues/issue222.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ definitions:
allOf:
- "$ref": "#/definitions/RequestFields"
- type: object
required: [same]
properties:
id:
type: string
same:
type: string
RequestFields:
description: Request fields
type: object
properties:
state:
type: integer
same:
type: string
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import dev.guardrail.generators.protocol.{ ClassChild, ClassHierarchy, ClassPare
import dev.guardrail.generators.scala.circe.CirceProtocolGenerator.WithValidations
import dev.guardrail.generators.scala.{ CirceModelGenerator, ScalaLanguage }
import dev.guardrail.generators.spi.{ ModuleLoadResult, ProtocolGeneratorLoader }
import dev.guardrail.generators.syntax._
import dev.guardrail.generators.{ ProtocolDefinitions, RawParameterName }
import dev.guardrail.terms.framework.FrameworkTerms
import dev.guardrail.terms.protocol.PropertyRequirement
Expand Down Expand Up @@ -808,6 +809,7 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa

def processProperty(name: String, schema: Tracker[Schema[_]]): Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]] =
for {
() <- Target.log.debug(s"processProperty: ${name} ${schema.unwrapTracker.showNotNull}")
nestedClassName <- formatTypeName(name).map(formattedName => getClsName(name).append(formattedName))
defn <- schema
.refine[Target[Option[Either[String, NestedProtocolElems[ScalaLanguage]]]]] { case ObjectExtractor(x) => x }(o =>
Expand Down Expand Up @@ -1342,107 +1344,133 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
supportPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val parentOpt = if (parents.exists(s => s.discriminators.nonEmpty)) {
parents.headOption
} else {
None
}
val params = (parents.reverse.flatMap(_.params) ++ selfParams).filterNot(param => discriminatorNames.contains(param.term.name.value))
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
parentOpt =
if (parents.exists(s => s.discriminators.nonEmpty)) {
parents.headOption
} else {
None
}
params <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams).map(_.filterNot(param => discriminatorNames.contains(param.term.name.value)))

val terms = params.map(_.term)
terms = params.map(_.term)

val toStringMethod = if (params.exists(_.dataRedaction != DataVisible)) {
def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match {
case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()"
case _ => Lit.String("[redacted]")
}
toStringMethod =
if (params.exists(_.dataRedaction != DataVisible)) {
def mkToStringTerm(param: ProtocolParameter[ScalaLanguage]): Term = param match {
case param if param.dataRedaction == DataVisible => q"${Term.Name(param.term.name.value)}.toString()"
case _ => Lit.String("[redacted]")
}

val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(",")))
val toStringTerms = params.map(p => List(mkToStringTerm(p))).intercalate(List(Lit.String(",")))

List[Defn.Def](
q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}"
)
} else {
List.empty[Defn.Def]
}
List[Defn.Def](
q"override def toString: String = ${toStringTerms.foldLeft[Term](Lit.String(s"${clsName}("))((accum, term) => q"$accum + $term")} + ${Lit.String(")")}"
)
} else {
List.empty[Defn.Def]
}

val code = parentOpt
.fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent =>
q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a =>
init"${Type.Name(a)}(...$Nil)"
)} { ..$toStringMethod }"""
)
code = parentOpt
.fold(q"""case class ${Type.Name(clsName)}(..${terms}) { ..$toStringMethod }""")(parent =>
q"""case class ${Type.Name(clsName)}(..${terms}) extends ..${init"${Type.Name(parent.clsName)}(...$Nil)" :: parent.interfaces.map(a =>
init"${Type.Name(a)}(...$Nil)"
)} { ..$toStringMethod }"""
)

Target.pure(code)
}
} yield code

private def finalizeParams(params: List[ProtocolParameter[ScalaLanguage]]): Target[List[ProtocolParameter[ScalaLanguage]]] =
for {
reduced <- params
.groupBy(_.name)
.toList
.traverse { case (k, xs) =>
implicit val ord: Ordering[PropertyRequirement] = {
case (PropertyRequirement.Required, _) => -1
case (_, PropertyRequirement.Required) => 1
case _ => 0
}
xs.distinctBy(_.term.syntax).sortBy(_.propertyRequirement) match {
case Nil => Target.raiseUserError(s"Unexpectedly empty parameter group: ${xs}")
case x :: Nil => Target.pure((k, x))
case xs @ (x :: _) if xs.distinctBy(_.baseType.syntax).length == 1 => Target.pure((k, x))
case xs @ (x :: rest) => Target.raiseUserError(s"Type conflicts for ${x.name.value}: ${xs.flatMap(_.term.decltpe.map(_.syntax)).mkString(", ")}")
}
}
.map(_.toMap)
names = params.map(_.name).distinct
} yield names.flatMap(n => reduced.get(n))

private def encodeModel(
clsName: String,
dtoPackage: List[String],
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
) = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val allParams = parents.reverse.flatMap(_.params) ++ selfParams
val (discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
val readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
val typeName = Type.Name(clsName)
val encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""

def encodeRequired(param: ProtocolParameter[ScalaLanguage]) =
q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)"""

def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = {
val name = Lit.String(param.name.value)
q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))"
}
) =
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
(discriminatorParams, params) = allParams.partition(param => discriminatorNames.contains(param.name.value))
readOnlyKeys: List[String] = params.flatMap(_.readOnlyKey).toList
typeName = Type.Name(clsName)
encVal = {
def encodeStatic(param: ProtocolParameter[ScalaLanguage], clsName: String) =
q"""(${Lit.String(param.name.value)}, _root_.io.circe.Json.fromString(${Lit.String(clsName)}))"""

def encodeRequired(param: ProtocolParameter[ScalaLanguage]) =
q"""(${Lit.String(param.name.value)}, a.${Term.Name(param.term.name.value)}.asJson)"""

def encodeOptional(param: ProtocolParameter[ScalaLanguage]) = {
val name = Lit.String(param.name.value)
q"a.${Term.Name(param.term.name.value)}.fold(ifAbsent = None, ifPresent = value => Some($name -> value.asJson))"
}

val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param =>
val name = Lit.String(param.name.value)
param.propertyRequirement match {
case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy =>
Right(encodeRequired(param))
case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) =>
Right(encodeRequired(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, _) =>
Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""")
val (optional, pairs): (List[Term.Apply], List[Term.Tuple]) = params.partitionEither { param =>
val name = Lit.String(param.name.value)
param.propertyRequirement match {
case PropertyRequirement.Required | PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy =>
Right(encodeRequired(param))
case PropertyRequirement.Optional | PropertyRequirement.OptionalNullable =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, PropertyRequirement.Optional) =>
Left(encodeOptional(param))
case PropertyRequirement.Configured(PropertyRequirement.RequiredNullable | PropertyRequirement.OptionalLegacy, _) =>
Right(encodeRequired(param))
case PropertyRequirement.Configured(PropertyRequirement.Optional, _) =>
Left(q"""a.${Term.Name(param.term.name.value)}.map(value => (${Lit.String(param.name.value)}, value.asJson))""")
}
}
}

val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}
val pairsWithStatic = pairs ++ discriminatorParams.map(encodeStatic(_, clsName))
val simpleCase = q"_root_.scala.Vector(..${pairsWithStatic})"
val allFields = optional.foldLeft[Term](simpleCase) { (acc, field) =>
q"$acc ++ $field"
}

q"""
q"""
${circeVersion.encoderObjectCompanion}.instance[${Type.Name(clsName)}](a => _root_.io.circe.JsonObject.fromIterable($allFields))
"""
}
val (readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
(
List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"),
encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))"
)
}
}
(readOnlyDefn, readOnlyFilter) = NonEmptyList.fromList(readOnlyKeys).fold((List.empty[Stat], identity[Term] _)) { roKeys =>
(
List(q"val readOnlyKeys = _root_.scala.Predef.Set[_root_.scala.Predef.String](..${roKeys.toList.map(Lit.String(_))})"),
encVal => q"$encVal.mapJsonObject(_.filterKeys(key => !(readOnlyKeys contains key)))"
)
}

Target.pure(Option(q"""
} yield Option(q"""
implicit val ${suffixClsName("encode", clsName)}: ${circeVersion.encoderObject}[${Type.Name(clsName)}] = {
..${readOnlyDefn};
${readOnlyFilter(encVal)}
}
"""))
}
""")

private def decodeModel(
clsName: String,
Expand All @@ -1451,13 +1479,14 @@ class CirceProtocolGenerator private (circeVersion: CirceModelGenerator, applyVa
selfParams: List[ProtocolParameter[ScalaLanguage]],
parents: List[SuperClass[ScalaLanguage]] = Nil
)(implicit Lt: LanguageTerms[ScalaLanguage, Target]): Target[Option[Defn.Val]] = {
val discriminators = parents.flatMap(_.discriminators)
val discriminatorNames = discriminators.map(_.propertyName).toSet
val allParams = parents.reverse.flatMap(_.params) ++ selfParams
val params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
val needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
val paramCount = params.length
for {
() <- Target.pure(())
discriminators = parents.flatMap(_.discriminators)
discriminatorNames = discriminators.map(_.propertyName).toSet
allParams <- finalizeParams(parents.reverse.flatMap(_.params) ++ selfParams)
params = allParams.filterNot(param => discriminatorNames.contains(param.name.value))
needsEmptyToNull: Boolean = params.exists(_.emptyToNull == EmptyIsNull)
paramCount = params.length
presence <- Lt.selectTerm(NonEmptyList.ofInitLast(supportPackage, "Presence"))
decVal <-
if (paramCount == 0) {
Expand Down
Loading

0 comments on commit 10eb849

Please sign in to comment.